diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6a965c6b21d287ef89db44e47028646008c8ed9b --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +.venv/ +__pycache__/ +*.pyc +*.pyo +.DS_Store +artifacts/ +outputs/ +lightning_logs/ +wandb/ +scenestreamer/outputs/ +scenestreamer/lightning_logs/ +scenestreamer/eval/outputs/ diff --git a/README.md b/README.md index 5e8b0fca9fffdf5c0a8e48a46151c580a1b8e32f..4b680882898b8ed598b4dd5ad154abfb75a90965 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- title: SceneStreamer -emoji: 😻 -colorFrom: green -colorTo: gray +emoji: 🚗 +colorFrom: blue +colorTo: indigo sdk: gradio sdk_version: 6.9.0 app_file: app.py @@ -10,4 +10,28 @@ pinned: false license: mit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# SceneStreamer Space + +This Space hosts the interactive Gradio demo for `SceneStreamer`. + +What is included here: + +- the Gradio app entrypoint in `app.py` +- the SceneStreamer package code needed by the demo +- a tiny bundled ScenarioNet subset in `data/20scenarios` + +Default behavior: + +- the app loads the bundled demo subset automatically +- the model checkpoint is fetched from the Hugging Face Hub by default +- `SCENESTREAMER_DEVICE` defaults to `cpu` + +Optional Space variables: + +- `SCENESTREAMER_DATASET_DIR` +- `SCENESTREAMER_HF_REPO` +- `SCENESTREAMER_HF_FILE` +- `SCENESTREAMER_CKPT` +- `SCENESTREAMER_DEVICE` + +If the app shows the setup screen instead of the demo, the dataset path is missing or the demo subset was not uploaded with the repo. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffcbd304991f35262b20c60c2bc0cb62b4387c1 --- /dev/null +++ b/app.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import os +from pathlib import Path + +os.environ.setdefault("MPLBACKEND", "Agg") +os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1") +os.environ.setdefault("SDL_AUDIODRIVER", "dummy") +os.environ.setdefault("SDL_VIDEODRIVER", "dummy") + +import gradio as gr + +from scenestreamer.gradio_ui.demo_app import DEFAULT_HF_FILE, DEFAULT_HF_REPO, build_demo + + +def _build_space_demo() -> gr.Blocks: + dataset_dir = os.environ.get("SCENESTREAMER_DATASET_DIR", "data/20scenarios") + hf_repo = os.environ.get("SCENESTREAMER_HF_REPO", DEFAULT_HF_REPO) + hf_file = os.environ.get("SCENESTREAMER_HF_FILE", DEFAULT_HF_FILE) + ckpt = os.environ.get("SCENESTREAMER_CKPT") or None + device = os.environ.get("SCENESTREAMER_DEVICE", "cpu") + + if not Path(dataset_dir).exists(): + with gr.Blocks(title="SceneStreamer Space Setup") as demo: + gr.Markdown("## SceneStreamer Space Setup Required") + gr.Markdown( + "This Space needs a local ScenarioNet dataset directory before the interactive demo can start.\n\n" + f"Current `SCENESTREAMER_DATASET_DIR`: `{dataset_dir}`" + ) + gr.Markdown( + "Set Space variables or attach storage, then restart the Space:\n" + "- `SCENESTREAMER_DATASET_DIR`\n" + "- `SCENESTREAMER_HF_REPO` (optional)\n" + "- `SCENESTREAMER_HF_FILE` (optional)\n" + "- `SCENESTREAMER_CKPT` (optional local checkpoint)\n" + "- `SCENESTREAMER_DEVICE` (default `cpu`)" + ) + gr.Markdown( + "This repo is expected to include a tiny bundled demo subset under `data/20scenarios`. " + "If you are seeing this page after pushing the repo, the demo data was likely not uploaded." + ) + return demo + + return build_demo( + dataset_dir=dataset_dir, + hf_repo=hf_repo, + hf_file=hf_file, + ckpt=ckpt, + device=device, + ) + + +demo = _build_space_demo() + + +if __name__ == "__main__": + demo.launch() + diff --git a/cfgs/motion_default.yaml b/cfgs/motion_default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..251fc937b926df7cfd4da78042c263b261cd25d0 --- /dev/null +++ b/cfgs/motion_default.yaml @@ -0,0 +1,211 @@ +defaults: + - _self_ + +# Experiment related +exp_name: 'default' +seed: 0 +epochs: 50 +batch_size: 10 +val_batch_size: 4 +num_workers: 16 +val_num_workers: 16 +num_sanity_val_steps: 100 +val_check_interval: 1.0 +wandb: False +log_dir: Null +limit_train_batches: -1 +limit_val_batches: -1 +prefetch_factor: 2 +ckpt: Null +eval: False +pretrain: Null +deterministic: False +detect_anomaly: False +check_val_every_n_epoch: 1 + +USE_RL_FINETUNING: False + + +# Turn both on when training TG to match TrafficGen's metrics. +LIMIT_MAP_RANGE: False +FOLLOW_TRAFFICGEN: False +FORCE_SDC_FOR_TRAFFICGEN: False +ONLY_LANE_FOR_TRAFFICGEN: False + +# True then agent info will not be used in encoder, +# and new tokens for history will be added for decoder. +GPT_STYLE: false +REMOVE_AGENT_FROM_SCENE_ENCODER: false +USE_DIFFUSION: false +USE_ADALN: null +BACKWARD_PREDICTION: false + + +USE_DESTINATION: false + +TF_DEST: True + + +ADD_CONTOUR_RELATION: false + +DELTA_TOKENIZER_FILE_NAME: "1030_argsort_less_256_128_128.pkl" + + +USE_TRAFFICGEN: false +TRAIN_TRAFFICGEN: null +USE_MOTION: true + +EVAL_MOTION: true +EVAL_TRAFFICGEN: false + +DELTA_POS_IS_VELOCITY: false + +SIMPLE_RELATION: false +SIMPLE_RELATION_FACTOR: 1 + +RECONSTRUCT_MAP: false + +UPDATE_RELATION: false + +REMOVE_REL_NORM: false + +DATA: + TRAINING_DATA_DIR: '/data/datasets/scenarionet/waymo/training' + TEST_DATA_DIR: '/data/datasets/scenarionet/waymo/validation' + ADV_INFO_PATH: 'data/all_adv.pkl' + SAMPLE_INTERVAL_TRAINING: 1 + SAMPLE_INTERVAL_TEST: 1 + SD_PASSTHROUGH: false + ALLOW_CACHE: false + RETURN_HALFWAY: false # Only used when generating LMDB dataset + USE_LMDB: false + USE_CACHE: false + +PREPROCESSING: + MAX_VECTORS: 128 + MAX_MAP_FEATURES: 512 + MAX_LENGTH_PER_MAP_FEATURE: 10000 # Useless + MAX_AGENTS: 128 + MAX_TRAFFIC_LIGHTS: 64 + PADDING_TO_MAX: false + keep_all_data: false # for debug + ADD_SDC_TO_OBJECT_OF_INTEREST: true # Should be True when WOSAC + REMOVE_TRAFFIC_LIGHT_STATE: false + TRUNCATE_TIME: -1 + +TRAINING: + PREDICT_ALL_AGENTS: true + +EVALUATION: + NAME: 'waymo_motion_prediction' + PREDICT_ALL_AGENTS: false + DELETE_EVAL_RESULT: true + NUM_MODES: 6 + MAXIMUM_BATCH_SIZE: 10000 + USE_CACHE: true + USE_TG_AS_GT: 1111 + TG_REJECT_SAMPLING: True + TG_SDC_DISTANCE_MASKING: False + +MODEL: + NAME: 'motionlm' + D_MODEL: 256 + NUM_ATTN_LAYERS: 4 + NUM_ATTN_HEAD: 8 +# DROPOUT_OF_ATTN: 0.0 + DROPOUT: 0.0 + NUM_DECODER_LAYERS: 6 + ADD_PE_FOR_TOKEN: true + RELATIVE_PE: true + RELATIVE_PE_DECODER: false + PRE_PROJECTION: false + KNN: 128 + S2S_DISTANCE: null + SELF_ATTN_KNN: 128 + CROSS_ATTN_KNN: 128 + RANDOMIZE_AGENT_ID: true + A2S_KNN: null + A2S_DISTANCE: null + A2A_KNN: null + A2A_DISTANCE: null + ADD_RELATION_TO_V: false + IS_V7: False + PER_CONTOUR_POINT_RELATION: null + +TOKENIZATION: + TOKENIZATION_METHOD: delta_delta + NUM_SKIPPED_STEPS: 5 + NUM_BINS: 13 + X_MAX: 3.5 # <<< Deprecated + X_MIN: -3.5 # <<< Deprecated + Y_MAX: 3.5 # <<< Deprecated + Y_MIN: -3.5 # <<< Deprecated + ADD_NOISE: false + NOISE_TOPK: 5 + ALLOW_SKIP_STEP: True + + MIN_DISPLACEMENT: 0.1 + MIN_DISPLACEMENT_INIT: null + MIN_SPEED: null + SMOOTH_FACTOR: null + MAX_HEADING_DIFF: null + + USE_CONTOUR_ERROR: True + + VEH_LIMIT: 3.5 + PED_LIMIT: 3.5 + CYC_LIMIT: 3.5 + + FLIP_WRONG_HEADING: false + SHOULD_STANDARDIZE: true + +# MIN_DISPLACEMENT: 0.3 +# MIN_DISPLACEMENT_INIT: 1.0 +# MIN_SPEED: 0.5 +# SMOOTH_FACTOR: null +# MAX_HEADING_DIFF: 0.3 + + +SAMPLING: + SAMPLING_METHOD: 'topp' + TEMPERATURE: 1.0 + TOPP: 0.95 + + +OPTIMIZATION: +# NUM_EPOCHS: 50 + OPTIMIZER: AdamW + LR: 0.0003 + WEIGHT_DECAY: 0.0 + GRAD_NORM_CLIP: 1.0 + SCHEDULER: cosine + WARMUP_STEPS: 2000 +# TRAINING_STEPS: 300000 + USE_FOCAL_LOSS: false + + +SUBMISSION: + GENERATE_SUBMISSION: false + PREFIX: "peng" + ACCOUNT: "dr.zhenghao.peng@gmail.com" + METHOD_NAME: "peng" + num_model_parameters: '10m' # TODO: Need to be changed accordingly! + SAVE_EVAL_DATA: true + +TMP_DIR: "tmp" # Relative to repo root + +ACTION_LABEL: + USE_ACTION_LABEL: false # Only valid for turning + acceleration + USE_SAFETY_LABEL: false + MASK_PROBABILITY_ACTION_LABEL: 0.0 # Might turn it on + MASK_PROBABILITY_SAFETY_LABEL: 0.0 # Might turn it on + +LANGUAGE_CONDITION: false +FINE_TUNE_BERT: false + +MCTS: + USE_MCTS: False + MCTS_DEPTH: -1 + MCTS_WIDTH: -1 + +TOKEN_BUFFER_CACHE_LENGTH: 100 \ No newline at end of file diff --git a/cfgs/scenestreamer-base-large.yaml b/cfgs/scenestreamer-base-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..636f47f0ae08203e99564acbcc3b117d8e6b6c39 --- /dev/null +++ b/cfgs/scenestreamer-base-large.yaml @@ -0,0 +1,96 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-base-large' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250506_scenestreamer_v18_notg_large_FIXETYPE_2025-05-06/checkpoints" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: true + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 128 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8 + # Encoder: + NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-base-large" + num_model_parameters: '3.3m' diff --git a/cfgs/scenestreamer-base-small.yaml b/cfgs/scenestreamer-base-small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17d82527b4f0c2c302d960eee1148c58b7fecefe --- /dev/null +++ b/cfgs/scenestreamer-base-small.yaml @@ -0,0 +1,96 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-base-small' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer_v17_notg_finetune_FIXTYPE_2025-05-07/checkpoints" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: true + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 64 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8 + # Encoder: + NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-base-small" + num_model_parameters: '1.1m' diff --git a/cfgs/scenestreamer-base-xl.yaml b/cfgs/scenestreamer-base-xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..021cd168adf4c4f8bd481728210b9ea61bdea79e --- /dev/null +++ b/cfgs/scenestreamer-base-xl.yaml @@ -0,0 +1,97 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-base-xl' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250512_scenestreamer-base-xl_2025-05-12/checkpoints" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: true + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 128 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 8 + # Encoder: + NUM_ATTN_LAYERS: 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-base-xl" + num_model_parameters: '4.2m' + ACCOUNT: "dr.zhenghao.peng@gmail.com" diff --git a/cfgs/scenestreamer-full-large-nors.yaml b/cfgs/scenestreamer-full-large-nors.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fee210a803876fdc38182edf655574137dd0f1a9 --- /dev/null +++ b/cfgs/scenestreamer-full-large-nors.yaml @@ -0,0 +1,99 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-full-large-nors' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer-full-large_2025-05-07/checkpoints/20250507_scenestreamer-full-large_2025-05-07_epoch=1-step=77031.ckpt" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: false + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 128 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8 + # Encoder: + NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-full-large" + num_model_parameters: '4.6m' + +EVALUATION: + TG_REJECT_SAMPLING: False \ No newline at end of file diff --git a/cfgs/scenestreamer-full-large.yaml b/cfgs/scenestreamer-full-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..06c2776ee5a97df3f15e1e35ef73e687197cfec9 --- /dev/null +++ b/cfgs/scenestreamer-full-large.yaml @@ -0,0 +1,99 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-full-large' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer-full-large_2025-05-07/checkpoints/20250507_scenestreamer-full-large_2025-05-07_epoch=1-step=77031.ckpt" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: false + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 128 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8 + # Encoder: + NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-full-large" + num_model_parameters: '4.6m' + +EVALUATION: + TG_REJECT_SAMPLING: True \ No newline at end of file diff --git a/cfgs/scenestreamer-full-small.yaml b/cfgs/scenestreamer-full-small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a9f2b76aa862411aef255758c2fe7e21face065 --- /dev/null +++ b/cfgs/scenestreamer-full-small.yaml @@ -0,0 +1,96 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-full-small' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250505_scenestreamer_v19_withtg_nodest_FIXEDAS_2025-05-05/checkpoints" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: false + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 64 # TODO: Need to increase? was 128 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8 + # Encoder: + NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-full-small" + num_model_parameters: '1.5m' diff --git a/cfgs/scenestreamer-full-xl.yaml b/cfgs/scenestreamer-full-xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcba6525c728e65554c875010abdbc6809f836d9 --- /dev/null +++ b/cfgs/scenestreamer-full-xl.yaml @@ -0,0 +1,100 @@ +defaults: + - motion_default + - _self_ + +exp_name: 'scenestreamer-full-xl' +pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250518_scenestreamer-full-xs_2025-05-18/checkpoints/20250518_scenestreamer-full-xs_2025-05-18_epoch=1-step=62494.ckpt" + +num_workers: 8 +val_num_workers: 8 +num_sanity_val_steps: 10 + +batch_size: 4 +val_batch_size: 4 +limit_val_batches: -1 + +eval_backward_model: False + +epochs: 30 +wandb: True +log_dir: /bigdata/zhenghao/scenestreamer + +SCENESTREAMER_ATTENTION_KNN: 128 +SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50 +SCENESTREAMER_NO_TG: false + +REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<< + +BACKWARD_PREDICTION: False # <<< +ADD_CONTOUR_RELATION: True # <<< + +DELTA_POS_IS_VELOCITY: True +SIMPLE_RELATION: True + +RECONSTRUCT_MAP: False +UPDATE_RELATION: False +REMOVE_REL_NORM: False # <<< + +USE_TRAFFICGEN: True +USE_MOTION: True +EVAL_MOTION: True +EVAL_TRAFFICGEN: False + +GPT_STYLE: True # <<< +USE_ADALN: False + +SAMPLING: + TOPP: 0.95 + TEMPERATURE: 1.0 + +TOKENIZATION: + TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<< + USE_CONTOUR_ERROR: True # <<< + ALLOW_SKIP_STEP: True + ADD_NOISE: False + NUM_BINS: 33 + +PREPROCESSING: + REMOVE_TRAFFIC_LIGHT_STATE: False + MAX_LENGTH_PER_MAP_FEATURE: 10 + MAX_MAP_FEATURES: 3000 + MAX_VECTORS: 30 + MAX_AGENTS: 128 + DEST_DROPOUT: 0.0 + ADD_SDC_TO_OBJECT_OF_INTEREST: False + +DATA: + TRAINING_DATA_DIR: '' + TEST_DATA_DIR: '' + +MODEL: + USE_MOTION_HEAD_PRENORM: True + ALL_TO_MAP_3D: False + D_MODEL: 128 + NAME: 'scenestreamer' + NUM_ATTN_HEAD: 8 + # Encoder: + NUM_ATTN_LAYERS: 3 + RELATIVE_PE: true + # Decoder: + NUM_DECODER_LAYERS: 6 + RELATIVE_PE_DECODER: True + SIMPLE_RELATION_FACTOR: 1 + # New: + KNN: -100 + S2S_DISTANCE: -100 + A2S_KNN: -100 + A2S_DISTANCE: -100 + A2A_KNN: -100 + A2A_DISTANCE: -100 + ADD_RELATION_TO_V: False + PER_CONTOUR_POINT_RELATION: False + IS_V7: True + +SUBMISSION: + METHOD_NAME: "scenestreamer-full-xl" + num_model_parameters: '5.5m' + ACCOUNT: "dr.zhenghao.peng@gmail.com" + +EVALUATION: + TG_REJECT_SAMPLING: True \ No newline at end of file diff --git a/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a540fa76f645a58878b138eca0263cc544e67df2 --- /dev/null +++ b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2f5a926ba159d4e9acec464c9d091c093d138a21796ce5c264fea7f4398a777 +size 3007314 diff --git a/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl new file mode 100644 index 0000000000000000000000000000000000000000..78a80f35e8129c583de92c066e4ba379c06d6921 --- /dev/null +++ b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8a001efb1d464a0f9a3c06f26cf921f2308e81b26a65741491875201ede70b1 +size 6364095 diff --git a/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl new file mode 100644 index 0000000000000000000000000000000000000000..32e3fd2950df261b1406bb2d35a504432641b47f --- /dev/null +++ b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:742aa8c350793d83e949396dcb055a17bcbdd4b2b728b41d5f1b5c6c5a897ce1 +size 4382994 diff --git a/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..1958506a3b1d79507a107a11e826ed36c8fa06ee --- /dev/null +++ b/data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:716fa8212d8d4dbed60703cf5a9a952129fab5ed9342da748a7a48f00478e6d9 +size 11279523 diff --git a/data/20scenarios/process.ipynb b/data/20scenarios/process.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..52630f06f78ade7b1e0b52317b9016be4ac3809c --- /dev/null +++ b/data/20scenarios/process.ipynb @@ -0,0 +1,128 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "f003e6e4", + "metadata": {}, + "outputs": [], + "source": [ + "import os, pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fdb6f299", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0m\u001b[01;32mdataset_summary.pkl\u001b[0m*\r\n", + "\u001b[01;32mprocess.ipynb\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_18840a098288507f.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl\u001b[0m*\r\n", + "\u001b[01;32msd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl\u001b[0m*\r\n" + ] + } + ], + "source": [ + "ls" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f08caf3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl',\n", + " 'sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl']" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[p for p in os.listdir(\".\") if p.endswith(\".pkl\") and p.startswith(\"sd\")]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d1d50b21", + "metadata": {}, + "outputs": [], + "source": [ + "d = {}\n", + "for p in [p for p in os.listdir(\".\") if p.endswith(\".pkl\") and p.startswith(\"sd\")]:\n", + " d[p] = {}\n", + "\n", + "pickle.dump(d, open(\"dataset_summary.pkl\", \"wb\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f0f2313e618f7a612a654af756b3a7d2e2cd0c71 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de3ceb564e944fda1d5b8e2f72428ada2b790d5d57195e6d9bca2cdf761a37f0 +size 338026 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7d416cd271320c67d0e945bc39e9bb2007783052 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72cae539b2993d4dd668e5029460e2cdbf621dc1734d69ce30b357a156bc8375 +size 535617 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl new file mode 100644 index 0000000000000000000000000000000000000000..326149dd408ac74d17771d5b2dded566eaabcb60 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:383b8cdeaefc13c15751f3ff29d28426543c24b6de674e0c5434e39f5d7a8d1f +size 1060327 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2aace67c7b633e088fba39420ddeb277339c0adf --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77209cb776b340b849ff123c7c9019d17495df41bf69fe2cfc765d5c8235fc66 +size 680549 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c99281cefc41c3b44eb88b4df0dfac999b2849b2 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf6fc9426a5b209d4f0862ff8046844f40fbf822fcf54a24947f23aa0f140161 +size 426242 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl new file mode 100644 index 0000000000000000000000000000000000000000..fd193fd274f282ff10c10a9579ee4af5b80c2998 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddfd0775eedec97b30316beb8c537e3d768928abbd255030403c5bdff59a7f85 +size 837708 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl new file mode 100644 index 0000000000000000000000000000000000000000..fafe9548aa6d415559d090351b96c95a2e95bd62 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b2f40bfe73b50954987f15a462a4a86659615e416b577b1c1096c8c69ef0d05 +size 667369 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f627d1b0235779b54a4df59af6c60f4c0fcf4eb2 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30f1f5e1d3a5863a88b919a79c8c4a75bff2774df99412a092706d5cebc7ffbb +size 319336 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl new file mode 100644 index 0000000000000000000000000000000000000000..283cf76d81405264d70ed2ec9db8aa0de6774008 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bacb875cdcd59d4c5ada7c1a2b546a31967de6f15f288d81136d2f3cc12f2413 +size 560466 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3b2d08f3544aa76998159bc6da6d63b2195b85b8 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5194b6697b1fa3d00113c8ba2f39d50ffb8066b004c8f2f5d15953a550027e7c +size 456308 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a64fc3cbde45351d07b4a31ccc0287572bc938de --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfce54f80da166a8cbaf5fe167770975dd4dfcc9bba3c29ad228ad112d995690 +size 396202 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl new file mode 100644 index 0000000000000000000000000000000000000000..aa230bfbbe42dff086768d47dc9b8c29d6eccce7 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a15fb32533b7472c38ec32c676a9425490de2c95d8dd39b9aa2a6b98f1050512 +size 253625 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl new file mode 100644 index 0000000000000000000000000000000000000000..1362443785c7ee0f28c016f9640e38aa87d97b60 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d126e0c7c1e4afa561739437ad7c7810830c766a45b725fb8cb8a33015029e61 +size 490897 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl new file mode 100644 index 0000000000000000000000000000000000000000..920285ba0703e8e9be3236c00708812031d45b8a --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffbd89555db615f2d20516fedf2dd3689b4487221a82c9f48251e29c30027f31 +size 627272 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl new file mode 100644 index 0000000000000000000000000000000000000000..af778cee935e0fa5c399c3f0a29459b7609cdb61 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e93b1b9cbfd8c541c5c52523aee09f5abd3410d217bf6e402ec217935f5df08 +size 435189 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8e89960200611667c977b9f65f9251ee75bfaaa0 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11ac7db4a7c917c8045c957ee93d2f761de089152fdb9b03554a889d4f1deca3 +size 376448 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7b04277fc22d5e4715e7728f573ec6f021099a2f --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cb753a3c0d7255b5d4333a95c5555b4b159dcc600f9da286787ab37893c70d8 +size 1558365 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl new file mode 100644 index 0000000000000000000000000000000000000000..34a4313565b3145c443b70c097756359784d7472 --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ac83f07304e39a8c5e468399b12e3ed034d6855b2921ab094ee425e492bea6d +size 1319602 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl new file mode 100644 index 0000000000000000000000000000000000000000..25e3dc68df105be8c9544f65f075da811320989e --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c81210cec08ec7bc5813e59bf4a2f8287800c2d626929441a0e7398f7283c509 +size 1316507 diff --git a/data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl b/data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..959ecb867507ba6ba66543281866c8b7ab0ecb4f --- /dev/null +++ b/data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26feb91313bee0c89ff1a9660fc352ea687daf04a7ca642f600d2ece17a6e301 +size 1079778 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..74697efd20b1c1938c37efc63dd1a00d1f8870a3 --- /dev/null +++ b/packages.txt @@ -0,0 +1,7 @@ +ffmpeg +libgl1 +libglib2.0-0 +libsm6 +libxext6 +libxrender1 +libsdl2-2.0-0 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f27486b73f20e11d8cee866ab46c1603765fbc0d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,80 @@ +[project] +name = "scenestreamer" +version = "1.0.0" +description = "SceneStreamer: Continuous Scenario Generation as Next Token Group Prediction" +readme = "README.md" +license = {text = "MIT"} +authors = [ + {name = "Zhenghao Peng", email = "pzh@berkeley.edu"} +] +requires-python = ">=3.10,<3.12" +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + "torch>=2.0.0", + "torchvision", + "lightning>=2.0.0", + "hydra-core", + "omegaconf", + "numpy", + "tqdm", + "matplotlib", + "seaborn", + "Pillow", + "easydict", + "wandb", + "torch_geometric", + "transformers", + "tokenizers", + "huggingface_hub", + "tensorboardX", + "pyyaml", + "scikit-image", + "chardet", + "charset-normalizer", + "tabulate", + "metadrive-simulator", + "gradio>=6.9.0", + "scenarionet @ git+https://github.com/metadriverse/scenarionet.git", +] + +[project.scripts] +scenestreamer = "scenestreamer.cli:main" + +[project.optional-dependencies] +dev = [ + "ruff", + "pytest", +] +rl = [ + "stable-baselines3>=2.0.0", + "gymnasium>=0.29.0", + "ipython", +] +# Note: waymo-open-dataset requires Python 3.10 and specific numpy versions. +# Install separately: pip install waymo-open-dataset-tf-2-12-0==1.6.4 + +[project.urls] +Homepage = "https://vail-ucla.github.io/scenestreamer/" +Repository = "https://github.com/pengzhenghao/scenestreamer" + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["scenestreamer*"] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W"] +ignore = ["E501"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c558e357c41674e39880abb6c3209e539de42e2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +. diff --git a/scenestreamer/__init__.py b/scenestreamer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/cli.py b/scenestreamer/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..43c9edeebe482bd832a02b3125743e008e20570d --- /dev/null +++ b/scenestreamer/cli.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import argparse +import json +import os +import pathlib +import runpy +import sys +import time +from dataclasses import dataclass +from typing import Any + +import yaml + + +def _to_plain(obj: Any) -> Any: + if hasattr(obj, "items"): + return {k: _to_plain(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_to_plain(v) for v in obj] + return obj + + +def _to_easydict(obj: Any): + from easydict import EasyDict + + if isinstance(obj, dict): + return EasyDict({k: _to_easydict(v) for k, v in obj.items()}) + if isinstance(obj, list): + return [_to_easydict(v) for v in obj] + return obj + + +def load_yaml_config(path: str | os.PathLike[str]): + with open(path, "r") as f: + data = yaml.safe_load(f) + return _to_easydict(data) + + +def apply_overrides(cfg, overrides: list[str]) -> None: + """ + Apply overrides of form KEY=VALUE where KEY is dot-delimited. + VALUE is parsed using yaml.safe_load (so numbers/bools/lists work). + """ + for item in overrides: + if "=" not in item: + raise ValueError(f"Invalid override (expected KEY=VALUE): {item}") + key, raw_val = item.split("=", 1) + value = yaml.safe_load(raw_val) + + cur = cfg + parts = key.split(".") + for p in parts[:-1]: + if not hasattr(cur, p): + setattr(cur, p, _to_easydict({})) + cur = getattr(cur, p) + setattr(cur, parts[-1], value) + + +@dataclass(frozen=True) +class RunPaths: + run_dir: pathlib.Path + config_path: pathlib.Path + metrics_path: pathlib.Path + + +def make_run_dir(base_dir: str | os.PathLike[str], run_id: str | None) -> RunPaths: + base = pathlib.Path(base_dir) + base.mkdir(parents=True, exist_ok=True) + if run_id is None: + run_id = time.strftime("%Y%m%d-%H%M%S") + run_dir = base / run_id + run_dir.mkdir(parents=True, exist_ok=False) + return RunPaths( + run_dir=run_dir, + config_path=run_dir / "config.yaml", + metrics_path=run_dir / "metrics.json", + ) + + +def cmd_preprocess(args: argparse.Namespace) -> None: + # Prefer failing fast on missing ScenarioNet without importing heavy deps. + try: + import scenarionet # noqa: F401 + except ModuleNotFoundError as e: + raise e + + from scenestreamer.dataset.dataset import SceneStreamerDataset + + cfg = load_yaml_config(args.config) + apply_overrides(cfg, args.set or []) + + # Paths: prefer CLI args, but allow config overrides. + if args.train_dir: + cfg.DATA.TRAINING_DATA_DIR = args.train_dir + if args.test_dir: + cfg.DATA.TEST_DATA_DIR = args.test_dir + + cfg.DATA.USE_CACHE = True + + run = make_run_dir(args.artifacts_dir, args.run_id) + with open(run.config_path, "w") as f: + yaml.safe_dump(_to_plain(cfg), f, sort_keys=False) + + mode = args.split + ds = SceneStreamerDataset(cfg, mode) + + # Iterate to materialize cache files. + for i in range(len(ds)): + _ = ds[i] + if args.limit is not None and (i + 1) >= args.limit: + break + + metrics = { + "status": "ok", + "mode": mode, + "train_dir": getattr(cfg.DATA, "TRAINING_DATA_DIR", None), + "test_dir": getattr(cfg.DATA, "TEST_DATA_DIR", None), + "limit": args.limit, + } + with open(run.metrics_path, "w") as f: + json.dump(metrics, f, indent=2) + + print(str(run.run_dir)) + + +def _load_model_from_args(args: argparse.Namespace): + import torch + + from scenestreamer.utils import utils + + device = torch.device(args.device) + if args.hf_repo: + return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=device) + if args.ckpt: + return utils.get_model(checkpoint_path=args.ckpt, device=device) + raise ValueError("Must provide either --hf-repo/--hf-file or --ckpt") + + +def cmd_table1(args: argparse.Namespace) -> None: + from scenestreamer.paper.table1_mmd import run_table1_mmd + + pl_model = _load_model_from_args(args) + run_dir = run_table1_mmd( + pl_model=pl_model, + dataset_dir=args.dataset_dir, + split=args.split, + limit=args.limit, + artifacts_dir=args.artifacts_dir, + run_id=args.run_id, + seed=args.seed, + ) + print(str(run_dir)) + + +def cmd_table2(args: argparse.Namespace) -> None: + from scenestreamer.paper.table2_motion import run_table2_motion + + pl_model = _load_model_from_args(args) + run_dir = run_table2_motion( + pl_model=pl_model, + dataset_dir=args.dataset_dir, + split=args.split, + mode=args.mode, + num_modes=args.num_modes, + limit=args.limit, + artifacts_dir=args.artifacts_dir, + run_id=args.run_id, + seed=args.seed, + ) + print(str(run_dir)) + + +def cmd_densify_demo(args: argparse.Namespace) -> None: + from scenestreamer.paper.densify_demo import run_densify_demo + + pl_model = _load_model_from_args(args) + run_dir = run_densify_demo( + pl_model=pl_model, + dataset_dir=args.dataset_dir, + split=args.split, + scenario_index=args.scenario_index, + max_agents=args.max_agents, + force_no_end=args.force_no_end, + artifacts_dir=args.artifacts_dir, + run_id=args.run_id, + seed=args.seed, + ) + print(str(run_dir)) + +def _run_module_as_main(module: str, argv: list[str]) -> None: + old_argv = sys.argv[:] + try: + sys.argv = [module] + argv + runpy.run_module(module, run_name="__main__") + finally: + sys.argv = old_argv + + +def cmd_table3_train(args: argparse.Namespace) -> None: + _run_module_as_main("scenestreamer.rl_train.train.train_td3", args.args) + + +def cmd_table3_eval(args: argparse.Namespace) -> None: + _run_module_as_main("scenestreamer.rl_train.train.eval_policy", args.args) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog="scenestreamer", description="SceneStreamer paper reproduction CLI") + sub = parser.add_subparsers(dest="cmd", required=True) + + def add_run_args(p: argparse.ArgumentParser) -> None: + p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts") + p.add_argument("--run-id", default=None, help="Run ID (default: timestamp)") + p.add_argument("--seed", type=int, default=0) + + def add_model_args(p: argparse.ArgumentParser) -> None: + p.add_argument("--device", default="cuda", help="torch device string, e.g. cuda or cpu") + p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint") + p.add_argument("--hf-repo", default=None, help="HuggingFace repo id, e.g. user/repo") + p.add_argument("--hf-file", default=None, help="HuggingFace filename, e.g. model.ckpt") + + # preprocess + p = sub.add_parser("preprocess", help="Preprocess ScenarioNet SD dataset and build cache") + add_run_args(p) + p.add_argument("--config", default="cfgs/motion_default.yaml") + p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)") + p.add_argument("--train-dir", default=None) + p.add_argument("--test-dir", default=None) + p.add_argument("--split", choices=["training", "test"], default="training") + p.add_argument("--limit", type=int, default=None) + p.set_defaults(func=cmd_preprocess) + + # table1 + p = sub.add_parser("table1", help="Table 1: initial state MMD (strict + relaxed)") + add_run_args(p) + add_model_args(p) + p.add_argument("--dataset-dir", required=True) + p.add_argument("--split", choices=["training", "test"], default="test") + p.add_argument("--limit", type=int, default=None) + p.set_defaults(func=cmd_table1) + + # table2 + p = sub.add_parser("table2", help="Table 2: motion prediction (ADE/FDE + ADD/FDD)") + add_run_args(p) + add_model_args(p) + p.add_argument("--dataset-dir", required=True) + p.add_argument("--split", choices=["training", "test"], default="test") + p.add_argument("--mode", choices=["motion", "full"], default="motion") + p.add_argument("--num-modes", type=int, default=6) + p.add_argument("--limit", type=int, default=None) + p.set_defaults(func=cmd_table2) + + # demo + p = sub.add_parser("densify-demo", help="Qualitative densification demo (generate to max agents)") + add_run_args(p) + add_model_args(p) + p.add_argument("--dataset-dir", required=True) + p.add_argument("--split", choices=["training", "test"], default="test") + p.add_argument("--scenario-index", type=int, default=0) + p.add_argument("--max-agents", type=int, default=128) + p.add_argument("--force-no-end", action="store_true", help="Disable end token so it keeps generating agents") + p.set_defaults(func=cmd_densify_demo) + + p = sub.add_parser("table3-train", help="Table 3: RL training (pass-through to train_td3.py)") + p.add_argument("args", nargs=argparse.REMAINDER, help="Arguments forwarded to train_td3.py") + p.set_defaults(func=cmd_table3_train) + + p = sub.add_parser("table3-eval", help="Table 3: RL evaluation (pass-through to eval_policy.py)") + p.add_argument("args", nargs=argparse.REMAINDER, help="Arguments forwarded to eval_policy.py") + p.set_defaults(func=cmd_table3_eval) + + return parser + + +def main(argv: list[str] | None = None) -> None: + parser = build_parser() + args = parser.parse_args(argv) + try: + args.func(args) + except ModuleNotFoundError as e: + # Most common in a fresh environment: scenarionet / waymo-open-dataset missing. + msg = str(e) + if "scenarionet" in msg: + raise SystemExit( + "Missing dependency 'scenarionet'. Install it via:\n" + " pip install git+https://github.com/metadriverse/scenarionet.git\n" + ) from e + raise + + +if __name__ == "__main__": + main() diff --git a/scenestreamer/clustering.sh b/scenestreamer/clustering.sh new file mode 100644 index 0000000000000000000000000000000000000000..789a30d54330eca01ff9defb710415b450aa8885 --- /dev/null +++ b/scenestreamer/clustering.sh @@ -0,0 +1,7 @@ +nohup python clustering.py --data 3 > clustering_obj3_all_nomin.log 2>&1 & +nohup python clustering.py --data 2 > clustering_obj2_all_nomin.log 2>&1 & +nohup python clustering.py --data 1 > clustering_obj1_all_nomin.log 2>&1 & + +#nohup python clustering.py --data 3 --min_scale 0.5 > clustering_obj3_all.log 2>&1 & +#nohup python clustering.py --data 2 --min_scale 0.5 > clustering_obj2_all.log 2>&1 & +#nohup python clustering.py --data 1 --min_scale 0.5 > clustering_obj1_all.log 2>&1 & \ No newline at end of file diff --git a/scenestreamer/dataset/__init__.py b/scenestreamer/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/dataset/constants.py b/scenestreamer/dataset/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a204c5f0cee74519891e5c003553a2b6797502a3 --- /dev/null +++ b/scenestreamer/dataset/constants.py @@ -0,0 +1,44 @@ +""" +Define a lot of constants. It should be totally removed as most of them should be defined by MetaDrive / ScenarioNet. +""" +from metadrive.scenario.scenario_description import MetaDriveType + +# NUM_TYPES = 3 +NUM_TYPES = 5 + +MAP_FEATURE_STATE_DIM = 27 +TRAFFIC_LIGHT_STATE_DIM = 7 + +AGENT_STATE_DIM = 16 + +# ACTOR_PREDICT_DIM = 6 + 2 + 4 + 5 # 3 for position, 1 for heading, 2 for velocity, 5 for types +TRAFFIC_LIGHT_PREDICT_DIM = 9 # 9 original possible state + +# TODO(pzh): Do we have to do the normalization? Shouldn't the layer norm solve this? +# POSITION_XY_RANGE = 100. +# LOCAL_POSITION_XY_RANGE = 5. +# HEADING_RANGE = np.pi +# VELOCITY_XY_RANGE = 10. +# SIZE_RANGE = 5. +# MAP_VECTOR_XY_RANGE = 50. + +# TODO(pzh): Consider remove this. +object_type_to_int = { + MetaDriveType.UNSET: 0, + MetaDriveType.VEHICLE: 1, + MetaDriveType.PEDESTRIAN: 2, + MetaDriveType.CYCLIST: 3, + MetaDriveType.OTHER: 4 +} + +# TODO(pzh): Consider remove this. +object_int_to_type = { + -1: MetaDriveType.UNSET, + 0: MetaDriveType.UNSET, + 1: MetaDriveType.VEHICLE, + 2: MetaDriveType.PEDESTRIAN, + 3: MetaDriveType.CYCLIST, + 4: MetaDriveType.OTHER +} + +HEADING_PLACEHOLDER = -100 # For the object that has no heading, set this. diff --git a/scenestreamer/dataset/datamodule.py b/scenestreamer/dataset/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5bed71ab56cac67a09d19c6417fc1838d290cf --- /dev/null +++ b/scenestreamer/dataset/datamodule.py @@ -0,0 +1,49 @@ +""" +This is a wrapper to wrap our dataset as a lightning datamodule. +""" +import lightning.pytorch as pl +from torch.utils.data import DataLoader + +from scenestreamer.dataset import dataset + + +class SceneStreamerDataModule(pl.LightningDataModule): + def __init__( + self, config, train_batch_size, train_num_workers, train_prefetch_factor, val_batch_size, val_num_workers, + val_prefetch_factor + ): + super().__init__() + self.config = config + self.train_batch_size = train_batch_size + self.train_num_workers = train_num_workers + self.train_prefetch_factor = train_prefetch_factor + self.val_batch_size = val_batch_size + self.val_num_workers = val_num_workers + self.val_prefetch_factor = val_prefetch_factor + + def setup(self, stage: str): + self.train_dataset = dataset.SceneStreamerDataset(config=self.config, mode="training") + self.val_dataset = dataset.SceneStreamerDataset(config=self.config, mode="test") + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + pin_memory=True, + num_workers=self.train_num_workers, + shuffle=True, + persistent_workers=True if self.train_num_workers > 0 else False, + collate_fn=self.train_dataset.collate_batch, + prefetch_factor=self.train_prefetch_factor if self.train_num_workers > 0 else None, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + pin_memory=True, + num_workers=self.val_num_workers, + shuffle=False, + collate_fn=self.val_dataset.collate_batch, + prefetch_factor=self.val_prefetch_factor if self.val_num_workers > 0 else None, + ) diff --git a/scenestreamer/dataset/dataset.py b/scenestreamer/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f82c0c22e15d0b4026c47e010ede99a4dff2eaf6 --- /dev/null +++ b/scenestreamer/dataset/dataset.py @@ -0,0 +1,630 @@ +""" +Create a pytorch dataset class for loading scenario files and padding data entries. +""" +import copy +import json +import os +import pathlib +import pickle + +try: + import hydra +except ModuleNotFoundError: # optional for core library usage + hydra = None +import numpy as np +from scenarionet import read_dataset_summary, read_scenario +from torch.utils.data import Dataset + +from scenestreamer.dataset.preprocessor import preprocess_scenario_description +from scenestreamer.utils import global_config +from scenestreamer.utils import utils + +# import lmdb + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent +QA_DATASET_MAPPING = {} +ADV_INFO_DICT = {} + + +class NoMapFeatureError(Exception): + pass + + +class LMDBDatasetReader: + def __init__(self, base_path): + self.base_path = base_path + # Load the lookup table that maps sample keys to LMDB file names + # Search recursively all subfolder to find lookup.json + self.lookup = {} + for root, dirs, files in os.walk(self.base_path): + if "lookup.json" in files: + lookup_path = os.path.join(root, "lookup.json") + with open(lookup_path, "r") as f: + lookup = json.load(f) + self.lookup.update(lookup) + self.lmdb_cache = {} # Cache for open LMDB environments + + +# def _get_lmdb_env(self, lmdb_name): +# """Fetches or opens an LMDB environment for reading.""" +# if lmdb_name not in self.lmdb_cache: +# self.lmdb_cache[lmdb_name] = lmdb.open(lmdb_name, readonly=True) +# return self.lmdb_cache[lmdb_name] + +# def load_sample(self, key): +# """Loads a preprocessed sample by key.""" +# lmdb_name = self.lookup.get(key) +# if lmdb_name is None: +# raise KeyError(f"Sample {key} not found in lookup.") +# env = self._get_lmdb_env(lmdb_name) +# with env.begin() as txn: +# npz_bytes = txn.get(key.encode('ascii')) +# if npz_bytes: +# with io.BytesIO(npz_bytes) as buffer: +# data = np.load(buffer, allow_pickle=True) +# return {name: data[name] for name in data.files} # Return data as a dictionary +# return None + +# def close(self): +# """Closes all open LMDB environments.""" +# for env in self.lmdb_cache.values(): +# env.close() + + +def process_QA_text_label(QA_dict): + # TODO: do we need to form label for each individual agent? Rightnow it is just a single label + labels = {} + + env_a = QA_dict['env_a'] + labels['env'] = ' '.join(env_a) + + ego_a = QA_dict['ego_a'] + labels['ego'] = ' '.join(ego_a) + + int_a = QA_dict['int_a'] + labels['int'] = ' '.join(int_a) + + return labels + + +def get_file_paths(directory): + file_paths = [] + # Traverse the directory + for root, dirs, files in os.walk(directory): + for file in files: + # Get the full path and add it to the list + full_path = os.path.join(root, file) + file_paths.append(full_path) + return + + +def load_json_to_dict(file_path): + """ + Load a JSON file into a Python dictionary. + + :param file_path: Path to the JSON file + :return: Dictionary containing the JSON data + """ + try: + with open(file_path, 'r') as file: + data = json.load(file) + return data + except FileNotFoundError: + print(f"Error: The file at {file_path} was not found.") + except json.JSONDecodeError: + print(f"Error: The file at {file_path} is not a valid JSON file.") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + return None + + +class SceneStreamerDataset(Dataset): + """ + SceneStreamer dataset class. Returns data_dict for each scenario. + Init args: + mode: "training" or "test". + config: + - model: Details about the model architecture. + - data: Data directories, sample intervals, number of agents, etc. + - evaluation: predict_all_agents, delete_eval_result (TODO: Add ScenarioDescription passthrough as a flag in the config.) + - optimization: Training hyperparameters. + - preprocessing: Max number of agents, map features, traffic lights, padding, etc. + - root_dir: Self-explanatory. + - sampling: Inference sampling parameters. + - tokenization: The part of the config passed to the tokenizer. + """ + def __init__(self, config, mode): + super().__init__() + self.mode = mode + self.config = config + dataset_cfg = self.config.DATA + + self.max_map_features = config.PREPROCESSING.MAX_MAP_FEATURES + self.max_vectors_per_map_feature = config.PREPROCESSING.MAX_VECTORS + self.max_agents = config.PREPROCESSING.MAX_AGENTS + self.max_traffic_lights = config.PREPROCESSING.MAX_TRAFFIC_LIGHTS + self.padding_to_max = config.PREPROCESSING.PADDING_TO_MAX + + # We are expecting the data_dir to be either an absolute path or a relative path w.r.t. the repo root. + if mode == "training": + self.data_dir = global_config.ROOT_DIR / dataset_cfg.TRAINING_DATA_DIR + elif mode == "test": + self.data_dir = global_config.ROOT_DIR / dataset_cfg.TEST_DATA_DIR + else: + raise ValueError(f"Unknown mode {mode}.") + + # summary_dict: A dictionary of .pkl filenames to ingest. Filenames (keys) are mapped to metadata objects. + # summary_list: Keys of summary_dict, in order of ingestion. + # mapping: A dict mapping scenario IDs to the folder that hosts their files. + summary_dict, summary_list, mapping = read_dataset_summary(self.data_dir) + + # We might want to use a subset of scenarios. + if self.mode == "training": + interval = dataset_cfg.SAMPLE_INTERVAL_TRAINING + elif self.mode == "test": + interval = dataset_cfg.SAMPLE_INTERVAL_TEST + else: + raise ValueError(f"Unknown mode {self.mode}.") + + if "SD_PASSTHROUGH" in config.DATA: + self.return_scenario_description = config.DATA["SD_PASSTHROUGH"] + else: # Default to False. + self.return_scenario_description = False + + summary_list = summary_list[::interval] + # self.data_summary_dict = {k: summary_dict[k] for k in summary_list} + self.data_mapping = {k: mapping[k] for k in summary_list} + self.length = len(summary_list) + self.use_cache_logged = False + + if self.config.BACKWARD_PREDICTION and self.mode == "training": + self.real_length = self.length + self.length = self.length * 2 + + # Convert each string to sequence of codepoints (integer), + # and then pack them into a numpy array. + # NOTE(pzh): I forgot why I wrote this. Seems like some issues in multiprocessing. + + # seqs: A list of np.arrays, each representing the ascii values of a string. + seqs = [utils.string_to_sequence(s) for s in summary_list] + + # strings_v: ascii values of all strings, concatenated. + # strings_o: offsets of each string in strings_v. + if len(seqs) == 0: + raise ValueError("No scenarios found in the dataset: {}".format(self.data_dir)) + self.strings_v, self.strings_o = utils.pack_sequences(seqs) + + # if self.config.DATA.USE_LMDB and self.mode == "training": + # cache_folder = pathlib.Path(self.data_dir) / "cache" + # assert cache_folder.is_dir() + # self.reader = LMDBDatasetReader(cache_folder) # LMDB Reader to load samples + + from scenestreamer.tokenization import get_tokenizer + self.tokenizer = get_tokenizer(config=self.config) + + def __len__(self): + return self.length + + def __getitem__(self, index): + # Unpack the stored codepoints at the correct index into a filename string. + + use_backward_prediction = False + if self.config.BACKWARD_PREDICTION and self.mode == "training": + if index >= self.real_length: + index = index - self.real_length + use_backward_prediction = True + + seq = utils.unpack_sequence(self.strings_v, self.strings_o, index) + string = utils.sequence_to_string(seq) + file_name = string + + try: + data_dict = self.create_scene_level_data(file_name, index, use_backward_prediction) + except NoMapFeatureError: + # This is workaround for Waymo test set where some scenarios do not have map features. + return self.__getitem__(index + 1) + + # If self.return_scenario_description is true, data_dict has an extra key [raw_scenario_description] that contains the ScenarioDescription object. + return data_dict + + def create_scene_level_data(self, file_name, index, use_backward_prediction=False): + """ + Reads a scenario file and preprocesses it. + """ + assert not self.config.DATA.USE_LMDB, "LMDB is not supported." + try: + # scenario: A ScenarioDescription instance. + cache = None + scenario = None + cache_path = None + + if self.config.DATA.USE_CACHE: + cache_folder = pathlib.Path(self.data_dir) / "cache" + if cache_folder.is_dir() is False: + cache_folder.mkdir(exist_ok=True) + + cache_path = pathlib.Path(self.data_dir) / "cache" / file_name + if cache_path.is_file(): + + try: + with open(cache_path, "rb") as f: + cache = pickle.load(f) + + if self.use_cache_logged is False: + print("=====================================") + print("=====================================") + print("\t*** WARNING ***") + print("\tYou are using cache files!!!") + print("\tIn folder: ", cache_folder) + print("\tThere are ", len(list(cache_folder.glob("*"))), " cache files!!!") + print("=====================================") + print("=====================================") + + self.use_cache_logged = True + + return cache + except EOFError as e: + print(f"Error in reading cache file: {cache_path=}") + + scenario = read_scenario( + dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name + ) + + else: + scenario = read_scenario( + dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name + ) + # print("Cannot find cache file: ", cache_path, "Creating one.") + + else: + # if self.config.DATA.USE_LMDB and self.mode == "training": + # cache = self.reader.load_sample(file_name) + # else: + scenario = read_scenario( + dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name + ) + + except EOFError as e: + print(f"{self.data_dir=}, {self.data_mapping=}, {file_name=}") + raise e + assert self.mode in ["training", "test"], self.mode + ret = {} + + if len(scenario["map_features"]) == 0: + raise NoMapFeatureError + + if self.return_scenario_description: + ret["raw_scenario_description"] = copy.deepcopy(scenario) + + # TODO: Remove error handling after debugging. + try: + preprocessed_scenario_description = preprocess_scenario_description( + scenario=scenario, + # cache=cache, + config=copy.deepcopy(self.config), + in_evaluation=self.mode != "training", + keep_all_data=self.config.PREPROCESSING.get("keep_all_data", False), + backward_prediction=use_backward_prediction, + tokenizer=self.tokenizer, + # cache_path=cache_path, + ) + preprocessed_scenario_description["file_name"] = file_name + except Exception as e: + print(f"Error in preprocessing {file_name=}, {index=}, {scenario['id']=}") + # Ensure that the exception is not swallowed by adding this. + raise RuntimeError( + f"{file_name=}, {index=}, {scenario['id']=}. Error in create_scene_level_data: {e}" + ) from e + + ret.update(preprocessed_scenario_description) + ret.update({"metadata/scenario_id": scenario['id']}) + + if cache_path is not None: + with open(cache_path, "wb") as f: + pickle.dump(ret, f) + # print("Writing cache file: ", cache_path) + + return ret + + def collate_batch(self, batch_list): + """ + Output format: + + agent_feature: [B, T, #agents, D] + agent_feature_position: [B, T, #agents, 3] + map_feature: [B, T, #mapfeat, #points, D] + map_feature_valid_mask: [B, T, #mapfeat, #points] + map_feature_position: [B, T, #mapfeat, 3] + """ + data_dict_sample = batch_list[0] + + num_map_feat, num_points, _ = data_dict_sample["encoder/map_feature"].shape + + data_dict = {} + object_keys = [ + "raw_scenario_description", + "encoder/track_name", + "decoder/track_name", + "eval/track_name", + # "scenario_id", + # "in_evaluation" + ] # Keys exempt from padding and tensor conversion. + + for k in set(data_dict_sample.keys()): + if k not in object_keys: + if not isinstance(data_dict_sample[k], np.ndarray): + assert isinstance(data_dict_sample[k], (int, float, bool, str)), (k, type(data_dict_sample[k])) + if isinstance(data_dict_sample[k], str): + data_dict[k] = np.array([b[k] for b in batch_list]) + else: + data_dict[k] = utils.numpy_to_torch(np.array([b[k] for b in batch_list])) + continue + # else: + # if batch_list[0][k].dtype == np.object: + # data_dict[k] = [b[k] for b in batch_list] + # continue + + val_list = [utils.numpy_to_torch(b[k]) for b in batch_list] + + # Map features that have vectors' information + if k in [ + "encoder/map_feature", + "vis/map_feature", + "raw/map_feature", + "encoder/map_feature_valid_mask", + ]: + data_dict[k] = utils.padding_1st_and_2nd_dim( + val_list, + max_1st_dim=self.max_map_features if self.padding_to_max else None, + max_2nd_dim=self.max_vectors_per_map_feature if self.padding_to_max else None + ) + + # Map features that have aggregated info from vectors + elif k in [ + "encoder/map_heading", + "encoder/map_position", + "encoder/map_valid_mask", + ]: + data_dict[k] = utils.padding_1st_dim( + val_list, max_1st_dim=self.max_map_features if self.padding_to_max else None + ) + + # Traffic light features that have temporal dim + elif k in [ + "encoder/traffic_light_feature", + "encoder/traffic_light_state", + "encoder/traffic_light_valid_mask", + ]: + + if self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE: + data_dict[k] = utils.padding_1st_dim( + val_list, max_1st_dim=self.max_traffic_lights if self.padding_to_max else None + ) + else: + data_dict[k] = utils.padding_1st_and_2nd_dim( + val_list, max_2nd_dim=self.max_traffic_lights if self.padding_to_max else None + ) + + # Traffic light features that do not have temporal dim + elif k in [ + "encoder/traffic_light_position", + "encoder/traffic_light_heading", + "encoder/traffic_light_map_id", + ]: + data_dict[k] = utils.padding_1st_dim( + val_list, max_1st_dim=self.max_traffic_lights if self.padding_to_max else None + ) + + # Agent features + elif k in [ + "encoder/agent_feature", + "encoder/agent_position", + "encoder/agent_valid_mask", + "encoder/agent_heading", + "encoder/agent_velocity", + "decoder/modeled_agent_position", + "decoder/modeled_agent_heading", + "decoder/modeled_agent_velocity", + "decoder/modeled_agent_delta", + ]: + data_dict[k] = utils.padding_1st_and_2nd_dim( + val_list, max_2nd_dim=self.max_agents if self.padding_to_max else None + ) + + # Other data that does not pass the model or does not need regular shapes + elif k in [ + # "encoder/modeled_agent_id", + # "action_label/labeled_agent_id", + "metadata/map_center", # "decoder/input_step", + # "decoder/input_intra_step", + "encoder/current_agent_heading", + "decoder/current_agent_heading", + "encoder/current_agent_shape", + "decoder/current_agent_shape", + "eval/current_agent_heading", + "encoder/current_agent_valid_mask", + "decoder/current_agent_valid_mask", + "eval/current_agent_valid_mask", + # "decoder/current_agent_valid_mask", # + # "decoder/modeled_agent_indices", + # For gen model: + # "decoder/input_token_valid_mask", + # "decoder/should_predict", + # "decoder/is_gt", + # "eval/should_predict_motion", + ]: + data_dict[k] = utils.padding_1st_dim(val_list) + + elif k in [ + "decoder/input_action_valid_mask", + "encoder/current_agent_position", + "decoder/current_agent_position", + "encoder/current_agent_velocity", + "decoder/current_agent_velocity", + "decoder/target_action_valid_mask", + #"decoder/future_agent_position", + #"decoder/future_agent_heading", + #"decoder/future_agent_valid_mask", + #"decoder/future_agent_velocity", + #"encoder/future_agent_position", + #"encoder/future_agent_heading", + #"encoder/future_agent_valid_mask", + #"encoder/future_agent_velocity", + "decoder/agent_position", + "decoder/agent_heading", + "decoder/agent_velocity", + "decoder/agent_valid_mask", + "eval/agent_velocity", + "eval/agent_heading", + "eval/agent_position", + "eval/agent_valid_mask", + "encoder/agent_shape", + "decoder/agent_shape", + "eval/agent_shape", # "decoder/target_valid_mask", + "decoder/input_agent_motion", + "decoder/target_agent_motion", + "decoder/dest_map_index_valid_mask", + ]: + data_dict[k] = utils.padding_1st_and_2nd_dim(val_list) + + elif k in [ + "encoder/agent_type", + "decoder/agent_type", + "encoder/modeled_agent_type", + "eval/agent_type", # "eval/raw_agent_name", + "encoder/object_of_interest_name", + "decoder/object_of_interest_name", + "metadata/sdc_name", # "eval/modeled_agent_id", + "encoder/object_of_interest_id", + "decoder/object_of_interest_id", + "encoder/modeled_agent_id", # "decoder/modeled_agent_id", + "encoder/agent_id", + "decoder/agent_id", + "decoder/labeled_agent_id", + "decoder/label_turning", + "decoder/label_acceleration", + "decoder/label_safety", + # For gen model: + # "decoder/input_token_id", + # "decoder/causal_mask_offset", + ]: + data_dict[k] = utils.padding_1st_dim(val_list, fill=-1) + + elif k in [ + "decoder/dest_map_index", + "decoder/dest_map_index_gt", + ]: + data_dict[k] = utils.padding_1st_and_2nd_dim(val_list, fill=-1) + + elif k in [ + "decoder/input_action", + "decoder/target_action", + "decoder/input_action_for_trafficgen", + + "decoder/current_agent_shape_for_trafficgen", + "decoder/modeled_agent_heading_for_trafficgen", + "decoder/modeled_agent_position_for_trafficgen", + "decoder/modeled_agent_velocity_for_trafficgen", + "decoder/input_action_valid_mask_for_trafficgen", + "decoder/modeled_agent_delta_for_trafficgen", + "decoder/input_action_feature_for_trafficgen", + "decoder/target_offset_for_trafficgen", + "decoder/input_offset_for_trafficgen", + "decoder/agent_id_for_trafficgen", + "decoder/trafficgen_position", + "decoder/trafficgen_heading", + "decoder/agent_type_for_trafficgen", + ]: + data_dict[k] = utils.padding_all_dims(val_list, fill=-1) + + elif k in object_keys: + # Passthrough: Have the data_dict[object] contain a list of objects. + data_dict[k] = [b[k] for b in batch_list] + + elif k in [ + "encoder/sdc_index", + ]: + pass + + else: + raise ValueError("Unknown key: {}".format(k)) + + return data_dict + + +if hydra is not None: + @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1009_safety_action_debug.yaml") + def debug(config): + test_dataset = SceneStreamerDataset(config, "training") + ddd = iter(test_dataset) + count = 0 + buggy_count = 0 + while True: + if count == 3: + return + try: + data = next(ddd) + count += 1 + + assert data["decoder/label_safety"][data["decoder/labeled_agent_id"]].sum() > 1 + + except StopIteration: + break + + except AssertionError: + print("ni collision") + buggy_count += 1 + print("scenario_id", data["scenario_id"]) + print("data['decoder/label_safety']", data["decoder/label_safety"]) + print("data['decoder/labeled_agent_id']", data["decoder/labeled_agent_id"]) + print("track_name", data["decoder/track_name"][data["decoder/labeled_agent_id"]]) + + print("buggy_count:", buggy_count) + print("count", count) + print("End") + + @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") + def read_traffic_light_state(config): + test_dataset = SceneStreamerDataset(config, "training") + + total_tl = 0 + total_green = 0 + total_yellow = 0 + total_red = 0 + total_unknown = 0 + total_mix = 0 + import tqdm + + for data in tqdm.tqdm(test_dataset): + tl = data["encoder/traffic_light_feature"] + mask = data["encoder/traffic_light_valid_mask"] + + for i in range(tl.shape[1]): + if mask[:, i].any(): + is_green = tl[:, i, 3].astype(bool).any() + is_yellow = tl[:, i, 4].astype(bool).any() + is_red = tl[:, i, 5].astype(bool).any() + is_unknown = tl[:, i, 6].astype(bool).any() + + total_tl += 1 + total_green += is_green + total_yellow += is_yellow + total_red += is_red + total_unknown += is_unknown + total_mix += (is_green and is_yellow) or (is_green and is_red) or (is_yellow and is_red) + + print("total_tl:", total_tl) + print("total_green: {}\t{:.4f}".format(total_green, total_green / total_tl)) + print("total_yellow: {}\t{:.4f}".format(total_yellow, total_yellow / total_tl)) + print("total_red: {}\t{:.4f}".format(total_red, total_red / total_tl)) + print("total_unknown: {}\t{:.4f}".format(total_unknown, total_unknown / total_tl)) + print("total_mix: {}\t{:.4f}".format(total_mix, total_mix / total_tl)) +else: + debug = None + read_traffic_light_state = None + + +if __name__ == '__main__': + # debug() + read_traffic_light_state() diff --git a/scenestreamer/dataset/make_lmdb.py b/scenestreamer/dataset/make_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0b3209df193d717b79fa2168b143bbcc749df6 --- /dev/null +++ b/scenestreamer/dataset/make_lmdb.py @@ -0,0 +1,233 @@ +""" +Only the TRAINING_DATA_DIR will be used in the code below. +Usage: + +python -m scenestreamer.dataset.make_lmdb \ +--config-name="1024_gpt" DATA.TEST_DATA_DIR='data/20scenarios' \ +DATA.TRAINING_DATA_DIR="/data_zhenghao/datasets/scenarionet/CAT_waymo_hybrid/" + +""" +import json +import os +import pathlib +import pickle +import multiprocessing as mp +from functools import partial +import tqdm +import hydra +import lmdb +import omegaconf +import tqdm + +from scenestreamer.dataset.dataset import SceneStreamerDataset + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent + + +class LMDBBulkWriter: + def __init__(self, base_path, max_size=1e9): + """ + Initializes the LMDBBulkWriter to save all data in batches, with map_size for each LMDB file. + Args: + base_path: Directory path to save LMDB files. + max_size: Maximum size of each LMDB file in bytes. + """ + self.base_path = base_path + # Create the cache directory if it doesn't exist + os.makedirs(self.base_path, exist_ok=True) + + self.max_size = int(max_size) # Set the max LMDB file size (e.g., 1 GB) + self.current_db_index = 0 + self.lookup = {} # Lookup table to track which LMDB file stores which sample + self.current_db = self._open_new_lmdb(self.current_db_index) + self.per_shard_size = 0 + + self.sample_buffer = [] + + def _open_new_lmdb(self, db_index): + """Opens a new LMDB file for saving samples.""" + db_path = f"{self.base_path}/data_{db_index}.lmdb" + return lmdb.open(db_path, map_size=self.max_size) + + def _save_a_batch(self): + + try: + # Commit the transaction if we have reached the commit interval + # if (not hasattr(self, 'txn')) or (self.txn is None): + # self.txn = self.current_db.begin(write=True) # Start a new transaction + # + # + # for key, data in self.sample_buffer: + # self.txn.put(key.encode('ascii'), pickle.dumps(data)) + # self.lookup[key] = f"data_{self.current_db_index}.lmdb" + # + # if hasattr(self, 'txn') and self.txn: + # self.txn.commit() # Commit the transaction + print(f"Saving {len(self.sample_buffer)} samples to data_{self.current_db_index}.lmdb") + with self.current_db.begin(write=True) as txn: + for key, data in self.sample_buffer: + txn.put(key.encode('ascii'), pickle.dumps(data)) + self.lookup[key] = f"data_{self.current_db_index}.lmdb" + + self.sample_buffer.clear() + + except lmdb.MapFullError: + + # If current LMDB file is full, create a new one and retry saving + self.current_db.close() + self.current_db_index += 1 + print(f"Creating new LMDB file: data_{self.current_db_index}.lmdb (size: {self.per_shard_size})") + self.current_db = self._open_new_lmdb(self.current_db_index) + self._save_a_batch() + self.per_shard_size = 0 + + def save_sample(self, key, data): + """Saves a sample to the current LMDB file, switching to a new file if necessary.""" + # Batch writes into a single transaction + if self.per_shard_size % 100 == 0: + self._save_a_batch() + self.sample_buffer.append((key, data)) + self.per_shard_size += 1 + + def close(self): + self._save_a_batch() + """Closes the LMDB environment and saves the lookup table as a JSON file.""" + self.current_db.close() + # Save the lookup table to track the LMDB file where each sample is stored + with open(f"{self.base_path}/lookup.json", "w") as f: + json.dump(self.lookup, f) + + +def preprocess_and_queue_worker(worker_id, config, indices, queue): + """ + This function runs in each worker to preprocess samples and send them to the write queue. + The writer process will handle writing to LMDB. + """ + print(f"Worker {worker_id} started.") + dataset = SceneStreamerDataset(config, "training") + + print(f"Worker {worker_id} has {len(dataset)} samples.") + + # Process and queue each sample assigned to this worker + if worker_id == 0: + pbar = tqdm.tqdm(indices, desc="Worker %d" % worker_id) + else: + pbar = indices + print(f"Worker {worker_id} has {len(indices)} samples.") + + for i in pbar: + sample = dataset[i] # Access the sample using its index + + # Simulate some preprocessing (replace with actual preprocessing logic) + file_name, processed_sample = sample["file_name"], sample + + # Put the preprocessed sample into the queue to be written by the writer process + print(f"Worker {worker_id} processed {file_name}") + queue.put((file_name, processed_sample)) + + # Signal that this worker is done + # queue.put(None) # 'None' signals that the worker is done + + +def write_process(queue, base_path, max_size): + """ + The write process receives samples from the queue and writes them to the LMDB environment. + """ + writer = LMDBBulkWriter(base_path=base_path, max_size=max_size) + print("Writer process started.") + + while True: + + # Blocking if no data is available + data = queue.get() + + if data == 100: + print("Received 100, stopping writer process.") + # If 'None' is received, this indicates that a worker has finished + break + + if data is None: + print("Received None, stopping writer process.") + continue + + file_name, processed_sample = data + print(f"Saved {file_name} to LMDB") + writer.save_sample(file_name, processed_sample) + + # Close the writer once all workers are done + writer.close() + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1024_gpt.yaml") +def make_lmdb(config): + omegaconf.OmegaConf.set_struct(config, False) + omegaconf.OmegaConf.set_struct(config, True) + + dataset = SceneStreamerDataset(config, "training") + folder = pathlib.Path(dataset.data_dir) + folder = folder / "cache" + folder.mkdir(parents=True, exist_ok=False) + + # Initialize the LMDBBulkWriter + print("Saving data to LMDB folder:", folder.absolute()) + + # num_workers = mp.cpu_count() + num_workers = 2 + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + chunk_size = dataset_size // num_workers + + # Split the indices into chunks, one for each worker + chunked_indices = [indices[i * chunk_size:(i + 1) * chunk_size] for i in range(num_workers)] + # The final chunk may have more samples if the dataset size is not divisible by the number of workers. + chunked_indices[0].extend(indices[num_workers * chunk_size:]) + + # Create a multiprocessing queue + queue = mp.Queue() + + # Create and start the writer process + writer_process = mp.Process(target=write_process, args=(queue, folder, 1e10)) + + writer_process.start() + + # Create a multiprocessing pool for parallel processing (preprocessing) + pool = mp.Pool(num_workers) + + results = [] + # Start each worker process, passing its chunk of indices + for worker_id, worker_indices in enumerate(chunked_indices): + print(f"Starting worker {worker_id} with {len(worker_indices)} samples.") + result = pool.apply_async(preprocess_and_queue_worker, args=(worker_id, config, worker_indices, queue)) + results.append(result) + + # Wait for all worker processes to complete + # for result in results: + # result.get() # This will block until the worker completes its task + # preprocess_and_queue_worker(0, config, chunked_indices[0], queue) + pool.close() + print("Waiting for workers to finish...") + pool.join() + print("All workers finished.") + + # Signal the writer process to stop (send 'None' once all workers are done) + queue.put(100) + # Wait for the writer process to finish + writer_process.join() + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +def debug(config): + omegaconf.OmegaConf.set_struct(config, False) + omegaconf.OmegaConf.set_struct(config, True) + dataset = SceneStreamerDataset(config, "training") + folder = pathlib.Path(dataset.data_dir) + folder = folder / "cache" + folder.mkdir(parents=True, exist_ok=True) + for i, sample in enumerate(tqdm.tqdm(dataset, total=len(dataset), desc="Scenarios")): + file_name = sample["file_name"] + + +if __name__ == '__main__': + make_lmdb() + # debug() diff --git a/scenestreamer/dataset/preprocess_action_label.py b/scenestreamer/dataset/preprocess_action_label.py new file mode 100644 index 0000000000000000000000000000000000000000..cd0e988123811052c08cf47f9a636b2895e779c3 --- /dev/null +++ b/scenestreamer/dataset/preprocess_action_label.py @@ -0,0 +1,293 @@ +import numpy as np +from shapely.geometry import Polygon + +from scenestreamer.utils import utils + +INVALID_VALUE = -10000 + + +class TurnAction: + STOP = 0 + KEEP_STRAIGHT = 1 + TURN_LEFT = 2 + TURN_RIGHT = 3 + U_TURN = 4 + + num_actions = 5 + + +class AccelerationAction: + STOP = 0 + KEEP_SPEED = 1 + SPEED_UP = 2 + SLOW_DOWN = 3 + + num_actions = 4 + + +class SafetyAction: + SAFE = 0 + COLLISION = 1 + num_actions = 2 + + +def cal_polygon_contour(x, y, theta, width, length): + + left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_front = np.column_stack((left_front_x, left_front_y)) + + right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_front = np.column_stack((right_front_x, right_front_y)) + + right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_back = np.column_stack((right_back_x, right_back_y)) + + left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_back = np.column_stack((left_back_x, left_back_y)) + + polygon_contour = np.concatenate( + (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1 + ) + + return polygon_contour + + +def detect_collision(contour_list1, mask1, contour_list2, mask2): + collision_detected = [] + assert len(contour_list1) == len(contour_list2) + + for i in range(len(contour_list1)): + if mask1[i] and mask2[i]: + poly1 = Polygon(contour_list1[i]) + poly2 = Polygon(contour_list2[i]) + + if poly1.intersects(poly2): + collision_detected.append(True) + else: + collision_detected.append(False) + else: + collision_detected.append(False) + + return collision_detected + + +def get_direction_action_from_trajectory_batch(traj, mask, dt=0.1, ooi=None): + U_TURN_DEG = 115 + LEFT_TURN_DEG = 25 + RIGHT_TURN_DEG = -25 + STOP_SPEED = 0.06 + + assert traj.ndim == 3 + traj_diff = traj[1:] - traj[:-1] + mask_diff = mask[1:] & mask[:-1] + + displacement = np.linalg.norm(traj_diff, axis=-1) + + mask_diff_stop = mask_diff & (displacement > 0.1) + + pred_angles = np.arctan2(traj_diff[..., 1], traj_diff[..., 0]) + pred_angles_diff = utils.wrap_to_pi(pred_angles[1:] - pred_angles[:-1]) + + # It's meaning less to compute heading for a stopped vehicle. So mask them out! + mask_diff_diff = mask_diff_stop[1:] & mask_diff_stop[:-1] + # Note that we should not wrap to pi here because the sign is important. + accumulated_heading_change_rad = (pred_angles_diff * mask_diff_diff).sum(axis=0) + accumulated_heading_change_deg = np.degrees(accumulated_heading_change_rad) + + # print("accumulated_heading_change_deg: ", list(zip(ooi, accumulated_heading_change_deg))) + + speed = displacement / dt + avg_speed = utils.masked_average_numpy(speed, mask_diff, dim=0) + + actions = np.zeros(accumulated_heading_change_deg.shape, dtype=int) + actions.fill(TurnAction.KEEP_STRAIGHT) + actions[accumulated_heading_change_deg > LEFT_TURN_DEG] = TurnAction.TURN_LEFT + actions[accumulated_heading_change_deg < RIGHT_TURN_DEG] = TurnAction.TURN_RIGHT + actions[accumulated_heading_change_deg > U_TURN_DEG] = TurnAction.U_TURN + actions[accumulated_heading_change_deg < -U_TURN_DEG] = TurnAction.U_TURN + actions[avg_speed < STOP_SPEED] = TurnAction.STOP + return actions + + +def get_acce_action_from_trajectory_batch(batch_trajs, mask, ooi=None, dt=0.1): + + SPEEDUP_ACCEL = 0.3 + SPEEDDOWN_ACCEL = -0.3 + STOP_SPEED = 0.06 + + traj_diff = batch_trajs[1:] - batch_trajs[:-1] # (T, A, 2) + mask_diff = mask[1:] & mask[:-1] # (T, A) + + speed = np.linalg.norm(traj_diff, axis=-1) / dt # (T, A) + + speed_change = speed[1:] - speed[:-1] + mask_diff_diff = mask_diff[1:] & mask_diff[:-1] + + absolute_avg_speed = utils.masked_average_numpy(speed, mask_diff, dim=0) + + accumulated_speed_change = (speed_change * mask_diff_diff).sum(0) + + init_speed_ind = mask_diff.argmax(axis=0) + init_speed = np.take_along_axis(speed, init_speed_ind[None, :], axis=0)[0] + + speed_change_ratio = accumulated_speed_change / np.maximum(init_speed, STOP_SPEED) + + # print("speed_change_ratio: ", list(zip(ooi, speed_change_ratio))) + + actions = np.zeros(speed_change_ratio.shape, dtype=int) + + actions.fill(AccelerationAction.KEEP_SPEED) + actions[speed_change_ratio > SPEEDUP_ACCEL] = AccelerationAction.SPEED_UP + actions[speed_change_ratio < SPEEDDOWN_ACCEL] = AccelerationAction.SLOW_DOWN + actions[absolute_avg_speed <= STOP_SPEED] = AccelerationAction.STOP # if stop + + return actions + + +def get_safety_action_from_sdc_adv(data_dict, adv_id, sdc_id): + + contours = [] + for agent_id in [adv_id, sdc_id]: + traj = data_dict["decoder/agent_position"][:91, agent_id, :] # (91, 3) + length = data_dict["decoder/agent_shape"][:91, agent_id, 0] + width = data_dict["decoder/agent_shape"][:91, agent_id, 1] + theta = data_dict['decoder/agent_heading'][:91, agent_id] # (91, ) # in pi + mask = data_dict['decoder/agent_valid_mask'][:91, agent_id] # (91,) + + poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length) + contours.append(poly) + + sdc_mask = data_dict['decoder/agent_valid_mask'][:, sdc_id] # (91,) + adv_mask = data_dict['decoder/agent_valid_mask'][:, adv_id] + adv_contour = contours[0] + sdc_contour = contours[1] + + collision_detected = detect_collision(adv_contour, adv_mask, sdc_contour, sdc_mask) + + # instead of loading a dict which saves all collision scenario, we could simply detect all agents' potential collision + return collision_detected + + +def get_safety_action_from_trajectory_batch(data_dict, track_agent_indicies): + + safety_actions = np.zeros((track_agent_indicies.shape[0], ), dtype=int) # plus sdc + + contours = [] + for agent1_id in track_agent_indicies: + traj = data_dict["decoder/agent_position"][:, agent1_id, :] # (91, 3) + length = data_dict["decoder/agent_shape"][:, agent1_id, 0] + width = data_dict["decoder/agent_shape"][:, agent1_id, 1] + theta = data_dict['decoder/agent_heading'][:, agent1_id] # (91, ) # in pi + mask = data_dict['decoder/agent_valid_mask'][:, agent1_id] # (91,) + poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length) + contours.append(poly) + + for i in range(track_agent_indicies.shape[0] - 1): + for j in range(i + 1, track_agent_indicies.shape[0]): + mask_1 = data_dict['decoder/agent_valid_mask'][:, track_agent_indicies[i]] # (91,) + mask_2 = data_dict['decoder/agent_valid_mask'][:, track_agent_indicies[j]] + collision_detected = detect_collision(contours[i], mask_1, contours[j], mask_2) + + if any(collision_detected): + # print(f"Collision between {i} and {j} happen at step: {np.array(collision_detected).nonzero()}") + safety_actions[i] = 1 # Label collisions for OOIs now. Later we will build a larger dict. + safety_actions[j] = 1 + + # instead of loading a dict which saves all collision scenario, we could simply detect all agents' potential collision + return safety_actions + + +def prepare_action_label(*, data_dict, dt, mask_probability, config): + """ + mask_probability: the probability of masking the label. Should be around 0.05 or 0.1. Can't be too high. + """ + ooi_ind = data_dict["decoder/labeled_agent_id"] + ooi_pos = utils.extract_data_by_agent_indices(data_dict["decoder/agent_position"], ooi_ind, agent_dim=1)[..., :2] + ooi_valid = utils.extract_data_by_agent_indices( + data_dict["decoder/agent_valid_mask"], ooi_ind, agent_dim=1 + ) # (T, A) + + # TODO: hardcoded here for now and we assume you can access GT trajectory. This won't work with test dataset. + assert ooi_pos.shape[0] == 91 + assert ooi_valid.shape[0] == 91 + + # get the degree, acceleration, speed + turn_actions = get_direction_action_from_trajectory_batch(traj=ooi_pos, mask=ooi_valid, dt=dt, ooi=ooi_ind) + acce_actions = get_acce_action_from_trajectory_batch(ooi_pos, ooi_valid, dt=dt, ooi=ooi_ind) + + # Rescatter labels to decoder-agent indices + assert config.TRAINING.PREDICT_ALL_AGENTS + B = data_dict["decoder/agent_valid_mask"].shape[1] + + full_turn_actions = np.full((B, ), -1, dtype=int) + full_acce_actions = np.full((B, ), -1, dtype=int) + + label_mask = np.random.binomial(1, mask_probability, size=len(ooi_ind)) + label_invalid_mask = label_mask == 1 + + turn_actions[label_invalid_mask] = -1 + acce_actions[label_invalid_mask] = -1 + + full_turn_actions[ooi_ind] = turn_actions + full_acce_actions[ooi_ind] = acce_actions + + data_dict["decoder/label_turning"] = full_turn_actions + data_dict["decoder/label_acceleration"] = full_acce_actions + + return data_dict + + +def prepare_safety_label(*, data_dict, dt, mask_probability, config): + ooi_ind = data_dict["decoder/labeled_agent_id"] + + ooi_pos = utils.extract_data_by_agent_indices(data_dict["decoder/agent_position"], ooi_ind, agent_dim=1)[..., :2] + ooi_valid = utils.extract_data_by_agent_indices( + data_dict["decoder/agent_valid_mask"], ooi_ind, agent_dim=1 + ) # (T, A) + + # TODO: hardcoded here for now and we assume you can access GT trajectory. This won't work with test dataset. + assert ooi_pos.shape[0] == 91 + assert ooi_valid.shape[0] == 91 + + safety_actions = get_safety_action_from_trajectory_batch(data_dict, ooi_ind) + + # Rescatter labels to decoder-agent indices + assert config.TRAINING.PREDICT_ALL_AGENTS + num_modeled_agents = data_dict["decoder/agent_valid_mask"].shape[1] + + full_safety_actions = np.full((num_modeled_agents, ), -1, dtype=int) + + label_mask = np.random.binomial(1, mask_probability, size=len(ooi_ind)) + label_invalid_mask = label_mask == 1 + + label_invalid_mask[safety_actions == 1] = False # We don't mask collision labels + + safety_actions[label_invalid_mask] = -1 + + full_safety_actions[ooi_ind] = safety_actions + + data_dict["decoder/label_safety"] = full_safety_actions + + return data_dict + + +if __name__ == '__main__': + scenario_dir = "/Users/claire_liu/validation_interactive_0/cat_reconstructed/sd_reconstructed_v0_ScenarioMap-21.pkl" + cat_dir = "/Users/claire_liu/validation_interactive_0/save.pkl" + + import pickle + + with open(scenario_dir, 'rb') as f: + scenario_data = pickle.load(f) + f.close() + + with open(cat_dir, 'rb') as ff: + cat_dict = pickle.load(ff) + ff.close() + + batch_labels = get_3d_action_label(scenario_data, cat_dict) + print(batch_labels) diff --git a/scenestreamer/dataset/preprocessor.py b/scenestreamer/dataset/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..fd301cde65de03bae6d5ee2c2aa97a4233e3a3dc --- /dev/null +++ b/scenestreamer/dataset/preprocessor.py @@ -0,0 +1,2507 @@ +""" +Translate a MetaDrive Scenario Description instance to a dict of tensors. +""" +import copy +import logging +import pickle + +import numpy as np +from metadrive.scenario.scenario_description import ScenarioDescription as SD, MetaDriveType + +from scenestreamer import utils +from scenestreamer.dataset import constants +from scenestreamer.dataset.preprocess_action_label import prepare_action_label, prepare_safety_label + +from scenestreamer.tokenization.trafficgen_tokenizers import TrafficGenTokenizerAutoregressive, TrafficGenTokenizer + +logger = logging.getLogger(__file__) + +extract_data_by_agent_indices = utils.extract_data_by_agent_indices + +CANT_FIND_DESTINATION = 76543231 + + +def centralize_to_map_center(position_array, map_center, map_heading): + """ + Centralize the position array to the map center and rotate the position array to the map heading. + Note that the map center and map heading do not change based on agent or timestep. + """ + ndim = position_array.ndim + # position_array = position_array.copy() + if map_center is not None: + assert map_center.shape == (3, ) + assert position_array.shape[-1] <= 3, position_array.shape + map_center = map_center.reshape(*(1, ) * (ndim - 1), 3) + position_array -= map_center[..., :position_array.shape[-1]] + if map_heading == 0.0: + return position_array + + if position_array.shape[-1] == 3: + position_array = utils.rotate( + position_array[..., 0], position_array[..., 1], -map_heading, z=position_array[..., 2] + ) + elif position_array.shape[-1] == 2: + position_array = utils.rotate( + position_array[..., 0], position_array[..., 1], -map_heading, z=np.zeros_like(position_array[..., 0]) + ) + else: + raise ValueError() + return position_array + + +def extract_map_center_heading_locations(map_feature): + assert isinstance(map_feature, dict) + max_x, max_y, max_z = float("-inf"), float("-inf"), float("-inf") + min_x, min_y, min_z = float("+inf"), float("+inf"), float("+inf") + for map_feat_id, map_feat in map_feature.items(): + if "polyline" in map_feat: + locations = map_feat['polyline'] + elif "position" in map_feat: + locations = map_feat['position'] + elif "polygon" in map_feat: + locations = map_feat["polygon"] + else: + raise ValueError("Unknown map feature: {}, {}".format(map_feat_id, map_feat.keys())) + locations = locations.reshape(-1, locations.shape[-1]) + map_feat["location"] = locations + max_boundary = locations.max(axis=0) + min_boundary = locations.min(axis=0) + max_x = max_boundary[0] + max_y = max_boundary[1] + min_x = min_boundary[0] + min_y = min_boundary[1] + if locations.shape[-1] == 3: + max_z = max_boundary[2] + min_z = min_boundary[2] + if max_z == float("-inf"): + max_z = 0.0 + if min_z == float("+inf"): + min_z = 0.0 + map_boundary_max = np.array([max_x, max_y, max_z]) + map_boundary_min = np.array([min_x, min_y, min_z]) + + map_center = np.stack([map_boundary_max, map_boundary_min], axis=0).mean(axis=0) + map_heading = 0.0 + + return { + "map_center": map_center, + "map_heading": map_heading, + "map_boundary_max": map_boundary_max, + "map_boundary_min": map_boundary_min, + "map_feature": map_feature + } + + +def prepare_destination(data_dict, config, FUTURE_STEPS=None, skip_step=30, dropout=0.0): + import torch + assert FUTURE_STEPS is not None + + # def find_last_valid(array, mask): + # + # array = torch.from_numpy(array[None]) + # mask = torch.from_numpy(mask[None]) + # + # assert mask.ndim + 1 == array.ndim + # assert mask.shape == array.shape[:-1] + # assert array.ndim == 4 + # B, T, N, D = array.shape + # indices = mask * torch.arange(T).reshape(1, T, 1).expand(*mask.shape) + # indices = indices.argmax(1, keepdims=True).unsqueeze(-1).expand(B, 1, N, D) + # ret = torch.gather(array, index=indices, dim=1) # [B, 1, N, D] + # ret[~mask.any(1, keepdims=True)] = 0 + # + # ret = ret[0].numpy() + # return ret + + def find_closest_map_feature(*, agent_pos, agent_heading, map_pos, map_heading, valid_map_feat): + heading_diff = utils.wrap_to_pi(agent_heading[:, None] - map_heading[None]) + valid_heading = np.abs(heading_diff) < np.deg2rad(90) + valid_map_feat = valid_map_feat & valid_heading + + dist = np.linalg.norm(agent_pos[:, None] - map_pos[None], axis=-1) + + dist[~valid_map_feat] = np.inf + closest_map_feat = np.argmin(dist, axis=1) + closest_map_dist = np.min(dist, axis=1) + closest_map_feat[closest_map_dist > 10] = CANT_FIND_DESTINATION + closest_map_feat[np.isinf(dist).all(axis=1)] = CANT_FIND_DESTINATION + return closest_map_feat + + agent_positions = data_dict["decoder/modeled_agent_position"] + agent_valid_mask = data_dict["decoder/input_action_valid_mask"] + agent_headings = data_dict["decoder/modeled_agent_heading"] + + # Use the center of map feature to be the destination. + map_positions = data_dict["encoder/map_position"][..., :2] + + map_heading = data_dict["encoder/map_heading"] + valid_map_feat = data_dict["encoder/map_valid_mask"] + + only_lane = True + if only_lane: + map_feature = data_dict["encoder/map_feature"] + is_lane = map_feature[:, 0, 13] == 1 + # is_lane = is_lane[None].repeat(N, 0) + valid_map_feat = is_lane & valid_map_feat + + num_agents = agent_positions.shape[1] + num_steps = agent_positions.shape[0] + + closest_map_features = np.full((num_steps, num_agents), CANT_FIND_DESTINATION, dtype=int) + dest_valid_mask = np.full((num_steps, num_agents), False, dtype=bool) + # TODO: Here we don't start at t=10. This might be an issue. + # However, as we set all to CANT_FIND_DESTINATION, so it will be free generation for those agent that are + # invalid at t=30 but valid at t=10. + for t in range(0, num_steps): + future_t = t + FUTURE_STEPS // skip_step + if future_t >= num_steps: + break + + future_positions = agent_positions[future_t, :, :2] + future_valid_mask = agent_valid_mask[future_t] + future_headings = agent_headings[future_t] + + closest_map_features_now = find_closest_map_feature( + agent_pos=future_positions, + agent_heading=future_headings, + map_pos=map_positions, + map_heading=map_heading, + valid_map_feat=valid_map_feat + ) + + # If the agent is static, don't set the destination. + current_positions = agent_positions[t, :, :2] + displacement = np.linalg.norm(future_positions - current_positions, axis=-1) + closest_map_features_now[displacement < 5] = CANT_FIND_DESTINATION + + is_this_dest_valid = (agent_valid_mask[t] & future_valid_mask) + + dest_valid_mask[t:t + 1] = is_this_dest_valid[None] + + # If future pos is invalid, don't set the destination. + closest_map_features_now[~is_this_dest_valid] = -1 + + closest_map_features[t:t+1] = closest_map_features_now[None] + + closest_map_features[~agent_valid_mask] = -1 + + closest_map_features[closest_map_features == CANT_FIND_DESTINATION] = -1 + + dest = closest_map_features + dest_valid_mask = dest_valid_mask # This control what dest to be learned. + + data_dict["decoder/dest_map_index_gt"] = np.copy(dest) + data_dict["decoder/dest_map_index_valid_mask"] = dest_valid_mask + if dropout > 0: + # Randomly drop some destination + dropout_mask = np.random.rand(*dest.shape) < dropout + dest[dropout_mask] = -1 + data_dict["decoder/dest_map_index"] = dest + return data_dict + + +def process_map_and_traffic_light( + *, data_dict, scenario, map_feature, dynamic_map_states, track_length, max_vectors, max_map_features, + max_length_per_map_feature, max_traffic_lights, remove_traffic_light_state, limit_map_range, is_scenestreamer=False +): + # ========== Find the boundary of the map first ========== + map_center_info = extract_map_center_heading_locations(map_feature) + map_center = map_center_info["map_center"] + map_heading = map_center_info["map_heading"] + map_feature_augmented = map_center_info["map_feature"] + + # ========== Process Map Features ========== + # The output is a dict whose keys are the lane ID and key is a state array in shape [T, ???] + + # Get a compact representation of all points in the maps + map_feature_list = [] # Key: map_feat_id, Value: A dict of processed values + map_heading_list = [] + map_valid_mask_list = [] + map_position_list = [] + + for map_index, (map_feat_id, map_feat) in enumerate(map_feature_augmented.items()): + rotated_polyline = centralize_to_map_center( + position_array=map_feat["location"], # [num points, 2 or 3] + map_center=map_center, # [1, 1, 3] + map_heading=map_heading + ) + + if "polygon" in map_feat: + # For crosswalk, and other "polygon" based map features, we need to pad the last point to the first point. + rotated_polyline = np.concatenate([rotated_polyline, rotated_polyline[:1]], axis=0) + + if rotated_polyline.shape[-1] == 2: + rotated_polyline = np.concatenate([rotated_polyline, np.zeros((rotated_polyline.shape[0], 1))], axis=-1) + + start_points = rotated_polyline[:-1].copy() # in shape [# map feats - 1, 2] + end_points = rotated_polyline[1:].copy() # in shape [# map feats - 1, 2] + if start_points.shape[0] == 0: + # A special case here is that the map feature contains only one points. + # In this case, we suppose the vector has the same point as start point and end point (its len=0) + start_points = rotated_polyline + end_points = rotated_polyline + num_vectors = 1 + + else: + num_vectors = start_points.shape[0] + + assert start_points.ndim == 2 # [num vectors, 3] + assert start_points.shape[-1] == 3 # [num vectors, 3] # for CAT, start_points.shape[-1] = 2 + + direction = end_points - start_points + heading = np.arctan2(direction[..., 1], direction[..., 0]) + + point_diff = np.linalg.norm(direction[..., :2], axis=-1) + + road_length = 0.0 + start_index = 0 + # Iterate over all "vectors" in a map feature. + # We will produce a map features, containing a set of vectors, in these conditions: + # (1) If the segment is a lane and has length > MAX_LENGTH_PER_MAP_FEATURE, or + # (2) The segment has max_vectors vectors, or + # (3) The segment contains the leftover vectors with less than max_vectors vectors. + for i in range(num_vectors): + road_length += point_diff[i] + + # 2025-04-21 Update: Only break the line if it's a lane. + # map_feat_too_long = ( + # (road_length >= max_length_per_map_feature) & MetaDriveType.is_lane(map_feat['type']) + # ) + # # Exempt the crosswalk from the length limit + map_feat_too_long = ( + (road_length >= max_length_per_map_feature) & ~MetaDriveType.is_crosswalk(map_feat['type']) + ) + + num_valid_vectors = i - start_index + 1 + + too_many_vectors = num_valid_vectors >= max_vectors + last_set_of_vectors = (i == num_vectors - 1) and ((i - start_index) > 0) + if i - start_index == 0: + continue + if map_feat_too_long or too_many_vectors or last_set_of_vectors: + # The map feature is a 2D array with shape [#vectors, 27]. + # map_feature = np.zeros([i - start_index, constants.MAP_FEATURE_STATE_DIM], dtype=np.float32) + map_feature = np.zeros([max_vectors, constants.MAP_FEATURE_STATE_DIM], dtype=np.float32) + + end_index = i + 1 + map_feature[:num_valid_vectors, :3] = start_points[start_index:end_index] + map_feature[:num_valid_vectors, 3:6] = end_points[start_index:end_index] + map_feature[:num_valid_vectors, 6:9] = direction[start_index:end_index] + map_feature[:num_valid_vectors, 9] = utils.wrap_to_pi(heading[start_index:end_index]) + map_feature[:num_valid_vectors, 10] = np.sin(heading[start_index:end_index]) + map_feature[:num_valid_vectors, 11] = np.cos(heading[start_index:end_index]) + map_feature[:num_valid_vectors, 12] = point_diff[start_index:end_index] + + map_feature[:num_valid_vectors, 13] = MetaDriveType.is_lane(map_feat['type']) + map_feature[:num_valid_vectors, 14] = MetaDriveType.is_sidewalk(map_feat['type']) + map_feature[:num_valid_vectors, 15] = MetaDriveType.is_road_boundary_line(map_feat['type']) + map_feature[:num_valid_vectors, 16] = MetaDriveType.is_road_line(map_feat['type']) + map_feature[:num_valid_vectors, 17] = MetaDriveType.is_broken_line(map_feat['type']) + map_feature[:num_valid_vectors, 18] = MetaDriveType.is_solid_line(map_feat['type']) + map_feature[:num_valid_vectors, 19] = MetaDriveType.is_yellow_line(map_feat['type']) + map_feature[:num_valid_vectors, 20] = MetaDriveType.is_white_line(map_feat['type']) + map_feature[:num_valid_vectors, 21] = MetaDriveType.is_driveway(map_feat['type']) + map_feature[:num_valid_vectors, 22] = MetaDriveType.is_crosswalk(map_feat['type']) + map_feature[:num_valid_vectors, 23] = MetaDriveType.is_speed_bump(map_feat['type']) + map_feature[:num_valid_vectors, 24] = MetaDriveType.is_stop_sign(map_feat['type']) + map_feature[:num_valid_vectors, 25] = road_length + # valid_mask = np.ones_like(start_points[start_index:i, 0]) + map_feature[:num_valid_vectors, 26] = 1 + + assert map_feature.shape[0] > 0 + avg_position = ((map_feature[:num_valid_vectors, 0:3] + map_feature[:num_valid_vectors, 3:6]) / + 2).mean(axis=0) + avg_heading = utils.wrap_to_pi(utils.average_angles(map_feature[:num_valid_vectors, 9])) + + # if i - start_index < max_vectors: + # map_feature = np.pad(map_feature, pad_width=((0, max_vectors - (i - start_index)), (0, 0))) + # valid_mask = np.pad(valid_mask, pad_width=(0, max_vectors - (i - start_index))) + + valid_mask = map_feature[:, 26].copy() + + map_feature_list.append(map_feature) + map_valid_mask_list.append(valid_mask) + + map_heading_list.append(avg_heading) + map_position_list.append(avg_position) + + start_index = i + road_length = 0.0 + + # if MetaDriveType.is_lane(map_feat['type']): + # map_id_of_lanes.append(map_feat_id) + + if len(map_feature_list) == 0: + map_feature_position = np.zeros([0, 0, 3], dtype=np.float32) + map_feature_heading = np.zeros([0, 0], dtype=np.float32) + else: + map_feature_position = np.stack(map_position_list, axis=0).astype(np.float32) # [num map feat, 2] + map_feature_heading = np.stack(map_heading_list, axis=0).astype(np.float32) # [num map feat, 2] + # print(f"# MAP FEATURES: {len(map_position_list)}, # Avg Vectors: {np.mean(np.sum(map_valid_mask_list, axis=1), axis=0)}, # Max Vectors: {np.max(np.sum(map_valid_mask_list, axis=1), axis=0)}" ) + + # Filter out too many map features + if limit_map_range: + # Should follow TrafficGen / LCTGen's preprocessing and crop map to + # 50m range within SDC's position + sdc_id = scenario['metadata']['sdc_id'] + sdc_tracks = scenario['tracks'][sdc_id]['state']['position'] + current_step = scenario['metadata']['current_time_index'] + sdc_position = sdc_tracks[current_step][..., :2] - map_center[..., :2] + + map_feature_position = np.stack(map_position_list) + + valid_map_feat = ( + (abs(map_feature_position[..., 0] - sdc_position[0]) < 50) & + (abs(map_feature_position[..., 1] - sdc_position[1]) < 50) + ) + indices = valid_map_feat.nonzero()[0] + map_feature_position = map_feature_position[indices] + map_feature_heading = np.stack([map_feature_heading[i] for i in indices], + axis=0).astype(np.float32) # [num map feat, 2] + map_feature_list = [map_feature_list[i] for i in indices] + map_valid_mask_list = [map_valid_mask_list[i] for i in indices] + + # print("Num map features: ", len(map_feature_list), "Max vectors: ", np.max(np.sum(map_valid_mask_list, axis=1))) + + if len(map_feature_position) > max_map_features: + # Sorted based on the distance to the SDC + sdc_id = scenario['metadata']['sdc_id'] + sdc_tracks = scenario['tracks'][sdc_id]['state']['position'] + current_step = scenario['metadata']['current_time_index'] + sdc_position = sdc_tracks[current_step][..., :2] - map_center[..., :2] + + dist = np.linalg.norm(map_feature_position[:, :2] - sdc_position[:2], axis=1) + + indices = np.argsort(dist)[:max_map_features] + map_feature_position = map_feature_position[indices] + map_feature_heading = np.stack([map_feature_heading[i] for i in indices], + axis=0).astype(np.float32) # [num map feat, 2] + map_feature_list = [map_feature_list[i] for i in indices] + map_valid_mask_list = [map_valid_mask_list[i] for i in indices] + + if len(map_valid_mask_list) > 0: + map_feature = np.stack(map_feature_list, axis=0).astype(np.float32) # [num map feat, max vectors, 27] + assert map_feature.shape[-1] == constants.MAP_FEATURE_STATE_DIM + map_feature_mask = np.stack(map_valid_mask_list, axis=0).astype(bool) # [num map feat, max vectors] + map_feature_heading = np.stack(map_feature_heading, axis=0).astype(np.float32) # [num map feat, max vectors] + + else: + map_feature = np.zeros([0, max_vectors, constants.MAP_FEATURE_STATE_DIM], dtype=np.float32) + map_feature_mask = np.zeros([0, max_vectors], dtype=bool) + map_feature_position = np.zeros([0, 3], dtype=np.float32) + map_feature_heading = np.zeros([0], dtype=np.float32) + + num_map_feat = map_feature.shape[0] + utils.assert_shape(map_feature, (num_map_feat, max_vectors, constants.MAP_FEATURE_STATE_DIM)) + utils.assert_shape(map_feature_mask, ( + num_map_feat, + max_vectors, + )) + utils.assert_shape(map_feature_position, (num_map_feat, 3)) + utils.assert_shape(map_feature_heading, (num_map_feat, )) + + # num_lights = traffic_light_valid_mask.any(axis=0).sum() + # print("num_lights: ", num_lights) + + data_dict.update( + { + "encoder/map_feature": map_feature, + "encoder/map_position": map_feature_position, + "encoder/map_heading": map_feature_heading, + "encoder/map_valid_mask": map_feature_mask.any(-1), # Token valid mask + "encoder/map_feature_valid_mask": map_feature_mask, + # "encoder/traffic_light_feature": traffic_light_feature, + # "encoder/traffic_light_position": traffic_light_position, + # "encoder/traffic_light_heading": traffic_light_heading, + # "encoder/traffic_light_valid_mask": traffic_light_valid_mask, + "metadata/map_center": map_center, + "metadata/map_heading": map_heading, + } + ) + + if not is_scenestreamer: + data_dict = process_traffic_light( + data_dict, + map_feature, + dynamic_map_states, + track_length, + max_vectors, + max_map_features, + max_length_per_map_feature, + max_traffic_lights, + map_center, + map_heading, + remove_traffic_light_state=remove_traffic_light_state + ) + else: + data_dict = process_traffic_light_scenestreamer( + data_dict, + map_feature, + dynamic_map_states, + track_length, + max_vectors, + max_map_features, + max_length_per_map_feature, + max_traffic_lights, + map_center, + map_heading, + remove_traffic_light_state=remove_traffic_light_state + ) + return data_dict + + +def process_traffic_light( + data_dict, map_feature, dynamic_map_states, track_length, max_vectors, max_map_features, max_length_per_map_feature, + max_traffic_lights, map_center, map_heading, remove_traffic_light_state +): + + # ===== Extract traffic light features ===== + traffic_light_position = np.zeros([max_traffic_lights, 3], dtype=np.float32) + + if remove_traffic_light_state: + traffic_light_heading = np.zeros([max_traffic_lights], dtype=np.float32) + traffic_light_feature = np.zeros([max_traffic_lights, constants.TRAFFIC_LIGHT_STATE_DIM], dtype=np.float32) + traffic_light_valid_mask = np.zeros([max_traffic_lights], dtype=bool) + for tl_count, (traffic_light_index, traffic_light) in enumerate(dynamic_map_states.items()): + traffic_light_state = [v for v in traffic_light["state"]["object_state"] if v is not None] + tl_states, tl_counts = np.unique(traffic_light_state, return_counts=True) + tl_state = str(tl_states[np.argmax(tl_counts)]) + stop_point = centralize_to_map_center( + position_array=traffic_light["stop_point"], map_center=map_center, map_heading=map_heading + ) + traffic_light_position[tl_count] = stop_point[..., :3] + traffic_light_feature[tl_count, :3] = stop_point + traffic_light_feature[tl_count, 3] = MetaDriveType.is_traffic_light_in_green(tl_state) + traffic_light_feature[tl_count, 4] = MetaDriveType.is_traffic_light_in_yellow(tl_state) + traffic_light_feature[tl_count, 5] = MetaDriveType.is_traffic_light_in_red(tl_state) + traffic_light_feature[tl_count, 6] = MetaDriveType.is_traffic_light_unknown(tl_state) + traffic_light_valid_mask[tl_count] = True + else: + traffic_light_heading = np.zeros([ + max_traffic_lights, + ], dtype=np.float32) + constants.HEADING_PLACEHOLDER + traffic_light_feature = np.zeros( + [track_length, max_traffic_lights, constants.TRAFFIC_LIGHT_STATE_DIM], dtype=np.float32 + ) + traffic_light_valid_mask = np.zeros([track_length, max_traffic_lights], dtype=bool) + + for tl_count, (traffic_light_index, traffic_light) in enumerate(dynamic_map_states.items()): + stop_point = centralize_to_map_center( + position_array=traffic_light["stop_point"], map_center=map_center, map_heading=map_heading + ) + + traffic_light_position[tl_count] = stop_point[..., :3] + for step in range(track_length): + assert traffic_light['type'] == MetaDriveType.TRAFFIC_LIGHT + traffic_light_state = {k: v[step] for k, v in traffic_light["state"].items()} + traffic_light_feature[step, tl_count, :3] = stop_point + traffic_light_feature[step, tl_count, + 3] = MetaDriveType.is_traffic_light_in_green(traffic_light_state["object_state"]) + traffic_light_feature[step, tl_count, 4] = MetaDriveType.is_traffic_light_in_yellow( + traffic_light_state["object_state"] + ) + traffic_light_feature[step, tl_count, + 5] = MetaDriveType.is_traffic_light_in_red(traffic_light_state["object_state"]) + traffic_light_feature[step, tl_count, + 6] = MetaDriveType.is_traffic_light_unknown(traffic_light_state["object_state"]) + traffic_light_valid_mask[step, tl_count] = True + if tl_count > max_traffic_lights: + logger.debug(f"WARNING: {len(dynamic_map_states)} exceeds {max_traffic_lights} traffic lights!") + print(f"WARNING: {len(dynamic_map_states)} exceeds {max_traffic_lights} traffic lights!") + break + + data_dict.update( + { + "encoder/traffic_light_feature": traffic_light_feature, + "encoder/traffic_light_position": traffic_light_position, + "encoder/traffic_light_heading": traffic_light_heading, + "encoder/traffic_light_valid_mask": traffic_light_valid_mask, + } + ) + return data_dict + + + +def process_traffic_light_scenestreamer( + data_dict, map_feature, dynamic_map_states, track_length, max_vectors, max_map_features, max_length_per_map_feature, + max_traffic_lights, map_center, map_heading, remove_traffic_light_state +): + assert remove_traffic_light_state is False + + L = len(dynamic_map_states) + + if L == 0: + L = 1 + + traffic_light_position = np.zeros([L, 3], dtype=np.float32) + # traffic_light_heading = np.zeros([L], dtype=np.float32) + + # TODO: hardcoded + count = len(range(0, track_length, 5)) + traffic_light_state_np = np.zeros([count, L,], dtype=int) + traffic_light_valid_mask = np.zeros([count, L], dtype=bool) + + # Find closest map feature + for tl_count, (traffic_light_index, traffic_light) in enumerate(dynamic_map_states.items()): + stop_point = centralize_to_map_center(position_array=traffic_light["stop_point"], map_center=map_center, map_heading=map_heading) + traffic_light_position[tl_count] = stop_point[..., :3] + + # Find the closest map feature + map_pos = data_dict["encoder/map_position"] + map_headings = data_dict["encoder/map_heading"] + valid_map_feat = data_dict["encoder/map_valid_mask"] + dist = np.linalg.norm((traffic_light_position[:, None, :2] - map_pos[None, :, :2]), axis=-1) + assert valid_map_feat.all() + closest_map_id = np.argmin(dist, axis=1) + # closest_dist = np.min(dist, axis=1) + traffic_light_heading = map_headings[closest_map_id] + + + for tl_count, (traffic_light_index, traffic_light) in enumerate(dynamic_map_states.items()): + step_compressed = 0 + for step in range(0, track_length, 5): + # TODO: Here we hardcode the step to 5. This might be an issue. + assert traffic_light['type'] == MetaDriveType.TRAFFIC_LIGHT + traffic_light_state = {k: v[step] for k, v in traffic_light["state"].items()} + is_green = MetaDriveType.is_traffic_light_in_green(traffic_light_state["object_state"]) + is_yellow = MetaDriveType.is_traffic_light_in_yellow(traffic_light_state["object_state"]) + is_red = MetaDriveType.is_traffic_light_in_red(traffic_light_state["object_state"]) + is_unknown = MetaDriveType.is_traffic_light_unknown(traffic_light_state["object_state"]) + # Convert to int 0~3: + if is_unknown: + traffic_light_state_np[step_compressed, tl_count] = 0 + elif is_green: + traffic_light_state_np[step_compressed, tl_count] = 1 + elif is_yellow: + traffic_light_state_np[step_compressed, tl_count] = 2 + elif is_red: + traffic_light_state_np[step_compressed, tl_count] = 3 + else: + raise ValueError(f"Unknown traffic light state: {traffic_light_state}") + traffic_light_valid_mask[step_compressed, tl_count] = True + step_compressed += 1 + if tl_count > max_traffic_lights: + logger.debug(f"WARNING: {len(dynamic_map_states)} exceeds {max_traffic_lights} traffic lights!") + print(f"WARNING: {len(dynamic_map_states)} exceeds {max_traffic_lights} traffic lights!") + + valid_tl = traffic_light_valid_mask.any(axis=0) + data_dict.update( + { + # "encoder/traffic_light_feature": traffic_light_feature, + "encoder/traffic_light_position": traffic_light_position * valid_tl[:, None], + "encoder/traffic_light_heading": traffic_light_heading * valid_tl, + "encoder/traffic_light_valid_mask": traffic_light_valid_mask, + "encoder/traffic_light_state": traffic_light_state_np, + "encoder/traffic_light_map_id": closest_map_id * valid_tl, + } + ) + return data_dict + + +def filter_and_reorder_agent(data_dict, max_agents=None): + """ + Put modeled agents and SDC to the first place. + """ + num_agents = data_dict["encoder/agent_feature"].shape[1] + agent_valid_mask = data_dict["encoder/agent_valid_mask"] + modeled_agent_indices = data_dict["encoder/object_of_interest_id"] + + sdc_index = data_dict["encoder/sdc_index"] + new_sdc_index = sdc_index + + # Sort agent based on validity. Put useless agent to the back. + index_to_validity = [] + for agent_index in range(num_agents): + index_to_validity.append((agent_index, agent_valid_mask[:, agent_index].sum())) + sorted_indices = sorted(index_to_validity, key=lambda v: v[1], reverse=True) + selected_agents = [key for key, _ in sorted_indices] + + if modeled_agent_indices is not None: + for agent_index in modeled_agent_indices: + selected_agents.remove(agent_index) + selected_agents.insert(0, int(agent_index)) + + # Put SDC to first place. + assert sdc_index in selected_agents + selected_agents.remove(sdc_index) + selected_agents.insert(0, sdc_index) + new_sdc_index = 0 + # new_sdc_index = selected_agents.index(sdc_index) + + if max_agents is not None: + selected_agents = selected_agents[:max_agents] + + selected_agents = np.asarray(selected_agents, dtype=int) + + # ===== Reorder all data ===== + # Those data whose first dim is the agent dim: + for key in [ + "encoder/agent_type", + "encoder/current_agent_shape", + "encoder/current_agent_valid_mask", + "encoder/current_agent_position", + "encoder/current_agent_heading", + "encoder/current_agent_velocity", + "encoder/track_name", + ]: + data_dict[key] = extract_data_by_agent_indices(data_dict[key], agent_indices=selected_agents, agent_dim=0) + # Those data whose second dim is the agent dim: + for key in [ + "encoder/agent_feature", + "encoder/agent_valid_mask", + "encoder/agent_position", + "encoder/agent_velocity", + "encoder/agent_heading", + # "encoder/future_agent_position", + # "encoder/future_agent_heading", + # "encoder/future_agent_valid_mask", + "encoder/agent_shape", + ]: + data_dict[key] = extract_data_by_agent_indices(data_dict[key], agent_indices=selected_agents, agent_dim=1) + + # ===== Reorder modeled agents and SDC, change modeled_agent_indices if necessary ===== + if modeled_agent_indices is not None: + # Need to translate track_index_to_predict + new_modeled_agent_indices = [] + for old_agent_index in modeled_agent_indices: + for new_ind, old_ind in enumerate(selected_agents): + if old_agent_index == old_ind: + new_modeled_agent_indices.append(new_ind) + break + assert len(new_modeled_agent_indices) == len(modeled_agent_indices) + modeled_agent_indices = new_modeled_agent_indices + new_sdc_index = 0 + if modeled_agent_indices is not None: + # Also update SDC index + if new_sdc_index in modeled_agent_indices: + modeled_agent_indices.remove(new_sdc_index) + modeled_agent_indices.insert(0, new_sdc_index) + + data_dict["encoder/sdc_index"] = new_sdc_index + data_dict["encoder/object_of_interest_id"] = np.asarray(modeled_agent_indices) + # Note that new ooi id doesn't change the order. So no need to change ooi name. + + assert bool(data_dict["encoder/current_agent_valid_mask"][new_sdc_index]) is True + + return data_dict + + +def filter_and_reorder_agent_for_scenestreamer(data_dict, max_agents=None): + agent_valid_mask = data_dict["encoder/agent_valid_mask"] + modeled_agent_indices = data_dict["encoder/object_of_interest_id"] + + sdc_index = data_dict["encoder/sdc_index"] + + default_max_agents = 128 + + def _get_first_last_pos(pos, valid_mask): + T, N = valid_mask.shape + ind = np.arange(T).reshape(-1, 1).repeat(N, axis=1) # T, N + ind[~valid_mask] = 0 + ind = ind.max(axis=0) + last = np.take_along_axis(pos, indices=ind.reshape(1, N, 1), axis=0) + last = np.squeeze(last, axis=0) + + # Find the index of the first True (or 1) along axis 0 (time) for each agent + # First, create a mask of where any True exists per column + has_valid = valid_mask.any(axis=0) + + # Use argmax along time axis: this returns first occurrence of maximum (i.e. True) + first_idx = valid_mask.argmax(axis=0) + + # Set result to -1 where there was no valid entry + first_idx[~has_valid] = -1 + + first = np.take_along_axis(pos, indices=first_idx.reshape(1, N, 1), axis=0) + first = np.squeeze(first, axis=0) + return first, last + + agent_types = data_dict["encoder/agent_type"] + agent_position = data_dict["encoder/agent_position"] + current_valid_mask = data_dict["encoder/current_agent_valid_mask"] + current_valid_agent_id = current_valid_mask.nonzero()[0] + + first_pos, last_pos = _get_first_last_pos(agent_position, agent_valid_mask) + moving_dist = np.linalg.norm((last_pos-first_pos)[:, :2], axis=-1) + moving_dist[~current_valid_mask] = -1000 + + # force to add all non-vehicle agent + selected_agents = np.argsort(moving_dist)[::-1] + + # Remove agent id that are not in current_valid_agent_id + selected_agents = np.array([i for i in selected_agents if i in current_valid_agent_id]) + + if max_agents is not None: + all_128_selected_agents = selected_agents[:default_max_agents] + else: + all_128_selected_agents = selected_agents + + all_128_selected_agents = all_128_selected_agents.tolist() + if modeled_agent_indices is not None: + for agent_index in modeled_agent_indices: + if agent_index in all_128_selected_agents: + all_128_selected_agents.remove(agent_index) + all_128_selected_agents.insert(0, int(agent_index)) + # Put SDC to first place. + if sdc_index in all_128_selected_agents: + all_128_selected_agents.remove(sdc_index) + all_128_selected_agents.insert(0, sdc_index) + + if max_agents is not None: + all_128_selected_agents = all_128_selected_agents[:default_max_agents] + + # reorganize the order of the agents based on their types + tmpagent_types = agent_types[all_128_selected_agents] + all_128_selected_agents = np.asarray(all_128_selected_agents, dtype=int) + new_selected_agents = [] + for atype in [1, 2, 3]: + if atype in tmpagent_types: + atype_ids = np.where(tmpagent_types == atype)[0].astype(int) + atype_ids = all_128_selected_agents[atype_ids] + new_selected_agents += list(atype_ids) + all_128_selected_agents = np.asarray(new_selected_agents, dtype=int) + + # In those all 128 selected agents, we need to filter out + if max_agents is not None and len(all_128_selected_agents) > max_agents: + # Do second round of filtering, but this time we only set invalidity + new_moving_dist = moving_dist[all_128_selected_agents] + new_selected_agents = np.argsort(new_moving_dist)[::-1] + new_selected_agents = new_selected_agents[:max_agents] + new_selected_agents = all_128_selected_agents[new_selected_agents] + new_selected_agents = new_selected_agents.tolist() + if modeled_agent_indices is not None: + for agent_index in modeled_agent_indices: + assert agent_index in all_128_selected_agents + if agent_index in new_selected_agents: + new_selected_agents.remove(agent_index) + new_selected_agents.insert(0, int(agent_index)) + # Put SDC to first place. + if sdc_index in new_selected_agents: + new_selected_agents.remove(sdc_index) + new_selected_agents.insert(0, sdc_index) + new_selected_agents = new_selected_agents[:max_agents] + new_selected_agents = np.asarray(new_selected_agents, dtype=int) + + valid_mask = np.zeros((agent_position.shape[1],), dtype=bool) + valid_mask[new_selected_agents] = True + data_dict["encoder/current_agent_valid_mask"] = np.logical_and( + data_dict["encoder/current_agent_valid_mask"], valid_mask + ) + data_dict["encoder/agent_valid_mask"] = np.logical_and( + data_dict["encoder/agent_valid_mask"], valid_mask[None] + ) + + assert all_128_selected_agents[0] == sdc_index + + # ===== Reorder all data ===== + # Those data whose first dim is the agent dim: + for key in [ + "encoder/agent_type", + "encoder/current_agent_shape", + "encoder/current_agent_valid_mask", + "encoder/current_agent_position", + "encoder/current_agent_heading", + "encoder/current_agent_velocity", + "encoder/track_name", + ]: + data_dict[key] = extract_data_by_agent_indices(data_dict[key], agent_indices=all_128_selected_agents, agent_dim=0) + # Those data whose second dim is the agent dim: + for key in [ + "encoder/agent_feature", + "encoder/agent_valid_mask", + "encoder/agent_position", + "encoder/agent_velocity", + "encoder/agent_heading", + # "encoder/future_agent_position", + # "encoder/future_agent_heading", + # "encoder/future_agent_valid_mask", + "encoder/agent_shape", + ]: + data_dict[key] = extract_data_by_agent_indices(data_dict[key], agent_indices=all_128_selected_agents, agent_dim=1) + + # ===== Reorder modeled agents and SDC, change modeled_agent_indices if necessary ===== + if modeled_agent_indices is not None: + # Need to translate track_index_to_predict + new_modeled_agent_indices = [] + for old_agent_index in modeled_agent_indices: + for new_ind, old_ind in enumerate(all_128_selected_agents): + if old_agent_index == old_ind: + new_modeled_agent_indices.append(new_ind) + break + assert len(new_modeled_agent_indices) == len(modeled_agent_indices) + modeled_agent_indices = new_modeled_agent_indices + new_sdc_index = 0 + if modeled_agent_indices is not None: + # Also update SDC index + if new_sdc_index in modeled_agent_indices: + modeled_agent_indices.remove(new_sdc_index) + modeled_agent_indices.insert(0, new_sdc_index) + + data_dict["encoder/sdc_index"] = new_sdc_index + data_dict["encoder/object_of_interest_id"] = np.asarray(modeled_agent_indices) + # Note that new ooi id doesn't change the order. So no need to change ooi name. + + assert data_dict["encoder/sdc_index"] in list(data_dict["encoder/current_agent_valid_mask"].nonzero()[0]) + + return data_dict + + +def process_track( + *, + data_dict, + tracks, + track_length, + sdc_name, # We need to translate sdc_name to sdc_index + max_agents, + exempt_max_agent_filtering=False, + is_scenestreamer=False, +): + map_center = data_dict["metadata/map_center"] + map_heading = data_dict["metadata/map_heading"] + current_t = data_dict["metadata/current_time_index"] + + agent_feature_dict = {} + agent_valid_mask_dict = {} + agent_velocity_dict = {} + agent_position_dict = {} + agent_heading_dict = {} + agent_type_dict = {} + agent_shape_dict = {} + sdc_index = None + sdc_name = str(sdc_name) + + valid_track_names = [] + track_count = 0 + + for _, (track_name, cur_data) in enumerate(tracks.items()): # number of objects + + # if not cur_data['type'] == 'VEHICLE': # CAT contains pedestrains which does not contain length, width, and height + # continue + + if not MetaDriveType.is_participant(cur_data["type"]): + # TODO(pzh): TrafficCone is in tracks for some reason. Looks very weird. Might be some bug. + continue + track_name = str(track_name) + + if track_name == sdc_name: + sdc_index = track_count + + cur_state = cur_data[SD.STATE] + + rotated_positions = centralize_to_map_center( + position_array=cur_state["position"], # [T, 3] + map_center=map_center, + map_heading=map_heading + ) # [T, num agents, 3] + + rotated_heading = utils.wrap_to_pi(cur_state["heading"] - map_heading) # [T, num agents] + rotated_velocity = centralize_to_map_center( + position_array=cur_state["velocity"], map_center=None, map_heading=map_heading + )[..., :2] # [T, num agents, 2] + + agent_shape_dict[track_name] = np.stack( + [cur_state["length"].reshape(-1), cur_state["width"].reshape(-1), cur_state["height"].reshape(-1)], axis=1 + ) # (T, N, 3) + + speed = np.linalg.norm(cur_state["velocity"], axis=1) + + valid_mask = np.asarray(cur_state["valid"], dtype=bool) + + agent_state = np.zeros([track_length, constants.AGENT_STATE_DIM], dtype=np.float32) + + # print("shape of rotated_positions", rotated_positions.shape) + # for CAT: we need to pad position dimension + if rotated_positions.shape[1] != 3: + rotated_positions = np.concatenate([rotated_positions, np.zeros((rotated_positions.shape[0], 1))], axis=-1) + + agent_state[:, :3] = rotated_positions + agent_state[:, 3] = rotated_heading + agent_state[:, 4] = np.sin(rotated_heading) + agent_state[:, 5] = np.cos(rotated_heading) + agent_state[:, 6:8] = rotated_velocity + + agent_state[:, 8] = speed + + agent_state[:, 9] = cur_state["length"].reshape(-1) + agent_state[:, 10] = cur_state["width"].reshape(-1) + agent_state[:, 11] = cur_state["height"].reshape(-1) + + agent_state[~valid_mask] = 0 + agent_state[:, 12] = MetaDriveType.is_vehicle(cur_data["type"]) + agent_state[:, 13] = MetaDriveType.is_pedestrian(cur_data["type"]) + agent_state[:, 14] = MetaDriveType.is_cyclist(cur_data["type"]) + agent_state[:, 15] = valid_mask + + # TODO(pzh): Remove mapping + assert cur_data["type"] in constants.object_type_to_int + + agent_feature_dict[track_name] = agent_state + agent_valid_mask_dict[track_name] = valid_mask + + agent_position_dict[track_name] = rotated_positions * valid_mask.reshape(-1, 1) + agent_heading_dict[track_name] = rotated_heading * valid_mask + agent_velocity_dict[track_name] = rotated_velocity * valid_mask.reshape(-1, 1) + + # TODO(pzh): Remove mapping + agent_type_dict[track_name] = constants.object_type_to_int[cur_data["type"]] + + valid_track_names.append(str(track_name)) + + track_count += 1 + + assert sdc_index is not None + + # ===== Store all data into dict ===== + agent_feature = np.stack(list(agent_feature_dict.values()), axis=1) # [T, ] + num_agents = agent_feature.shape[1] + utils.assert_shape(agent_feature, (track_length, num_agents, constants.AGENT_STATE_DIM)) + + agent_valid_mask = np.stack(list(agent_valid_mask_dict.values()), axis=1).astype(bool) + utils.assert_shape(agent_valid_mask, ( + track_length, + num_agents, + )) + + agent_position = np.stack(list(agent_position_dict.values()), axis=1) + utils.assert_shape(agent_position, (track_length, num_agents, 3)) + + agent_velocity = np.stack(list(agent_velocity_dict.values()), axis=1) + utils.assert_shape(agent_velocity, (track_length, num_agents, 2)) + + agent_heading = np.stack(list(agent_heading_dict.values()), axis=1) + utils.assert_shape(agent_heading, (track_length, num_agents)) + + agent_type = np.stack(list(agent_type_dict.values()), axis=0).astype(int) + utils.assert_shape(agent_type, (num_agents, )) + + agent_shape = np.stack(list(agent_shape_dict.values()), axis=1) + utils.assert_shape(agent_shape, (track_length, num_agents, 3)) + + data_dict["encoder/agent_feature"] = agent_feature.astype(np.float32) # [T, num agent, D_agent] + data_dict["encoder/agent_valid_mask"] = agent_valid_mask.astype(bool) # [T, num agent] + data_dict["encoder/agent_position"] = agent_position.astype(np.float32) + data_dict["encoder/agent_velocity"] = agent_velocity.astype(np.float32) + data_dict["encoder/agent_heading"] = agent_heading.astype(np.float32) + data_dict["encoder/agent_type"] = agent_type + + # data_dict["encoder/future_agent_position"] = agent_position.astype(np.float32)[current_t + 1:] + # data_dict["encoder/future_agent_heading"] = agent_heading.astype(np.float32)[current_t + 1:] + # data_dict["encoder/future_agent_valid_mask"] = agent_valid_mask.astype(bool)[current_t + 1:] + # data_dict["encoder/future_agent_velocity"] = agent_velocity.astype(np.float32)[current_t + 1:] + data_dict["encoder/current_agent_valid_mask"] = agent_valid_mask.astype(bool)[current_t] + data_dict["encoder/current_agent_position"] = agent_position.astype(np.float32)[current_t] + data_dict["encoder/current_agent_heading"] = agent_heading.astype(np.float32)[current_t] + data_dict["encoder/current_agent_velocity"] = agent_velocity.astype(np.float32)[current_t] + + data_dict["encoder/track_name"] = np.array(valid_track_names, dtype=str) + data_dict["encoder/agent_shape"] = agent_shape.astype(np.float32) + data_dict["encoder/current_agent_shape"] = data_dict["encoder/agent_shape"][current_t] + data_dict["encoder/sdc_index"] = sdc_index + + # ===== Process the case where the number of agents exceeds max_agents ===== + if is_scenestreamer: + data_dict = filter_and_reorder_agent_for_scenestreamer(data_dict, max_agents=max_agents if not exempt_max_agent_filtering else None) + else: + data_dict = filter_and_reorder_agent(data_dict, max_agents=max_agents if not exempt_max_agent_filtering else None) + + # Add agent ID: + num_agents = data_dict["encoder/agent_feature"].shape[1] + data_dict["encoder/agent_id"] = np.arange(num_agents) + + # assert (data_dict["decoder/current_agent_valid_mask"] == data_dict["encoder/agent_valid_mask"][current_t]).all() + assert data_dict["encoder/sdc_index"] in list(data_dict["encoder/current_agent_valid_mask"].nonzero()[0]) + + return data_dict + + +def prepare_modeled_agent_and_eval_data( + data_dict, predict_all_agents, eval_all_agents, current_t, add_sdc_to_object_of_interest +): + # ===== Need to extract only the modeled agents for decoder and GT ===== + object_of_interest = data_dict["encoder/object_of_interest_id"] + + if predict_all_agents: + modeled_agent_indices = list(data_dict["encoder/current_agent_valid_mask"].nonzero()[0]) + + # In the following code, we will select only the valid agents at this step as modeled agents. + # After the selection, the order of agents will change (again ..). So the object_of_interests + # should also be changed. + + new_object_of_interests = [] + for old_agent_index in object_of_interest: + for new_ind, old_ind in enumerate(modeled_agent_indices): + if old_agent_index == old_ind: + new_object_of_interests.append(new_ind) + break + assert len(new_object_of_interests) == len(object_of_interest) + + assert data_dict["encoder/sdc_index"] in modeled_agent_indices + data_dict["decoder/sdc_index"] = modeled_agent_indices.index(data_dict["encoder/sdc_index"]) + data_dict["decoder/object_of_interest_id"] = np.asarray(new_object_of_interests) + # Note that new ooi id doesn't change the order. So no need to change ooi name + + else: + raise ValueError("Not sure what will happen...") + object_of_interest = data_dict["encoder/object_of_interest_id"] + modeled_agent_indices = object_of_interest + # object_of_interest don't change + assert eval_all_agents is False + data_dict["decoder/object_of_interest_id"] = np.arange(len(object_of_interest)) + + data_dict["decoder/agent_id"] = np.arange(len(modeled_agent_indices)) + + assert modeled_agent_indices is not None + + data_dict["decoder/agent_type"] = extract_data_by_agent_indices( + data_dict["encoder/agent_type"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["decoder/track_name"] = extract_data_by_agent_indices( + data_dict["encoder/track_name"], modeled_agent_indices, agent_dim=0, fill=-1 + ) + data_dict["encoder/modeled_agent_id"] = extract_data_by_agent_indices( + data_dict["encoder/agent_id"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["encoder/modeled_agent_type"] = extract_data_by_agent_indices( + data_dict["encoder/agent_type"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["decoder/current_agent_valid_mask"] = extract_data_by_agent_indices( + data_dict["encoder/current_agent_valid_mask"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["decoder/current_agent_position"] = extract_data_by_agent_indices( + data_dict["encoder/current_agent_position"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["decoder/current_agent_heading"] = extract_data_by_agent_indices( + data_dict["encoder/current_agent_heading"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["decoder/current_agent_shape"] = extract_data_by_agent_indices( + data_dict["encoder/current_agent_shape"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + data_dict["decoder/current_agent_velocity"] = extract_data_by_agent_indices( + data_dict["encoder/current_agent_velocity"], agent_indices=modeled_agent_indices, agent_dim=0 + ) + + # agent_dim = 1 + # data_dict["decoder/future_agent_position"] = extract_data_by_agent_indices( + # data_dict["encoder/future_agent_position"], agent_indices=modeled_agent_indices, agent_dim=1 + # ) + # data_dict["decoder/future_agent_heading"] = extract_data_by_agent_indices( + # data_dict["encoder/future_agent_heading"], agent_indices=modeled_agent_indices, agent_dim=1 + # ) + # data_dict["decoder/future_agent_velocity"] = extract_data_by_agent_indices( + # data_dict["encoder/future_agent_velocity"], agent_indices=modeled_agent_indices, agent_dim=1 + # ) + # data_dict["decoder/future_agent_valid_mask"] = extract_data_by_agent_indices( + # data_dict["encoder/future_agent_valid_mask"], agent_indices=modeled_agent_indices, agent_dim=1 + # ) + + data_dict["decoder/agent_position"] = extract_data_by_agent_indices( + data_dict["encoder/agent_position"], modeled_agent_indices, agent_dim=1 + ) + data_dict["decoder/agent_velocity"] = extract_data_by_agent_indices( + data_dict["encoder/agent_velocity"], modeled_agent_indices, agent_dim=1 + ) + data_dict["decoder/agent_heading"] = extract_data_by_agent_indices( + data_dict["encoder/agent_heading"], modeled_agent_indices, agent_dim=1 + ) + data_dict["decoder/agent_valid_mask"] = extract_data_by_agent_indices( + data_dict["encoder/agent_valid_mask"], modeled_agent_indices, agent_dim=1 + ) + data_dict["decoder/agent_shape"] = extract_data_by_agent_indices( + data_dict["encoder/agent_shape"], modeled_agent_indices, agent_dim=1 + ) + data_dict["decoder/object_of_interest_name"] = data_dict["encoder/object_of_interest_name"] + + if add_sdc_to_object_of_interest: + + if data_dict["metadata/sdc_name"] not in data_dict["encoder/object_of_interest_name"]: + data_dict["encoder/object_of_interest_name"] = np.concatenate( + [[data_dict["metadata/sdc_name"]], data_dict["encoder/object_of_interest_name"]] + ) + else: + assert data_dict["metadata/sdc_name"] == data_dict["encoder/object_of_interest_name"][0] + + if data_dict["metadata/sdc_name"] not in data_dict["decoder/object_of_interest_name"]: + data_dict["decoder/object_of_interest_name"] = np.concatenate( + [[data_dict["metadata/sdc_name"]], data_dict["encoder/object_of_interest_name"]] + ) + else: + assert data_dict["metadata/sdc_name"] == data_dict["decoder/object_of_interest_name"][0] + + # Evaluation data: all with leading dimensions: (num of interested objects, T, ...) + # If not eval all agents, a new index system `eval/` is introduced. + if eval_all_agents: + pass + # data_dict["eval/track_name"] = data_dict["decoder/track_name"] + # data_dict["eval/agent_type"] = data_dict["decoder/agent_type"] + # data_dict["eval/agent_position"] = data_dict["decoder/agent_position"] + # data_dict["eval/agent_velocity"] = data_dict["decoder/agent_velocity"] + # data_dict["eval/agent_heading"] = data_dict["decoder/agent_heading"] + # data_dict["eval/agent_valid_mask"] = data_dict["decoder/agent_valid_mask"] + # data_dict["eval/agent_shape"] = data_dict["decoder/agent_shape"] + else: + assert new_object_of_interests is not None + decoder_ooi_id = new_object_of_interests + data_dict["eval/track_name"] = extract_data_by_agent_indices( + data_dict["decoder/track_name"], decoder_ooi_id, agent_dim=0 + ) + data_dict["eval/agent_type"] = extract_data_by_agent_indices( + data_dict["decoder/agent_type"], decoder_ooi_id, agent_dim=0 + ) + data_dict["eval/agent_position"] = extract_data_by_agent_indices( + data_dict["decoder/agent_position"], decoder_ooi_id, agent_dim=1 + ) + data_dict["eval/agent_velocity"] = extract_data_by_agent_indices( + data_dict["decoder/agent_velocity"], decoder_ooi_id, agent_dim=1 + ) + data_dict["eval/agent_heading"] = extract_data_by_agent_indices( + data_dict["decoder/agent_heading"], decoder_ooi_id, agent_dim=1 + ) + data_dict["eval/agent_valid_mask"] = extract_data_by_agent_indices( + data_dict["decoder/agent_valid_mask"], decoder_ooi_id, agent_dim=1 + ) + data_dict["eval/agent_shape"] = extract_data_by_agent_indices( + data_dict["decoder/agent_shape"], decoder_ooi_id, agent_dim=1 + ) + assert data_dict["eval/agent_valid_mask"][current_t].all() # not all object_of_interest is in CAT + + return data_dict + + +def preprocess_scenario_description(*args, config, **kwargs): + # if scenario['length'] < 5: # TODO: filter out CAT data that is not valid + # return None + # TODO: combine all cat info dictionary .pkl files into one and provide the paths in the config + if config.MODEL.NAME in ["motionlm", "language_motionlm", "gpt", "scenestreamer"]: + return preprocess_scenario_description_for_motionlm(*args, config=config, **kwargs) + elif config.MODEL.NAME == "gen": + return preprocess_scenario_description_for_gen(scenario, config, in_evaluation, keep_all_data) + # elif config.MODEL.NAME == "scenestreamer": + # return preprocess_scenario_description_for_scenestreamer(scenario, config, in_evaluation, keep_all_data) + else: + raise ValueError(f"Unknown model name: {config.MODEL.NAME}") + + +def prepare_trafficgen_data( + data_dict, + config, + scenario, + force_t=True, # Just disable this function... + only_lane=False, +): + + sdc_index = data_dict["decoder/sdc_index"] + T = data_dict["encoder/agent_feature"].shape[0] + if force_t: + # if 180 <= T <= 200: + # assert data_dict["metadata/current_time_index"] == 0 + # current_t = 0 + # elif T == 91: + current_t = data_dict["metadata/current_time_index"] + # else: + # raise ValueError(f"Unknown T: {T}") + + else: + current_t = np.random.randint(0, T) + + start_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + end_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + + # Note that here we reuse the "current_agent". + pos = data_dict["decoder/agent_position"][current_t, ..., :2] + heading = data_dict["decoder/agent_heading"][current_t] + valid = data_dict["decoder/agent_valid_mask"][current_t] + vel = data_dict["decoder/agent_velocity"][current_t] + + agent_type = data_dict["decoder/agent_type"] # in 123 + agent_type = np.clip(agent_type - 1, 0, 2) # in 012 + + current_agent_shape = data_dict["decoder/current_agent_shape"] + # assert valid.all() + N = len(pos) + + # Randomize the agent order (but still put SDC in the first place) + agent_id = np.arange(len(pos)) + agent_id = agent_id[agent_id != sdc_index] + randomized_agent_id = np.random.permutation(agent_id) + randomized_agent_id = np.concatenate([np.array([sdc_index]), randomized_agent_id], axis=0) + agent_type = agent_type[randomized_agent_id] + pos = pos[randomized_agent_id] + heading = heading[randomized_agent_id] + vel = vel[randomized_agent_id] + current_agent_shape = current_agent_shape[randomized_agent_id] + + # Filter map feature and only keep lanes: + # map_feature = data_dict["encoder/map_feature"] + # is_lane = map_feature[:, 0, 13] == 1 + map_pos = data_dict["encoder/map_position"][..., :2] + map_heading = data_dict["encoder/map_heading"] + + # Get map feature valid mask + valid_map_feat = data_dict["encoder/map_valid_mask"] + heading_diff = utils.wrap_to_pi(heading[:, None] - map_heading[None]) + valid_heading = np.abs(heading_diff) < np.deg2rad(90) + valid_map_feat = valid_map_feat & valid_heading + + if only_lane: + map_feature = data_dict["encoder/map_feature"] + is_lane = map_feature[:, 0, 13] == 1 + is_lane = is_lane[None].repeat(N, 0) + valid_map_feat = is_lane & valid_map_feat + + # Find the closest map feature + dist = np.linalg.norm((pos[:, None] - map_pos[None])[..., :2], axis=-1) + dist[~valid_map_feat] = np.inf + closest_map_feat = np.argmin(dist, axis=1) + # closest_map_dist = dist[np.arange(N), closest_map_feat] + + # Set invalid if an agent is far away from the center of the map feat. + # By saying far I mean exceeding the length of the map feat. + # map_feat_length = map_feature[..., 25].max(-1)[closest_map_feat] + # valid_mask = closest_map_dist < map_feat_length + # valid_mask = np.ones(N, dtype=bool) + valid_mask = valid + + # Get the selected map feature + selected_map_pos = map_pos[closest_map_feat] + selected_map_heading = map_heading[closest_map_feat] + + # Get relative information + relative_pos = pos - selected_map_pos + relative_pos = utils.rotate(x=relative_pos[:, 0], y=relative_pos[:, 1], angle=-selected_map_heading) + relative_heading = utils.wrap_to_pi(heading - selected_map_heading) + relative_vel = utils.rotate(x=vel[:, 0], y=vel[:, 1], angle=-selected_map_heading) + + # Filter out the agents that are out of the scope + valid_mask = ( + valid_mask & (relative_pos[:, 0] >= TrafficGenTokenizer.limit["position_x"][0]) & + (relative_pos[:, 0] <= TrafficGenTokenizer.limit["position_x"][1]) & + (relative_pos[:, 1] >= TrafficGenTokenizer.limit["position_y"][0]) & + (relative_pos[:, 1] <= TrafficGenTokenizer.limit["position_y"][1]) & + (relative_heading >= TrafficGenTokenizer.limit["heading"][0]) & + (relative_heading <= TrafficGenTokenizer.limit["heading"][1]) + ) + + if config.FOLLOW_TRAFFICGEN: + tg_select_index = _get_trafficgen_data( + raw_scenario_description=scenario, data_dict=data_dict, current_t=current_t + ) + new_valid_mask = np.zeros_like(valid_mask) + new_valid_mask[tg_select_index] = True + valid_mask = valid_mask & new_valid_mask + + if not valid_mask.any() and force_t is False: + return prepare_trafficgen_data(data_dict, config, force_t=True) + + pos = pos[valid_mask] + heading = heading[valid_mask] + vel = vel[valid_mask] + agent_type = agent_type[valid_mask] + current_agent_shape = current_agent_shape[valid_mask] + relative_pos = relative_pos[valid_mask] + relative_heading = relative_heading[valid_mask] + relative_vel = relative_vel[valid_mask] + selected_map_pos = selected_map_pos[valid_mask] + selected_map_heading = selected_map_heading[valid_mask] + closest_map_feat = closest_map_feat[valid_mask] + valid_mask = valid_mask[valid_mask] + + # Get the discretized relative position + gt_position_x = TrafficGenTokenizer.bucketize(relative_pos[:, 0], "position_x") + gt_position_y = TrafficGenTokenizer.bucketize(relative_pos[:, 1], "position_y") + relative_pos_x = TrafficGenTokenizer.de_bucketize(gt_position_x, "position_x") + relative_pos_y = TrafficGenTokenizer.de_bucketize(gt_position_y, "position_y") + + # Reconstruct the position with the bucketized value + relative_pos = np.stack([relative_pos_x, relative_pos_y], axis=1) + pos = utils.rotate(x=relative_pos_x, y=relative_pos_y, angle=selected_map_heading) + selected_map_pos + + # Reconstruct the heading and velocity + gt_heading = TrafficGenTokenizer.bucketize(relative_heading, "heading") + relative_heading = TrafficGenTokenizer.de_bucketize(gt_heading, "heading") + heading = utils.wrap_to_pi(relative_heading + selected_map_heading) + + # Reconstruct the velocity + gt_vel_x = TrafficGenTokenizer.bucketize(relative_vel[:, 0], "velocity_x") + gt_vel_y = TrafficGenTokenizer.bucketize(relative_vel[:, 1], "velocity_y") + relative_vel_x = TrafficGenTokenizer.de_bucketize(gt_vel_x, "velocity_x") + relative_vel_y = TrafficGenTokenizer.de_bucketize(gt_vel_y, "velocity_y") + relative_vel = np.stack([relative_vel_x, relative_vel_y], axis=1) + vel = utils.rotate(x=relative_vel_x, y=relative_vel_y, angle=selected_map_heading) + + # Reconstruct shape + gt_shape_l = TrafficGenTokenizer.bucketize(current_agent_shape[:, 0], "length") + gt_shape_w = TrafficGenTokenizer.bucketize(current_agent_shape[:, 1], "width") + gt_shape_h = TrafficGenTokenizer.bucketize(current_agent_shape[:, 2], "height") + current_agent_shape = np.stack( + [ + TrafficGenTokenizer.de_bucketize(gt_shape_l, "length"), + TrafficGenTokenizer.de_bucketize(gt_shape_w, "width"), + TrafficGenTokenizer.de_bucketize(gt_shape_h, "height") + ], + axis=1 + ) + + # ===== Fill in the data for trafficgen ===== + data_dict["decoder/input_action_for_trafficgen"] = np.concatenate( + [[start_action_id], closest_map_feat, [end_action_id]] + ).astype(int) + data_dict["decoder/input_action_valid_mask_for_trafficgen"] = np.concatenate([[1], valid_mask, [1]], + axis=0).astype(bool) + data_dict["decoder/modeled_agent_position_for_trafficgen"] = np.concatenate([[[0, 0]], pos, [[0, 0]]], + axis=0).astype(np.float32) + + data_dict["decoder/modeled_agent_velocity_for_trafficgen"] = np.concatenate([[[0, 0]], vel, [[0, 0]]], + axis=0).astype(np.float32) + + data_dict["decoder/modeled_agent_heading_for_trafficgen"] = np.concatenate([[0], heading, [0]], + axis=0).astype(np.float32) + data_dict["decoder/current_agent_shape_for_trafficgen"] = np.concatenate( + [[[0, 0, 0]], current_agent_shape, [[0, 0, 0]]], axis=0 + ).astype(np.float32) + data_dict["decoder/agent_type_for_trafficgen"] = np.concatenate([[0], agent_type, [0]], axis=0).astype(int) + + feat = np.zeros((len(pos) + 2, 5), dtype=np.float32) + # import matplotlib.pyplot as plt;plt.scatter(relative_pos[:, 0], relative_pos[:, 1]);plt.show() + feat[1:-1, :2] = relative_pos + feat[1:-1, 2] = relative_heading + feat[1:-1, 3:5] = relative_vel + # print("MAX: x={:.3f}, y={:.3f}, h={:.3f}, vx={:.3f}, vy={:.3f}".format( + # feat[:, 0].max(), feat[:, 1].max(), feat[:, 2].max(), feat[:, 3].max(), feat[:, 4].max() + # )) + # print("MIN: x={:.3f}, y={:.3f}, h={:.3f}, vx={:.3f}, vy={:.3f}".format( + # feat[:, 0].min(), feat[:, 1].min(), feat[:, 2].min(), feat[:, 3].min(), feat[:, 4].min() + # )) + # print("MAX: l={:.3f}, w={:.3f}, h={:.3f}".format( + # current_agent_shape[:, 0].max(), current_agent_shape[:, 1].max(), current_agent_shape[:, 2].max() + # )) + # print("MIN: l={:.3f}, w={:.3f}, h={:.3f}".format( + # current_agent_shape[:, 0].min(), current_agent_shape[:, 1].min(), current_agent_shape[:, 2].min() + # )) + data_dict["decoder/input_action_feature_for_trafficgen"] = feat + + data_dict["decoder/target_offset_for_trafficgen"] = np.stack( + [gt_position_x, gt_position_y, gt_heading, gt_vel_x, gt_vel_y, gt_shape_l, gt_shape_w, gt_shape_h, agent_type], + axis=1 + ).astype(int) + + data_dict["decoder/input_offset_for_trafficgen"] = np.concatenate( + [ + np.full((1, 9), -1, dtype=int), + data_dict["decoder/target_offset_for_trafficgen"] + ], axis=0 + ) + + # Pad one more step for "end action" + data_dict["decoder/target_offset_for_trafficgen"] = np.concatenate( + [data_dict["decoder/target_offset_for_trafficgen"], + np.zeros((1, 9), dtype=int)], axis=0 + ) + + return data_dict + + +NUM_TG_MULTI = 4 + +TG_SKIP_STEP = 2 + + +def slice_trafficgen_data(tensor, dim): + # We planned to slice TG data every 1s = 2 steps. + num_skip = TG_SKIP_STEP + if dim == 0: + return tensor[::num_skip] + elif dim == 1: + return tensor[:, ::num_skip] + elif dim == 2: + return tensor[:, :, ::num_skip] + else: + raise ValueError(f"Unknown dimension: {dim}") + + +def prepare_trafficgen_data_for_scenestreamer( + data_dict, + config, + scenario, + force_t=True, # Just disable this function... + only_lane=False, + dest_dropout=0.0 +): + + data_dict = prepare_destination(data_dict, config, FUTURE_STEPS=30, skip_step=5, dropout=dest_dropout) + + sdc_index = data_dict["decoder/sdc_index"] + T = data_dict["encoder/agent_feature"].shape[0] + + # start_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + # start_sequence_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + # end_sequence_id = config.PREPROCESSING.MAX_MAP_FEATURES + 2 + # dest_pad_id = config.PREPROCESSING.MAX_MAP_FEATURES + 3 + + trafficgen_sequence_sos_id = config.PREPROCESSING.MAX_MAP_FEATURES + trafficgen_sequence_eos_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + trafficgen_sequence_pad_id = config.PREPROCESSING.MAX_MAP_FEATURES + 2 + veh_id = config.PREPROCESSING.MAX_MAP_FEATURES + 3 + ped_id = config.PREPROCESSING.MAX_MAP_FEATURES + 4 + cyc_id = config.PREPROCESSING.MAX_MAP_FEATURES + 5 + trafficgen_agent_sos_id = config.PREPROCESSING.MAX_MAP_FEATURES + 6 + + data_dict["decoder/dest_map_index_gt"][data_dict["decoder/dest_map_index_gt"] == -1] = trafficgen_sequence_pad_id + + agent_type = data_dict["decoder/agent_type"] # in 123 (that's why we using 5 possible actions) + assert agent_type.max() < 4 + + current_agent_shape = data_dict["decoder/current_agent_shape"] + + # Note that here we reuse the "current_agent". + map_pos = data_dict["encoder/map_position"][..., :2] + map_heading = data_dict["encoder/map_heading"] + + # Get map feature valid mask + valid_map_feat = data_dict["encoder/map_valid_mask"] + + tg_action_list = [] + tg_valid_list = [] + tg_feat_list = [] + tg_target_offset_list = [] + tg_pos_list = [] + tg_head_list = [] + + # TODO: Hardcoded + for sparse_t, current_t in enumerate(range(0, T, 5)): + + pos = data_dict["decoder/modeled_agent_position"][sparse_t, ..., :2] + heading = data_dict["decoder/modeled_agent_heading"][sparse_t] + valid = data_dict["decoder/input_action_valid_mask"][sparse_t] + vel = data_dict["decoder/modeled_agent_velocity"][sparse_t] + dest = data_dict["decoder/dest_map_index"][sparse_t] + + tg_map_id, tg_valid, tg_feat, tg_target_offset, tg_pos, tg_head = prepare_trafficgen_data_for_scenestreamer_a_step( + pos=pos, + heading=heading, + agent_valid_mask=valid, + vel=vel, + dest=dest, + map_pos=map_pos, + map_heading=map_heading, + agent_type=agent_type, + map_valid_mask=valid_map_feat, + current_agent_shape=current_agent_shape, + # start_action_id=start_action_id, + # end_action_id=end_action_id, + start_sequence_id=trafficgen_sequence_sos_id, + end_sequence_id=trafficgen_sequence_eos_id, + dest_pad_id=trafficgen_sequence_pad_id, + veh_id=veh_id, + cyc_id=cyc_id, + ped_id=ped_id, + start_agent_id=trafficgen_agent_sos_id, + ) + + tg_action_list.append(tg_map_id) + tg_valid_list.append(tg_valid) + tg_feat_list.append(tg_feat) + tg_target_offset_list.append(tg_target_offset) + tg_pos_list.append(tg_pos) + tg_head_list.append(tg_head) + + tg_action_list = np.stack(tg_action_list, axis=0) + tg_valid_list = np.stack(tg_valid_list, axis=0) + tg_feat_list = np.stack(tg_feat_list, axis=0).astype(np.float32) + tg_target_offset_list = np.stack(tg_target_offset_list, axis=0) + tg_pos_list = np.stack(tg_pos_list, axis=0).astype(np.float32) + tg_head_list = np.stack(tg_head_list, axis=0).astype(np.float32) + + data_dict["decoder/trafficgen_position"] = tg_pos_list + data_dict["decoder/trafficgen_heading"] = tg_head_list + + data_dict["decoder/input_action_for_trafficgen"] = tg_action_list + data_dict["decoder/input_action_valid_mask_for_trafficgen"] = tg_valid_list + assert tg_action_list[tg_valid_list].min() >= 0 + + data_dict["decoder/input_action_feature_for_trafficgen"] = tg_feat_list.astype(np.float32) + + data_dict["decoder/target_offset_for_trafficgen"] = ( + tg_target_offset_list * data_dict["decoder/agent_valid_mask"][::5][..., None] + ) + tg_input_offset_list = np.concatenate([ + np.full((tg_target_offset_list.shape[0], tg_target_offset_list.shape[1], 1), -1, dtype=int), + tg_target_offset_list], axis=-1 + ) + data_dict["decoder/input_offset_for_trafficgen"] = ( + tg_input_offset_list * data_dict["decoder/agent_valid_mask"][::5][..., None] + ) + + G = tg_action_list.shape[1] + N = agent_type.shape[0] + sparse_T = tg_action_list.shape[0] + + # agent_type = agent_type[None].repeat(tg_action_list.shape[0], axis=0) + new_agent_type = np.full((sparse_T, N, NUM_TG_MULTI), -1) + tmp_agent_type = np.full((N,), -1) + tmp_agent_type[agent_type == 1] = veh_id + tmp_agent_type[agent_type == 2] = ped_id + tmp_agent_type[agent_type == 3] = cyc_id + new_agent_type[:, :, 2:] = tmp_agent_type.reshape(1, N, 1) + # new_new_agent_type = np.full((sparse_T, G), -1) + # new_new_agent_type[:, 1:-1] = new_agent_type.reshape(sparse_T, -1) + new_agent_type = np.concatenate([ + np.full((sparse_T, 1), -1), + new_agent_type.reshape(sparse_T, -1), + np.full((sparse_T, 1), -1), + ], axis=1) + assert new_agent_type.shape == (sparse_T, G) + data_dict["decoder/agent_type_for_trafficgen"] = new_agent_type.astype(int) + + # data_dict["decoder/current_agent_shape_for_trafficgen"] = np.concatenate([[[0, 0, 0]], current_agent_shape, [[0, 0, 0]]], axis=0).astype(np.float32) + agent_id = np.concatenate([ + np.full((sparse_T, 1), -1), + data_dict["encoder/modeled_agent_id"].repeat(NUM_TG_MULTI)[None].repeat(sparse_T, axis=0), + np.full((sparse_T, 1), -1), + ], axis=1) + assert agent_id.shape == (sparse_T, G) + data_dict["decoder/agent_id_for_trafficgen"] = agent_id + + # Overwrite the original agent type with the new one in SceneStreamer + data_dict["decoder/agent_type"] = tmp_agent_type + + return data_dict + + + +def prepare_trafficgen_data_for_scenestreamer_a_step( + *, pos, heading, agent_valid_mask, vel, map_pos, map_heading, map_valid_mask, agent_type, current_agent_shape, + start_sequence_id, end_sequence_id, dest, dest_pad_id, + veh_id, cyc_id, ped_id, start_agent_id, +): + original_pos = pos + original_heading = heading + original_shape = current_agent_shape + + N = len(pos) + + from scenestreamer.models.scenestreamer_model import get_num_tg + G = get_num_tg(N) + + heading_diff = utils.wrap_to_pi(heading[:, None] - map_heading[None]) + valid_heading = np.abs(heading_diff) < np.deg2rad(90) + valid_map_feat = map_valid_mask & valid_heading + + # if only_lane: + # map_feature = data_dict["encoder/map_feature"] + # is_lane = map_feature[:, 0, 13] == 1 + # is_lane = is_lane[None].repeat(N, 0) + # valid_map_feat = is_lane & valid_map_feat + + # Find the closest map feature + dist = np.linalg.norm((pos[:, None] - map_pos[None])[..., :2], axis=-1) + dist[~valid_map_feat] = np.inf + closest_map_feat = np.argmin(dist, axis=1) + + # Get the selected map feature + selected_map_pos = map_pos[closest_map_feat] + selected_map_heading = map_heading[closest_map_feat] + + # Get relative information + relative_pos = pos - selected_map_pos + relative_pos = utils.rotate(x=relative_pos[:, 0], y=relative_pos[:, 1], angle=-selected_map_heading) + original_relative_pos = relative_pos + + relative_heading = utils.wrap_to_pi(heading - selected_map_heading) + relative_vel = utils.rotate(x=vel[:, 0], y=vel[:, 1], angle=-selected_map_heading) + original_relative_heading = relative_heading + original_relative_vel = relative_vel + + # Get the discretized relative position + gt_position_x = TrafficGenTokenizerAutoregressive.bucketize(relative_pos[:, 0], "position_x") + gt_position_y = TrafficGenTokenizerAutoregressive.bucketize(relative_pos[:, 1], "position_y") + # relative_pos_x = TrafficGenTokenizerAutoregressive.de_bucketize(gt_position_x, "position_x") + # relative_pos_y = TrafficGenTokenizerAutoregressive.de_bucketize(gt_position_y, "position_y") + + # Reconstruct the position with the bucketized value + # relative_pos = np.stack([relative_pos_x, relative_pos_y], axis=1) + # pos = utils.rotate(x=relative_pos_x, y=relative_pos_y, angle=selected_map_heading) + selected_map_pos + + # Reconstruct the heading and velocity + gt_heading = TrafficGenTokenizerAutoregressive.bucketize(relative_heading, "heading") + # relative_heading = TrafficGenTokenizerAutoregressive.de_bucketize(gt_heading, "heading") + # heading = utils.wrap_to_pi(relative_heading + selected_map_heading) + + # Reconstruct the velocity + gt_vel_x = TrafficGenTokenizerAutoregressive.bucketize(relative_vel[:, 0], "velocity_x") + gt_vel_y = TrafficGenTokenizerAutoregressive.bucketize(relative_vel[:, 1], "velocity_y") + # relative_vel_x = TrafficGenTokenizerAutoregressive.de_bucketize(gt_vel_x, "velocity_x") + # relative_vel_y = TrafficGenTokenizerAutoregressive.de_bucketize(gt_vel_y, "velocity_y") + # relative_vel = np.stack([relative_vel_x, relative_vel_y], axis=1) + # vel = utils.rotate(x=relative_vel_x, y=relative_vel_y, angle=selected_map_heading) + + # Reconstruct shape + gt_shape_l = TrafficGenTokenizerAutoregressive.bucketize(current_agent_shape[:, 0], "length") + gt_shape_w = TrafficGenTokenizerAutoregressive.bucketize(current_agent_shape[:, 1], "width") + gt_shape_h = TrafficGenTokenizerAutoregressive.bucketize(current_agent_shape[:, 2], "height") + # current_agent_shape = np.stack( + # [ + # TrafficGenTokenizerAutoregressive.de_bucketize(gt_shape_l, "length"), + # TrafficGenTokenizerAutoregressive.de_bucketize(gt_shape_w, "width"), + # TrafficGenTokenizerAutoregressive.de_bucketize(gt_shape_h, "height") + # ], + # axis=1 + # ) + + # ===== Fill in the data for trafficgen ===== + # note: when generating one agent, we have 5 tokens: + # agent_start, map_id, agent_state, dest_map_id, agent_end + + # if dest is not None: + # dest_pos = map_pos[dest] + # dest_heading = map_heading[dest] + # + # # Use agent's current position and heading if destination is not valid + # dest_pos[dest == -1] = original_pos[dest == -1] + # dest_heading[dest == -1] = original_heading[dest == -1] + # + # else: + # dest_pos = np.zeros((N, 2)) + # dest_heading = np.zeros(N) + # dest = np.full((N,), -1) + # + # # destdist = np.linalg.norm(dest_pos - original_pos, axis=-1) + # # print("Dist>1 Rate {}, Dist Avg {}".format((destdist > 1).mean(), destdist[destdist > 1].mean() if (destdist > 1).any() else 0)) + # + # dest[dest==-1] = dest_pad_id + # dest[~agent_valid_mask] = -1 + + new_agent_type_id = np.full((N, ), -1) + + new_agent_type_id[agent_type == constants.object_type_to_int["VEHICLE"]] = veh_id + new_agent_type_id[agent_type == veh_id] = veh_id + new_agent_type_id[agent_type == constants.object_type_to_int["PEDESTRIAN"]] = ped_id + new_agent_type_id[agent_type == ped_id] = ped_id + new_agent_type_id[agent_type == constants.object_type_to_int["CYCLIST"]] = cyc_id + new_agent_type_id[agent_type == cyc_id] = cyc_id + # assert new_agent_type_id.min() != -1 + + map_id = np.concatenate([ + np.full((N, 1), start_agent_id), + new_agent_type_id.reshape(N, 1), + closest_map_feat[:, None], + closest_map_feat[:, None], + # dest[:, None], + ], axis=1).flatten() + map_id = np.concatenate([[start_sequence_id], map_id, [end_sequence_id]]).astype(int) + + tg_valid = np.concatenate([[1], agent_valid_mask[:, None].repeat(NUM_TG_MULTI, axis=1).flatten(), [1]], axis=0).astype(bool) + + tg_target_offset = np.stack( + [ + gt_shape_l, + gt_shape_w, + gt_shape_h, + gt_position_x, + gt_position_y, + gt_heading, + gt_vel_x, + gt_vel_y + ], + axis=1 + ).astype(int) + + + # The idea is that all the model's input should use GT data instead of TG's reconstructed data + # == These are wrong: + # tg_pos = np.concatenate([ + # np.zeros((N, 1, 2)), + # selected_map_pos[:, None], + # pos[:, None], + # dest_pos[:, None], + # np.zeros((N, 1, 2)), + # ], axis=1).reshape(-1, 2) + # tg_head = np.concatenate([ + # np.zeros((N, 1)), + # selected_map_heading[:, None], + # heading[:, None], + # dest_heading[:, None], + # np.zeros((N, 1)), + # ], axis=1).reshape(-1) + # tg_feat = np.zeros((len(pos), 5, 8), dtype=np.float32) + # # import matplotlib.pyplot as plt;plt.scatter(relative_pos[:, 0], relative_pos[:, 1]);plt.show() + # tg_feat[:, 2, :2] = relative_pos + # tg_feat[:, 2, 2] = relative_heading + # tg_feat[:, 2, 3:5] = relative_vel + # tg_feat[:, 2, 5:] = current_agent_shape + # tg_feat = tg_feat.reshape(-1, 8) + # tg_feat = np.concatenate([ + # np.full((1, 8), 0), + # tg_feat, + # np.full((1, 8), 0), + # ], axis=0) + # assert tg_feat.shape == (G, 8) + # == These are correct: + tg_pos = np.concatenate([ + np.zeros((N, 2, 2)), + selected_map_pos[:, None], + original_pos[:, None], + # dest_pos[:, None], + ], axis=1).reshape(-1, 2) + tg_head = np.concatenate([ + np.zeros((N, 2)), + selected_map_heading[:, None], + original_heading[:, None], + # dest_heading[:, None], + ], axis=1).reshape(-1) + tg_feat = np.zeros((len(pos), NUM_TG_MULTI, 8), dtype=np.float32) + # import matplotlib.pyplot as plt;plt.scatter(relative_pos[:, 0], relative_pos[:, 1]);plt.show() + tg_feat[:, 3, :2] = original_relative_pos + tg_feat[:, 3, 2] = original_relative_heading + tg_feat[:, 3, 3:5] = original_relative_vel + tg_feat[:, 3, 5:] = original_shape + tg_feat = tg_feat.reshape(-1, 8) + tg_feat = np.concatenate([ + np.full((1, 8), 0), + tg_feat, + np.full((1, 8), 0), + ], axis=0) + assert tg_feat.shape == (G, 8) + + assert map_id.shape[0] == tg_valid.shape[0] == tg_feat.shape[0] == G + assert tg_target_offset.shape[0] == N + + tg_pos = np.concatenate([ + np.zeros((1, 2)), + tg_pos, + np.zeros((1, 2)), + ], axis=0) + assert tg_pos.shape == (G, 2) + + tg_head = np.concatenate([ + np.zeros((1)), + tg_head, + np.zeros((1)), + ], axis=0) + assert tg_head.shape == (G,) + + return map_id, tg_valid, tg_feat, tg_target_offset, tg_pos, tg_head + + +def translate_abs_info_to_ego_centric(data_dict, current_t, retain_raw=False): + + if retain_raw: + data_dict["vis/map_feature"] = data_dict["encoder/map_feature"].copy() + data_dict["raw/map_feature"] = data_dict["encoder/map_feature"].copy() + + def _get_last_pos(pos, head, valid_mask): + T, N = valid_mask.shape + ind = np.arange(T).reshape(-1, 1).repeat(N, axis=1) # T, N + ind[~valid_mask] = 0 + ind = ind.max(axis=0) + out = np.take_along_axis(pos, indices=ind.reshape(1, N, 1), axis=0) + outh = np.take_along_axis(head, indices=ind.reshape(1, N), axis=0) + out = np.squeeze(out, axis=0) + outh = np.squeeze(outh, axis=0) + return out, outh + + # === Agent features === + agent_p, agent_h = _get_last_pos( + data_dict["encoder/agent_position"][:current_t + 1], data_dict["encoder/agent_heading"][:current_t + 1], + data_dict["encoder/agent_valid_mask"][:current_t + 1] + ) + + pos = data_dict["encoder/agent_feature"][..., :3] + pos = pos - agent_p[None] + pos = utils.rotate( + x=pos[..., 0], y=pos[..., 1], angle=-agent_h.reshape(1, -1).repeat(pos.shape[0], axis=0), z=pos[..., 2] + ) + data_dict["encoder/agent_feature"][..., :3] = pos + + head = data_dict["encoder/agent_feature"][..., 3] + head = utils.wrap_to_pi(head - agent_h[None]) + data_dict["encoder/agent_feature"][..., 3] = head + data_dict["encoder/agent_feature"][..., 4] = np.sin(head) + data_dict["encoder/agent_feature"][..., 5] = np.cos(head) + + vel = data_dict["encoder/agent_feature"][..., 6:8] + vel = utils.rotate( + x=vel[..., 0], + y=vel[..., 1], + angle=-agent_h.reshape(1, -1).repeat(vel.shape[0], axis=0), + ) + data_dict["encoder/agent_feature"][..., 6:8] = vel + + data_dict["encoder/agent_feature"][~data_dict["encoder/agent_valid_mask"]] = 0 + + # === Map features === + map_pos = data_dict["encoder/map_position"][:, None] + map_h = data_dict["encoder/map_heading"][:, None] + + pos = data_dict["encoder/map_feature"][..., :3] - map_pos + pos = utils.rotate( + x=pos[..., 0], y=pos[..., 1], angle=-map_h.reshape(-1, 1).repeat(pos.shape[1], axis=1), z=pos[..., 2] + ) + data_dict["encoder/map_feature"][..., :3] = pos + + pos = data_dict["encoder/map_feature"][..., 3:6] - map_pos + pos = utils.rotate( + x=pos[..., 0], y=pos[..., 1], angle=-map_h.reshape(-1, 1).repeat(pos.shape[1], axis=1), z=pos[..., 2] + ) + data_dict["encoder/map_feature"][..., 3:6] = pos + + pos = data_dict["encoder/map_feature"][..., 6:9] # direction, no need to translate + pos = utils.rotate( + x=pos[..., 0], y=pos[..., 1], angle=-map_h.reshape(-1, 1).repeat(pos.shape[1], axis=1), z=pos[..., 2] + ) + data_dict["encoder/map_feature"][..., 6:9] = pos + + head = data_dict["encoder/map_feature"][..., 9] + head = utils.wrap_to_pi(head - map_h) + data_dict["encoder/map_feature"][..., 9] = head + data_dict["encoder/map_feature"][..., 10] = np.sin(head) + data_dict["encoder/map_feature"][..., 11] = np.cos(head) + + # === Traffic light features === + # Note: We want to remove all absolute information so just remove traffic light position! + if "encoder/traffic_light_feature" in data_dict: + data_dict["encoder/traffic_light_feature"][..., :3] = 0 + + return data_dict + + +def limit_map_range(data_dict, limit_range=50): + sdc_index = data_dict["decoder/sdc_index"] + current_t = data_dict["metadata/current_time_index"] + sdc_center = data_dict["decoder/agent_position"][current_t, sdc_index] # (3,) + + # Limit the map range + margin = 0 + valid_map_feat = ( + (abs(data_dict["encoder/map_position"][..., 0] - sdc_center[0]) < limit_range + margin) & + (abs(data_dict["encoder/map_position"][..., 1] - sdc_center[1]) < limit_range + margin) + ) + valid_map_feat = valid_map_feat & data_dict["encoder/map_valid_mask"] + data_dict["encoder/map_feature_valid_mask"][~valid_map_feat] = False + data_dict["encoder/map_valid_mask"][~valid_map_feat] = False + + # Delete agents that are out of the map range + agent_pos = data_dict["encoder/agent_position"][current_t] + distance_mask = ( + (abs(agent_pos[..., 0] - sdc_center[0]) < limit_range) & (abs(agent_pos[..., 1] - sdc_center[1]) < limit_range) + ) + data_dict["encoder/agent_valid_mask"][current_t] = ( + data_dict["encoder/agent_valid_mask"][current_t] & distance_mask + ) + data_dict["encoder/current_agent_valid_mask"] = data_dict["encoder/agent_valid_mask"][current_t].copy() + agent_pos = data_dict["decoder/agent_position"][current_t] + distance_mask = ( + (abs(agent_pos[..., 0] - sdc_center[0]) < limit_range) & (abs(agent_pos[..., 1] - sdc_center[1]) < limit_range) + ) + data_dict["decoder/agent_valid_mask"][current_t] = ( + data_dict["decoder/agent_valid_mask"][current_t] & distance_mask + ) + data_dict["decoder/current_agent_valid_mask"] = data_dict["decoder/agent_valid_mask"][current_t].copy() + + # TODO: eval/agent_valid_mask is not touched yet. But it's fine now... + return data_dict + + +def preprocess_scenario_description_for_motionlm( + scenario, config, in_evaluation, keep_all_data=False, backward_prediction=None, tokenizer=None +): + metadata = scenario[SD.METADATA] + + if in_evaluation: + max_agents = 128 # TODO: hardcoded + else: + max_agents = config.PREPROCESSING.MAX_AGENTS + + tracks_to_predict_dict = metadata.get('tracks_to_predict', {}) + track_index_to_predict = np.array([int(v['track_index']) for v in tracks_to_predict_dict.values()]) + track_name_to_predict = [int(k) for k in tracks_to_predict_dict.keys()] + + # Put SDC name to the first place. + sdc_name = metadata["sdc_id"] + try: + sdc_name = int(sdc_name) + except: + pass + if sdc_name in track_name_to_predict: + track_name_to_predict.remove(sdc_name) + track_name_to_predict.insert(0, sdc_name) + track_name_to_predict = np.array(track_name_to_predict) + + data_dict = { + "in_evaluation": in_evaluation, + "metadata/sdc_name": sdc_name, + "encoder/object_of_interest_name": track_name_to_predict, + "encoder/object_of_interest_id": track_index_to_predict, + "scenario_id": scenario[SD.ID], + } + if "current_time_index" in metadata: + data_dict["metadata/current_time_index"] = metadata['current_time_index'] + else: + # TODO: Not sure in nuscenes if there is no current_time_index. Might need to check. + data_dict["metadata/current_time_index"] = 0 + metadata['current_time_index'] = 0 + + # ===== Extract map and traffic light features ===== + data_dict = process_map_and_traffic_light( + data_dict=data_dict, + scenario=scenario, + map_feature=scenario[SD.MAP_FEATURES], + dynamic_map_states=scenario[SD.DYNAMIC_MAP_STATES], + track_length=scenario[SD.LENGTH], + max_vectors=config.PREPROCESSING.MAX_VECTORS, + max_map_features=config.PREPROCESSING.MAX_MAP_FEATURES, + limit_map_range=config.LIMIT_MAP_RANGE, + max_length_per_map_feature=config.PREPROCESSING.MAX_LENGTH_PER_MAP_FEATURE, + max_traffic_lights=config.PREPROCESSING.MAX_TRAFFIC_LIGHTS, + remove_traffic_light_state=config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE, + is_scenestreamer=config.MODEL.NAME == "scenestreamer", + ) + + # ===== Extract agent features ===== + data_dict = process_track( + data_dict=data_dict, + tracks=scenario[SD.TRACKS], + track_length=scenario[SD.LENGTH], + sdc_name=metadata["sdc_id"], + max_agents=max_agents, + exempt_max_agent_filtering=in_evaluation, + is_scenestreamer=config.MODEL.NAME == "scenestreamer", + ) + data_dict = prepare_modeled_agent_and_eval_data( + data_dict=data_dict, + predict_all_agents=config.TRAINING.PREDICT_ALL_AGENTS, + eval_all_agents=config.EVALUATION.PREDICT_ALL_AGENTS, + current_t=metadata['current_time_index'], + add_sdc_to_object_of_interest=config.PREPROCESSING.ADD_SDC_TO_OBJECT_OF_INTEREST + ) + + if config.LIMIT_MAP_RANGE: + data_dict = limit_map_range(data_dict) + + if config.MODEL.RELATIVE_PE: + data_dict = translate_abs_info_to_ego_centric( + data_dict, current_t=data_dict["metadata/current_time_index"], retain_raw=keep_all_data + ) + + # if use_action_label: + sdc_ind = data_dict["decoder/sdc_index"] + object_of_interest = data_dict["decoder/object_of_interest_id"] + if sdc_ind not in object_of_interest: + object_of_interest = np.concatenate([[sdc_ind], object_of_interest]) + data_dict["decoder/labeled_agent_id"] = np.asarray(object_of_interest).astype(int) + + # ===== Call the tokenizer and generate target discretized actions ===== + # Error stats is removed from here. It's used in independent test script. + use_backward_prediction = config.BACKWARD_PREDICTION + if in_evaluation: + use_backward_prediction = False + if use_backward_prediction: + # Use 50% probability to set backward_prediction to True + use_backward_prediction = np.random.rand() < 0.5 + if backward_prediction is not None: # Overwrite the value + use_backward_prediction = backward_prediction + if config.USE_DIFFUSION: + detok, error_stat = tokenizer.tokenize_numpy_array(data_dict, backward_prediction=use_backward_prediction) + for k in ["decoder/target_agent_motion", "decoder/input_agent_motion", "decoder/target_action_valid_mask", + "decoder/input_action", "decoder/input_action_valid_mask", "decoder/modeled_agent_position", + "decoder/modeled_agent_heading", "decoder/modeled_agent_velocity", "decoder/modeled_agent_delta", + "in_backward_prediction"]: + if k in detok: + data_dict[k] = detok[k] + + else: + if config.TOKENIZATION.TOKENIZATION_METHOD is not None: + detok, error_stat = tokenizer.tokenize_numpy_array(data_dict, backward_prediction=use_backward_prediction) + for k in ["decoder/target_action", "decoder/target_action_valid_mask", "decoder/input_action", + "decoder/input_action_valid_mask", "decoder/modeled_agent_position", + "decoder/modeled_agent_heading", "decoder/modeled_agent_velocity", "decoder/modeled_agent_delta", + "in_backward_prediction"]: + if k in detok: + data_dict[k] = detok[k] + + if config.ACTION_LABEL.USE_ACTION_LABEL: + data_dict = prepare_action_label( + data_dict=data_dict, + dt=0.1, # TODO(PZH): Hardcoded here. + config=config, + mask_probability=config.ACTION_LABEL.MASK_PROBABILITY_ACTION_LABEL if in_evaluation else 0.0 + ) + + if config.get("ACTION_LABEL") and config.ACTION_LABEL.USE_SAFETY_LABEL: + data_dict = prepare_safety_label( + data_dict=data_dict, + dt=0.1, # TODO(PZH): Hardcoded here. + config=config, + mask_probability=config.ACTION_LABEL.MASK_PROBABILITY_SAFETY_LABEL if in_evaluation else 0.0 + ) + + if config["USE_TRAFFICGEN"]: + is_scenestreamer = config.MODEL.NAME == "scenestreamer" + + if is_scenestreamer: + assert not config.USE_DESTINATION + data_dict = prepare_trafficgen_data_for_scenestreamer( + data_dict=data_dict, config=config, scenario=scenario, force_t=True, dest_dropout=config.PREPROCESSING.DEST_DROPOUT, + ) + + else: + data_dict = prepare_trafficgen_data( + data_dict=data_dict, config=config, scenario=scenario, only_lane=config.ONLY_LANE_FOR_TRAFFICGEN, + ) + + if config.PREPROCESSING.TRUNCATE_TIME >= 0: + for k in [ + "encoder/traffic_light_state", + "encoder/traffic_light_valid_mask", + "decoder/target_action", + "decoder/input_action", + "decoder/target_action_valid_mask", + "decoder/input_action_valid_mask", + "decoder/modeled_agent_position", + "decoder/modeled_agent_heading", + "decoder/modeled_agent_velocity", + "decoder/modeled_agent_delta", + "decoder/trafficgen_position", + "decoder/trafficgen_heading", + "decoder/input_action_for_trafficgen", + "decoder/input_action_valid_mask_for_trafficgen", + "decoder/input_action_feature_for_trafficgen", + "decoder/target_offset_for_trafficgen", + "decoder/input_offset_for_trafficgen", + "decoder/agent_type_for_trafficgen", + "decoder/agent_id_for_trafficgen", + ]: + data_dict[k] = data_dict[k][:config.PREPROCESSING.TRUNCATE_TIME] + + if config.USE_DESTINATION: + data_dict = prepare_destination(data_dict, config) + + # TODO: A little hack here... + if config.EVALUATION.NAME == "lmdb": + keep_all_data = False + in_evaluation = False + + if not keep_all_data: + if in_evaluation: + pass + # data_dict = {k: v for k, v in data_dict.items() if not k.startswith("decoder/")} # Remove decoder/ data + else: + + # Discard these data + for pattern in [ + "eval/", + "encoder/current_", + "encoder/future_", + ]: + data_dict = {k: v for k, v in data_dict.items() if not k.startswith(pattern)} + if config.GPT_STYLE and config.REMOVE_AGENT_FROM_SCENE_ENCODER: + data_dict = {k: v for k, v in data_dict.items() if not k.startswith("encoder/agent_")} + + # Keep these data + new_data_dict = {} + for pattern in ["scenario_id", "decoder/label_", "decoder/agent_id", "decoder/agent_type", + "decoder/current_", "decoder/modeled_agent_", "decoder/input_", "decoder/target_", + "encoder/", "in_evaluation", "in_backward_prediction", "decoder/dest_map_index", + "decoder/trafficgen_", "decoder/sdc_", "metadata/", ]: + new_data_dict.update({k: v for k, v in data_dict.items() if k.startswith(pattern)}) + data_dict = new_data_dict + + sorted_keys = sorted(data_dict.keys()) + data_dict = {k: data_dict[k] for k in sorted_keys} + return data_dict + + +def _get_trafficgen_data(raw_scenario_description, data_dict, current_t): + """ + PZH: + I don't want to waste time to read through the LCTGen code, + which essentially is from the TrafficGen code base. + I've read the TrafficGen code base and I really really don't want + to look into it for the second time. + Just copy the code here and modify it to fit the current code base. + """ + def rotate(x, y, angle): + + other_x_trans = np.cos(angle) * x - np.sin(angle) * y + other_y_trans = np.cos(angle) * y + np.sin(angle) * x + output_coords = np.stack((other_x_trans, other_y_trans), axis=-1) + return output_coords + + def cal_rel_dir(dir1, dir2): + dist = dir1 - dir2 + + while not np.all(dist >= 0): + dist[dist < 0] += np.pi * 2 + while not np.all(dist < np.pi * 2): + dist[dist >= np.pi * 2] -= np.pi * 2 + + dist[dist > np.pi] -= np.pi * 2 + return dist + + def normalize_angle(angle): + """ + From: https://github.com/metadriverse/trafficgen/blob/28b109e8e640d820192d5485bf9a28128b38ca21/trafficgen/utils/utils.py#L20 + """ + # if isinstance(angle, torch.Tensor): + # while not torch.all(angle >= 0): + # angle[angle < 0] += np.pi * 2 + # while not torch.all(angle < np.pi * 2): + # angle[angle >= np.pi * 2] -= np.pi * 2 + # return angle + # + # else: + while not np.all(angle >= 0): + angle[angle < 0] += np.pi * 2 + while not np.all(angle < np.pi * 2): + angle[angle >= np.pi * 2] -= np.pi * 2 + + return angle + + from scenestreamer.eval.scenarionet_to_trafficgen import metadrive_scenario_to_init_data + + data = metadrive_scenario_to_init_data(raw_scenario_description) + PZH_TRACK_NAMES = data["PZH_TRACK_NAMES"] + case_info = {} + other = {} + + # agent = copy.deepcopy(data['all_agent']) + other['traf'] = copy.deepcopy(data['traffic_light']) + + max_time_step = 190 + gap = 190 + index = -1 + RANGE = 50 + + if index == -1: + data['all_agent'] = data['all_agent'][current_t:max_time_step:gap] + data['traffic_light'] = data['traffic_light'][current_t:max_time_step:gap] + else: + raise ValueError + # index = min(index, len(data['all_agent']) - 1) + # data['all_agent'] = data['all_agent'][index:index + self.data_cfg.MAX_TIME_STEP:gap] + # data['traffic_light'] = data['traffic_light'][index:index + self.data_cfg.MAX_TIME_STEP:gap] + + def _transform_coordinate_map(data): + """ + Every frame is different + """ + timestep = data['all_agent'].shape[0] + + ego = data['all_agent'][:, 0] + pos = ego[:, [0, 1]][:, np.newaxis] + + lane = data['lane'][np.newaxis] + lane = np.repeat(lane, timestep, axis=0) + lane[..., :2] -= pos + + x = lane[..., 0] + y = lane[..., 1] + ego_heading = ego[:, [4]] + lane[..., :2] = rotate(x, y, -ego_heading) + + unsampled_lane = data['unsampled_lane'][np.newaxis] + unsampled_lane = np.repeat(unsampled_lane, timestep, axis=0) + unsampled_lane[..., :2] -= pos + + x = unsampled_lane[..., 0] + y = unsampled_lane[..., 1] + ego_heading = ego[:, [4]] + unsampled_lane[..., :2] = rotate(x, y, -ego_heading) + return lane, unsampled_lane[0] + + data['lane'], other['unsampled_lane'] = _transform_coordinate_map(data) + other['lane'] = data['lane'] + + def _process_agent(agent, sort_agent): + + ego = agent[:, 0] + + # transform every frame into ego coordinate in the first frame + ego_pos = copy.deepcopy(ego[[0], :2])[:, np.newaxis] + ego_heading = ego[[0], [4]] + + agent[..., :2] -= ego_pos + agent[..., :2] = rotate(agent[..., 0], agent[..., 1], -ego_heading) + agent[..., 2:4] = rotate(agent[..., 2], agent[..., 3], -ego_heading) + agent[..., 4] -= ego_heading + + agent_mask = agent[..., -1] + agent_type_mask = agent[..., -2] + agent_range_mask = (abs(agent[..., 0]) < RANGE) * (abs(agent[..., 1]) < RANGE) + + mask = agent_mask * agent_type_mask + # use agent range mask only for the first frame + # allow agent to be out of range in the future frames + mask[0, :] *= agent_range_mask[0, :] + + return agent, mask.astype(bool) + + def process_lane(lane, max_vec, lane_range, offset=-40): + # dist = lane[..., 0]**2+lane[..., 1]**2 + # idx = np.argsort(dist) + # lane = lane[idx] + + vec_dim = 6 + + lane_point_mask = (abs(lane[..., 0] + offset) < lane_range) * (abs(lane[..., 1]) < lane_range) + + lane_id = np.unique(lane[..., -2]).astype(int) + + vec_list = [] + vec_mask_list = [] + vec_id_list = [] + b_s, _, lane_dim = lane.shape + + for id in lane_id: + id_set = lane[..., -2] == id + points = lane[id_set].reshape(b_s, -1, lane_dim) + masks = lane_point_mask[id_set].reshape(b_s, -1) + + vec_ids = np.ones([b_s, points.shape[1] - 1, 1]) * id + vector = np.zeros([b_s, points.shape[1] - 1, vec_dim]) + vector[..., 0:2] = points[:, :-1, :2] + vector[..., 2:4] = points[:, 1:, :2] + # id + # vector[..., 4] = points[:,1:, 3] + # type + vector[..., 4] = points[:, 1:, 2] + # traffic light + vector[..., 5] = points[:, 1:, 4] + vec_mask = masks[:, :-1] * masks[:, 1:] + vector[vec_mask == 0] = 0 + vec_list.append(vector) + vec_mask_list.append(vec_mask) + vec_id_list.append(vec_ids) + + vector = np.concatenate(vec_list, axis=1) if vec_list else np.zeros([b_s, 0, vec_dim]) + vector_mask = np.concatenate(vec_mask_list, axis=1) if vec_mask_list else np.zeros([b_s, 0], dtype=bool) + vec_id = np.concatenate(vec_id_list, axis=1) if vec_id_list else np.zeros([b_s, 0, 1]) + + all_vec = np.zeros([b_s, max_vec, vec_dim]) + all_mask = np.zeros([b_s, max_vec]) + all_id = np.zeros([b_s, max_vec, 1]) + + for t in range(b_s): + mask_t = vector_mask[t] + vector_t = vector[t][mask_t] + vec_id_t = vec_id[t][mask_t] + + dist = vector_t[..., 0]**2 + vector_t[..., 1]**2 + idx = np.argsort(dist) + vector_t = vector_t[idx] + mask_t = np.ones(vector_t.shape[0]) + vec_id_t = vec_id_t[idx] + + vector_t = vector_t[:max_vec] + mask_t = mask_t[:max_vec] + vec_id_t = vec_id_t[:max_vec] + + vector_t = np.pad(vector_t, ([0, max_vec - vector_t.shape[0]], [0, 0])) + mask_t = np.pad(mask_t, ([0, max_vec - mask_t.shape[0]])) + vec_id_t = np.pad(vec_id_t, ([0, max_vec - vec_id_t.shape[0]], [0, 0])) + + all_vec[t] = vector_t + all_mask[t] = mask_t + all_id[t] = vec_id_t + + return all_vec, all_mask.astype(bool), all_id.astype(int) + + def process_map(lane, traf=None, center_num=384, edge_num=128, lane_range=60, offest=-40, rest_num=192): + lane_with_traf = np.zeros([*lane.shape[:-1], 5]) + lane_with_traf[..., :4] = lane + + lane_id = lane[..., -1] + b_s = lane_id.shape[0] + + # print(traf) + if traf is not None: + for i in range(b_s): + traf_t = traf[i] + lane_id_t = lane_id[i] + # print(traf_t) + for a_traf in traf_t: + # print(a_traf) + control_lane_id = a_traf[0] + state = a_traf[-2] + lane_idx = np.where(lane_id_t == control_lane_id) + lane_with_traf[i, lane_idx, -1] = state + lane = lane_with_traf + + # lane = np.delete(lane_with_traf,-2,axis=-1) + lane_type = lane[0, :, 2] + center_1 = lane_type == 1 + center_2 = lane_type == 2 + center_3 = lane_type == 3 + center_ind = center_1 + center_2 + center_3 + + boundary_1 = lane_type == 15 + boundary_2 = lane_type == 16 + bound_ind = boundary_1 + boundary_2 + + cross_walk = lane_type == 18 + speed_bump = lane_type == 19 + cross_ind = cross_walk + speed_bump + + rest = ~(center_ind + bound_ind + cross_walk + speed_bump + cross_ind) + + cent, cent_mask, cent_id = process_lane(lane[:, center_ind], center_num, lane_range, offest) + bound, bound_mask, _ = process_lane(lane[:, bound_ind], edge_num, lane_range, offest) + cross, cross_mask, _ = process_lane(lane[:, cross_ind], 32, lane_range, offest) + rest, rest_mask, _ = process_lane(lane[:, rest], rest_num, lane_range, offest) + + return cent, cent_mask, cent_id, bound, bound_mask, cross, cross_mask, rest, rest_mask + + case_info["agent"], case_info["agent_mask"] = _process_agent(data['all_agent'], False) + case_info['center'], case_info['center_mask'], case_info['center_id'], case_info['bound'], case_info[ + 'bound_mask'], \ + case_info['cross'], case_info['cross_mask'], case_info['rest'], case_info['rest_mask'] = process_map( + data['lane'], data['traffic_light'], lane_range=RANGE, offest=0) + + # get vector-based representatiomn + def _get_vec_based_rep(case_info, PZH_TRACK_NAMES): + THRES = 5 + thres = THRES + # max_agent_num = 32 + # _process future agent + + agent = case_info['agent'] + vectors = case_info["center"] + + agent_mask = case_info['agent_mask'] + + vec_x = ((vectors[..., 0] + vectors[..., 2]) / 2) + vec_y = ((vectors[..., 1] + vectors[..., 3]) / 2) + + agent_x = agent[..., 0] + agent_y = agent[..., 1] + + b, vec_num = vec_y.shape + _, agent_num = agent_x.shape + + vec_x = np.repeat(vec_x[:, np.newaxis], axis=1, repeats=agent_num) + vec_y = np.repeat(vec_y[:, np.newaxis], axis=1, repeats=agent_num) + + agent_x = np.repeat(agent_x[:, :, np.newaxis], axis=-1, repeats=vec_num) + agent_y = np.repeat(agent_y[:, :, np.newaxis], axis=-1, repeats=vec_num) + + dist = np.sqrt((vec_x - agent_x)**2 + (vec_y - agent_y)**2) + + cent_mask = np.repeat(case_info['center_mask'][:, np.newaxis], axis=1, repeats=agent_num) + dist[cent_mask == 0] = 10e5 + vec_index = np.argmin(dist, -1) + min_dist_to_lane = np.min(dist, -1) + min_dist_mask = min_dist_to_lane < thres + + selected_vec = np.take_along_axis(vectors, vec_index[..., np.newaxis], axis=1) + + vx, vy = agent[..., 2], agent[..., 3] + v_value = np.sqrt(vx**2 + vy**2) + low_vel = v_value < 0.1 + + dir_v = np.arctan2(vy, vx) + x1, y1, x2, y2 = selected_vec[..., 0], selected_vec[..., 1], selected_vec[..., 2], selected_vec[..., 3] + dir = np.arctan2(y2 - y1, x2 - x1) + agent_dir = agent[..., 4] + + v_relative_dir = cal_rel_dir(dir_v, agent_dir) + relative_dir = cal_rel_dir(agent_dir, dir) + + v_relative_dir[low_vel] = 0 + + v_dir_mask = abs(v_relative_dir) < np.pi / 6 + dir_mask = abs(relative_dir) < np.pi / 4 + + agent_x = agent[..., 0] + agent_y = agent[..., 1] + vec_x = (x1 + x2) / 2 + vec_y = (y1 + y2) / 2 + + cent_to_agent_x = agent_x - vec_x + cent_to_agent_y = agent_y - vec_y + + coord = rotate(cent_to_agent_x, cent_to_agent_y, np.pi / 2 - dir) + + vec_len = np.clip(np.sqrt(np.square(y2 - y1) + np.square(x1 - x2)), a_min=4.5, a_max=5.5) + + lat_perc = np.clip(coord[..., 0], a_min=-vec_len / 2, a_max=vec_len / 2) / vec_len + long_perc = np.clip(coord[..., 1], a_min=-vec_len / 2, a_max=vec_len / 2) / vec_len + + # ignore other masks for future agents (to support out-of-range agent prediction) + total_mask = agent_mask + # for the first frame, use all masks to filter out off-road agents + total_mask[0, :] = (min_dist_mask * agent_mask * v_dir_mask * dir_mask)[0, :] + + total_mask[:, 0] = 1 + total_mask = total_mask.astype(bool) + + b_s, agent_num, agent_dim = agent.shape + agent_ = np.zeros([b_s, agent_num, agent_dim]) + agent_mask_ = np.zeros([b_s, agent_num]).astype(bool) + + the_vec = np.take_along_axis(vectors, vec_index[..., np.newaxis], 1) + # 0: vec_index + # 1-2 long and lat percent + # 3-5 velocity and direction + # 6-9 lane vector + # 10-11 lane type and traff state + info = np.concatenate( + [ + vec_index[..., np.newaxis], long_perc[..., np.newaxis], lat_perc[..., np.newaxis], + v_value[..., np.newaxis], v_relative_dir[..., np.newaxis], relative_dir[..., np.newaxis], the_vec + ], -1 + ) + + info_ = np.zeros([b_s, agent_num, info.shape[-1]]) + + start_mask = total_mask[0] + for i in range(agent.shape[0]): + agent_i = agent[i][start_mask] + info_i = info[i][start_mask] + + step_mask = total_mask[i] + valid_mask = step_mask[start_mask] + + agent_i = agent_i[:agent_num] + info_i = info_i[:agent_num] + + valid_num = agent_i.shape[0] + agent_i = np.pad(agent_i, [[0, agent_num - agent_i.shape[0]], [0, 0]]) + info_i = np.pad(info_i, [[0, agent_num - info_i.shape[0]], [0, 0]]) + + agent_[i] = agent_i + info_[i] = info_i + agent_mask_[i, :valid_num] = valid_mask[:valid_num] + + PZH_TRACK_NAMES_new = np.array(list(PZH_TRACK_NAMES[start_mask]) + [None] * (agent_num - start_mask.sum())) + + case_info['vec_based_rep'] = info_[..., 1:] + case_info['agent_vec_index'] = info_[..., 0].astype(int) + case_info['agent_mask'] = agent_mask_ + case_info["agent"] = agent_ + + return case_info, PZH_TRACK_NAMES_new + + case_info, PZH_TRACK_NAMES = _get_vec_based_rep(case_info, PZH_TRACK_NAMES) + + case_num = case_info['agent'].shape[0] + case_list = [] + for i in range(case_num): + dic = {} + for k, v in case_info.items(): + dic[k] = v[i] + case_list.append(dic) + + # PZH: Obviously, you only pick T=0 from the data. + ret = case_list[0] + ret["PZH_TRACK_NAMES"] = PZH_TRACK_NAMES + + trafficgen_select_track_names = ret["PZH_TRACK_NAMES"][ret['agent_mask']] + decoder_track_name = list(data_dict["decoder/track_name"]) + trafficgen_select_index = [] + for name in trafficgen_select_track_names: + if name in decoder_track_name: + trafficgen_select_index.append(decoder_track_name.index(name)) + else: + # print(11) + pass + + return trafficgen_select_index diff --git a/scenestreamer/dataset/scenarionet_utils.py b/scenestreamer/dataset/scenarionet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..969e72daa82fefdecb799c0594c8c4008dbfb554 --- /dev/null +++ b/scenestreamer/dataset/scenarionet_utils.py @@ -0,0 +1,239 @@ +import numpy as np +import torch +from scenestreamer.utils import wrap_to_pi, rotate + + +def overwrite_gt_to_pred_field(data_dict): + import copy + new_data_dict = copy.deepcopy(data_dict) + T, N, _ = data_dict["decoder/agent_position"].shape + + new_data_dict["decoder/reconstructed_position"] = np.zeros((96, N, 2)).astype(np.float32) + new_data_dict["decoder/reconstructed_valid_mask"] = np.zeros(( + 96, + N, + )).astype(bool) + new_data_dict["decoder/reconstructed_heading"] = np.zeros(( + 96, + N, + )).astype(np.float32) + new_data_dict["decoder/reconstructed_velocity"] = np.zeros((96, N, 2)).astype(np.float32) + + for id in range(N): # overwrite all agents + traj = new_data_dict["decoder/agent_position"][:91, id, :2].astype(np.float32) + traj_mask = new_data_dict["decoder/agent_valid_mask"][:91, id].astype(bool) + theta = new_data_dict['decoder/agent_heading'][:91, id].astype(np.float32) + vel = new_data_dict['decoder/agent_velocity'][:91, id].astype(np.float32) + + new_data_dict["decoder/reconstructed_position"][:91, id, :2] = traj + # new_data_dict["decoder/reconstructed_position"][:91, id, 2] = 0.0 + new_data_dict["decoder/reconstructed_valid_mask"][:91, id] = traj_mask + # print(traj_mask) + new_data_dict["decoder/reconstructed_heading"][:91, id] = theta + new_data_dict["decoder/reconstructed_velocity"][:91, id] = vel + + return new_data_dict + + +def create_new_adv(data_dict): + ego_id = data_dict["decoder/sdc_index"] + + ego_traj = data_dict["decoder/agent_position"][:, ego_id] + ego_heading = data_dict["decoder/agent_heading"][:, ego_id] + ego_velocity = data_dict["decoder/agent_velocity"][:, ego_id] + ego_shape = data_dict["decoder/agent_shape"][:, ego_id] + ego_mask = data_dict["decoder/agent_valid_mask"][:, ego_id] + + last_valid_step = np.where(ego_mask)[0][-1] + + # Create a new ADV at the final step. + + adv_mask = np.zeros_like(ego_mask) + adv_mask[:last_valid_step + 1] = True + + adv_traj = np.zeros_like(ego_traj) + adv_heading = np.zeros_like(ego_heading) + adv_velocity = np.zeros_like(ego_velocity) + adv_shape = np.zeros_like(ego_shape) + + # Copy the final pos/head/vel/shape of ego + adv_traj[last_valid_step] = ego_traj[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=3) + adv_heading[last_valid_step] = ego_heading[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=1) + adv_velocity[last_valid_step] = ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=2) + + for i in range(data_dict["decoder/agent_shape"].shape[0]): + adv_shape[i] = ego_shape[last_valid_step] + + # Insert data back: + data_dict["decoder/agent_position"] = np.concatenate( + [data_dict["decoder/agent_position"], adv_traj[:, None]], axis=1 + ) + data_dict["decoder/agent_heading"] = np.concatenate( + [data_dict["decoder/agent_heading"], adv_heading[:, None]], axis=1 + ) + data_dict["decoder/agent_velocity"] = np.concatenate( + [data_dict["decoder/agent_velocity"], adv_velocity[:, None]], axis=1 + ) + # data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1) + + data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1) + + data_dict["decoder/agent_valid_mask"] = np.concatenate( + [data_dict["decoder/agent_valid_mask"], adv_mask[:, None]], axis=1 + ) + + data_dict["decoder/current_agent_shape"] = np.concatenate( + [data_dict["decoder/current_agent_shape"], data_dict["decoder/current_agent_shape"][ego_id:ego_id + 1]], axis=0 + ) + data_dict["decoder/agent_type"] = np.concatenate( + [data_dict["decoder/agent_type"], data_dict["decoder/agent_type"][ego_id:ego_id + 1]], axis=0 + ) + data_dict["decoder/agent_id"] = np.concatenate( + [data_dict["decoder/agent_id"], [len(data_dict["decoder/agent_id"])]], axis=0 + ) + + # Add ADV into OOI: + data_dict["decoder/object_of_interest_id"] = np.concatenate( + [data_dict["decoder/object_of_interest_id"], [len(data_dict["decoder/agent_id"]) - 1]], axis=0 + ) + + # Deal with some thing for forward prediction: + data_dict["decoder/current_agent_valid_mask"] = np.concatenate( + [data_dict["decoder/current_agent_valid_mask"], [1]], axis=0 + ) + + print("====================================") + print( + "The new ADV is created at the final step {}, it's ID is: {}".format( + last_valid_step, + len(data_dict["decoder/agent_id"]) - 1 + ) + ) + print("====================================") + + return data_dict + + +def overwrite_to_scenario_description(output_dict_mode, original_SD, ooi=None, adv_id=None): + # overwrite original SD with all predicted ooi trajectories included + # import pdb; pdb.set_trace() + if not ooi: + ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents + sdc_track_name = original_SD['metadata']['sdc_id'] + adv_track_name = str(output_dict_mode['decoder/track_name'][int(adv_id)].item()) + + for id in ooi: + agent_track_name = str(output_dict_mode['decoder/track_name'][id].item()) + + # begin to overwrite original scenario_data + agent_traj = output_dict_mode["decoder/agent_position"][:91, id, ] + agent_heading = output_dict_mode["decoder/agent_heading"][:91, id] + agent_vel = output_dict_mode["decoder/agent_velocity"][:91, id] + agent_traj_mask = output_dict_mode["decoder/agent_valid_mask"][:91, id] + + # modify adv info + # agent_z = original_SD['tracks'][agent_track_name]['state']['position'][10, 2] # fill the z-axis + # agent_traj_z = np.full((91, 1), agent_z) + # agent_new_traj = np.concatenate([agent_traj, agent_traj_z], axis=1) + # print("new_traj:", agent_new_traj.shape) + original_SD['tracks'][agent_track_name]['state']['position'] = agent_traj + original_SD['tracks'][agent_track_name]['state']['velocity'] = agent_vel + original_SD['tracks'][agent_track_name]['state']['heading'] = agent_heading + original_SD['tracks'][agent_track_name]['state']['valid'] = agent_traj_mask + + length = original_SD['tracks'][agent_track_name]['state']['length'][10] + width = original_SD['tracks'][agent_track_name]['state']['width'][10] + height = original_SD['tracks'][agent_track_name]['state']['height'][10] + original_SD['tracks'][agent_track_name]['state']['length'] = np.full((91, ), length) + original_SD['tracks'][agent_track_name]['state']['width'] = np.full((91, ), width) + original_SD['tracks'][agent_track_name]['state']['height'] = np.full((91, ), height) + + original_SD['metadata']['selected_adv_id'] = adv_track_name + + return original_SD + + +def overwrite_to_scenario_description_new_agent(output_dict_mode, original_SD, ooi=None): + # overwrite original SD with all predicted ooi trajectories included + ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents + + adv_track_name = 'new_adv_agent' + original_SD['tracks'][adv_track_name] = {'state': {}, 'type': 'VEHICLE', 'metadata': {}} + sdc_track_name = original_SD['metadata']['sdc_id'] + + for id in ooi: + if id == ooi[-1]: + agent_track_name = 'new_adv_agent' + else: + agent_track_name = str(output_dict_mode['decoder/track_name'][id].item()) + + # begin to overwrite original scenario_data + agent_traj = output_dict_mode["decoder/agent_position"][:, id, ] + agent_heading = output_dict_mode["decoder/agent_heading"][:, id] + agent_vel = output_dict_mode["decoder/agent_velocity"][:, id] + agent_traj_mask = output_dict_mode["decoder/agent_valid_mask"][:, id] + + # modify adv info + # agent_z = original_SD['tracks'][agent_track_name]['state']['position'][10, 2] # fill the z-axis + # agent_traj_z = np.full((91, 1), agent_z) + # agent_new_traj = np.concatenate([agent_traj, agent_traj_z], axis=1) + # print("new_traj:", agent_new_traj.shape) + original_SD['tracks'][agent_track_name]['state']['position'] = agent_traj + + original_SD['tracks'][agent_track_name]['state']['velocity'] = agent_vel + original_SD['tracks'][agent_track_name]['state']['heading'] = agent_heading + original_SD['tracks'][agent_track_name]['state']['valid'] = agent_traj_mask + + length = original_SD['tracks'][sdc_track_name]['state']['length'][10] + width = original_SD['tracks'][sdc_track_name]['state']['width'][10] + height = original_SD['tracks'][sdc_track_name]['state']['height'][10] + original_SD['tracks'][agent_track_name]['state']['length'] = np.full((91, ), length) + original_SD['tracks'][agent_track_name]['state']['width'] = np.full((91, ), width) + original_SD['tracks'][agent_track_name]['state']['height'] = np.full((91, ), height) + + original_SD['tracks'][adv_track_name]['metadata']['dataset'] = 'waymo' + original_SD['tracks'][adv_track_name]['metadata']['object_id'] = 'new_adv_agent' + original_SD['tracks'][adv_track_name]['metadata']['track_length'] = 91 + original_SD['tracks'][adv_track_name]['metadata']['type'] = 'VEHICLE' + original_SD['metadata']['new_adv_id'] = 'new_adv_agent' + original_SD['metadata']['objects_of_interest'].append('new_adv_agent') + tracks_length = len(list(original_SD['tracks'].keys())) + original_SD['metadata']['tracks_to_predict']['new_adv_agent'] = { + 'difficulty': 0, + 'object_type': 'VEHICLE', + 'track_id': 'new_adv_agent', + 'track_index': tracks_length - 1 + } + + return original_SD + + +def transform_to_global_coordinate(data_dict): + map_center = data_dict["metadata/map_center"].reshape(-1, 1, 3) # (1,1,3) + assert "decoder/agent_position" in data_dict, "Have you set EVALUATION.PREDICT_ALL_AGENTS to False?" + T, N, _ = data_dict["decoder/agent_position"].shape + assert data_dict["decoder/agent_position"].ndim == 3 + data_dict["decoder/agent_position"] += map_center + + return data_dict + + +def _overwrite_datadict_all_agents(data_dict): + import copy + new_data_dict = copy.deepcopy(data_dict) + + T, N, _ = data_dict["decoder/reconstructed_position"].shape + + for id in range(N): # overwrite all agents + traj = data_dict["decoder/reconstructed_position"][:91, id, ] + traj_mask = data_dict["decoder/reconstructed_valid_mask"][:91, id] + theta = data_dict['decoder/reconstructed_heading'][:91, id] + vel = data_dict['decoder/reconstructed_velocity'][:91, id] + + new_data_dict["decoder/agent_position"][:, id, :2] = traj + new_data_dict["decoder/agent_position"][:, id, 2] = 0.0 + new_data_dict["decoder/agent_valid_mask"][:, id] = traj_mask + new_data_dict["decoder/agent_heading"][:, id] = theta + new_data_dict["decoder/agent_velocity"][:, id] = vel + + return new_data_dict diff --git a/scenestreamer/dataset/test_preprocessing_efficiency.py b/scenestreamer/dataset/test_preprocessing_efficiency.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb07415d64715739af64adad9e0f918a9222a17 --- /dev/null +++ b/scenestreamer/dataset/test_preprocessing_efficiency.py @@ -0,0 +1,39 @@ +import torch + +from scenestreamer.utils import debug_tools + + +def toy_test(): + print("Start") + + from time import time as t + + s = t() + + cfg_file = "cfgs/motion_debug.yaml" + config = debug_tools.get_debug_config(cfg_file) + config.DATA.TRAINING_DATA_DIR = "/data/datasets/scenarionet/waymo/training" + config.DATA.TEST_DATA_DIR = "/data/datasets/scenarionet/waymo/validation" + + config.PREPROCESSING.keep_all_data = True + + dataloader = debug_tools.get_debug_dataloader(config=config, in_evaluation=False) + + print("After dataloader", t() - s) + for input_dict in dataloader: + B, M = input_dict["encoder/map_feature"].shape[:2] + B, T, N = input_dict["encoder/agent_feature"].shape[:3] + + for k, v in input_dict.items(): + if not isinstance(v, torch.Tensor): + continue + if torch.isinf(v).any() or torch.isnan(v).any(): + print(f"Found {k} has nan or inf. Data: ", input_dict["scenario_id"]) + + print(t() - s, "map: ", M, " agent: ", N) + s = t() + continue + + +if __name__ == '__main__': + toy_test() diff --git a/scenestreamer/diffusion/create_diffusion.py b/scenestreamer/diffusion/create_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ce32c436df0135855463758cae52bba1587adb4e --- /dev/null +++ b/scenestreamer/diffusion/create_diffusion.py @@ -0,0 +1,154 @@ +""" +PZH from: https://github.com/LTH14/mar/tree/fe470ac24afbee924668d8c5c83e9fec60af3a73/diffusion +""" +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from . import gaussian_diffusion as gd +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim"):]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/scenestreamer/diffusion/diffusion_loss.py b/scenestreamer/diffusion/diffusion_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..48ea9d7da630c6da39da7dcc0a47ca0395b62875 --- /dev/null +++ b/scenestreamer/diffusion/diffusion_loss.py @@ -0,0 +1,295 @@ +""" +PZH from: https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py#L23 +""" +import math + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from scenestreamer.diffusion.create_diffusion import create_diffusion + + +class DiffLoss(nn.Module): + """Diffusion Loss""" + def __init__( + self, + target_channels, + z_channels, + depth, + width, + num_sampling_steps, + predict_xstart, + grad_checkpointing=False, + diffusion_steps=1000, + use_vlb_loss=True, + ): + + use_vlb_loss = False + sigma_small = True + learn_sigma = False + + super(DiffLoss, self).__init__() + self.in_channels = target_channels + self.net = SimpleMLPAdaLN( + in_channels=target_channels, + model_channels=width, + out_channels=target_channels * 2 if use_vlb_loss else target_channels, # for vlb loss + z_channels=z_channels, + num_res_blocks=depth, + grad_checkpointing=grad_checkpointing + ) + + self.train_diffusion = create_diffusion( + timestep_respacing="", + noise_schedule="cosine", + predict_xstart=predict_xstart, + diffusion_steps=diffusion_steps, + sigma_small=sigma_small, + learn_sigma=learn_sigma + ) + self.gen_diffusion = create_diffusion( + timestep_respacing=num_sampling_steps, + noise_schedule="cosine", + predict_xstart=predict_xstart, + diffusion_steps=diffusion_steps, + sigma_small=sigma_small, + learn_sigma=learn_sigma + ) + # self.train_diffusion = create_diffusion( + # timestep_respacing="", noise_schedule="cosine", predict_xstart=predict_xstart + # ) + # self.gen_diffusion = create_diffusion( + # timestep_respacing=num_sampling_steps, noise_schedule="cosine", predict_xstart=predict_xstart + # ) + + def forward(self, target, z, mask=None): + t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0], ), device=target.device) + model_kwargs = dict(c=z) + loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) + # loss = loss_dict["loss"] + # if mask is not None: + # loss = (loss * mask).sum() / mask.sum() + # return loss.mean() + # PZH: Modification + assert mask is None + return loss_dict + + def sample(self, z, temperature=1.0, cfg=1.0): + # diffusion loss sampling + if not cfg == 1.0: + noise = torch.randn(z.shape[0] // 2, self.in_channels).to(z.device) + noise = torch.cat([noise, noise], dim=0) + model_kwargs = dict(c=z, cfg_scale=cfg) + sample_fn = self.net.forward_with_cfg + else: + noise = torch.randn(z.shape[0], self.in_channels).to(z.device) + model_kwargs = dict(c=z) + sample_fn = self.net.forward + + sampled_token_latent = self.gen_diffusion.p_sample_loop( + sample_fn, + noise.shape, + noise, + clip_denoised=False, + model_kwargs=model_kwargs, + progress=False, + temperature=temperature + ) + + return sampled_token_latent + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + """ + def __init__(self, channels): + super().__init__() + self.channels = channels + + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class FinalLayer(nn.Module): + """ + The final layer adopted from DiT. + """ + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP for Diffusion Loss. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition. + :param num_res_blocks: number of residual blocks per downsample. + """ + def __init__(self, in_channels, model_channels, out_channels, z_channels, num_res_blocks, grad_checkpointing=False): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.grad_checkpointing = grad_checkpointing + + self.time_embed = TimestepEmbedder(model_channels) + self.cond_embed = nn.Linear(z_channels, model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append(ResBlock(model_channels, )) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, c): + """ + Apply the model to an input batch. + :param x: an [N x C] Tensor of inputs. + :param t: a 1-D batch of timesteps. + :param c: conditioning from AR transformer. + :return: an [N x C] Tensor of outputs. + """ + x = self.input_proj(x) + t = self.time_embed(t) + c = self.cond_embed(c) + + y = t + c + + if self.grad_checkpointing and not torch.jit.is_scripting(): + for block in self.res_blocks: + x = checkpoint(block, x, y) + else: + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x, y) + + def forward_with_cfg(self, x, t, c, cfg_scale): + half = x[:len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, c) + eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +if __name__ == '__main__': + # PZH: Play around: + diffloss_w = 1024 + diffloss_d = 3 + num_sampling_steps = '100' + grad_checkpointing = False + decoder_embed_dim = 1024 + token_embed_dim = 777 + diffloss = DiffLoss( + target_channels=token_embed_dim, + z_channels=decoder_embed_dim, + width=diffloss_w, + depth=diffloss_d, + num_sampling_steps=num_sampling_steps, + grad_checkpointing=grad_checkpointing + ) + + print(11111) diff --git a/scenestreamer/diffusion/diffusion_utils.py b/scenestreamer/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02d0b3dae3a1006c0c4df926b493f5152b1ee643 --- /dev/null +++ b/scenestreamer/diffusion/diffusion_utils.py @@ -0,0 +1,68 @@ +""" +PZH from: https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/diffusion/diffusion_utils.py +""" + +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2)**2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/scenestreamer/diffusion/gaussian_diffusion.py b/scenestreamer/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..10e30623f54199bdbafea00dc24ede402964cd41 --- /dev/null +++ b/scenestreamer/diffusion/gaussian_diffusion.py @@ -0,0 +1,851 @@ +""" +PZH: From https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/diffusion/gaussian_diffusion.py +""" + +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import enum +import math + +import numpy as np +import torch as th + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = (enum.auto()) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = (np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + )**2) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps, ) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2, + ) + elif schedule_name == "lcsim": + # Implement the power-law interpolation + # ChatGPT writes buggy code. Just don't use it. + raise ValueError() + return betas_for_power_law(num_diffusion_timesteps, rho=2.0) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_power_law(num_diffusion_timesteps, rho=2.0, max_beta=0.999): + """ + Create a beta schedule based on a power-law interpolation of alpha_t. + :param num_diffusion_timesteps: Total number of timesteps. + :param rho: Power-law exponent controlling the steepness. + :param max_beta: Maximum value for beta to prevent numerical instability. + :return: A numpy array of betas. + """ + # Compute alpha_t using power-law interpolation, avoiding zero + step_indices = np.linspace(1e-5, 1, num_diffusion_timesteps + 1) # Avoid zero + alpha_bar = step_indices**rho + + # Compute betas from alpha_t + betas = [] + for i in range(num_diffusion_timesteps): + alpha_t = alpha_bar[i + 1] + alpha_t_prev = alpha_bar[i] + beta_t = 1 - (alpha_t / alpha_t_prev) + beta_t = max(0.0, min(beta_t, max_beta)) # Ensure beta_t is in [0, max_beta] + betas.append(beta_t) + + return np.array(betas, dtype=np.float64) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all(), (betas.min(), betas.max()) + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps, ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) + self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == + x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B, ) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + pred_xstart) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, temperature=1.0 + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param temperature: temperature scaling during Diff Loss sampling. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + # scale the noise by temperature + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + temperature=1.0, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :param temperature: temperature scaling during Diff Loss sampling. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + temperature=temperature, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + temperature=1.0, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape).cuda() + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0]).cuda() + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + temperature=temperature, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)) + # Equation 12. + noise = th.randn_like(x) + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - + out["pred_xstart"]) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape).cuda() + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0]).cuda() + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output)**2) + terms["model_output"] = model_output.mean(0) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + terms["total_loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + terms["total_loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise)**2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/scenestreamer/eval/__init__.py b/scenestreamer/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/eval/challenge_2023_config.textproto b/scenestreamer/eval/challenge_2023_config.textproto new file mode 100644 index 0000000000000000000000000000000000000000..0a8b1e9ec63d931397af66713a7b2c81653ca27f --- /dev/null +++ b/scenestreamer/eval/challenge_2023_config.textproto @@ -0,0 +1,89 @@ +# proto-file: protos/sim_agents_metrics.proto +# proto-message: car.open_dataset.SimAgentMetricsConfig + +linear_speed: { + histogram: { + min_val: 0.0 + max_val: 35.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} + +linear_acceleration: { + histogram: { + min_val: -15.0 + max_val: 15.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} + +angular_speed: { + histogram: { + min_val: -31.5 + max_val: 31.5 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} + +angular_acceleration: { + histogram: { + min_val: -31.5 + max_val: 31.5 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} + +distance_to_nearest_object: { + histogram: { + min_val: -5.0 + max_val: 40.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} + +collision_indication: { + bernoulli: {} + metametric_weight: 0.18 +} + +distance_to_road_edge: { + histogram: { + min_val: -20.0 + max_val: 40 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} + +offroad_indication: { + bernoulli: {} + metametric_weight: 0.18 +} + +time_to_collision: { + histogram: { + min_val: 0.0 + max_val: 5.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.09 +} \ No newline at end of file diff --git a/scenestreamer/eval/challenge_2024_config.textproto b/scenestreamer/eval/challenge_2024_config.textproto new file mode 100644 index 0000000000000000000000000000000000000000..1ff3b72a8cc717f658d1d47c00f74b067b81f132 --- /dev/null +++ b/scenestreamer/eval/challenge_2024_config.textproto @@ -0,0 +1,89 @@ +# proto-file: protos/sim_agents_metrics.proto +# proto-message: car.open_dataset.SimAgentMetricsConfig + +linear_speed: { + histogram: { + min_val: 0.0 + max_val: 25.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +linear_acceleration: { + histogram: { + min_val: -12.0 + max_val: 12.0 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +angular_speed: { + histogram: { + min_val: -0.628 + max_val: 0.628 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +angular_acceleration: { + histogram: { + min_val: -3.14 + max_val: 3.14 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +distance_to_nearest_object: { + histogram: { + min_val: -5.0 + max_val: 40.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 +} + +collision_indication: { + bernoulli: {} + metametric_weight: 0.25 +} + +distance_to_road_edge: { + histogram: { + min_val: -20.0 + max_val: 40 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 +} + +offroad_indication: { + bernoulli: {} + metametric_weight: 0.25 +} + +time_to_collision: { + histogram: { + min_val: 0.0 + max_val: 5.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 +} \ No newline at end of file diff --git a/scenestreamer/eval/challenge_2025_scenario_gen_config.textproto b/scenestreamer/eval/challenge_2025_scenario_gen_config.textproto new file mode 100644 index 0000000000000000000000000000000000000000..dc239383c7aa68fdb1df989e36b8e1933987ea1a --- /dev/null +++ b/scenestreamer/eval/challenge_2025_scenario_gen_config.textproto @@ -0,0 +1,103 @@ +# proto-file: protos/sim_agents_metrics.proto +# proto-message: car.open_dataset.SimAgentMetricsConfig + +linear_speed: { + histogram: { + min_val: 0.0 + max_val: 25.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 + aggregate_objects: true +} + +linear_acceleration: { + histogram: { + min_val: -12.0 + max_val: 12.0 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 + aggregate_objects: true +} + +angular_speed: { + histogram: { + min_val: -0.628 + max_val: 0.628 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 + aggregate_objects: true +} + +angular_acceleration: { + histogram: { + min_val: -3.14 + max_val: 3.14 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 + aggregate_objects: true +} + +distance_to_nearest_object: { + histogram: { + min_val: -5.0 + max_val: 40.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 + aggregate_objects: true +} + +collision_indication: { + bernoulli: {} + metametric_weight: 0.25 + aggregate_objects: true +} + +distance_to_road_edge: { + histogram: { + min_val: -20.0 + max_val: 40 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 + aggregate_objects: true +} + +offroad_indication: { + bernoulli: {} + metametric_weight: 0.25 + aggregate_objects: true +} + +time_to_collision: { + histogram: { + min_val: 0.0 + max_val: 5.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 + aggregate_objects: true +} + +traffic_light_violation: { + bernoulli: {} + metametric_weight: 0.05 +} diff --git a/scenestreamer/eval/challenge_2025_sim_agents_config.textproto b/scenestreamer/eval/challenge_2025_sim_agents_config.textproto new file mode 100644 index 0000000000000000000000000000000000000000..bd2aa680c4f4e2405b9074a65029bddce8cc4d2a --- /dev/null +++ b/scenestreamer/eval/challenge_2025_sim_agents_config.textproto @@ -0,0 +1,94 @@ +# proto-file: protos/sim_agents_metrics.proto +# proto-message: car.open_dataset.SimAgentMetricsConfig + +linear_speed: { + histogram: { + min_val: 0.0 + max_val: 25.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +linear_acceleration: { + histogram: { + min_val: -12.0 + max_val: 12.0 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +angular_speed: { + histogram: { + min_val: -0.628 + max_val: 0.628 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +angular_acceleration: { + histogram: { + min_val: -3.14 + max_val: 3.14 + num_bins: 11 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +distance_to_nearest_object: { + histogram: { + min_val: -5.0 + max_val: 40.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 +} + +collision_indication: { + bernoulli: {} + metametric_weight: 0.25 +} + +distance_to_road_edge: { + histogram: { + min_val: -20.0 + max_val: 40 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.05 +} + +offroad_indication: { + bernoulli: {} + metametric_weight: 0.25 +} + +time_to_collision: { + histogram: { + min_val: 0.0 + max_val: 5.0 + num_bins: 10 + additive_smoothing_pseudocount: 0.1 + } + independent_timesteps: true + metametric_weight: 0.1 +} + +traffic_light_violation: { + bernoulli: {} + metametric_weight: 0.05 +} diff --git a/scenestreamer/eval/debug_scenario_metrics.py b/scenestreamer/eval/debug_scenario_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..75c5eeb644fce896427c9b1fb905065af6b50b04 --- /dev/null +++ b/scenestreamer/eval/debug_scenario_metrics.py @@ -0,0 +1,895 @@ +import copy + +import PIL +import hydra +import matplotlib.pyplot as plt +import numpy as np +import omegaconf +import seaborn as sns +from matplotlib.animation import FFMpegWriter +from matplotlib.patches import Polygon, Circle, Rectangle +import tqdm +# Load model +from scenestreamer.utils import utils +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.utils import REPO_ROOT +import torch + +from scenestreamer.gradio_ui.plot import plot_pred, create_animation_from_pred +from scenestreamer.gradio_ui.plot import plot_pred +import pathlib +from waymo_open_dataset.protos import sim_agents_metrics_pb2 +from google.protobuf import text_format +import tensorflow as tf +from waymo_open_dataset.wdl_limited.sim_agents_metrics import interaction_features +from waymo_open_dataset.wdl_limited.sim_agents_metrics import map_metric_features +import torch.nn.functional as F +import itertools +from waymo_open_dataset.protos import map_pb2 +from scenestreamer.eval.waymo_motion_prediction_evaluator import _repeat_for_modes +from collections.abc import Iterable +import pdb +from scenestreamer.dataset.preprocessor import preprocess_scenario_description_for_motionlm +# @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +# def debug(config): +# omegaconf.OmegaConf.set_struct(config, False) +# config.PREPROCESSING.keep_all_data = True +# omegaconf.OmegaConf.set_struct(config, True) +# test_dataset = SceneStreamerDataset(config, "test") +# ddd = iter(test_dataset) +# while True: +# try: +# raw_data = data = next(ddd) +# +# from scenestreamer.tokenization import get_tokenizer +# tokenizer = get_tokenizer(config) +# data, _ = tokenizer.tokenize_numpy_array(data) +# data["decoder/output_action"] = data["decoder/target_action"] +# fill_zero = ~data["decoder/target_action_valid_mask"] +# data["decoder/input_action_valid_mask"][fill_zero] = False +# +# data = tokenizer.detokenize_numpy_array(data, detokenizing_gt=True) +# raw_data.update(data) +# # plot_pred(raw_data) +# plot_pred(raw_data, show=True) +# +# # break +# except StopIteration: +# break +# print("End") +# + + +def print_type_and_dtype(name, tensor): + print(f"{name} - Type: {type(tensor)}, Dtype: {getattr(tensor, 'dtype', 'N/A')}") + + +def conv(tensor, dtype=tf.float32): + if isinstance(tensor, torch.Tensor): + tensor = tensor.cpu().numpy() + return tf.convert_to_tensor(tensor, dtype=dtype) + + +# conv = lambda tensor: tf.convert_to_tensor(tensor if type(tensor) == np.ndarray else tensor.cpu().numpy()) +rconv = lambda tf_tensor: torch.from_numpy(tf_tensor if type(tf_tensor) == np.ndarray else tf_tensor.numpy() + ).to(torch.device("cuda")) +from scenestreamer.utils.utils import numpy_to_torch + + +def tf_to_torch(tf_tensor, device=None): + # Convert TensorFlow tensor to NumPy array on CPU if necessary + if tf_tensor.device.endswith("GPU:0"): # If on GPU, move to CPU first + tf_tensor = tf_tensor.cpu() + np_array = tf_tensor.numpy() + + return torch.from_numpy(np_array).to( + device if device else torch.device("cuda" if tf_tensor.device.endswith("GPU:0") else "cpu") + ) + + +def jsd(gt_hist, pred_hist, epsilon=1e-10): + gt_prob = gt_hist / gt_hist.sum() + pred_prob = pred_hist / pred_hist.sum() + gt_prob += epsilon + pred_prob += epsilon + m = 0.5 * (gt_prob + pred_prob) + jsd = 0.0 + jsd += F.kl_div(gt_prob.log(), m, reduction="sum") + jsd += F.kl_div(pred_prob.log(), m, reduction="sum") + return (0.5 * jsd) + + +# Timing +from time import perf_counter +from contextlib import contextmanager + +TIMER = False +SCENE_IDX = 0 + + +@contextmanager +def timer(task_name: str): + start = perf_counter() + yield + prof_t = perf_counter() - start + if TIMER: + print(f"{task_name}: {prof_t:.5f}") + + +class Evaluator: + # TP_safety_0 = 0 + # FP_safety_0 = 0 + # FN_safety_0 = 0 + # + # TP_safety_1 = 0 + # FP_safety_1 = 0 + # FN_safety_1 = 0 + + scenario_count = 0 + + # Diversity + minSFDE = 0 # (supervised) avg over scenarios: minimum over all modes: average of L2 error of final positions of all agents + FDD = 0 # (unsupervised) avg over scenarios: average over all agents: maximum L2 distance in final position of that agent between generated modes + # Xuanhao: In MixSim paper, they used squared norm of distance, but maybe they meant L2 norm not squared norm? + # Unit given in AdvDiffuser for FDD is m not m^2, so I am using L2 norm here + + # Distribution Realism + vel_jsd = 0 # avg over scenarios: build histogram across agents, modes, timestamps: velocity JS divergence + acc_jsd = 0 # avg over scenarios: build histogram across agents, modes, timestamps: acceleration JS divergence + ttc_jsd = 0 # avg over scenarios: build histogram across agents, modes, timestamps: time to collision JS divergence + + # Common Sense + env_coll = 0 # offroad + veh_coll = 0 # collision rate + sdc_coll_adv = 0 + sdc_coll_adv_active = False + sdc_coll_bv = 0 + sdc_coll_bv_active = False + adv_coll_bv = 0 + adv_coll_bv_active = False + coll_vel = 0 # collision velocity + # no clue what collision JSD means so not calculating it for now + + # AV comfortable + acc = 0 # Xuanhao: avg over scenarios: min over modes: max over time steps: acceleration of ego vehicle + jerk = 0 # Xuanhao: avg over scenarios: min over modes: max over time steps: jerk of ego vehicle + + # Output Metrics + metrics = {} + metric_units = {} + + # Constants + SECONDS_PER_STEP = 0.1 + + def __init__(self): + self.metric_units = { + "# Scenarios": "", + "minSFDE": "m", + "FDD": "m", + "Vel. JSD": "", + "Acc. JSD": "", + "TTC JSD": "", + "Env CR": "", + "Veh CR": "", + "ADV+SDC CR": "", + "ADV+BV CR": "", + "SDC+BV CR": "", + "Avg. Max Coll Vel.": "m/s", + "Avg. Max Acc.": "m/s^2", + "Avg. Max Jerk": "m/s^3", + } + self.display_keys = [ + "# Scenarios", "minSFDE", "FDD", "Vel. JSD", "Acc. JSD", "TTC JSD", "Env CR", "Veh CR", "ADV+SDC CR", + "ADV+BV CR", "SDC+BV CR", "Avg. Max Coll Vel.", "Avg. Max Acc.", "Avg. Max Jerk" + ] + self.jsd_config = { + "vel": { + "min_val": 0.0, + "max_val": 50.0, + "num_bins": 100 + }, + "acc": { + "min_val": -10.0, + "max_val": 10.0, + "num_bins": 200 + }, + "ttc": { + "min_val": 0.0, + "max_val": 5.0, + "num_bins": 50 + } + } + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def add(self, gt_data_dict, pred_data_dict, **kwargs): + global SCENE_IDX + self.scenario_count += 1 + + T_gt = gt_data_dict["decoder/agent_position"].shape[0] + T_context = 0 + + T_pred = pred_data_dict["decoder/reconstructed_position"].shape[1] + K = pred_data_dict["decoder/reconstructed_position"].shape[0] + N = gt_data_dict["decoder/agent_position"].shape[1] + vehicle_mask = numpy_to_torch(gt_data_dict["decoder/agent_type"] == 1, device=self.device) # (num agents) + ooi_mask = torch.zeros_like(vehicle_mask, dtype=torch.bool, device=self.device) + ooi_mask[(gt_data_dict["decoder/labeled_agent_id"])] = 1 # (num agents) + gt_valid_mask = numpy_to_torch( + gt_data_dict["decoder/agent_valid_mask"], device=self.device + ).T # (num agents, num steps) + pred_valid_mask = pred_data_dict["decoder/reconstructed_valid_mask"].transpose( + 1, 2 + ) # (K, num agents, num steps) + # joint_mask = vehicle_mask.unsqueeze(-1) & valid_mask # (num agents, num steps) + gt_ooi_joint = ooi_mask.unsqueeze(-1) & gt_valid_mask[..., T_context:T_gt] # (num agents, num steps) + pred_ooi_joint = ooi_mask[None, ..., None] & pred_valid_mask[..., T_context:T_gt] # (K, num agents, num steps) + gt_shape = numpy_to_torch( + gt_data_dict["decoder/current_agent_shape"][None], device=self.device + ).expand(T_gt, -1, -1).transpose(0, 1) + pred_shape = pred_data_dict["decoder/current_agent_shape"][:, None].expand(-1, T_pred, -1, -1).transpose(1, 2) + + # minSFDE + with timer("minSFDE"): + import pdb + pdb.set_trace() + self.minSFDE += torch.min( + torch.sum( + torch.where( + gt_ooi_joint[None, ..., -1] & pred_ooi_joint[..., -1], + torch.linalg.norm( + numpy_to_torch(gt_data_dict["decoder/agent_position"], + device=self.device)[None, T_gt - 1, :, :2] - + pred_data_dict["decoder/reconstructed_position"][:, T_gt - 1], + dim=-1 + ), 0 + ), + dim=-1 + ) / (gt_ooi_joint[None, ..., -1] & pred_ooi_joint[..., -1]).sum(dim=-1) + ) + # FDD + with timer("FDD"): + # there doesn't appear to be an easy way to do this with cartesian product + # import pdb; pdb.set_trace() + cur_FDD = None + for i, j in itertools.product(range(K), range(K)): + final_dist = torch.where( + pred_ooi_joint[..., -1], + torch.linalg.norm( + pred_data_dict["decoder/reconstructed_position"][i, T_gt - 1] - + pred_data_dict["decoder/reconstructed_position"][j, T_gt - 1], + dim=-1 + ), 0 + ) + if cur_FDD == None: + cur_FDD = final_dist + else: + cur_FDD = torch.maximum(cur_FDD, final_dist) + self.FDD += cur_FDD.sum() / torch.pow(pred_ooi_joint[..., -1].sum(), 2) + + with timer("Kinematic Metrics"): + gt_speed, gt_accel, gt_jerk = self._compute_kinematic_metrics( + gt_data_dict["decoder/agent_velocity"].swapaxes(1, 0) + ) # (N, T) + pred_speed, pred_accel, pred_jerk = self._compute_kinematic_metrics( + pred_data_dict["decoder/reconstructed_velocity"].transpose(1, 2) + ) # (K, N, T) + gt_speed = gt_speed[..., T_context:T_gt] + gt_accel = gt_accel[..., T_context:T_gt] + gt_jerk = gt_jerk[..., T_context:T_gt] + pred_speed = pred_speed[..., T_context:T_gt] + pred_accel = pred_accel[..., T_context:T_gt] + pred_jerk = pred_jerk[..., T_context:T_gt] + + with timer("Collision Metrics"): + # Following wosac_eval, fill in z with GT t = 10 data + z_values = pred_data_dict["decoder/current_agent_position"][..., 2].unsqueeze(-1).expand(-1, -1, T_pred) + + def build_collision_data(candidate_agents, evaluate_agents): + if isinstance(candidate_agents, torch.Tensor): + candidate_agents = candidate_agents.cpu().numpy() + if isinstance(evaluate_agents, torch.Tensor): + evaluate_agents = evaluate_agents.cpu().numpy() + if isinstance(candidate_agents, np.ndarray): + candidate_agents = candidate_agents.tolist() + if isinstance(evaluate_agents, np.ndarray): + evaluate_agents = evaluate_agents.tolist() + if not isinstance(candidate_agents, Iterable): + candidate_agents = [candidate_agents] + if not isinstance(evaluate_agents, Iterable): + evaluate_agents = [evaluate_agents] + candidate_agents = sorted(set(candidate_agents + evaluate_agents)) + candidate_agents_map = {v: k for k, v in enumerate(candidate_agents)} + evaluate_agents_mask = np.zeros(len(candidate_agents), dtype=bool) + evaluate_agents_mask[[candidate_agents_map[k] for k in evaluate_agents]] = 1 + return [ + dict( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, :, candidate_agents, 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, :, candidate_agents, 1].T), + center_z=conv(z_values[k, candidate_agents]), + length=conv(pred_shape[k, candidate_agents, :, 0]), + width=conv(pred_shape[k, candidate_agents, :, 1]), + height=conv(pred_shape[k, candidate_agents, :, 2]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k, :, candidate_agents].T), + valid=conv(pred_valid_mask[k, candidate_agents], dtype=tf.bool), + evaluated_object_mask=conv(evaluate_agents_mask, dtype=tf.bool) + ) for k in range(K) + ] + + def get_dists(args_list): + return torch.stack( + [ + tf_to_torch( + interaction_features.compute_distance_to_nearest_object(**args_list[k]), device=self.device + ) for k in range(K) + ] + ) + + def calc_collision(dists, valid_masks): + if type(valid_masks) == list: + valid_masks = torch.stack(valid_masks) + collisions = torch.le(dists, interaction_features.COLLISION_DISTANCE_THRESHOLD) + collisions = collisions[..., T_context:T_gt] + valid_masks = valid_masks[..., T_context:T_gt] + collision_rate = torch.min( + torch.any(collisions & valid_masks, dim=-1).double().sum(dim=-2) / + torch.any(valid_masks, dim=-1).double().sum(dim=-2) + ) + return collisions, collision_rate + + def calc_collision_rate(candidate_agents, evaluate_agents): + args = build_collision_data(candidate_agents, evaluate_agents) + collisions, collision_rate = calc_collision( + get_dists(args), [ + tf_to_torch( + tf.boolean_mask(args[k]["valid"], args[k]["evaluated_object_mask"], axis=0), + device=self.device + ) for k in range(K) + ] + ) + return collisions, collision_rate + + # import pdb; pdb.set_trace() + pred_veh_collisions, veh_cr = calc_collision_rate(list(range(N)), gt_data_dict["decoder/labeled_agent_id"]) + self.veh_coll += veh_cr + + if "adv" in kwargs: # assumes list of adv agents + self.sdc_coll_adv_active = True + self.sdc_coll_adv += calc_collision_rate(kwargs["adv"], pred_data_dict["decoder/sdc_index"])[-1] + + if "bv" in kwargs: # assumes list of bv agents + self.sdc_coll_bv_active = True + self.sdc_coll_bv += calc_collision_rate(kwargs["bv"], pred_data_dict["decoder/sdc_index"])[-1] + + if "adv" in kwargs and "bv" in kwargs: + self.adv_coll_bv_active = True + self.adv_coll_bv += calc_collision_rate(kwargs["bv"], kwargs["adv"])[-1] + + map_feature = gt_data_dict["encoder/map_feature"] + road_edges = [ + [map_pb2.MapPoint(x=map_feature[i, 0, 0], y=map_feature[i, 0, 1], z=map_feature[i, 0, 2])] + [ + map_pb2.MapPoint(x=map_feature[i, j, 3], y=map_feature[i, j, 4], z=map_feature[i, j, 5]) + for j in range(map_feature.shape[1]) if gt_data_dict['encoder/map_feature_valid_mask'][i, j] + ] # start point + end points + for i in range(map_feature.shape[0]) if map_feature[i, 0, 15] == 1 + ] # is boundary + env_nearest_distances = torch.stack( + [ + tf_to_torch( + map_metric_features.compute_distance_to_road_edge( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 1].T), + center_z=conv(z_values[k]), + length=conv(pred_shape[k, ..., 0]), + width=conv(pred_shape[k, ..., 1]), + height=conv(pred_shape[k, ..., 2]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k].T), + valid=conv(pred_valid_mask[k], dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + road_edge_polylines=road_edges, + ), + device=self.device + ) for k in range(K) + ] + ) + + pred_env_collisions = torch.greater(env_nearest_distances, map_metric_features.OFFROAD_DISTANCE_THRESHOLD) + pred_env_collisions = pred_env_collisions[..., T_context:T_gt] + env_collision_rate = torch.min( + pred_env_collisions.sum(dim=(-1, -2), dtype=torch.double) / + pred_valid_mask[:, ooi_mask, T_context:T_gt].sum(dtype=torch.double, dim=(-1, -2)) + ) + + self.env_coll += env_collision_rate + + # Debug: ground truth env collision rate + # debug_env_nearest_distances = rconv(map_metric_features.compute_distance_to_road_edge( + # center_x=conv(gt_data_dict["decoder/agent_position"][..., 0].T), + # center_y=conv(gt_data_dict["decoder/agent_position"][..., 1].T), + # center_z=conv(gt_data_dict["decoder/agent_position"][..., 2].T), + # length=conv(gt_data_dict["decoder/agent_shape"][..., 0].T), + # width=conv(gt_data_dict["decoder/agent_shape"][..., 1].T), + # height=conv(gt_data_dict["decoder/agent_shape"][..., 2].T), + # heading=conv(gt_data_dict["decoder/agent_heading"].T), + # valid=conv(gt_data_dict["decoder/agent_valid_mask"].T), + # evaluated_object_mask=conv(vehicle_mask), + # road_edge_polylines=road_edges, + # )) + + # debug_env_collisions = torch.le(debug_env_nearest_distances, map_metric_features.OFFROAD_DISTANCE_THRESHOLD) + # debug_env_collisions = debug_env_collisions[..., 11:91] + + # debug_env_collision_rate = torch.min(debug_env_collisions.double().mean(dim=(-1, -2))) + # print(debug_env_collision_rate) + + self.coll_vel += torch.min( + torch.nan_to_num(torch.where(pred_env_collisions | pred_veh_collisions, pred_speed[:, ooi_mask], + 0)).amax(dim=(-1, -2)) + ) + self.acc += torch.min(torch.abs(torch.nan_to_num(pred_accel[:, 0])).amax(dim=-1)) + self.jerk += torch.min(torch.abs(torch.nan_to_num(pred_jerk[:, 0])).amax(dim=-1)) + + with timer("Time to Collision"): + gt_ttc = tf_to_torch( + interaction_features.compute_time_to_collision_with_object_in_front( + center_x=conv(gt_data_dict["decoder/agent_position"][..., 0].T), + center_y=conv(gt_data_dict["decoder/agent_position"][..., 1].T), + length=conv(gt_shape[..., 0]), + width=conv(gt_shape[..., 1]), + heading=conv(gt_data_dict["decoder/agent_heading"].T), + valid=conv(gt_valid_mask, dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + seconds_per_step=self.SECONDS_PER_STEP + ), + device=self.device + ) + pred_ttc = torch.stack( + [ + tf_to_torch( + interaction_features.compute_time_to_collision_with_object_in_front( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 1].T), + length=conv(pred_shape[k, ..., 0]), + width=conv(pred_shape[k, ..., 1]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k].T), + valid=conv(pred_data_dict["decoder/reconstructed_valid_mask"][k].T, dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + seconds_per_step=self.SECONDS_PER_STEP + ), + device=self.device + ) for k in range(K) + ] + ) + gt_ttc = gt_ttc[..., T_context:T_gt] + pred_ttc = pred_ttc[..., T_context:T_gt] + + with timer("Histograms"): + gt_speed_hist, gt_speed_bins = torch.histogram( + torch.clip( + gt_speed[gt_ooi_joint & ~gt_speed.isnan()], self.jsd_config["vel"]["min_val"], + self.jsd_config["vel"]["max_val"] + ).cpu(), + self.jsd_config["vel"]["num_bins"], + density=False + ) + # .cpu() since histogram doesn't support cuda backend + pred_speed_hist, pred_speed_bins = torch.histogram( + torch.clip( + pred_speed[pred_ooi_joint & ~pred_speed.isnan()], self.jsd_config["vel"]["min_val"], + self.jsd_config["vel"]["max_val"] + ).cpu(), + self.jsd_config["vel"]["num_bins"], + density=False + ) + gt_accel_hist, gt_accel_bins = torch.histogram( + torch.clip( + gt_accel[gt_ooi_joint & ~gt_accel.isnan()], self.jsd_config["acc"]["min_val"], + self.jsd_config["acc"]["max_val"] + ).cpu(), + self.jsd_config["acc"]["num_bins"], + density=False + ) + pred_accel_hist, pred_accel_bins = torch.histogram( + torch.clip( + pred_accel[pred_ooi_joint & ~pred_accel.isnan()], self.jsd_config["acc"]["min_val"], + self.jsd_config["acc"]["max_val"] + ).cpu(), + self.jsd_config["acc"]["num_bins"], + density=False + ) + gt_ttc_hist, gt_ttc_bins = torch.histogram( + torch.clip( + gt_ttc[gt_valid_mask[ooi_mask, T_context:T_gt] & ~gt_ttc.isnan()], + self.jsd_config["ttc"]["min_val"], self.jsd_config["ttc"]["max_val"] + ).cpu(), + self.jsd_config["ttc"]["num_bins"], + density=False + ) + pred_ttc_hist, pred_ttc_bins = torch.histogram( + torch.clip( + pred_ttc[pred_valid_mask[:, ooi_mask, T_context:T_gt] & ~pred_ttc.isnan()], + self.jsd_config["ttc"]["min_val"], self.jsd_config["ttc"]["max_val"] + ).cpu(), + self.jsd_config["ttc"]["num_bins"], + density=False + ) + # visualize histograms for debug + plt.clf() + plt.hist(gt_speed_bins[:-1], bins=gt_speed_bins, weights=gt_speed_hist, density=False) + plt.savefig(f"gt_speed{SCENE_IDX}.png", bbox_inches='tight') + plt.clf() + plt.hist(pred_speed_bins[:-1], bins=pred_speed_bins, weights=pred_speed_hist, density=False) + plt.savefig(f"pred_speed{SCENE_IDX}.png", bbox_inches='tight') + plt.clf() + plt.hist(gt_accel_bins[:-1], bins=gt_accel_bins, weights=gt_accel_hist, density=False) + plt.savefig(f"gt_accel{SCENE_IDX}.png", bbox_inches='tight') + plt.clf() + plt.hist(pred_accel_bins[:-1], bins=pred_accel_bins, weights=pred_accel_hist, density=False) + plt.savefig(f"pred_accel{SCENE_IDX}.png", bbox_inches='tight') + plt.clf() + plt.hist(gt_ttc_bins[:-1], bins=gt_ttc_bins, weights=gt_ttc_hist, density=False) + plt.savefig(f"gt_ttc{SCENE_IDX}.png", bbox_inches='tight') + plt.clf() + plt.hist(pred_ttc_bins[:-1], bins=pred_ttc_bins, weights=pred_ttc_hist, density=False) + plt.savefig(f"pred_ttc{SCENE_IDX}.png", bbox_inches='tight') + plt.clf() + + with timer("JSD"): + speed_jsd = jsd(gt_speed_hist, pred_speed_hist) + acc_jsd = jsd(gt_accel_hist, pred_accel_hist) + ttc_jsd = jsd(gt_ttc_hist, pred_ttc_hist) + self.vel_jsd += speed_jsd + self.acc_jsd += acc_jsd + self.ttc_jsd += ttc_jsd + + def _compute_kinematic_metrics(self, vel): + if type(vel) == np.ndarray: + vel = numpy_to_torch(vel, device=self.device) + speed = torch.linalg.norm(vel, axis=-1) + accel = self._central_diff(speed, pad_value=torch.nan) / self.SECONDS_PER_STEP + jerk = self._central_diff(accel, pad_value=torch.nan) / self.SECONDS_PER_STEP + return speed, accel, jerk + + def _central_diff(self, tensor, pad_value=torch.nan): + pad_shape = (*tensor.shape[:-1], 1) + pad_tensor = torch.ones(pad_shape, device=self.device) * pad_value + diff_t = (tensor[..., 2:] - tensor[..., :-2]) / 2 + return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1) + + def aggregate(self): + # TODO: write some "aggregate" function to compute the metrics + self.metrics["# Scenarios"] = self.scenario_count + self.metrics["minSFDE"] = self.minSFDE / self.scenario_count + self.metrics["FDD"] = self.FDD / self.scenario_count + self.metrics["Vel. JSD"] = self.vel_jsd / self.scenario_count + self.metrics["Acc. JSD"] = self.acc_jsd / self.scenario_count + self.metrics["TTC JSD"] = self.ttc_jsd / self.scenario_count + self.metrics["Env CR"] = self.env_coll / self.scenario_count + self.metrics["Veh CR"] = self.veh_coll / self.scenario_count + self.metrics["ADV+SDC CR"] = self.sdc_coll_adv / self.scenario_count + self.metrics["ADV+BV CR"] = self.adv_coll_bv / self.scenario_count + self.metrics["SDC+BV CR"] = self.sdc_coll_bv / self.scenario_count + self.metrics["Avg. Max Coll Vel."] = self.coll_vel / self.scenario_count + self.metrics["Avg. Max Acc."] = self.acc / self.scenario_count + self.metrics["Avg. Max Jerk"] = self.jerk / self.scenario_count + + def print(self): + # TODO(xuanhao): Maybe implement a handy function to print output + pass + # self.precision_safety_0 = self.TP_safety_0 / (self.TP_safety_0 + self.FP_safety_0) + # self.recall_safety_0 = self.TP_safety_0 / (self.TP_safety_0 + self.FN_safety_0) + # + # print("=====================================") + # print( + # "precision_safety_0: {:.5f} = {} / {}".format( + # self.precision_safety_0, self.TP_safety_0, self.TP_safety_0 + self.FP_safety_0 + # ) + # ) + # print( + # "recall_safety_0: {:.5f} = {} / {}".format( + # self.recall_safety_0, self.TP_safety_0, self.TP_safety_0 + self.FN_safety_0 + # ) + # ) + # print("=====================================") + # + # self.precision_safety_1 = self.TP_safety_1 / (self.TP_safety_1 + self.FP_safety_1) + # self.recall_safety_1 = self.TP_safety_1 / (self.TP_safety_1 + self.FN_safety_1) + # print( + # "precision_safety_1: {:.5f} = {} / {}".format( + # self.precision_safety_1, self.TP_safety_1, self.TP_safety_1 + self.FP_safety_1 + # ) + # ) + # print( + # "recall_safety_1: {:.5f} = {} / {}".format( + # self.recall_safety_1, self.TP_safety_1, self.TP_safety_1 + self.FN_safety_1 + # ) + # ) + # + # print("=====================================") + # self.precision_macro = (self.precision_safety_0 + self.precision_safety_1) / 2 + # self.recall_macro = (self.recall_safety_0 + self.recall_safety_1) / 2 + # self.f1_macro = 2 * self.precision_macro * self.recall_macro / (self.precision_macro + self.recall_macro) + # print("precision_macro:", self.precision_macro) + # print("recall_macro:", self.recall_macro) + # print("f1_macro:", self.f1_macro) + # print("=====================================") + + self.aggregate() + for k in self.display_keys: + if k == "ADV+SDC CR" and not self.sdc_coll_adv_active: + continue + if k == "ADV+BV CR" and not self.adv_coll_bv_active: + continue + if k == "SDC+BV CR" and not self.sdc_coll_bv_active: + continue + print(f"{k}: {self.metrics[k]:.5f} {self.metric_units[k]}") + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1031_midgpt.yaml") +def debug_run_model(config): + import os + global SCENE_IDX + path = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + config.pretrain = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + config.BACKWARD_PREDICTION = True # <<< + config.ADD_CONTOUR_RELATION = True + config.DATA.TRAINING_DATA_DIR = "/bigdata/yuxin/waymo_validation_interactive_500" #"data/20scenarios" + config.DATA.TEST_DATA_DIR = "/bigdata/yuxin/waymo_validation_interactive_500" #"data/20scenarios" + + omegaconf.OmegaConf.set_struct(config, True) + + model = utils.get_model(config, device="cuda") + device = model.device + + test_dataset = SceneStreamerDataset(config, "test") + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config) + + evaluator = Evaluator() + + num_scenario = 100 + count = 0 + num_modes = 1 + for raw_data_dict in tqdm.tqdm(test_dataset): + data_dict = copy.deepcopy(raw_data_dict) + + # Get the torch version of the data. + input_data_dict = { + k: torch.from_numpy(v).to(device) if isinstance(v, np.ndarray) and "track_name" not in k else v + for k, v in data_dict.items() + } + + # Extend the batch dim: + input_data_dict = { + k: utils.expand_for_modes(v.unsqueeze(0), num_modes=num_modes) if isinstance(v, torch.Tensor) else v + for k, v in input_data_dict.items() + } + input_data_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + if config.BACKWARD_PREDICTION: + input_data_dict["in_backward_prediction"] = torch.tensor([False] * num_modes, dtype=bool).to(device) + + with torch.no_grad(): + ar_func = model.model.autoregressive_rollout + output_dict = ar_func( + input_data_dict, + num_decode_steps=None, + sampling_method=config.SAMPLING.SAMPLING_METHOD, + temperature=config.SAMPLING.TEMPERATURE, + ) + output_dict = tokenizer.detokenize( + output_dict, + detokenizing_gt=False, + backward_prediction=False, + ) + + # Just for debug... Plot first mode. + output_dict_numpy = { + k: (v[0].cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in output_dict.items() + } + # plot_pred(output_dict_numpy, show=True, path=f"plot_output{SCENE_IDX}.png") + SCENE_IDX += 1 + evaluator.add(raw_data_dict, output_dict) + evaluator.print() + + # evaluator.aggregate() + evaluator.print() + print("End of evaluating ") + + +def _get_mode(data, mode): + ret = {} + for k, v in output_dict.items(): + if isinstance(v, np.ndarray) and len(v) == num_modes: + ret[k] = v[mode] + else: + ret[k] = v + return ret + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1031_midgpt.yaml") +def evaluate_scgen(config): + from scenestreamer.utils.safety_critical_generation_utils import _overwrite_data_given_agents_not_ooi, get_ego_edge_points, get_ego_edge_points_old, post_process_adv_traj, _overwrite_data_given_agents_ooi, _overwrite_data_given_agents, set_adv, run_backward_prediction_with_teacher_forcing + from scenestreamer.utils import utils + import copy + path = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + config.pretrain = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + config.BACKWARD_PREDICTION = True # <<< + config.ADD_CONTOUR_RELATION = True + config.DATA.TRAINING_DATA_DIR = "/bigdata/yuxin/waymo_validation_interactive_500" #"data/20scenarios" + config.DATA.TEST_DATA_DIR = "/bigdata/yuxin/waymo_validation_interactive_500" #"data/20scenarios" + omegaconf.OmegaConf.set_struct(config, True) + model = utils.get_model(checkpoint_path=path) + import torch + model = model.to("cuda") + device = model.device + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config) + + evaluator = Evaluator() + num_modes = 1 + count = 0 + num_scenario = 100 + + pbar = tqdm.tqdm(total=500, desc="Scenario") + # for count, raw_data in enumerate(datamodule.val_dataloader()): + dataset = SceneStreamerDataset(config, "test") + for raw_data_dict in tqdm.tqdm(dataset): + # if count >= num_scenario: + # break + flip_heading_accordingly = True + backward_prediction = True + + data_dict = raw_data_dict + raw_data_dict = copy.deepcopy(data_dict) + + # Create a new ADV in the data so backward prediction will help us generate it. + # TODO: If we also want to TF ego, then we should not overwrite ego data. + sdc_id = data_dict["decoder/sdc_index"] + + data_dict, adv_id = set_adv(data_dict) + # data_dict = create_new_adv(data_dict) + # pdb.set_trace() + + input_data_dict = utils.numpy_to_torch(data_dict, device=device) + original_data_dict_tensor = copy.deepcopy(input_data_dict) + input_data_dict = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in input_data_dict.items()} + + all_agents = input_data_dict["decoder/agent_id"][0] + not_tf_ids = all_agents[all_agents != 0] + + for iteration in range(1): + print("====================================") + print("Iteration: ", iteration) + print("====================================") + + backward_input_dict = copy.deepcopy(input_data_dict) + + backward_output_dict = run_backward_prediction_with_teacher_forcing( + model=model, + config=config, + backward_input_dict=backward_input_dict, + tokenizer=tokenizer, + + # TODO: Which to TF? + not_teacher_forcing_ids=not_tf_ids + ) + + # pdb.set_trace() + + # ===== Only used for vis ===== + # from scenestreamer.utils.utils import numpy_to_torch + + # backward_output_dict_numpy = {k: (v.cpu().numpy() if isinstance(v, torch.Tensor) else v) for k, v in backward_output_dict.items()} + + # backward_output_dict_numpy = { + # k: (v.squeeze(0).cpu().numpy() if isinstance(v, torch.Tensor) else v) + # for k, v in backward_output_dict.items() + # } + + # original_data_dict = {k: (v.cpu().numpy() if isinstance(v, torch.Tensor) else v) for k, v in original_data_dict.items()} + all_agents = raw_data_dict["decoder/agent_id"] + sdc_id = raw_data_dict["decoder/sdc_index"] + all_agents_except_sdc = all_agents[all_agents != sdc_id] + evaluator.add(original_data_dict_tensor, backward_output_dict, adv=[adv_id], bv=all_agents_except_sdc) + evaluator.print() + print("====================================") + evaluator.print() + print("End of evaluating SCGEN") + + +def convert_tensors_to_double(data_dict, double_keys): + + for key, value in data_dict.items(): + if isinstance(value, np.ndarray): + if key in double_keys: + data_dict[key] = value.astype(np.float32) + + elif isinstance(value, torch.Tensor): + if key in double_keys: + device = value.device + data_dict[key] = value.to(device=device, dtype=torch.float32) + + return data_dict + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1031_midgpt.yaml") +def debug_eval_CAT(config): + import os + from scenestreamer.dataset.scenarionet_utils import overwrite_gt_to_pred_field + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + config.pretrain = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + config.BACKWARD_PREDICTION = True # <<< + config.ADD_CONTOUR_RELATION = True + # config.DATA.TEST_DATA_DIR = "/bigdata/yuxin/cat_adv_validation_interactive/validation_interactive_58/cat_new/" #"data/20scenarios" + config.DATA.TEST_DATA_DIR = "/bigdata/datasets/scenarionet/waymo/validation_interactive/validation_interactive_58/" + CAT_DIR = "/bigdata/yuxin/cat_adv_validation_interactive/validation_interactive_58/cat_new" + omegaconf.OmegaConf.set_struct(config, True) + + test_dataset = SceneStreamerDataset(config, "test") + evaluator = Evaluator() + + model = utils.get_model(config, device="cuda") + device = model.device + + num_scenario = 500 + num_modes = 1 + count = 0 + + import pickle + import os + with open(os.path.join(CAT_DIR, 'dataset_summary.pkl'), "rb") as f: + cat_summary = pickle.load(f) + f.close() + all_cat_scenarios = cat_summary.keys() + + for raw_data_dict in tqdm.tqdm(test_dataset): + if count >= num_scenario: + break + input_dict = copy.deepcopy(raw_data_dict) + # import pdb; pdb.set_trace() + sid = input_dict["metadata/scenario_id"] + cat_file_name = f"sd_reconstructed_v0_{sid}.pkl" + if cat_file_name not in all_cat_scenarios: + continue + + input_data_dict = numpy_to_torch(input_dict, device=device) + double_keys = [ + "decoder/agent_position", 'decoder/agent_heading', 'decoder/agent_velocity', + "decoder/reconstructed_position", "decoder/reconstructed_heading", "decoder/reconstructed_velocity", + "decoder/agent_shape", "decoder/current_agent_shape", "decoder/current_agent_position" + ] + input_data_dict = convert_tensors_to_double(input_data_dict, double_keys) + + with open(os.path.join(CAT_DIR, cat_file_name), 'rb') as f: + cat_data = pickle.load(f) + f.close() + + cat_data_dict = preprocess_scenario_description_for_motionlm( + scenario=cat_data, config=config, in_evaluation=True, keep_all_data=True, cache=None + ) + + output_dict = overwrite_gt_to_pred_field(cat_data_dict) + output_data_dict = numpy_to_torch(output_dict, device=device) + output_data_dict = convert_tensors_to_double(output_data_dict, double_keys) + output_data_dict = { + k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v + for k, v in output_data_dict.items() + } + + evaluator.add(input_data_dict, output_data_dict) + evaluator.print() + count += 1 + + evaluator.print() + print("End of evaluation of CAT generation") + + +if __name__ == '__main__': + # debug_eval_CAT() + evaluate_scgen() + # debug_run_model() diff --git a/scenestreamer/eval/eval_open_loop.py b/scenestreamer/eval/eval_open_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..a239732352ef81cfbe0b4950a30eb850c07aeeeb --- /dev/null +++ b/scenestreamer/eval/eval_open_loop.py @@ -0,0 +1,1266 @@ +import copy +import dataclasses +import itertools +from collections.abc import Iterable +from contextlib import contextmanager +from time import perf_counter +from typing import Any + +import numpy as np +import pytorch_lightning as pl +import tensorflow as tf +import torch +import torch.nn.functional as F +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.utils.data import DataLoader +from waymo_open_dataset.protos import map_pb2 +from waymo_open_dataset.wdl_limited.sim_agents_metrics import interaction_features +from waymo_open_dataset.wdl_limited.sim_agents_metrics import map_metric_features + +from scenestreamer.utils.utils import numpy_to_torch + + +def _overwrite_datadict_all_agents(source_data_dict, dest_data_dict, ooi=None): + import copy + new_data_dict = copy.deepcopy(dest_data_dict) + B, T, N, _ = source_data_dict["decoder/reconstructed_position"].shape + + if ooi is None: + ooi = np.arange(N) + + for id in ooi: # overwrite all agents + traj = source_data_dict["decoder/reconstructed_position"][:, :91, id, ] + traj_mask = source_data_dict["decoder/reconstructed_valid_mask"][:, :91, id] + theta = source_data_dict['decoder/reconstructed_heading'][:, :91, id] + vel = source_data_dict['decoder/reconstructed_velocity'][:, :91, id] + + new_data_dict["decoder/agent_position"][:, :, id, :2] = traj + new_data_dict["decoder/agent_position"][:, :, id, 2] = 0.0 + new_data_dict["decoder/agent_valid_mask"][:, :, id] = traj_mask + new_data_dict["decoder/agent_heading"][:, :, id] = theta + new_data_dict["decoder/agent_velocity"][:, :, id] = vel + + return new_data_dict + + +def detect_env_collision(contour_list1, mask1, lineString): + collision_detected = [] + + for i in range(len(contour_list1)): + if mask1[i]: + agent_poly = Polygon(contour_list1[i]) + + if agent_poly.intersects(lineString): + collision_detected.append(True) + else: + collision_detected.append(False) + else: + collision_detected.append(False) + + return collision_detected + + +from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour +from shapely.geometry import Polygon + + +def get_dists(args_list, device): + return torch.stack( + [ + tf_to_torch(interaction_features.compute_distance_to_nearest_object(**args_list[k]), device=device) + for k in range(len(args_list)) + ] + ) + + +def build_collision_data(*, pred_data_dict, pred_shape, candidate_agents, evaluate_agents, z_values): + candidate_agents = sorted(set(candidate_agents + evaluate_agents)) + candidate_agents_map = {int(v): k for k, v in enumerate(candidate_agents)} + evaluate_agents_mask = np.zeros(len(candidate_agents), dtype=bool) + + assert evaluate_agents_mask.ndim == 1 + for k in evaluate_agents: + evaluate_agents_mask[candidate_agents_map[int(k)]] = 1 + + K = pred_shape.shape[0] + + return [ + dict( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, :, candidate_agents, 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, :, candidate_agents, 1].T), + center_z=conv(z_values[k, candidate_agents]), + length=conv(pred_shape[k, candidate_agents, :, 0]), + width=conv(pred_shape[k, candidate_agents, :, 1]), + height=conv(pred_shape[k, candidate_agents, :, 2]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k, :, candidate_agents].T), + valid=conv(pred_data_dict["decoder/reconstructed_valid_mask"][k, :, candidate_agents].T, dtype=tf.bool), + evaluated_object_mask=conv(evaluate_agents_mask, dtype=tf.bool) + ) for k in range(K) + ] + + +def calc_collision(*, dists, valid_masks, T_context, T_gt): + if type(valid_masks) == list: + valid_masks = torch.stack(valid_masks) + collisions = torch.le(dists, interaction_features.COLLISION_DISTANCE_THRESHOLD) + collisions = collisions[..., T_context:T_gt] + valid_masks = valid_masks[..., T_context:T_gt] + + collisions = collisions & valid_masks # Shape: (B, N, T) + # Number of agents that has coll. + collisions_count = torch.any(collisions, dim=-1).double().sum(dim=-1) # Shape: (B,) + valid_agent_for_collision = torch.any(valid_masks, dim=-1) # Shape: (B, N) + + # Ratio of agents that has coll. + mode_cr = collisions_count / valid_agent_for_collision.sum(dim=-1) + + assert mode_cr.ndim == 1 + return collisions, mode_cr + + +def calc_collision_rate( + *, pred_data_dict, pred_shape, candidate_agents, evaluate_agents, device, T_gt, T_context, z_values +): + if isinstance(candidate_agents, torch.Tensor): + candidate_agents = candidate_agents.cpu().numpy() + if isinstance(evaluate_agents, torch.Tensor): + evaluate_agents = evaluate_agents.cpu().numpy() + if isinstance(candidate_agents, np.ndarray): + candidate_agents = candidate_agents.tolist() + if isinstance(evaluate_agents, np.ndarray): + evaluate_agents = evaluate_agents.tolist() + if not isinstance(candidate_agents, Iterable): + candidate_agents = [candidate_agents] + if not isinstance(evaluate_agents, Iterable): + evaluate_agents = [evaluate_agents] + + args = build_collision_data( + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + candidate_agents=candidate_agents, + evaluate_agents=evaluate_agents, + z_values=z_values + ) + dist = get_dists(args, device=device) + + def _get_valid_masks(candidate_agents, evaluate_agents): + + candidate_agents = sorted(set(candidate_agents + evaluate_agents)) + candidate_agents_map = {int(v): k for k, v in enumerate(candidate_agents)} + evaluate_agents_mask = np.zeros(len(candidate_agents), dtype=bool) + assert evaluate_agents_mask.ndim == 1 + for k in evaluate_agents: + evaluate_agents_mask[candidate_agents_map[int(k)]] = 1 + candidate_valid_mask = pred_data_dict["decoder/reconstructed_valid_mask"][:, T_context:T_gt, candidate_agents] + evaluate_valid_mask = candidate_valid_mask[:, :, evaluate_agents_mask] + return evaluate_valid_mask.swapaxes(1, 2) + + pred_veh_collisions, veh_cr_mode = calc_collision( + dists=dist, valid_masks=_get_valid_masks(candidate_agents, evaluate_agents), T_context=T_context, T_gt=T_gt + ) + return pred_veh_collisions, veh_cr_mode + + +def print_type_and_dtype(name, tensor): + print(f"{name} - Type: {type(tensor)}, Dtype: {getattr(tensor, 'dtype', 'N/A')}") + + +def conv(tensor, dtype=tf.float32): + if isinstance(tensor, torch.Tensor): + tensor = tensor.cpu() + return tf.convert_to_tensor(tensor, dtype=dtype) + + +def tf_to_torch(tf_tensor, device=None): + # Convert TensorFlow tensor to NumPy array on CPU if necessary + if tf_tensor.device.endswith("GPU:0"): # If on GPU, move to CPU first + tf_tensor = tf_tensor.cpu() + np_array = tf_tensor.numpy() + + return torch.from_numpy(np_array).to( + device if device else torch.device("cuda" if tf_tensor.device.endswith("GPU:0") else "cpu") + ) + + +def jsd(gt_hist, pred_hist, epsilon=1e-10): + gt_prob = gt_hist / gt_hist.sum() + pred_prob = pred_hist / pred_hist.sum() + gt_prob += epsilon + pred_prob += epsilon + m = 0.5 * (gt_prob + pred_prob) + jsd = 0.0 + jsd += F.kl_div(gt_prob.log(), m, reduction="sum") + jsd += F.kl_div(pred_prob.log(), m, reduction="sum") + return (0.5 * jsd) + + +TIMER = False + + +@contextmanager +def timer(task_name: str): + start = perf_counter() + yield + prof_t = perf_counter() - start + if TIMER: + print(f"{task_name}: {prof_t:.5f}") + + +@dataclasses.dataclass +class Metrics: + scenario_count: int = 0 + sdc_coll_scenario_count: int = 0 + veh_coll_scenario_count: int = 0 + + # Diversity + sfde_avg: float = 0.0 + sade_avg: float = 0.0 + sfde_min: float = 0.0 # (supervised) avg over scenarios: minimum over all modes: average of L2 error of final positions of all agents + sade_min: float = 0.0 + + skipped_sfde_avg: float = 0.0 + skipped_sade_avg: float = 0.0 + skipped_sfde_min: float = 0.0 # (supervised) avg over scenarios: minimum over all modes: average of L2 error of final positions of all agents + skipped_sade_min: float = 0.0 + + fdd: float = 0.0 # (unsupervised) avg over scenarios: average over all agents: maximum L2 distance in final position of that agent between generated modes + add: float = 0.0 # (unsupervised) avg over scenarios: average over all agents: maximum L2 distance in final position of that agent between generated modes + # Xuanhao: In MixSim paper, they used squared norm of distance, but maybe they meant L2 norm not squared norm? + # Unit given in AdvDiffuser for FDD is m not m^2, so I am using L2 norm here + + # Distribution Realism + vel_jsd: float = 0.0 # avg over scenarios: build histogram across agents, modes, timestamps: velocity JS divergence + acc_jsd: float = 0.0 # avg over scenarios: build histogram across agents, modes, timestamps: acceleration JS divergence + ttc_jsd: float = 0.0 # avg over scenarios: build histogram across agents, modes, timestamps: time to collision JS divergence + + # Common Sense + env_coll_max: float = 0.0 # offroad + env_coll_min: float = 0.0 # offroad + env_coll_avg: float = 0.0 # offroad + + veh_coll_max: float = 0.0 # collision rate + veh_coll_min: float = 0.0 # collision rate + veh_coll_avg: float = 0.0 # collision rate + + # SDC-ADV coll + sdc_adv_coll_max: float = 0.0 # collision rate + sdc_adv_coll_min: float = 0.0 # collision rate + sdc_adv_coll_avg: float = 0.0 # collision rate + + sdc_bv_coll_max: float = 0.0 # collision rate + sdc_bv_coll_min: float = 0.0 # collision rate + sdc_bv_coll_avg: float = 0.0 # collision rate + + adv_bv_coll_max: float = 0.0 # collision rate + adv_bv_coll_min: float = 0.0 # collision rate + adv_bv_coll_avg: float = 0.0 # collision rate + + coll_vel_maxagent_avg: float = 0.0 # collision velocity max over agents, avg over modes + coll_vel_maxagent_max: float = 0.0 # collision velocity max over agents, max over modes + coll_vel_maxagent_min: float = 0.0 # collision velocity max over agents, min over modes + coll_vel_sdc_avg: float = 0.0 # collision velocity only for SDC + coll_vel_sdc_max: float = 0.0 # collision velocity only for SDC, max over modes + coll_vel_sdc_min: float = 0.0 # collision velocity only for SDC, min over modes + + # no clue what collision JSD means so not calculating it for now + + # AV comfortable + sdc_acc_maxtime_avg: float = 0.0 + sdc_acc_maxtime_min: float = 0.0 + sdc_acc_maxtime_max: float = 0.0 + sdc_acc_avgtime_avg: float = 0.0 + sdc_acc_avgtime_min: float = 0.0 + sdc_acc_avgtime_max: float = 0.0 + + sdc_jerk_maxtime_avg: float = 0.0 + sdc_jerk_maxtime_min: float = 0.0 + sdc_jerk_maxtime_max: float = 0.0 + sdc_jerk_avgtime_avg: float = 0.0 + sdc_jerk_avgtime_min: float = 0.0 + sdc_jerk_avgtime_max: float = 0.0 + + customized_max_sdc_adv_coll: float = 0.0 + customized_max_sdc_bv_coll: float = 0.0 + customized_max_adv_bv_coll: float = 0.0 + + customized_min_sdc_adv_coll: float = 0.0 + customized_min_sdc_bv_coll: float = 0.0 + customized_min_adv_bv_coll: float = 0.0 + + customized_avg_sdc_adv_coll: float = 0.0 + customized_avg_sdc_bv_coll: float = 0.0 + customized_avg_adv_bv_coll: float = 0.0 + + customized_avg_overall_coll: float = 0.0 + + customized_all_agent_coll: float = 0.0 + + def clean(self): + # If the entry is tensor, drop it to float. + for k, v in dataclasses.asdict(self).items(): + if isinstance(v, torch.Tensor): + setattr(self, k, v.item()) + + def aggregate(self): + self.clean() + + # Get all metrics + all_metrics = dataclasses.asdict(self) + for k, v in all_metrics.items(): + if k.startswith("coll_vel_sdc"): + if self.sdc_coll_scenario_count > 0: + all_metrics[k] = v / self.sdc_coll_scenario_count + else: + all_metrics[k] = torch.nan + + elif k.startswith("coll_vel_maxagent"): + if self.veh_coll_scenario_count > 0: + all_metrics[k] = v / self.veh_coll_scenario_count + else: + all_metrics[k] = torch.nan + + elif k != "scenario_count": + all_metrics[k] = v / self.scenario_count + return all_metrics + + +class Evaluator: + SECONDS_PER_STEP = 0.1 + + def __init__(self, CR_mode="mean", key_metrics_only=True, use_waymo=False): + assert CR_mode in ["min", "max", "mean"] + self.CR_mode = CR_mode + self.jsd_config = { + "vel": { + "min_val": 0.0, + "max_val": 50.0, + "num_bins": 100 + }, + "acc": { + "min_val": -10.0, + "max_val": 10.0, + "num_bins": 200 + }, + + # From WOSAC: https://github.com/waymo-research/waymo-open-dataset/blob/5f8a1cd42491210e7de629b6f8fc09b65e0cbe99/src/waymo_open_dataset/wdl_limited/sim_agents_metrics/challenge_2024_config.textproto#L80C1-L89C2 + "ttc": { + "min_val": 0.0, + "max_val": 5.0, + "num_bins": 10 + } + } + + self.metrics = Metrics() + self.key_metrics_only = key_metrics_only + self.use_waymo = use_waymo + + def filter_static_agents(self, gt_data_dict): + # return a mask for all static agetns (GT traj in both x and y less than 5m) + + mask = torch.zeros_like(gt_data_dict["decoder/agent_id"], dtype=torch.bool) + + for id in gt_data_dict["decoder/agent_id"]: + traj = gt_data_dict["decoder/agent_position"][:, id][gt_data_dict["decoder/agent_valid_mask"][:, id], :2] + diffs = traj[0] - traj[-1] # calcualte the difference of start and end index + + dist = torch.norm(diffs, dim=-1) + + if dist < 5: + mask[id] = 1 + + return mask + + def add(self, gt_data_dict, pred_data_dict, adv_list, bv_list, device=None): + + self.metrics.scenario_count += 1 + + T_gt = gt_data_dict["decoder/agent_position"].shape[0] + T_context = 0 + + T_pred = pred_data_dict["decoder/reconstructed_position"].shape[1] + B = K = pred_data_dict["decoder/reconstructed_position"].shape[0] + N = gt_data_dict["decoder/agent_position"].shape[1] + + vehicle_mask = numpy_to_torch(gt_data_dict["decoder/agent_type"] == 1, device=device) # (num agents) + static_agent_mask = self.filter_static_agents(gt_data_dict) + + ooi_mask = torch.zeros_like(vehicle_mask, dtype=torch.bool, device=device) + # ooi_mask[(gt_data_dict["decoder/object_of_interest_id"])] = 1 + # ooi_mask[(gt_data_dict["decoder/sdc_id"])] = 1 # now only predict OOI + ooi_mask[(gt_data_dict["decoder/agent_id"])] = 1 # (num agents) + + gt_valid_mask = numpy_to_torch( + gt_data_dict["decoder/agent_valid_mask"], device=device + ).T # (num agents, num steps) + pred_valid_mask = pred_data_dict["decoder/reconstructed_valid_mask"].transpose( + 1, 2 + ) # (K, num agents, num steps) + # joint_mask = vehicle_mask.unsqueeze(-1) & valid_mask # (num agents, num steps) + # gt_ooi_joint = ooi_mask.unsqueeze(-1) & gt_valid_mask[..., T_context:T_gt] # (num agents, num steps) + # pred_ooi_joint = ooi_mask[None, ..., None] & pred_valid_mask[..., T_context:T_gt] # (K, num agents, num steps) + + gt_ooi_joint = gt_valid_mask[..., T_context:T_gt] # (num agents, num steps) + pred_ooi_joint = pred_valid_mask[..., T_context:T_gt] # (K, num agents, num steps) + + gt_shape = numpy_to_torch( + gt_data_dict["decoder/current_agent_shape"][None], device=device + ).expand(T_gt, -1, -1).transpose(0, 1) + pred_shape = pred_data_dict["decoder/current_agent_shape"][:, None].expand(-1, T_pred, -1, -1).transpose(1, 2) + + sdc_index = int(gt_data_dict["decoder/sdc_index"]) + sdc_index_in_ooi = list(gt_data_dict["decoder/agent_id"]).index(sdc_index) + + # minSFDE + with timer("minSFDE"): + gt_pos = numpy_to_torch(gt_data_dict["decoder/agent_position"], device=device)[None, ..., :2] + pred_pos = pred_data_dict["decoder/reconstructed_position"][:, :T_gt] + + gt_valid = gt_ooi_joint[None] + gt_valid_skipped = gt_valid[:, :, ::5] + + last_valid_ind = gt_valid.cumsum(dim=-1).argmax(dim=-1) + last_valid_ind_skipped = gt_valid_skipped.cumsum(dim=-1).argmax(dim=-1) + + error = torch.linalg.norm(gt_pos - pred_pos, dim=-1) + assert error.ndim == 3 + assert error.shape[0] == B + + last_valid_ind = last_valid_ind.unsqueeze(0).expand(B, 1, N) + last_valid_ind_skipped = last_valid_ind_skipped.unsqueeze(0).expand(B, 1, N) + + assert last_valid_ind.shape == (B, 1, N) + fde = torch.gather(error, 1, last_valid_ind).squeeze(1) # shape: B, N + fde_skipped = torch.gather(error, 1, last_valid_ind_skipped * 5).squeeze(1) # shape: B, N + + assert fde.shape[0] == B + agent_valid = gt_valid.any(-1).expand(B, N) # shape: B, N + agent_valid_skipped = gt_valid_skipped.any(-1).expand(B, N) + + sfde = (fde * agent_valid).sum(-1) / agent_valid.sum(-1) + sfde_skipped = (fde_skipped * agent_valid_skipped).sum(-1) / agent_valid_skipped.sum(-1) + + gt_valid = gt_valid.permute(0, 2, 1) + gt_valid_skipped = gt_valid_skipped.permute(0, 2, 1) + + gt_valid_expand = gt_valid.expand(B, T_gt, N) + pred_ooi_joint_expand = pred_ooi_joint.permute(0, 2, 1) + sade_per_agent = (error * gt_valid).sum(1) / gt_valid.sum(1).clamp(1) + assert sade_per_agent.ndim == 2 + sade = (sade_per_agent * gt_valid.any(1)).sum(1) / gt_valid.any(1).sum(1) + assert sade.ndim == 1 + sade_skipped_per_agent = (error[:, ::5] * gt_valid_skipped).sum(1) / gt_valid_skipped.sum(1).clamp(1) + sade_skipped = sade_skipped_per_agent.sum(1) / gt_valid_skipped.any(1).sum(1) + + assert sfde.ndim == 1 + assert sfde.shape[0] == B + self.metrics.sfde_min += sfde.min() + self.metrics.sade_min += sade.min() + self.metrics.sfde_avg += sfde.mean() + self.metrics.sade_avg += sade.mean() + + self.metrics.skipped_sfde_min += sfde_skipped.min() + self.metrics.skipped_sade_min += sade_skipped.min() + self.metrics.skipped_sfde_avg += sfde_skipped.mean() + self.metrics.skipped_sade_avg += sade_skipped.mean() + + # Following wosac_eval, fill in z with GT t = 10 data + z_values = pred_data_dict["decoder/current_agent_position"][..., 2].unsqueeze(-1).expand(-1, -1, T_pred) + + # FDD + with timer("FDD"): + # there doesn't appear to be an easy way to do this with cartesian product + cur_FDD = None + pred_ooi_valid_mask = pred_valid_mask[:, ooi_mask] + single_mode_ooi_valid_mask = pred_ooi_valid_mask[0] + + # assert torch.all(torch.any(pred_ooi_valid_mask, dim=-1)) + last_valid_ind = pred_ooi_valid_mask.cumsum(dim=-1).argmax(dim=-1) # (K, N) + ooi_reconstructed_pos = pred_data_dict["decoder/reconstructed_position"][:, :, + ooi_mask] # (K, T_pred, N, 2) + last_valid_ind_reshaped = last_valid_ind[:, None, :, None].expand(-1, -1, -1, 2) + final_pos = torch.gather(ooi_reconstructed_pos, dim=1, index=last_valid_ind_reshaped).squeeze(1) + for i, j in itertools.product(range(K), range(K)): + final_dist = torch.linalg.norm(final_pos[i] - final_pos[j], dim=-1) + assert final_dist.ndim == 1 + if cur_FDD == None: + cur_FDD = final_dist + else: + cur_FDD = torch.maximum(cur_FDD, final_dist) + self.metrics.fdd += (cur_FDD * + single_mode_ooi_valid_mask.any(-1)).sum() / single_mode_ooi_valid_mask.any(-1).sum() + + def _add(pos, mask): + assert pos.ndim == 4 + assert pos.shape[0] == B + assert pos.shape[2] == N + assert pos.shape[3] == 2 + T = pos.shape[1] + pos = pos.reshape(B, -1, 2) + pos_NB = pos.swapaxes(0, 1) + dist_NBB = torch.cdist(pos_NB, pos_NB) + max_dist_N = dist_NBB.amax((1, 2)) + assert max_dist_N.shape == (N * T, ), max_dist_N.shape + max_dist_TN = max_dist_N.reshape(T, N) + assert mask.shape == max_dist_TN.shape, (mask.shape, max_dist_TN.shape) + avg_t = utils.masked_average(max_dist_TN, mask, dim=0) + assert avg_t.shape == (N, ) + return avg_t + + # add_full_all = utils.masked_average(_add(pred_pos, gt_valid[0]), agent_valid[0], dim=0) + add_skipped_all = utils.masked_average( + _add(ooi_reconstructed_pos, single_mode_ooi_valid_mask.swapaxes(0, 1)), + single_mode_ooi_valid_mask.any(-1), + dim=0 + ) + self.metrics.add += add_skipped_all + + with timer("Kinematic Metrics"): + gt_speed, gt_accel, gt_jerk = self._compute_kinematic_metrics( + gt_data_dict["decoder/agent_velocity"].swapaxes(1, 0), device + ) # (N, T) + pred_speed, pred_accel, pred_jerk = self._compute_kinematic_metrics( + pred_data_dict["decoder/reconstructed_velocity"].transpose(1, 2), device + ) # (K, N, T) + gt_speed = gt_speed[..., T_context:T_gt] + gt_accel = gt_accel[..., T_context:T_gt] + gt_jerk = gt_jerk[..., T_context:T_gt] + pred_speed = pred_speed[..., T_context:T_gt] + pred_accel = pred_accel[..., T_context:T_gt] + pred_jerk = pred_jerk[..., T_context:T_gt] + + if self.use_waymo: + candidate_agents = gt_data_dict["decoder/current_agent_valid_mask"] + if isinstance(candidate_agents, torch.Tensor): + candidate_agents = candidate_agents.cpu().numpy() + candidate_agents = candidate_agents.nonzero()[0] + + pred_veh_collisions, veh_cr_mode = calc_collision_rate( + candidate_agents=candidate_agents, + evaluate_agents=gt_data_dict["decoder/agent_id"], + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + + assert veh_cr_mode.shape[0] == B + self.metrics.veh_coll_avg += veh_cr_mode.mean() + self.metrics.veh_coll_min += veh_cr_mode.min() + self.metrics.veh_coll_max += veh_cr_mode.max() + + if adv_list is not None: + assert len(adv_list) == 1 + assert adv_list[0] not in bv_list + assert sdc_index not in adv_list + + # self.sdc_coll_adv_active = True + for kk in adv_list: + assert int(kk.item()) in candidate_agents + + adv_sdc_coll, adv_sdc_coll_rate = calc_collision_rate( + candidate_agents=adv_list, + evaluate_agents=pred_data_dict["decoder/sdc_index"], + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + + assert adv_sdc_coll_rate.ndim == 1 + self.metrics.sdc_adv_coll_avg += adv_sdc_coll_rate.mean() + self.metrics.sdc_adv_coll_min += adv_sdc_coll_rate.min() + self.metrics.sdc_adv_coll_max += adv_sdc_coll_rate.max() + + adv_bv_coll, adv_bv_coll_rate = calc_collision_rate( + candidate_agents=bv_list, + evaluate_agents=adv_list, + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + assert adv_bv_coll_rate.ndim == 1 + + assert adv_bv_coll_rate.shape[0] == B + self.metrics.adv_bv_coll_avg += adv_bv_coll_rate.mean() + self.metrics.adv_bv_coll_min += adv_bv_coll_rate.min() + self.metrics.adv_bv_coll_max += adv_bv_coll_rate.max() + + assert sdc_index not in bv_list + assert bv_list is not None + for kk in bv_list: + assert int(kk.item()) in candidate_agents + sdc_bv_coll, sdc_bv_coll_rate = calc_collision_rate( + candidate_agents=bv_list, + evaluate_agents=pred_data_dict["decoder/sdc_index"], + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + assert sdc_bv_coll_rate.ndim == 1 + + assert sdc_bv_coll_rate.shape[0] == B + self.metrics.sdc_bv_coll_avg += sdc_bv_coll_rate.mean() + self.metrics.sdc_bv_coll_min += sdc_bv_coll_rate.min() + self.metrics.sdc_bv_coll_max += sdc_bv_coll_rate.max() + + # map_feature = gt_data_dict["encoder/map_feature"] + map_feature = gt_data_dict["vis/map_feature"] + assert map_feature.ndim == 3 # This is unbatched. + + road_edges = [] + for i in range(map_feature.shape[0]): + # For each map feature + if map_feature[i, 0, 15] == 1: + map_feat = [] + for j in range(map_feature.shape[1]): + if gt_data_dict['encoder/map_feature_valid_mask'][i, j]: + map_feat.append( + map_pb2.MapPoint( + x=map_feature[i, j, 0], + y=map_feature[i, j, 1], + z=0 + # map_feature[i, j, 2] # let's say there is no z axis any more + ) + ) + road_edges.append(map_feat) + + eval_mask = ooi_mask & (~static_agent_mask) + env_nearest_distances = torch.stack( + [ + tf_to_torch( + map_metric_features.compute_distance_to_road_edge( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 1].T), + center_z=conv(z_values[k]), + length=conv(pred_shape[k, ..., 0]), + width=conv(pred_shape[k, ..., 1]), + height=conv(pred_shape[k, ..., 2]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k].T), + valid=conv(pred_valid_mask[k], dtype=tf.bool), + evaluated_object_mask=conv(eval_mask, dtype=tf.bool), + road_edge_polylines=road_edges, + ), + device=device + ) for k in range(K) + ] + ) + + pred_valid_mask = pred_valid_mask[..., T_context:T_gt] + + pred_env_collisions = torch.greater(env_nearest_distances, map_metric_features.OFFROAD_DISTANCE_THRESHOLD) + pred_env_collisions = pred_env_collisions[..., T_context:T_gt] + pred_env_collisions_traj_level = pred_env_collisions.any(dim=-1) # (B, num_ooi) + # Avg over agent dim. Here we assume all evaluated agents are valid so don't do the masked_avg + env_collision_rate = pred_env_collisions_traj_level.float().mean(-1) + + # ==================================== customized env collision rate ==================================== + # env_collision_rate = np.array(pred_env_collisions_traj_level).mean(-1) + assert env_collision_rate.ndim == 1 + + assert env_collision_rate.shape[0] == B + self.metrics.env_coll_avg += env_collision_rate.mean() + self.metrics.env_coll_min += env_collision_rate.min() + self.metrics.env_coll_max += env_collision_rate.max() + + step_wise_collision = pred_veh_collisions + scenario_has_collision = torch.any(step_wise_collision).item() + + if scenario_has_collision: + speed_when_collision = torch.where(step_wise_collision, pred_speed[:, ooi_mask], 0) + + coll_vel_max_agent = speed_when_collision.amax(dim=(-1, -2)) + coll_valid_mask = speed_when_collision.any(-1).any(-1) + coll_vel_max_agent = coll_vel_max_agent[coll_valid_mask] + + assert coll_vel_max_agent.ndim == 1 + + if coll_vel_max_agent.numel() != 0: + self.metrics.coll_vel_maxagent_avg += (coll_vel_max_agent).sum() / coll_valid_mask.sum().clamp(1) + self.metrics.coll_vel_maxagent_min += coll_vel_max_agent.min() + self.metrics.coll_vel_maxagent_max += coll_vel_max_agent.max() + + self.metrics.veh_coll_scenario_count += 1 + + sdc_speed = pred_speed[:, sdc_index] + sdc_coll = step_wise_collision[:, sdc_index_in_ooi] + + # sdc_speed_when_coll = torch.where(sdc_coll, sdc_speed, torch.nan).amax(-1) + sdc_speed_when_coll = torch.where(sdc_coll, sdc_speed, torch.nan) + valid_mask = ~torch.isnan(sdc_speed_when_coll) + + if torch.any(sdc_coll).item(): # if there is valid collision + self.metrics.coll_vel_sdc_avg += (sdc_speed_when_coll[valid_mask]).sum() / valid_mask.sum() + self.metrics.coll_vel_sdc_max += (sdc_speed_when_coll[valid_mask]).max() + self.metrics.coll_vel_sdc_min += (sdc_speed_when_coll[valid_mask]).min() + self.metrics.sdc_coll_scenario_count += 1 + assert sdc_speed_when_coll.shape[0] == B + + gt_ttc = tf_to_torch( + interaction_features.compute_time_to_collision_with_object_in_front( + center_x=conv(gt_data_dict["decoder/agent_position"][..., 0].T), + center_y=conv(gt_data_dict["decoder/agent_position"][..., 1].T), + length=conv(gt_shape[..., 0]), + width=conv(gt_shape[..., 1]), + heading=conv(gt_data_dict["decoder/agent_heading"].T), + valid=conv(gt_valid_mask, dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + seconds_per_step=self.SECONDS_PER_STEP + ), + device=device + ) + pred_ttc = torch.stack( + [ + tf_to_torch( + interaction_features.compute_time_to_collision_with_object_in_front( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 1].T), + length=conv(pred_shape[k, ..., 0]), + width=conv(pred_shape[k, ..., 1]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k].T), + valid=conv(pred_data_dict["decoder/reconstructed_valid_mask"][k].T, dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + seconds_per_step=self.SECONDS_PER_STEP + ), + device=device + ) for k in range(K) + ] + ) + gt_ttc = gt_ttc[..., T_context:T_gt] + pred_ttc = pred_ttc[..., T_context:T_gt] + + sdc_acc = torch.abs(torch.nan_to_num(pred_accel[:, sdc_index])) # Shape: (K, T) + sdc_mask = pred_valid_mask[:, sdc_index][:, T_context:T_gt] + sdc_acc_avgt = (sdc_acc * sdc_mask).sum(-1) / sdc_mask.sum(-1).clamp(1) + assert sdc_acc_avgt.ndim == 1 + assert sdc_acc_avgt.shape[0] == B + self.metrics.sdc_acc_avgtime_max += sdc_acc_avgt.max() + self.metrics.sdc_acc_avgtime_avg += sdc_acc_avgt.mean() + self.metrics.sdc_acc_avgtime_min += sdc_acc_avgt.min() + + sdc_acc_maxt = sdc_acc.amax(-1) + assert sdc_acc_maxt.ndim == 1 + assert sdc_acc_maxt.shape[0] == B + self.metrics.sdc_acc_maxtime_max += sdc_acc_maxt.max() + self.metrics.sdc_acc_maxtime_avg += sdc_acc_maxt.mean() + self.metrics.sdc_acc_maxtime_min += sdc_acc_maxt.min() + + sdc_jerk = torch.abs(torch.nan_to_num(pred_jerk[:, sdc_index])) # Shape: (K, T) + sdc_mask = pred_valid_mask[:, sdc_index][:, T_context:T_gt] + sdc_jerk_avgt = (sdc_jerk * sdc_mask).sum(-1) / sdc_mask.sum(-1).clamp(1) + assert sdc_jerk_avgt.ndim == 1 + assert sdc_jerk_avgt.shape[0] == B + self.metrics.sdc_jerk_avgtime_max += sdc_jerk_avgt.max() + self.metrics.sdc_jerk_avgtime_avg += sdc_jerk_avgt.mean() + self.metrics.sdc_jerk_avgtime_min += sdc_jerk_avgt.min() + + sdc_jerk_maxt = sdc_jerk.amax(-1) + assert sdc_jerk_maxt.ndim == 1 + assert sdc_jerk_maxt.shape[0] == B + self.metrics.sdc_jerk_maxtime_max += sdc_jerk_maxt.max() + self.metrics.sdc_jerk_maxtime_avg += sdc_jerk_maxt.mean() + self.metrics.sdc_jerk_maxtime_min += sdc_jerk_maxt.min() + + with timer("Histograms"): + gt_speed_hist, gt_speed_bins = torch.histogram( + torch.clip( + gt_speed[gt_ooi_joint & ~gt_speed.isnan()], self.jsd_config["vel"]["min_val"], + self.jsd_config["vel"]["max_val"] + ).cpu(), + self.jsd_config["vel"]["num_bins"], + density=False + ) + # .cpu() since histogram doesn't support cuda backend + pred_speed_hist, pred_speed_bins = torch.histogram( + torch.clip( + pred_speed[pred_ooi_joint & ~pred_speed.isnan()], self.jsd_config["vel"]["min_val"], + self.jsd_config["vel"]["max_val"] + ).cpu(), + self.jsd_config["vel"]["num_bins"], + density=False + ) + gt_accel_hist, gt_accel_bins = torch.histogram( + torch.clip( + gt_accel[gt_ooi_joint & ~gt_accel.isnan()], self.jsd_config["acc"]["min_val"], + self.jsd_config["acc"]["max_val"] + ).cpu(), + self.jsd_config["acc"]["num_bins"], + density=False + ) + pred_accel_hist, pred_accel_bins = torch.histogram( + torch.clip( + pred_accel[pred_ooi_joint & ~pred_accel.isnan()], self.jsd_config["acc"]["min_val"], + self.jsd_config["acc"]["max_val"] + ).cpu(), + self.jsd_config["acc"]["num_bins"], + density=False + ) + + if self.use_waymo: + gt_ttc_hist, gt_ttc_bins = torch.histogram( + torch.clip( + gt_ttc[gt_valid_mask[ooi_mask, T_context:T_gt] & ~gt_ttc.isnan()], + self.jsd_config["ttc"]["min_val"], self.jsd_config["ttc"]["max_val"] + ).cpu(), + self.jsd_config["ttc"]["num_bins"], + density=False + ) + pred_ttc_hist, pred_ttc_bins = torch.histogram( + torch.clip( + pred_ttc[pred_valid_mask[:, ooi_mask, T_context:T_gt] & ~pred_ttc.isnan()], + self.jsd_config["ttc"]["min_val"], self.jsd_config["ttc"]["max_val"] + ).cpu(), + self.jsd_config["ttc"]["num_bins"], + density=False + ) + + with timer("JSD"): + speed_jsd = jsd(gt_speed_hist, pred_speed_hist) + acc_jsd = jsd(gt_accel_hist, pred_accel_hist) + self.metrics.vel_jsd += speed_jsd + self.metrics.acc_jsd += acc_jsd + if self.use_waymo: + ttc_jsd = jsd(gt_ttc_hist, pred_ttc_hist) + self.metrics.ttc_jsd += ttc_jsd + + def _compute_kinematic_metrics(self, vel, device): + if type(vel) == np.ndarray: + vel = numpy_to_torch(vel, device=device) + speed = torch.linalg.norm(vel, axis=-1) + accel = self._central_diff(speed, device, pad_value=torch.nan) / self.SECONDS_PER_STEP + jerk = self._central_diff(accel, device, pad_value=torch.nan) / self.SECONDS_PER_STEP + return speed, accel, jerk + + def _central_diff(self, tensor, device, pad_value=torch.nan): + pad_shape = (*tensor.shape[:-1], 1) + pad_tensor = torch.ones(pad_shape, device=device) * pad_value + diff_t = (tensor[..., 2:] - tensor[..., :-2]) / 2 + return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1) + + def add_customized_CR( + self, + max_sdc_adv_cr=None, + max_sdc_bv_cr=None, + max_adv_bv_cr=None, + min_sdc_adv_cr=None, + min_sdc_bv_cr=None, + min_adv_bv_cr=None, + avg_sdc_adv_cr=None, + avg_sdc_bv_cr=None, + avg_adv_bv_cr=None, + all_agent_cr=None + ): + if max_sdc_adv_cr is not None: + self.metrics.customized_max_sdc_adv_coll += max_sdc_adv_cr + if max_sdc_bv_cr is not None: + self.metrics.customized_max_sdc_bv_coll += max_sdc_bv_cr + + if max_adv_bv_cr is not None: + self.metrics.customized_max_adv_bv_coll += max_adv_bv_cr + + if min_sdc_adv_cr is not None: + self.metrics.customized_min_sdc_adv_coll += min_sdc_adv_cr + if min_sdc_bv_cr is not None: + self.metrics.customized_min_sdc_bv_coll += min_sdc_bv_cr + if min_adv_bv_cr is not None: + self.metrics.customized_min_adv_bv_coll += min_adv_bv_cr + + if avg_sdc_adv_cr is not None: + self.metrics.customized_avg_sdc_adv_coll += avg_sdc_adv_cr + if avg_sdc_bv_cr is not None: + self.metrics.customized_avg_sdc_bv_coll += avg_sdc_bv_cr + if avg_adv_bv_cr is not None: + self.metrics.customized_avg_adv_bv_coll += avg_adv_bv_cr + + if all_agent_cr is not None: + self.metrics.customized_all_agent_coll += all_agent_cr + + def aggregate(self): + return self.metrics.aggregate() + + def print(self): + metrics = self.metrics.aggregate() + print("\n=====================================") + print("Evaluation Metrics:") + print(utils.pretty_print(metrics)) + print("=====================================") + return metrics + + def save(self, save_path=None): + if save_path is None: + save_path = "evaluation_results" + + metrics = self.metrics.aggregate() + metrics["save_path"] = save_path + + # Save a json: + import json + json_file = save_path + ".json" + with open(json_file, "w") as f: + json.dump(metrics, f, indent=4) + print(f"Saved metrics to {json_file}") + + # Save a csv: + import pandas as pd + df = pd.DataFrame([metrics]) + csv_file = save_path + ".csv" + df.to_csv(csv_file, index=False) + print(f"Saved metrics to {csv_file}") + + return metrics + + +class TurnAction: + STOP = 0 + KEEP_STRAIGHT = 1 + TURN_LEFT = 2 + TURN_RIGHT = 3 + U_TURN = 4 + + num_actions = 5 + + +class AccelerationAction: + STOP = 0 + KEEP_SPEED = 1 + SPEED_UP = 2 + SLOW_DOWN = 3 + + num_actions = 4 + + +class SafetyAction: + SAFE = 0 + COLLISION = 1 + num_actions = 2 + + +def detect_collision(contour_list1, mask1, contour_list2, mask2): + collision_detected = [] + + contour_list1, len1 = contour_list1 + contour_list2, len2 = contour_list2 + + assert len(contour_list1) == len(contour_list2) + + for i in range(len(contour_list1)): + if mask1[i] and mask2[i]: + pos1 = contour_list1[i].mean(0) + pos2 = contour_list2[i].mean(0) + dist = np.linalg.norm(pos1 - pos2) + + # PZH: Actually the largest possible distance is sqrt(2)/2*(len1 + len2) + # We relax it to (len1+len2) + if dist > (len1 + len2): + collision_detected.append(False) + continue + + poly1 = Polygon(contour_list1[i]) + poly2 = Polygon(contour_list2[i]) + + if poly1.intersects(poly2): + collision_detected.append(True) + else: + collision_detected.append(False) + else: + collision_detected.append(False) + + return collision_detected + + +def get_2D_collision_labels(data_dict, track_agent_indicies): + # Now, instead of getting 1d-array of collision labels, let's do 2-d array to detect whether there is collision between given two agents. + + safety_actions = torch.zeros((track_agent_indicies.shape[0], track_agent_indicies.shape[0]), dtype=int) # plus sdc + + contours = [] + for agent1_id in track_agent_indicies: + traj = data_dict["decoder/agent_position"][:91, agent1_id, :] # (91, 3) + length = data_dict["decoder/agent_shape"][10, agent1_id, 0] + width = data_dict["decoder/agent_shape"][10, agent1_id, 1] + theta = data_dict['decoder/agent_heading'][:91, agent1_id] # (91, ) # in pi + mask = data_dict['decoder/agent_valid_mask'][:91, agent1_id] # (91,) + poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length) + contours.append((poly, length)) + + for i in range(track_agent_indicies.shape[0] - 1): + for j in range(i + 1, track_agent_indicies.shape[0]): + mask_1 = data_dict['decoder/agent_valid_mask'][:91, track_agent_indicies[i]] # (91,) + mask_2 = data_dict['decoder/agent_valid_mask'][:91, track_agent_indicies[j]] + collision_detected = detect_collision(contours[i], mask_1, contours[j], mask_2) + + if any(collision_detected): + # print(f"Collision between {i} and {j} happen at step: {np.array(collision_detected).nonzero()}") + safety_actions[i][j] = 1 # Label collisions for OOIs now. Later we will build a larger dict. + safety_actions[j][i] = 1 # Label collisions for OOIs now. Later we will build a larger dict. + + assert np.array_equal(safety_actions, safety_actions.T), "The 2D label is not symmetrical" + return safety_actions + + +def _get_mode(output_dict, mode, num_modes): + ret = {} + for k, v in output_dict.items(): + if isinstance(v, np.ndarray) and len(v) == num_modes: + ret[k] = v[mode] + else: + ret[k] = v + return ret + + +class EvaluationLightningModule(pl.LightningModule): + def __init__( + self, + model, + evaluator: Evaluator, + tokenizer, + config, + # dataset, + autoregressive_start_step, + num_modes=1, + save_path=None, + use_waymo=False + ): + super().__init__() + self.model = model.to("cuda" if torch.cuda.is_available() else "cpu") + self.evaluator = evaluator + self.tokenizer = tokenizer + self.config = config + # self.dataset = dataset + self.num_modes = num_modes + self.cat_summary = None + self.baseline_summary = None + self.adv_index = None + self.sid = None + self.save_path = save_path + self.autoregressive_start_step = autoregressive_start_step + assert save_path is not None, "Please specify the save path for the evaluation results." + + def GPT_AR(self, input_data, backward_prediction=False, teacher_forcing=False): + assert not teacher_forcing + assert not backward_prediction + autoregressive_start_step = self.autoregressive_start_step + from scenestreamer.infer.motion import generate_motion + return generate_motion( + data_dict=input_data, + model=self.model.model, + autoregressive_start_step=autoregressive_start_step, + allow_newly_added_agent_step=2, + teacher_forcing_sdc=False + ) + + def preprocess_GPTmodel(self, raw_data, backward_prediction=False): + input_data = utils.numpy_to_torch(raw_data, device=self.model.device) + input_data["in_evaluation"] = torch.tensor([1], dtype=bool).to(self.model.device) + + input_data = { + # k: utils.expand_for_modes(v, num_modes=self.num_modes) if isinstance(v, torch.Tensor) else v + k: utils.expand_for_modes(v.unsqueeze(0), num_modes=self.num_modes) if isinstance(v, torch.Tensor) else v + for k, v in input_data.items() + } + + # Force to run backward prediction first to make sure the data is tokenized correctly!!! + tok_data_dict, _ = self.tokenizer.tokenize(input_data, backward_prediction=backward_prediction) + input_data.update(tok_data_dict) + + if not backward_prediction: # handle backward flag + if self.config.BACKWARD_PREDICTION: + input_data["in_backward_prediction"] = torch.tensor( + [False] * self.num_modes, dtype=bool + ).to(self.model.device) + else: + input_data["in_backward_prediction"] = torch.tensor( + [True] * self.num_modes, dtype=bool + ).to(self.model.device) + + return input_data + + def validation_step(self, batch, batch_idx): + + data_dict = copy.deepcopy(batch) + input_data = numpy_to_torch(data_dict, device=self.model.device) + original_data_dict_tensor = copy.deepcopy(input_data) + + input_data = self.preprocess_GPTmodel(batch) + + with torch.no_grad(): + output_data = self.GPT_AR(input_data) + + gathered_output = output_data + + avg_sdc_adv_cr, avg_sdc_bv_cr, avg_adv_bv_cr, all_agent_cr = self.calculate_collision_statistics( + output_data, + is_CAT_data=False, + ) + self.evaluator.add_customized_CR( + avg_sdc_adv_cr=avg_sdc_adv_cr, + avg_adv_bv_cr=avg_adv_bv_cr, + avg_sdc_bv_cr=avg_sdc_bv_cr, + all_agent_cr=all_agent_cr + ) + + all_agents = batch["decoder/agent_id"] # prepare parameters for differet CR metrics + sdc_id = batch["decoder/sdc_index"] + all_agents_except_sdc = all_agents[all_agents != sdc_id] + self.evaluator.add( + original_data_dict_tensor, + gathered_output, + adv_list=None, + bv_list=all_agents_except_sdc, + device=self.device + ) + + return gathered_output + + def on_test_epoch_end(self): + self.trainer.strategy.barrier() # ensure all processes are done with evaluation + if self.trainer.is_global_zero: + self.evaluator.print() + self.evaluator.save(self.save_path) + + def on_validation_epoch_end(self): + self.trainer.strategy.barrier() # ensure all processes are done with evaluation + if self.trainer.is_global_zero: + self.evaluator.print() + self.evaluator.save(self.save_path) + + def configure_optimizers(self): + # No optimizer required for evaluation + return None + + def calculate_collision_statistics(self, output_data, cr_mode="avg", is_CAT_data=False): + + ooi_ind = output_data["decoder/agent_id"][0] # ooi is all agent + # from scenestreamer.dataset.preprocess_action_label import get_2D_collision_labels + + output_data_all_modes = { + k: (v.cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in output_data.items() + } + + output_data_all_modes = _overwrite_datadict_all_agents( + source_data_dict=output_data_all_modes, dest_data_dict=output_data_all_modes + ) # overwrite pred to GT + + num_modes = self.num_modes + + sdc_adv_col = 0 + sdc_bv_col = 0 + adv_bv_col = 0 + + avg_sdc_adv_col = 0 + avg_sdc_bv_col = 0 + avg_adv_bv_col = 0 + + avg_all_agent_cr = 0 + num_bv_agent = 0 + + for i in range(num_modes): + output_dict_mode = _get_mode(output_data_all_modes, i, num_modes=num_modes) + + col_label = get_2D_collision_labels(data_dict=output_dict_mode, track_agent_indicies=ooi_ind) + + sdc_index = 0 + adv_index = self.adv_index # value is None for eval_mode = GPTmodel + + if adv_index is not None and col_label[sdc_index][adv_index]: + sdc_adv_col += 1 + + for agent_id in ooi_ind: + if agent_id == adv_index or agent_id == sdc_index: + continue + + if col_label[sdc_index][agent_id]: + sdc_bv_col += 1 + + if adv_index is not None and col_label[adv_index][agent_id]: + adv_bv_col += 1 + + avg_sdc_adv_col += sdc_adv_col + avg_sdc_bv_col += sdc_bv_col + avg_adv_bv_col += adv_bv_col + avg_all_agent_cr += np.sum(np.triu(col_label, k=1)) / ooi_ind.shape[0] + + num_bv_agent += ooi_ind.shape[0] - 1 # only sdc no adv + + if num_bv_agent > 0: + avg_sdc_bv_cr = avg_sdc_bv_col / (num_modes * num_bv_agent) + avg_adv_bv_cr = avg_adv_bv_col / (num_modes * num_bv_agent) + + else: + avg_sdc_bv_cr = None + avg_adv_bv_cr = None + + avg_sdc_adv_cr = avg_sdc_adv_col / num_modes + avg_all_agent_cr = avg_all_agent_cr / num_modes + + return avg_sdc_adv_cr, avg_sdc_bv_cr, avg_adv_bv_cr, avg_all_agent_cr + + +if __name__ == '__main__': + from pytorch_lightning import Trainer + from scenestreamer.utils import utils + from scenestreamer.dataset.dataset import SceneStreamerDataset + + pl_model = utils.get_model( + huggingface_repo="pengzhenghao97/scenestreamer_0301", + huggingface_file="0228_MidGPT_V19_WTG_addstep_2025-02-28_epoch=14-step=426133.ckpt" + ) + device = pl_model.device + config = pl_model.config + config.DATA.TRAINING_DATA_DIR = "data/20scenarios" + config.PREPROCESSING.keep_all_data = True + + exp_name = "0307_arstep2_yuxin500" + autoregressive_start_step = 2 + limit_test_batches = 5000000 + use_waymo = False + # config.DATA.TEST_DATA_DIR = "data/20scenarios" + # config.DATA.TEST_DATA_DIR = "/data/datasets/scenarionet/waymo/validation" + config.DATA.TEST_DATA_DIR = "/bigdata/yuxin/scenarionet_waymo_training_500" + + num_modes = 6 + save_path = "{}_open_loop_results".format(exp_name) + test_bs = 1 + + tokenizer = pl_model.model.tokenizer + evaluator = Evaluator(key_metrics_only=False, use_waymo=use_waymo) + + from scenestreamer.dataset.datamodule import SceneStreamerDataModule + dataset = SceneStreamerDataset(config, "test") + dataloader = DataLoader(dataset, batch_size=test_bs, collate_fn=lambda x: x[0]) + + evaluation_module = EvaluationLightningModule( + pl_model, + evaluator, + tokenizer, + config, + # dataset, + num_modes=num_modes, + save_path=save_path, + autoregressive_start_step=autoregressive_start_step, + ) + trainer = Trainer(limit_test_batches=limit_test_batches) + + # datamodule = SceneStreamerDataModule( + # config, + # train_batch_size=1, + # train_num_workers=0, + # train_prefetch_factor=0, + # val_batch_size=1, + # val_num_workers=0, + # val_prefetch_factor=0, + # ) + # datamodule.setup("") + # dataloader = datamodule.val_dataloader() + + trainer.validate(evaluation_module, dataloaders=dataloader) diff --git a/scenestreamer/eval/lmdb_evaluator.py b/scenestreamer/eval/lmdb_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..9644c8557f1b4db56c6d7141d74ac034f0c82f09 --- /dev/null +++ b/scenestreamer/eval/lmdb_evaluator.py @@ -0,0 +1,203 @@ +"""This is a fake evaluator which save the data to LMDB dataset.""" +""" +Script to generate submission files for Waymo SimAgent Challenge. +Please check out the end of this file where we provide a script to merge submission files. +""" +import copy +import os +import pathlib +import uuid +import io +import numpy as np +import torch + +from scenestreamer.dataset.preprocessor import centralize_to_map_center +from scenestreamer.eval.waymo_motion_prediction_evaluator import _repeat_for_modes +from scenestreamer.eval.wosac_eval import wosac_evaluation +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import wrap_to_pi, rotate + +import json +import os +import pathlib +import pickle +import multiprocessing as mp +from functools import partial +import tqdm +import hydra +import lmdb +import omegaconf +import tqdm + +from scenestreamer.dataset.dataset import SceneStreamerDataset + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent + + +class LMDBBulkWriter: + def __init__(self, base_path, max_size=1e9): + """ + Initializes the LMDBBulkWriter to save all data in batches, with map_size for each LMDB file. + Args: + base_path: Directory path to save LMDB files. + max_size: Maximum size of each LMDB file in bytes. + """ + self.base_path = base_path + # Create the cache directory if it doesn't exist + os.makedirs(self.base_path, exist_ok=True) + + self.max_size = int(max_size) # Set the max LMDB file size (e.g., 1 GB) + self.current_db_index = 0 + self.lookup = {} # Lookup table to track which LMDB file stores which sample + self.current_db = self._open_new_lmdb(self.current_db_index) + self.per_shard_size = 0 + + self.sample_buffer = [] + + def _data_to_npz_bytes(self, data): + """Convert data to .npz compressed bytes.""" + with io.BytesIO() as buffer: + np.savez_compressed(buffer, **data) # Save dictionary elements as separate arrays in .npz + return buffer.getvalue() # Retrieve the bytes + + def _open_new_lmdb(self, db_index): + """Opens a new LMDB file for saving samples.""" + db_path = f"{self.base_path}/data_{db_index}.lmdb" + return lmdb.open(db_path, map_size=self.max_size) + + def _save_a_batch(self): + + try: + print(f"Saving {len(self.sample_buffer)} samples to data_{self.current_db_index}.lmdb") + with self.current_db.begin(write=True) as txn: + for key, data in self.sample_buffer: + npz_bytes = self._data_to_npz_bytes(data) + txn.put(key.encode('ascii'), npz_bytes) + self.lookup[key] = str((self.base_path / f"data_{self.current_db_index}.lmdb").absolute().resolve()) + + self.sample_buffer.clear() + + except lmdb.MapFullError: + + # If current LMDB file is full, create a new one and retry saving + self.current_db.close() + self.current_db_index += 1 + print(f"Creating new LMDB file: data_{self.current_db_index}.lmdb (size: {self.per_shard_size})") + self.current_db = self._open_new_lmdb(self.current_db_index) + self._save_a_batch() + self.per_shard_size = 0 + + def save_sample(self, key, data): + """Saves a sample to the current LMDB file, switching to a new file if necessary.""" + # Batch writes into a single transaction + if self.per_shard_size % 100 == 0: + self._save_a_batch() + self.sample_buffer.append((key, data)) + self.per_shard_size += 1 + + def close(self): + self._save_a_batch() + """Closes the LMDB environment and saves the lookup table as a JSON file.""" + self.current_db.close() + # Save the lookup table to track the LMDB file where each sample is stored + with open(f"{self.base_path}/lookup.json", "w") as f: + json.dump(self.lookup, f) + + +def transform_to_global_coordinate(data_dict): + map_center = data_dict["metadata/map_center"].reshape(-1, 1, 1, 3) + map_heading = data_dict["metadata/map_heading"].reshape(-1, 1, 1) + B, T, N, _ = data_dict["decoder/agent_position"].shape + map_heading = map_heading.repeat(T, axis=1).repeat(N, axis=2) + assert map_heading.shape == (B, T, N) + data_dict["decoder/agent_position"] = rotate( + x=data_dict["decoder/agent_position"][..., 0], + y=data_dict["decoder/agent_position"][..., 1], + angle=map_heading, + z=data_dict["decoder/agent_position"][..., 2] + ) + assert data_dict["decoder/agent_position"].ndim == 4 + data_dict["decoder/agent_position"] += map_center + + data_dict["decoder/agent_heading"] = wrap_to_pi(data_dict["decoder/agent_heading"] + map_heading) + + data_dict["decoder/agent_velocity"] = rotate( + x=data_dict["decoder/agent_velocity"][..., 0], + y=data_dict["decoder/agent_velocity"][..., 1], + angle=map_heading, + ) + + data_dict["pred_trajs"] = [ + centralize_to_map_center( + traj, map_center=-data_dict["expanded_map_center"][b], map_heading=-data_dict["expanded_map_heading"][b] + ) for b, traj in enumerate(data_dict["pred_trajs"]) + ] + + return data_dict + + +scenario_metrics_keys = [ + # 'scenario_id', + 'metametric', + 'average_displacement_error', + 'min_average_displacement_error', + 'linear_speed_likelihood', + 'linear_acceleration_likelihood', + 'angular_speed_likelihood', + 'angular_acceleration_likelihood', + 'distance_to_nearest_object_likelihood', + 'collision_indication_likelihood', + 'time_to_collision_likelihood', + 'distance_to_road_edge_likelihood', + 'offroad_indication_likelihood' +] + +aggregate_metrics_keys = [ + 'realism_meta_metric', 'kinematic_metrics', 'interactive_metrics', 'map_based_metrics', 'min_ade' +] + + +def scenario_metrics_to_dict(scenario_metrics): + return {k: getattr(scenario_metrics, k) for k in scenario_metrics_keys} + + +def aggregate_metrics_to_dict(aggregate_metrics): + return {k: getattr(aggregate_metrics, k) for k in aggregate_metrics_keys} + + +import msgpack + + +class LMDBEvaluator: + def __init__(self, config): + self.config = config + + self.writer = None + + def validation_step( + self, data_dict, batch_idx, model, log_dict_func, global_rank, logger, lightning_model, **kwargs + ): + + if self.writer is None: + cache_folder = REPO_ROOT / self.config.DATA.TEST_DATA_DIR / "cache" + cache_folder.mkdir(parents=True, exist_ok=True) + cache_folder = cache_folder / "rank_{}".format(global_rank) + self.writer = LMDBBulkWriter(base_path=cache_folder, max_size=1e10) + + B = data_dict["decoder/input_action_valid_mask"].shape[0] + for b in range(B): + data_dict_b = {k: v[b] for k, v in data_dict.items()} + new_data_dict = {} + for k, v in data_dict_b.items(): + if isinstance(v, torch.Tensor): + new_data_dict[k] = v.cpu().numpy() + elif isinstance(v, np.ndarray) and v.dtype == np.str_: + assert v.shape == () + new_data_dict[k] = v.item() + else: + new_data_dict[k] = v + self.writer.save_sample(data_dict_b["file_name"], new_data_dict) + + def on_validation_epoch_end(self, *args, global_rank, logger, trainer, **kwargs): + pass + self.writer.close() diff --git a/scenestreamer/eval/nms.py b/scenestreamer/eval/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..fb67c81f2ac7ee731cb3d7843c32e09ae120f601 --- /dev/null +++ b/scenestreamer/eval/nms.py @@ -0,0 +1,72 @@ +import numpy as np +import torch + + +def batch_nms( + predicted_trajectories, + predicted_scores, + pred_to_scenario_id, + dist_thresh, + num_ret_modes=6, + num_original_modes=6, +): + """ + Copy from MTR. Modified to support our data. + """ + ret_predicted_trajectories = [] + ret_predicted_scores = [] + B = len(predicted_trajectories) + num_scenarios = B // num_original_modes + assert num_scenarios * num_original_modes == B + + for sid in range(num_scenarios): + assert len(np.unique(pred_to_scenario_id[sid * num_original_modes:(sid + 1) * num_original_modes])) == 1 + batch_traj = predicted_trajectories[sid * num_original_modes:(sid + 1) * num_original_modes] + batch_scores = predicted_scores[sid * num_original_modes:(sid + 1) * num_original_modes] + batch_traj = torch.stack(batch_traj, dim=1) + batch_scores = torch.stack(batch_scores, dim=1) + pred_scores = batch_scores + + # batch_traj shape is (T, num_modes, N, 2) + batch_traj = batch_traj.permute(2, 1, 0, 3) + # Now becomes (N, num_modes, T, 2) + pred_trajs = batch_traj + + batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape + + sorted_idxs = pred_scores.argsort(dim=-1, descending=True) + bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) + sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] + sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) + sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) + + dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) + point_cover_mask = (dist < dist_thresh) + + point_val = sorted_pred_scores.clone() # (batch_size, N) + point_val_selected = torch.zeros_like(point_val) # (batch_size, N) + + ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() + ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) + ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) + bs_idxs = torch.arange(batch_size).type_as(ret_idxs) + + for k in range(num_ret_modes): + cur_idx = point_val.argmax(dim=-1) # (batch_size) + ret_idxs[:, k] = cur_idx + + new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) + point_val = point_val * (~new_cover_mask).float() # (batch_size, N) + point_val_selected[bs_idxs, cur_idx] = -1 + point_val += point_val_selected + + ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] + ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] + + ret_trajs = ret_trajs.permute(1, 2, 0, 3) # (N, num_modes, T, 2) -> (num_modes, T, N, 2) + ret_scores = ret_scores.permute(1, 0) # (N, num_modes) -> (num_modes, N) + + ret_predicted_trajectories.extend(list(ret_trajs)) + ret_predicted_scores.extend(list(ret_scores)) + + return ret_predicted_trajectories, ret_predicted_scores diff --git a/scenestreamer/eval/peng_evaluator.py b/scenestreamer/eval/peng_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..78b2f86d4bb602077b8f45b15abf10812e4db7c1 --- /dev/null +++ b/scenestreamer/eval/peng_evaluator.py @@ -0,0 +1,1163 @@ +# Referenced from https://github.com/Tsinghua-MARS-Lab/InterSim/blob/main/simulator/proto.py + +import copy +import dataclasses +import itertools + +import hydra +import numpy as np +import omegaconf +import tensorflow as tf +import torch +import tqdm +from shapely.geometry import Polygon +from tqdm import tqdm +from waymo_open_dataset.protos import map_pb2 +from waymo_open_dataset.wdl_limited.sim_agents_metrics import interaction_features +from waymo_open_dataset.wdl_limited.sim_agents_metrics import map_metric_features + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour +from scenestreamer.infer.motion import generate_motion +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils +from scenestreamer.utils.utils import numpy_to_torch + +try: + from waymo_open_dataset.protos import motion_submission_pb2 +except ModuleNotFoundError: + motion_submission_pb2 = None + + +def _overwrite_datadict_all_agents(source_data_dict, dest_data_dict, ooi=None): + import copy + new_data_dict = copy.deepcopy(dest_data_dict) + B, T, N, _ = source_data_dict["decoder/reconstructed_position"].shape + + if ooi is None: + ooi = np.arange(N) + + for id in ooi: # overwrite all agents + traj = source_data_dict["decoder/reconstructed_position"][:, :91, id, ] + traj_mask = source_data_dict["decoder/reconstructed_valid_mask"][:, :91, id] + theta = source_data_dict['decoder/reconstructed_heading'][:, :91, id] + vel = source_data_dict['decoder/reconstructed_velocity'][:, :91, id] + + new_data_dict["decoder/agent_position"][:, :, id, :2] = traj + new_data_dict["decoder/agent_position"][:, :, id, 2] = 0.0 + new_data_dict["decoder/agent_valid_mask"][:, :, id] = traj_mask + new_data_dict["decoder/agent_heading"][:, :, id] = theta + new_data_dict["decoder/agent_velocity"][:, :, id] = vel + + return new_data_dict + + +def _get_mode(output_dict, mode, num_modes): + ret = {} + for k, v in output_dict.items(): + if isinstance(v, np.ndarray) and len(v) == num_modes: + ret[k] = v[mode] + else: + ret[k] = v + return ret + + +def get_2D_collision_labels(data_dict, track_agent_indicies): + # Now, instead of getting 1d-array of collision labels, let's do 2-d array to detect whether there is collision between given two agents. + + assert data_dict["decoder/agent_position"].ndim == 3 + + N = data_dict["decoder/agent_position"].shape[1] + + safety_actions = np.zeros((N, N), dtype=bool) + + contours = [] + for agent1_id in track_agent_indicies: + traj = data_dict["decoder/agent_position"][:91, agent1_id, :] # (91, 3) + length = data_dict["decoder/agent_shape"][10, agent1_id, 0] + width = data_dict["decoder/agent_shape"][10, agent1_id, 1] + theta = data_dict['decoder/agent_heading'][:91, agent1_id] # (91, ) # in pi + mask = data_dict['decoder/agent_valid_mask'][:91, agent1_id] # (91,) + poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length) + contours.append((poly, length)) + + for i in range(track_agent_indicies.shape[0] - 1): + for j in range(i + 1, track_agent_indicies.shape[0]): + mask_1 = data_dict['decoder/agent_valid_mask'][:91, track_agent_indicies[i]] # (91,) + mask_2 = data_dict['decoder/agent_valid_mask'][:91, track_agent_indicies[j]] + collision_detected = detect_collision(contours[i], mask_1, contours[j], mask_2) + + if any(collision_detected): + safety_actions[track_agent_indicies[i]][ + track_agent_indicies[j]] = 1 # Label collisions for OOIs now. Later we will build a larger dict. + safety_actions[track_agent_indicies[j]][ + track_agent_indicies[i]] = 1 # Label collisions for OOIs now. Later we will build a larger dict. + return safety_actions + + +def detect_collision(contour_list1, mask1, contour_list2, mask2): + collision_detected = [] + + contour_list1, len1 = contour_list1 + contour_list2, len2 = contour_list2 + + assert len(contour_list1) == len(contour_list2) + + for i in range(len(contour_list1)): + if mask1[i] and mask2[i]: + pos1 = contour_list1[i].mean(0) + pos2 = contour_list2[i].mean(0) + dist = np.linalg.norm(pos1 - pos2) + + # PZH: Actually the largest possible distance is sqrt(2)/2*(len1 + len2) + # We relax it to (len1+len2) + if dist > (len1 + len2): + collision_detected.append(False) + continue + + poly1 = Polygon(contour_list1[i]) + poly2 = Polygon(contour_list2[i]) + + if poly1.intersects(poly2): + collision_detected.append(True) + else: + collision_detected.append(False) + else: + collision_detected.append(False) + + return collision_detected + + +@dataclasses.dataclass +class Metrics: + scenario_count: int = 0 + sdc_coll_scenario_count: int = 0 + veh_coll_scenario_count: int = 0 + + # Diversity + sfde_avg: float = 0.0 + sade_avg: float = 0.0 + sfde_min: float = 0.0 # (supervised) avg over scenarios: minimum over all modes: average of L2 error of final positions of all agents + sade_min: float = 0.0 + + skipped_sfde_avg: float = 0.0 + skipped_sade_avg: float = 0.0 + skipped_sfde_min: float = 0.0 # (supervised) avg over scenarios: minimum over all modes: average of L2 error of final positions of all agents + skipped_sade_min: float = 0.0 + + fdd: float = 0.0 # (unsupervised) avg over scenarios: average over all agents: maximum L2 distance in final position of that agent between generated modes + # Xuanhao: In MixSim paper, they used squared norm of distance, but maybe they meant L2 norm not squared norm? + # Unit given in AdvDiffuser for FDD is m not m^2, so I am using L2 norm here + + # Distribution Realism + vel_jsd: float = 0.0 # avg over scenarios: build histogram across agents, modes, timestamps: velocity JS divergence + acc_jsd: float = 0.0 # avg over scenarios: build histogram across agents, modes, timestamps: acceleration JS divergence + ttc_jsd: float = 0.0 # avg over scenarios: build histogram across agents, modes, timestamps: time to collision JS divergence + + # Common Sense + env_coll_max: float = 0.0 # offroad + env_coll_min: float = 0.0 # offroad + env_coll_avg: float = 0.0 # offroad + + veh_coll_max: float = 0.0 # collision rate + veh_coll_min: float = 0.0 # collision rate + veh_coll_avg: float = 0.0 # collision rate + + # SDC-ADV coll + sdc_adv_coll_max: float = 0.0 # collision rate + sdc_adv_coll_min: float = 0.0 # collision rate + sdc_adv_coll_avg: float = 0.0 # collision rate + + sdc_bv_coll_max: float = 0.0 # collision rate + sdc_bv_coll_min: float = 0.0 # collision rate + sdc_bv_coll_avg: float = 0.0 # collision rate + + adv_bv_coll_max: float = 0.0 # collision rate + adv_bv_coll_min: float = 0.0 # collision rate + adv_bv_coll_avg: float = 0.0 # collision rate + + coll_vel_maxagent_avg: float = 0.0 # collision velocity max over agents, avg over modes + coll_vel_maxagent_max: float = 0.0 # collision velocity max over agents, max over modes + coll_vel_maxagent_min: float = 0.0 # collision velocity max over agents, min over modes + coll_vel_sdc_avg: float = 0.0 # collision velocity only for SDC + coll_vel_sdc_max: float = 0.0 # collision velocity only for SDC, max over modes + coll_vel_sdc_min: float = 0.0 # collision velocity only for SDC, min over modes + + # no clue what collision JSD means so not calculating it for now + + # AV comfortable + sdc_acc_maxtime_avg: float = 0.0 + sdc_acc_maxtime_min: float = 0.0 + sdc_acc_maxtime_max: float = 0.0 + sdc_acc_avgtime_avg: float = 0.0 + sdc_acc_avgtime_min: float = 0.0 + sdc_acc_avgtime_max: float = 0.0 + + sdc_jerk_maxtime_avg: float = 0.0 + sdc_jerk_maxtime_min: float = 0.0 + sdc_jerk_maxtime_max: float = 0.0 + sdc_jerk_avgtime_avg: float = 0.0 + sdc_jerk_avgtime_min: float = 0.0 + sdc_jerk_avgtime_max: float = 0.0 + + customized_max_sdc_adv_coll: float = 0.0 + customized_max_sdc_bv_coll: float = 0.0 + customized_max_adv_bv_coll: float = 0.0 + + customized_min_sdc_adv_coll: float = 0.0 + customized_min_sdc_bv_coll: float = 0.0 + customized_min_adv_bv_coll: float = 0.0 + + customized_avg_sdc_adv_coll: float = 0.0 + customized_avg_sdc_bv_coll: float = 0.0 + customized_avg_adv_bv_coll: float = 0.0 + + customized_avg_overall_coll: float = 0.0 + + customized_all_agent_coll: float = 0.0 + + def clean(self): + # If the entry is tensor, drop it to float. + for k, v in dataclasses.asdict(self).items(): + if isinstance(v, torch.Tensor): + setattr(self, k, v.item()) + + def aggregate(self): + self.clean() + + # Get all metrics + all_metrics = dataclasses.asdict(self) + for k, v in all_metrics.items(): + if k.startswith("coll_vel_sdc"): + if self.sdc_coll_scenario_count > 0: + all_metrics[k] = v / self.sdc_coll_scenario_count + else: + all_metrics[k] = torch.nan + + elif k.startswith("coll_vel_maxagent"): + if self.veh_coll_scenario_count > 0: + all_metrics[k] = v / self.veh_coll_scenario_count + else: + all_metrics[k] = torch.nan + + elif k != "scenario_count": + all_metrics[k] = v / self.scenario_count + return all_metrics + + +class PengEvaluator: + def __init__(self, config): + self.config = config + self.jsd_config = { + "vel": { + "min_val": 0.0, + "max_val": 50.0, + "num_bins": 100 + }, + "acc": { + "min_val": -10.0, + "max_val": 10.0, + "num_bins": 200 + }, + + # From WOSAC: https://github.com/waymo-research/waymo-open-dataset/blob/5f8a1cd42491210e7de629b6f8fc09b65e0cbe99/src/waymo_open_dataset/wdl_limited/sim_agents_metrics/challenge_2024_config.textproto#L80C1-L89C2 + "ttc": { + "min_val": 0.0, + "max_val": 5.0, + "num_bins": 10 + } + } + + self.metrics = Metrics() + # self.key_metrics_only = key_metrics_only + # self.use_waymo = use_waymo + self.use_waymo = False + + def _call_model(self, model, expanded_data_dict): + + if self.config.MODEL.NAME == "scenestreamer": + if not hasattr(self, "scenestreamer_generator"): + from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator + self.scenestreamer_generator = SceneStreamerGenerator( + model=model, + device=expanded_data_dict["encoder/agent_feature"].device, + ) + with torch.no_grad(): + self.scenestreamer_generator.reset(new_data_dict=expanded_data_dict) + output_dict = self.scenestreamer_generator.generate_scenestreamer_motion( + teacher_forcing_sdc=False, + ) + + else: + with torch.no_grad(): + output_dict = generate_motion( + data_dict=expanded_data_dict, + model=model, + autoregressive_start_step=2, + allow_newly_added_agent_step=2, + ) + + return output_dict + + def validation_step(self, data_dict, batch_idx, model, log_dict_func, **kwargs): + # TODO: Pass this from config. + num_decode_steps = 16 + + num_modes_for_eval = self.config.EVALUATION.NUM_MODES + maximum_batch_size = self.config.EVALUATION.MAXIMUM_BATCH_SIZE + + # assert num_modes_for_eval == 6 + + + if num_modes_for_eval <= maximum_batch_size: + num_repeat_calls = 1 + else: + assert num_modes_for_eval % maximum_batch_size == 0 + num_repeat_calls = num_modes_for_eval // maximum_batch_size + + + + B = data_dict["encoder/agent_feature"].shape[0] + assert B == 1 + data_dict["batch_idx"] = torch.arange(B) + + + if num_repeat_calls == 1: + + expanded_data_dict = { + k: utils.repeat_for_modes(data_dict[k], num_modes=num_modes_for_eval) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("eval/") or k.startswith("decoder/") or k == "batch_idx" or k == "in_evaluation" + or k == "in_backward_prediction" + ) + } + + output_dict = self._call_model(model, expanded_data_dict) + + else: + + assert B == 1, B + num_modes_per_call = num_modes_for_eval // num_repeat_calls + assert num_modes_per_call * num_repeat_calls == num_modes_for_eval + expanded_data_dict = { + k: utils.repeat_for_modes(data_dict[k], num_modes=num_modes_per_call) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("eval/") or k.startswith("decoder/") or k == "batch_idx" or k == "in_evaluation" + or k == "in_backward_prediction" + ) + } + output_dict = [] + for i in range(num_repeat_calls): + expanded_data_dict["batch_idx"] = torch.arange(B) + i * maximum_batch_size + output_dict.append(self._call_model(model, copy.deepcopy(expanded_data_dict))) + output_dict = { + k: (torch.cat([out[k] for out in output_dict], dim=0)) + if isinstance(output_dict[0][k], torch.Tensor) else None + for k in output_dict[0].keys() + } + output_dict.pop("batch_idx", None) + + MAX_MODES = 6 + if num_modes_for_eval > MAX_MODES: + sort_scores = output_dict["decoder/output_score"].sum(-1).sort(descending=True) + selected_indices = sort_scores.indices[:MAX_MODES] + selected_scores = sort_scores.values[:MAX_MODES] + selected_scores = selected_scores.to(output_dict["decoder/output_score"].device) + output_dict = { + k: v[selected_indices.to(v.device)] if isinstance(v, torch.Tensor) else v for k,v in output_dict.items() + } + + expanded_data_dict = utils.expand_for_modes(data_dict, num_modes=MAX_MODES) + original_expanded = copy.deepcopy(expanded_data_dict) + + # log_dict_func(self.compute_collision_statistics(pred_data_dict=output_dict, num_modes=MAX_MODES)) + log_dict_func(self.compute_ade_fde_fdd(gt_data_dict=original_expanded, pred_data_dict=output_dict)) + + + def compute_collision_statistics(self, *, pred_data_dict, num_modes): + # TODO: All agent? + + B, T, N, D = pred_data_dict["decoder/agent_position"].shape + + sdc_index = pred_data_dict["decoder/sdc_index"] + + agent_valid_mask = utils.torch_to_numpy(pred_data_dict["decoder/agent_valid_mask"].any(1)) + assert agent_valid_mask.shape == (B, N) + + # TODO: OOI index is also wrong. Can we use diff SDC / BV / OOI / ADV ??? + ooi_ind = pred_data_dict["decoder/object_of_interest_id"][0] + + all_ind = pred_data_dict["decoder/agent_id"][0] + + output_data_all_modes = utils.torch_to_numpy(pred_data_dict) + output_data_all_modes = _overwrite_datadict_all_agents( + source_data_dict=output_data_all_modes, dest_data_dict=output_data_all_modes + ) # overwrite pred to GT + + sdc_ooi_cr = [] + sdc_bv_cr = [] + sdc_all_cr = [] + + all_agent_cr = [] + + for i in range(num_modes): + output_dict_mode = _get_mode(output_data_all_modes, i, num_modes=num_modes) + col_label = get_2D_collision_labels(data_dict=output_dict_mode, track_agent_indicies=all_ind) + assert col_label.shape == (N, N) + + # TODO: I think this is wrong: + sid = sdc_index[i] + + # TODO: THIS IS NOT FINISHED YET. + + sdc_bv_cr.append(sum([col_label[sid][agent_id] for agent_id in ooi_ind])) + + agent_has_coll = np.triu(col_label, k=1).astype(bool).any(-1) + assert agent_has_coll.shape == (N, ) + all_agent_cr.append(utils.masked_average_numpy(agent_has_coll, agent_valid_mask[i], dim=0)) + + return { + "sdc_bv_cr": np.mean(sdc_bv_cr), + "all_agent_cr": np.mean(all_agent_cr), + } + + def compute_ade_fde_fdd(self, gt_data_dict, pred_data_dict): + gt_valid = gt_data_dict["decoder/agent_valid_mask"] + gt_valid_skipped = gt_data_dict["decoder/agent_valid_mask"][:, ::5] + + gt_pos = gt_data_dict["decoder/agent_position"][:, :91, :, :2] + pred_pos = pred_data_dict["decoder/reconstructed_position"][:, :91, :, :2] + pred_pos_skipped = pred_data_dict["decoder/reconstructed_position"][:, 0:91:5, :, :2] + + ooi_ind = gt_data_dict["decoder/object_of_interest_id"] + sdc_ind = gt_data_dict["decoder/sdc_index"] + + ooi_and_sdc_ind = torch.cat([ooi_ind, sdc_ind[:, None]], dim=1) + + B, T, N, _ = gt_pos.shape + + assert gt_valid.ndim == 3 + + # last_valid_ind = gt_valid.cumsum(dim=1).argmax(dim=1) + last_valid_ind_skipped = gt_valid_skipped.cumsum(dim=1).argmax(dim=1) + + error = torch.linalg.norm(gt_pos - pred_pos, dim=-1) + assert error.ndim == 3 + assert error.shape[0] == B + + # last_valid_ind = last_valid_ind.reshape(B, 1, N) + last_valid_ind_skipped = last_valid_ind_skipped.reshape(B, 1, N) + # assert last_valid_ind.shape == (B, 1, N) + + # fde = torch.gather(error, 1, last_valid_ind).squeeze(1) # shape: B, N + fde_skipped = torch.gather(error, 1, last_valid_ind_skipped * 5).squeeze(1) # shape: B, N + # assert fde.shape[0] == B + + # agent_valid = gt_valid.any(1).expand(B, N) + agent_valid_skipped = gt_valid_skipped.any(1).expand(B, N) + + # assert fde.shape == agent_valid.shape == (B, N) + assert fde_skipped.shape == agent_valid_skipped.shape == (B, N) + + # Set of OOI+SDC to True, and exclude other in agent_valid: + # agent_valid_ooi = torch.zeros_like(agent_valid, dtype=torch.bool) + batch_indices = torch.arange(B).unsqueeze(1) + # agent_valid_ooi[batch_indices, ooi_and_sdc_ind] = True + # agent_valid_ooi = agent_valid_ooi & agent_valid + + agent_valid_ooi_skipped = torch.zeros_like(agent_valid_skipped, dtype=torch.bool) + agent_valid_ooi_skipped[batch_indices, ooi_and_sdc_ind] = True + agent_valid_ooi_skipped = agent_valid_ooi_skipped & agent_valid_skipped + + # sfde_full_all = utils.masked_average(fde, agent_valid, dim=1) + sfde_skipped_all = utils.masked_average(fde_skipped, agent_valid_skipped, dim=1) + + # sfde_full_ooisdc = utils.masked_average(fde, agent_valid_ooi, dim=1) + sfde_skipped_ooisdc = utils.masked_average(fde_skipped, agent_valid_ooi_skipped, dim=1) + + # sade_per_agent = utils.masked_average(error, gt_valid, dim=1) + sade_per_agent_skipped = utils.masked_average(error[:, ::5], gt_valid_skipped, dim=1) + # assert sade_per_agent.shape == (B, N) + + # sade_full_all = utils.masked_average(sade_per_agent, agent_valid, dim=1) + sade_skipped_all = utils.masked_average(sade_per_agent_skipped, agent_valid_skipped, dim=1) + + # sade_full_ooisdc = utils.masked_average(sade_per_agent, agent_valid_ooi, dim=1) + sade_skipped_ooisdc = utils.masked_average(sade_per_agent_skipped, agent_valid_ooi_skipped, dim=1) + + # assert sfde_full_all.shape == sfde_skipped_all.shape == sfde_full_ooisdc.shape == sfde_skipped_ooisdc.shape == (B,) + # assert sade_full_all.shape == sade_skipped_all.shape == sade_full_ooisdc.shape == sade_skipped_ooisdc.shape == (B,) + + # there doesn't appear to be an easy way to do this with cartesian product + # final_pos = torch.gather(pred_pos, 1, last_valid_ind[..., None].expand(-1, -1, -1, 2)).squeeze(1) + final_pos_skipped = torch.gather(pred_pos, 1, last_valid_ind_skipped[..., None].expand(-1, -1, -1, + 2)).squeeze(1) + + # assert final_pos.shape == final_pos_skipped.shape == (B, N, 2) + + def _fdd(pos): + assert pos.shape == (B, N, 2), pos.shape + pos_NB = pos.swapaxes(0, 1) + dist_NBB = torch.cdist(pos_NB, pos_NB) + max_dist_N = dist_NBB.amax((1, 2)) + assert max_dist_N.shape == (N, ) + return max_dist_N + + # fdd_full_all = utils.masked_average(_fdd(final_pos), agent_valid[0], dim=0) + fdd_skipped_all = utils.masked_average(_fdd(final_pos_skipped), agent_valid_skipped[0], dim=0) + # fdd_full_ooisdc = utils.masked_average(_fdd(final_pos), agent_valid_ooi[0], dim=0) + fdd_skipped_ooisdc = utils.masked_average(_fdd(final_pos_skipped), agent_valid_ooi_skipped[0], dim=0) + + def _add(pos, mask): + assert pos.ndim == 4 + assert pos.shape[0] == B + assert pos.shape[2] == N + assert pos.shape[3] == 2 + T = pos.shape[1] + pos = pos.reshape(B, -1, 2) + pos_NB = pos.swapaxes(0, 1) + dist_NBB = torch.cdist(pos_NB, pos_NB) + max_dist_N = dist_NBB.amax((1, 2)) + assert max_dist_N.shape == (N * T, ), max_dist_N.shape + max_dist_TN = max_dist_N.reshape(T, N) + assert mask.shape == max_dist_TN.shape, (mask.shape, max_dist_TN.shape) + avg_t = utils.masked_average(max_dist_TN, mask, dim=0) + assert avg_t.shape == (N, ) + return avg_t + + # add_full_all = utils.masked_average(_add(pred_pos, gt_valid[0]), agent_valid[0], dim=0) + add_skipped_all = utils.masked_average( + _add(pred_pos_skipped, gt_valid_skipped[0]), agent_valid_skipped[0], dim=0 + ) + # add_full_ooisdc = utils.masked_average(_add(pred_pos, gt_valid[0]), agent_valid_ooi[0], dim=0) + add_skipped_ooisdc = utils.masked_average( + _add(pred_pos_skipped, gt_valid_skipped[0]), agent_valid_ooi_skipped[0], dim=0 + ) + + return { + # FDE + # "sfde_full_all": sfde_full_all, + # "sfde_skipped_all": sfde_skipped_all, + "sfde_all_avg": sfde_skipped_all.mean(), + "sfde_all_min": sfde_skipped_all.min(), + # "sfde_full_ooisdc": sfde_full_ooisdc, + # "sfde_skipped_ooisdc": sfde_skipped_ooisdc, + "sfde_ooisdc_avg": sfde_skipped_ooisdc.mean(), + "sfde_ooisdc_min": sfde_skipped_ooisdc.min(), + # ADE + # "sade_full_all": sade_full_all, + # "sade_skipped_all": sade_skipped_all, + "sade_all_avg": sade_skipped_all.mean(), + "sade_all_min": sade_skipped_all.min(), + # "sade_full_ooisdc": sade_full_ooisdc, + # "sade_skipped_ooisdc": sade_skipped_ooisdc, + "sade_ooisdc_avg": sade_skipped_ooisdc.mean(), + "sade_ooisdc_min": sade_skipped_ooisdc.min(), + # FDDD + # "fdd_full_all": fdd_full_all, + "fdd_all": fdd_skipped_all, + # "fdd_full_ooisdc": fdd_full_ooisdc, + "fdd_ooisdc": fdd_skipped_ooisdc, + # ADD + # "add_full_all": add_full_all, + "add_all": add_skipped_all, + # "add_full_ooisdc": add_full_ooisdc, + "add_ooisdc": add_skipped_ooisdc, + + "num_all_agents": agent_valid_skipped.sum(-1).float().mean(), + "num_ooisdc_agents": agent_valid_ooi_skipped.sum(-1).float().mean(), + } + + def add(self, gt_expanded_data_dict): + self.metrics.scenario_count += 1 + + # T_gt = gt_data_dict["decoder/agent_position"].shape[0] + # T_context = 0 + # + # T_pred = pred_data_dict["decoder/reconstructed_position"].shape[1] + # B = K = pred_data_dict["decoder/reconstructed_position"].shape[0] + # N = gt_data_dict["decoder/agent_position"].shape[1] + + vehicle_mask = numpy_to_torch(gt_data_dict["decoder/agent_type"] == 1, device=device) # (num agents) + static_agent_mask = self.filter_static_agents(gt_data_dict) + + ooi_mask = torch.zeros_like(vehicle_mask, dtype=torch.bool, device=device) + # ooi_mask[(gt_data_dict["decoder/object_of_interest_id"])] = 1 + # ooi_mask[(gt_data_dict["decoder/sdc_id"])] = 1 # now only predict OOI + ooi_mask[(gt_data_dict["decoder/agent_id"])] = 1 # (num agents) + + gt_valid_mask = numpy_to_torch( + gt_data_dict["decoder/agent_valid_mask"], device=device + ).T # (num agents, num steps) + pred_valid_mask = pred_data_dict["decoder/reconstructed_valid_mask"].transpose( + 1, 2 + ) # (K, num agents, num steps) + # joint_mask = vehicle_mask.unsqueeze(-1) & valid_mask # (num agents, num steps) + # gt_ooi_joint = ooi_mask.unsqueeze(-1) & gt_valid_mask[..., T_context:T_gt] # (num agents, num steps) + # pred_ooi_joint = ooi_mask[None, ..., None] & pred_valid_mask[..., T_context:T_gt] # (K, num agents, num steps) + + gt_ooi_joint = gt_valid_mask[..., T_context:T_gt] # (num agents, num steps) + pred_ooi_joint = pred_valid_mask[..., T_context:T_gt] # (K, num agents, num steps) + + gt_shape = numpy_to_torch( + gt_data_dict["decoder/current_agent_shape"][None], device=device + ).expand(T_gt, -1, -1).transpose(0, 1) + pred_shape = pred_data_dict["decoder/current_agent_shape"][:, None].expand(-1, T_pred, -1, -1).transpose(1, 2) + + sdc_index = int(gt_data_dict["decoder/sdc_index"]) + sdc_index_in_ooi = list(gt_data_dict["decoder/agent_id"]).index(sdc_index) + + # minSFDE + gt_pos = numpy_to_torch(gt_data_dict["decoder/agent_position"], device=device)[None, ..., :2] + pred_pos = pred_data_dict["decoder/reconstructed_position"][:, :T_gt] + + gt_valid = gt_ooi_joint[None] + gt_valid_skipped = gt_valid[:, :, ::5] + + last_valid_ind = gt_valid.cumsum(dim=-1).argmax(dim=-1) + last_valid_ind_skipped = gt_valid_skipped.cumsum(dim=-1).argmax(dim=-1) + + error = torch.linalg.norm(gt_pos - pred_pos, dim=-1) + assert error.ndim == 3 + assert error.shape[0] == B + + last_valid_ind = last_valid_ind.unsqueeze(0).expand(B, 1, N) + last_valid_ind_skipped = last_valid_ind_skipped.unsqueeze(0).expand(B, 1, N) + + assert last_valid_ind.shape == (B, 1, N) + fde = torch.gather(error, 1, last_valid_ind).squeeze(1) # shape: B, N + fde_skipped = torch.gather(error, 1, last_valid_ind_skipped * 5).squeeze(1) # shape: B, N + + assert fde.shape[0] == B + agent_valid = gt_valid.any(-1).expand(B, N) # shape: B, N + agent_valid_skipped = gt_valid_skipped.any(-1).expand(B, N) + + sfde = (fde * agent_valid).sum(-1) / agent_valid.sum(-1) + sfde_skipped = (fde_skipped * agent_valid_skipped).sum(-1) / agent_valid_skipped.sum(-1) + + gt_valid = gt_valid.permute(0, 2, 1) + gt_valid_skipped = gt_valid_skipped.permute(0, 2, 1) + + gt_valid_expand = gt_valid.expand(B, T_gt, N) + pred_ooi_joint_expand = pred_ooi_joint.permute(0, 2, 1) + sade_per_agent = (error * gt_valid).sum(1) / gt_valid.sum(1).clamp(1) + assert sade_per_agent.ndim == 2 + sade = (sade_per_agent * gt_valid.any(1)).sum(1) / gt_valid.any(1).sum(1) + assert sade.ndim == 1 + sade_skipped_per_agent = (error[:, ::5] * gt_valid_skipped).sum(1) / gt_valid_skipped.sum(1).clamp(1) + sade_skipped = sade_skipped_per_agent.sum(1) / gt_valid_skipped.any(1).sum(1) + + assert sfde.ndim == 1 + assert sfde.shape[0] == B + self.metrics.sfde_min += sfde.min() + self.metrics.sade_min += sade.min() + self.metrics.sfde_avg += sfde.mean() + self.metrics.sade_avg += sade.mean() + + self.metrics.skipped_sfde_min += sfde_skipped.min() + self.metrics.skipped_sade_min += sade_skipped.min() + self.metrics.skipped_sfde_avg += sfde_skipped.mean() + self.metrics.skipped_sade_avg += sade_skipped.mean() + + # Following wosac_eval, fill in z with GT t = 10 data + z_values = pred_data_dict["decoder/current_agent_position"][..., 2].unsqueeze(-1).expand(-1, -1, T_pred) + + # FDD + # there doesn't appear to be an easy way to do this with cartesian product + cur_FDD = None + pred_ooi_valid_mask = pred_valid_mask[:, ooi_mask] + single_mode_ooi_valid_mask = pred_ooi_valid_mask[0] + + # assert torch.all(torch.any(pred_ooi_valid_mask, dim=-1)) + last_valid_ind = pred_ooi_valid_mask.cumsum(dim=-1).argmax(dim=-1) # (K, N) + ooi_reconstructed_pos = pred_data_dict["decoder/reconstructed_position"][:, :, ooi_mask] # (K, T_pred, N, 2) + last_valid_ind_reshaped = last_valid_ind[:, None, :, None].expand(-1, -1, -1, 2) + final_pos = torch.gather(ooi_reconstructed_pos, dim=1, index=last_valid_ind_reshaped).squeeze(1) + for i, j in itertools.product(range(K), range(K)): + final_dist = torch.linalg.norm(final_pos[i] - final_pos[j], dim=-1) + assert final_dist.ndim == 1 + if cur_FDD == None: + cur_FDD = final_dist + else: + cur_FDD = torch.maximum(cur_FDD, final_dist) + self.metrics.fdd += (cur_FDD * + single_mode_ooi_valid_mask.any(-1)).sum() / single_mode_ooi_valid_mask.any(-1).sum() + + gt_speed, gt_accel, gt_jerk = self._compute_kinematic_metrics( + gt_data_dict["decoder/agent_velocity"].swapaxes(1, 0), device + ) # (N, T) + pred_speed, pred_accel, pred_jerk = self._compute_kinematic_metrics( + pred_data_dict["decoder/reconstructed_velocity"].transpose(1, 2), device + ) # (K, N, T) + gt_speed = gt_speed[..., T_context:T_gt] + gt_accel = gt_accel[..., T_context:T_gt] + gt_jerk = gt_jerk[..., T_context:T_gt] + pred_speed = pred_speed[..., T_context:T_gt] + pred_accel = pred_accel[..., T_context:T_gt] + pred_jerk = pred_jerk[..., T_context:T_gt] + + if self.use_waymo: + candidate_agents = gt_data_dict["decoder/current_agent_valid_mask"] + if isinstance(candidate_agents, torch.Tensor): + candidate_agents = candidate_agents.cpu().numpy() + candidate_agents = candidate_agents.nonzero()[0] + + pred_veh_collisions, veh_cr_mode = calc_collision_rate( + candidate_agents=candidate_agents, + evaluate_agents=gt_data_dict["decoder/agent_id"], + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + + assert veh_cr_mode.shape[0] == B + self.metrics.veh_coll_avg += veh_cr_mode.mean() + self.metrics.veh_coll_min += veh_cr_mode.min() + self.metrics.veh_coll_max += veh_cr_mode.max() + + if adv_list is not None: + assert len(adv_list) == 1 + assert adv_list[0] not in bv_list + assert sdc_index not in adv_list + + # self.sdc_coll_adv_active = True + for kk in adv_list: + assert int(kk.item()) in candidate_agents + + adv_sdc_coll, adv_sdc_coll_rate = calc_collision_rate( + candidate_agents=adv_list, + evaluate_agents=pred_data_dict["decoder/sdc_index"], + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + + assert adv_sdc_coll_rate.ndim == 1 + self.metrics.sdc_adv_coll_avg += adv_sdc_coll_rate.mean() + self.metrics.sdc_adv_coll_min += adv_sdc_coll_rate.min() + self.metrics.sdc_adv_coll_max += adv_sdc_coll_rate.max() + + adv_bv_coll, adv_bv_coll_rate = calc_collision_rate( + candidate_agents=bv_list, + evaluate_agents=adv_list, + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + assert adv_bv_coll_rate.ndim == 1 + + assert adv_bv_coll_rate.shape[0] == B + self.metrics.adv_bv_coll_avg += adv_bv_coll_rate.mean() + self.metrics.adv_bv_coll_min += adv_bv_coll_rate.min() + self.metrics.adv_bv_coll_max += adv_bv_coll_rate.max() + + assert sdc_index not in bv_list + assert bv_list is not None + for kk in bv_list: + assert int(kk.item()) in candidate_agents + sdc_bv_coll, sdc_bv_coll_rate = calc_collision_rate( + candidate_agents=bv_list, + evaluate_agents=pred_data_dict["decoder/sdc_index"], + pred_data_dict=pred_data_dict, + pred_shape=pred_shape, + device=device, + T_gt=T_gt, + T_context=T_context, + z_values=z_values + ) + assert sdc_bv_coll_rate.ndim == 1 + + assert sdc_bv_coll_rate.shape[0] == B + self.metrics.sdc_bv_coll_avg += sdc_bv_coll_rate.mean() + self.metrics.sdc_bv_coll_min += sdc_bv_coll_rate.min() + self.metrics.sdc_bv_coll_max += sdc_bv_coll_rate.max() + + # map_feature = gt_data_dict["encoder/map_feature"] + map_feature = gt_data_dict["vis/map_feature"] + assert map_feature.ndim == 3 # This is unbatched. + + road_edges = [] + for i in range(map_feature.shape[0]): + # For each map feature + if map_feature[i, 0, 15] == 1: + map_feat = [] + for j in range(map_feature.shape[1]): + if gt_data_dict['encoder/map_feature_valid_mask'][i, j]: + map_feat.append( + map_pb2.MapPoint( + x=map_feature[i, j, 0], + y=map_feature[i, j, 1], + z=0 + # map_feature[i, j, 2] # let's say there is no z axis any more + ) + ) + road_edges.append(map_feat) + + eval_mask = ooi_mask & (~static_agent_mask) + env_nearest_distances = torch.stack( + [ + tf_to_torch( + map_metric_features.compute_distance_to_road_edge( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 1].T), + center_z=conv(z_values[k]), + length=conv(pred_shape[k, ..., 0]), + width=conv(pred_shape[k, ..., 1]), + height=conv(pred_shape[k, ..., 2]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k].T), + valid=conv(pred_valid_mask[k], dtype=tf.bool), + evaluated_object_mask=conv(eval_mask, dtype=tf.bool), + road_edge_polylines=road_edges, + ), + device=device + ) for k in range(K) + ] + ) + + pred_valid_mask = pred_valid_mask[..., T_context:T_gt] + + pred_env_collisions = torch.greater(env_nearest_distances, map_metric_features.OFFROAD_DISTANCE_THRESHOLD) + pred_env_collisions = pred_env_collisions[..., T_context:T_gt] + pred_env_collisions_traj_level = pred_env_collisions.any(dim=-1) # (B, num_ooi) + # Avg over agent dim. Here we assume all evaluated agents are valid so don't do the masked_avg + env_collision_rate = pred_env_collisions_traj_level.float().mean(-1) + + # ==================================== customized env collision rate ==================================== + # env_collision_rate = np.array(pred_env_collisions_traj_level).mean(-1) + assert env_collision_rate.ndim == 1 + + assert env_collision_rate.shape[0] == B + self.metrics.env_coll_avg += env_collision_rate.mean() + self.metrics.env_coll_min += env_collision_rate.min() + self.metrics.env_coll_max += env_collision_rate.max() + + step_wise_collision = pred_veh_collisions + scenario_has_collision = torch.any(step_wise_collision).item() + + if scenario_has_collision: + speed_when_collision = torch.where(step_wise_collision, pred_speed[:, ooi_mask], 0) + + coll_vel_max_agent = speed_when_collision.amax(dim=(-1, -2)) + coll_valid_mask = speed_when_collision.any(-1).any(-1) + coll_vel_max_agent = coll_vel_max_agent[coll_valid_mask] + + assert coll_vel_max_agent.ndim == 1 + + if coll_vel_max_agent.numel() != 0: + self.metrics.coll_vel_maxagent_avg += (coll_vel_max_agent).sum() / coll_valid_mask.sum().clamp(1) + self.metrics.coll_vel_maxagent_min += coll_vel_max_agent.min() + self.metrics.coll_vel_maxagent_max += coll_vel_max_agent.max() + + self.metrics.veh_coll_scenario_count += 1 + + sdc_speed = pred_speed[:, sdc_index] + sdc_coll = step_wise_collision[:, sdc_index_in_ooi] + + # sdc_speed_when_coll = torch.where(sdc_coll, sdc_speed, torch.nan).amax(-1) + sdc_speed_when_coll = torch.where(sdc_coll, sdc_speed, torch.nan) + valid_mask = ~torch.isnan(sdc_speed_when_coll) + + if torch.any(sdc_coll).item(): # if there is valid collision + self.metrics.coll_vel_sdc_avg += (sdc_speed_when_coll[valid_mask]).sum() / valid_mask.sum() + self.metrics.coll_vel_sdc_max += (sdc_speed_when_coll[valid_mask]).max() + self.metrics.coll_vel_sdc_min += (sdc_speed_when_coll[valid_mask]).min() + self.metrics.sdc_coll_scenario_count += 1 + assert sdc_speed_when_coll.shape[0] == B + + gt_ttc = tf_to_torch( + interaction_features.compute_time_to_collision_with_object_in_front( + center_x=conv(gt_data_dict["decoder/agent_position"][..., 0].T), + center_y=conv(gt_data_dict["decoder/agent_position"][..., 1].T), + length=conv(gt_shape[..., 0]), + width=conv(gt_shape[..., 1]), + heading=conv(gt_data_dict["decoder/agent_heading"].T), + valid=conv(gt_valid_mask, dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + seconds_per_step=self.SECONDS_PER_STEP + ), + device=device + ) + pred_ttc = torch.stack( + [ + tf_to_torch( + interaction_features.compute_time_to_collision_with_object_in_front( + center_x=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 0].T), + center_y=conv(pred_data_dict["decoder/reconstructed_position"][k, ..., 1].T), + length=conv(pred_shape[k, ..., 0]), + width=conv(pred_shape[k, ..., 1]), + heading=conv(pred_data_dict["decoder/reconstructed_heading"][k].T), + valid=conv(pred_data_dict["decoder/reconstructed_valid_mask"][k].T, dtype=tf.bool), + evaluated_object_mask=conv(ooi_mask, dtype=tf.bool), + seconds_per_step=self.SECONDS_PER_STEP + ), + device=device + ) for k in range(K) + ] + ) + gt_ttc = gt_ttc[..., T_context:T_gt] + pred_ttc = pred_ttc[..., T_context:T_gt] + + sdc_acc = torch.abs(torch.nan_to_num(pred_accel[:, sdc_index])) # Shape: (K, T) + sdc_mask = pred_valid_mask[:, sdc_index][:, T_context:T_gt] + sdc_acc_avgt = (sdc_acc * sdc_mask).sum(-1) / sdc_mask.sum(-1).clamp(1) + assert sdc_acc_avgt.ndim == 1 + assert sdc_acc_avgt.shape[0] == B + self.metrics.sdc_acc_avgtime_max += sdc_acc_avgt.max() + self.metrics.sdc_acc_avgtime_avg += sdc_acc_avgt.mean() + self.metrics.sdc_acc_avgtime_min += sdc_acc_avgt.min() + + sdc_acc_maxt = sdc_acc.amax(-1) + assert sdc_acc_maxt.ndim == 1 + assert sdc_acc_maxt.shape[0] == B + self.metrics.sdc_acc_maxtime_max += sdc_acc_maxt.max() + self.metrics.sdc_acc_maxtime_avg += sdc_acc_maxt.mean() + self.metrics.sdc_acc_maxtime_min += sdc_acc_maxt.min() + + sdc_jerk = torch.abs(torch.nan_to_num(pred_jerk[:, sdc_index])) # Shape: (K, T) + sdc_mask = pred_valid_mask[:, sdc_index][:, T_context:T_gt] + sdc_jerk_avgt = (sdc_jerk * sdc_mask).sum(-1) / sdc_mask.sum(-1).clamp(1) + assert sdc_jerk_avgt.ndim == 1 + assert sdc_jerk_avgt.shape[0] == B + self.metrics.sdc_jerk_avgtime_max += sdc_jerk_avgt.max() + self.metrics.sdc_jerk_avgtime_avg += sdc_jerk_avgt.mean() + self.metrics.sdc_jerk_avgtime_min += sdc_jerk_avgt.min() + + sdc_jerk_maxt = sdc_jerk.amax(-1) + assert sdc_jerk_maxt.ndim == 1 + assert sdc_jerk_maxt.shape[0] == B + self.metrics.sdc_jerk_maxtime_max += sdc_jerk_maxt.max() + self.metrics.sdc_jerk_maxtime_avg += sdc_jerk_maxt.mean() + self.metrics.sdc_jerk_maxtime_min += sdc_jerk_maxt.min() + + gt_speed_hist, gt_speed_bins = torch.histogram( + torch.clip( + gt_speed[gt_ooi_joint & ~gt_speed.isnan()], self.jsd_config["vel"]["min_val"], + self.jsd_config["vel"]["max_val"] + ).cpu(), + self.jsd_config["vel"]["num_bins"], + density=False + ) + # .cpu() since histogram doesn't support cuda backend + pred_speed_hist, pred_speed_bins = torch.histogram( + torch.clip( + pred_speed[pred_ooi_joint & ~pred_speed.isnan()], self.jsd_config["vel"]["min_val"], + self.jsd_config["vel"]["max_val"] + ).cpu(), + self.jsd_config["vel"]["num_bins"], + density=False + ) + gt_accel_hist, gt_accel_bins = torch.histogram( + torch.clip( + gt_accel[gt_ooi_joint & ~gt_accel.isnan()], self.jsd_config["acc"]["min_val"], + self.jsd_config["acc"]["max_val"] + ).cpu(), + self.jsd_config["acc"]["num_bins"], + density=False + ) + pred_accel_hist, pred_accel_bins = torch.histogram( + torch.clip( + pred_accel[pred_ooi_joint & ~pred_accel.isnan()], self.jsd_config["acc"]["min_val"], + self.jsd_config["acc"]["max_val"] + ).cpu(), + self.jsd_config["acc"]["num_bins"], + density=False + ) + + if self.use_waymo: + gt_ttc_hist, gt_ttc_bins = torch.histogram( + torch.clip( + gt_ttc[gt_valid_mask[ooi_mask, T_context:T_gt] & ~gt_ttc.isnan()], + self.jsd_config["ttc"]["min_val"], self.jsd_config["ttc"]["max_val"] + ).cpu(), + self.jsd_config["ttc"]["num_bins"], + density=False + ) + pred_ttc_hist, pred_ttc_bins = torch.histogram( + torch.clip( + pred_ttc[pred_valid_mask[:, ooi_mask, T_context:T_gt] & ~pred_ttc.isnan()], + self.jsd_config["ttc"]["min_val"], self.jsd_config["ttc"]["max_val"] + ).cpu(), + self.jsd_config["ttc"]["num_bins"], + density=False + ) + + speed_jsd = jsd(gt_speed_hist, pred_speed_hist) + acc_jsd = jsd(gt_accel_hist, pred_accel_hist) + self.metrics.vel_jsd += speed_jsd + self.metrics.acc_jsd += acc_jsd + if self.use_waymo: + ttc_jsd = jsd(gt_ttc_hist, pred_ttc_hist) + self.metrics.ttc_jsd += ttc_jsd + + def on_validation_epoch_end(self, *args, **kwargs): + pass + + # def on_validation_epoch_end( + # self, trainer, logger, global_rank, log_dict_func, log_func, print_func, exp_name, **kwargs + # ): + # """ + # This function gathers intermediate evaluation result and pass them to the Waymo + # evaluation pipeline together and log the final results. + # """ + # st = time.time() + # + # # print(debug_tools.using(f"val epoch end start")) + # + # # https://lightning.ai/docs/pytorch/latest/accelerators/accelerator_prepare.html?highlight=hardware + # # torch.cuda.empty_cache() + # # PZH NOTE: Hack to implement our own all_gather across ranks. + # trainer.strategy.barrier() + # + # # Collect the intermediate evaluation results from each call to on_validation_step in this particular rank. + # self.validation_outputs = [ + # {k: (v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v) + # for k, v in final_pred_dicts.items()} for final_pred_dicts in self.validation_outputs + # ] + # + # # Dump all results in this rank to a local file so that later the rank0 process can read them. + # tmpdir = self.config.ROOT_DIR / self.config.TMP_DIR / "validation_tmpdir_{}".format(exp_name) + # print(f"Rank {global_rank} saving validation results to {tmpdir}.") + # + # os.makedirs(tmpdir, exist_ok=True) + # with open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(global_rank)), 'wb') as f: + # pickle.dump(self.validation_outputs, f) + # self.validation_outputs.clear() + # + # # print(debug_tools.using(f"val epoch saved file.")) + # + # # If this is the main process (rank0), read all results in local filesystem and call evaluation pipeline. + # torch.cuda.empty_cache() + # trainer.strategy.barrier() + # if trainer.is_global_zero: + # print_func(f"===== Start evaluation: {time.time() - st:.3f} =====") + # + # # Gather results from different ranks + # validation_list = [] + # for i in range(trainer.world_size): + # file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i)) + # success = False + # for sleep in range(10): + # if not os.path.isfile(file): + # time.sleep(1) + # print(f"Can't find file: {file}. Sleep {sleep}/{10} seconds.") + # else: + # success = True + # break + # if not success: + # print(f"[WARNING] Can't find file: {file}. Skip this rank.") + # continue + # with open(file, "rb") as f: + # val_outputs = pickle.load(f) + # validation_list.extend(val_outputs) + # if self.config.EVALUATION.DELETE_EVAL_RESULT: + # shutil.rmtree(tmpdir) + # + # if not validation_list: + # print_func("No evaluation results found. Skip evaluation.") + # return + # + # # print(debug_tools.using(f"going to eval")) + # + # # Call evaluation pipeline + # torch.cuda.empty_cache() + # result_dict, result_str, submission_dict = waymo_evaluation_optimized( + # validation_list, + # + # # TODO: This flag + # generate_submission=self.config.SUBMISSION.GENERATE_SUBMISSION, + # predict_all_agents=self.config.EVALUATION.PREDICT_ALL_AGENTS, + # ) + # torch.cuda.empty_cache() + # validation_list.clear() + # + # # Log result + # result_dict = {f"eval/{k}": float(v) for k, v in result_dict.items()} + # log_dict_func(result_dict, rank_zero_only=True) + # for k in ['eval/minADE', 'eval/minFDE', 'eval/MissRate', 'eval/mAP', "eval/mJADE", "eval/avgJADE", + # "eval/mJFDE", "eval/avgJFDE"]: + # if k not in result_dict: + # continue + # log_func(name=k.split("/")[1], value=result_dict[k], rank_zero_only=True) + # print_func(result_str) + # print_func(f"===== Finish evaluation: {time.time() - st:.3f} =====") + # + # print_func(f"Rank {global_rank} finished evaluation!") + # torch.cuda.empty_cache() + # trainer.strategy.barrier() + # + # # TODO This flag + # if trainer.is_global_zero and self.config.SUBMISSION.GENERATE_SUBMISSION: + # account_name = self.config.SUBMISSION.ACCOUNT + # unique_method_name = self.config.SUBMISSION.METHOD_NAME + # output_dir = logger.log_dir + # submission_prefix = logger.name + # path, duplicated_scenarios, done_scenarios = generate_submission( + # prefix=submission_prefix, + # account_name=account_name, + # unique_method_name=unique_method_name, + # output_dir=output_dir, + # **submission_dict + # ) + # print_func( + # "Submission created at: {}. Finished {} scenarios. Duplicated scenarios: {}.".format( + # path, len(done_scenarios), duplicated_scenarios + # ) + # ) + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="0220_midgpt.yaml") +def main(config): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + path = "/home/zhenghao/scenestreamer/lightning_logs/scenestreamer/0220_MidGPT_V19_2025-02-20/" + pl_model = utils.get_model(checkpoint_path=path, device=device) + + # model = utils.get_model(config, device=device) + # evaluator = TrafficGenEvaluator(config) + + config = pl_model.config + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING["keep_all_data"] = True + config.DATA.TRAINING_DATA_DIR = "data/20scenarios" + config.DATA.TEST_DATA_DIR = "data/20scenarios" + + test_dataset = SceneStreamerDataset(config, "training") + # ddd = iter(test_dataset) + + START_ACTION = config.PREPROCESSING.MAX_MAP_FEATURES + END_ACTION = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + + for count, raw_data_dict in enumerate(tqdm.tqdm(test_dataset)): + data_dict = raw_data_dict + data_dict = utils.numpy_to_torch(data_dict, device=device) + batched_data_dict = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} + + data_dict = generate_motion( + data_dict=batched_data_dict, + model=pl_model.model, + autoregressive_start_step=2, + remove_out_of_map_agent=True, + remove_static_agent=True, + teacher_forcing_sdc=True, + ) + + data_dict = utils.unbatch_data(utils.torch_to_numpy(data_dict)) + from scenestreamer.gradio_ui.plot import plot_pred + plot_pred(data_dict, show=True) + + print("End") + + +if __name__ == '__main__': + main() diff --git a/scenestreamer/eval/scenarionet_to_trafficgen.py b/scenestreamer/eval/scenarionet_to_trafficgen.py new file mode 100644 index 0000000000000000000000000000000000000000..ac14dcf20f6aa346a33baf28ef43010b80ba88b5 --- /dev/null +++ b/scenestreamer/eval/scenarionet_to_trafficgen.py @@ -0,0 +1,244 @@ +""" +PZH Note (2025): + +This code tries to convert the scenario description from MetaDrive to the internal data representation of TrafficGen. +And it's a complete disaster and I don't want to investigate it as what I've done in 2023. Just try to make it work with +SceneStreamer code and get the TrafficGen's evaluation metrics (aka MMD). + + +PZH Note (2023): + +This file provides a function to read latest MetaDrive Scenario Description (SD) +to the internal data representation of TrafficGen. + +The test script provided below allows you to read the pickle files storing SD +and directly connect it to TrafficGen placing vehicles functionality. + +The output images after placing vehicles in the scenes will be saved to the TMP_IMG folder. +""" +import argparse +import os + +import numpy as np +from metadrive.scenario.scenario_description import ScenarioDescription as SD, MetaDriveType + +ALL_TYPE = { + 'LANE_FREEWAY': 1, + 'LANE_SURFACE_STREET': 2, + 'LANE_BIKE_LANE': 3, + 'ROAD_LINE_BROKEN_SINGLE_WHITE': 6, + 'ROAD_LINE_SOLID_SINGLE_WHITE': 7, + 'ROAD_LINE_SOLID_DOUBLE_WHITE': 8, + 'ROAD_LINE_BROKEN_SINGLE_YELLOW': 9, + 'ROAD_LINE_BROKEN_DOUBLE_YELLOW': 10, + 'ROAD_LINE_SOLID_SINGLE_YELLOW': 11, + 'ROAD_LINE_SOLID_DOUBLE_YELLOW': 12, + 'ROAD_LINE_PASSING_DOUBLE_YELLOW': 13, + 'ROAD_EDGE_BOUNDARY': 15, + 'ROAD_EDGE_MEDIAN': 16, + 'STOP_SIGN': 17, + 'CROSS_WALK': 18, + 'SPEED_BUMP': 19, + 'LANE_UNKNOWN': 19, + 'UNKNOWN': 19, +} + + +def _down_sampling(line, sample_num): + # if is center lane + point_num = line.shape[0] + + ret = [] + + if point_num < sample_num: + for i in range(0, point_num): + ret.append(line[i]) + else: + for i in range(0, point_num, sample_num): + ret.append(line[i]) + + return ret + + +def _extract_map(map_feat, sample_num): + lanes = [] + + center_info = [] + + for map_feat_id, map_feat in map_feat.items(): + + if "polyline" not in map_feat: + if "position" not in map_feat: + # PZH: I just don't XXXXing care about this. (SPEED_BUMP) + continue + map_feat['polyline'] = map_feat['position'][np.newaxis] + + poly_unsampled = map_feat['polyline'][:, :2] + + # TODO(PZH): Revisit the down sampling function. It seems quite werid to me. + poly = _down_sampling(poly_unsampled, sample_num=sample_num) + + a_lane = np.zeros([len(poly), 4], dtype='float32') + + a_lane[:, :2] = np.array(poly) + a_lane[:, 2] = ALL_TYPE[map_feat['type']] + a_lane[:, 3] = str(map_feat_id) + + lanes.append(a_lane) + + lanes = np.concatenate(lanes, axis=0) + return lanes + + +def metadrive_scenario_to_init_data(scenario): + ret = {} + + ret['id'] = scenario[SD.ID] + + tracks = scenario[SD.TRACKS] + traffic_lights = scenario[SD.DYNAMIC_MAP_STATES] + map_feat = scenario[SD.MAP_FEATURES] + sdc_id = scenario[SD.METADATA][SD.SDC_ID] + + track_len = scenario[SD.LENGTH] + + # all_agent in shape [Time steps, Num agents, Num state dim] + all_agent = np.zeros([track_len, len(tracks), 9], dtype='float32') + + PZH_TRACK_NAMES = [None] * len(tracks) + + sdc_index = None + for indx, (id, track) in enumerate(tracks.items()): + # if track[SD.TYPE] != MetaDriveType.VEHICLE: + # continue + if id == sdc_id: + sdc_index = indx + all_agent[:, indx, :2] = track[SD.STATE]['position'][:, :2] + all_agent[:, indx, 2:4] = track[SD.STATE]['velocity'] + all_agent[:, indx, 4] = track[SD.STATE]['heading'].reshape(track_len) + + # TODO: Width length or length width I don't XXXXing care. + all_agent[:, indx, 5] = track[SD.STATE]['width'] + all_agent[:, indx, 6] = track[SD.STATE]['length'] + + all_agent[:, indx, 7] = 1 + all_agent[:, indx, 8] = track[SD.STATE]['valid'].reshape(track_len) + + PZH_TRACK_NAMES[indx] = id + + assert sdc_index is not None + + # Make ego agent to the first place + all_agent[:, [sdc_index, 0]] = all_agent[:, [0, sdc_index]] + PZH_TRACK_NAMES = np.array(PZH_TRACK_NAMES) + PZH_TRACK_NAMES[[sdc_index, 0]] = PZH_TRACK_NAMES[[0, sdc_index]] + + ret['all_agent'] = all_agent + + traffic_light_data = [] + for step in range(track_len): + tl_states_in_one_step = [] + + for traffic_light_index, traffic_light in traffic_lights.items(): + traffic_light_state = {k: v[step] for k, v in traffic_light["state"].items()} + + traffic_light_step_data = np.zeros(6, dtype='float32') + + # The range of this data is int [0, 253]. Will use to filter lanes. It is very useful. + if "lane" in traffic_light_state: + traffic_light_step_data[0] = str(traffic_light_state["lane"]) + + # TODO: The range of this data is float with shape [200, 3] in range [-352, 169]. + traffic_light_step_data[1:3] = traffic_light_state["stop_point"][:2] + + # Int in range [0, 3], stands for UNKNOWN, STOP, CAUTION, GO + traffic_light_step_data[4] = traffic_light_state["object_state"] + + # Whether valid + traffic_light_step_data[5] = 1 if traffic_light_state["object_state"] else 0 + + tl_states_in_one_step.append(traffic_light_step_data) + + traffic_light_data.append(tl_states_in_one_step) + + ret['traffic_light'] = traffic_light_data + + ret['lane'] = _extract_map(map_feat, sample_num=10) + ret['unsampled_lane'] = _extract_map(map_feat, sample_num=10e9) + + # ret["original_metadrive_scenario"] = scenario + + ret["PZH_TRACK_NAMES"] = PZH_TRACK_NAMES + + return ret + + +def extend_batch_dim(data): + new_data = {} + for k, tensor in data.items(): + if k != "other": + new_data[k] = np.expand_dims(tensor, 0) + + new_data["other"] = {} + for k, tensor in data["other"].items(): + if k == "traf": # What the fuck this name is? + pass + else: + tensor = np.expand_dims(tensor, 0) + new_data["other"][k] = tensor + + return new_data + + +# if __name__ == '__main__': +# parser = argparse.ArgumentParser() +# parser.add_argument("--input", default="raw_data", help="The folder of input data.") +# parser.add_argument("--output", default="test_output", help="The folder of output data.") +# parser.add_argument("--num_scenarios", "-n", default=-1, type=int) # -1 stands for loading all +# args = parser.parse_args() +# +# input_folder = args.input +# assert os.path.isdir(input_folder) +# pickle_files = [p for p in os.listdir(input_folder) if p.endswith(".pkl")] +# +# output_folder = args.output +# os.makedirs(output_folder, exist_ok=True) +# +# num_scenarios = args.num_scenarios +# if num_scenarios == -1: +# num_scenarios = len(pickle_files) +# +# vis_dir = "TMP_IMG" +# os.makedirs(vis_dir, exist_ok=True) +# +# cnt = 0 +# +# from trafficgen_reference.traffic_generator.traffic_generator import TrafficGen +# from trafficgen_reference.traffic_generator.utils.utils import get_parsed_args +# from trafficgen_reference.utils.config import load_config_init +# from trafficgen_reference.traffic_generator.utils.data_utils import process_data_to_internal_format +# +# args = get_parsed_args() +# cfg = load_config_init(args.config) +# model = TrafficGen(cfg) +# +# batch = [] +# +# for index, pickle_file in enumerate(tqdm(pickle_files)): +# md_path = os.path.join(input_folder, pickle_file) +# scenario = read_waymo_data(md_path) +# transformed = metadrive_scenario_to_init_data(scenario) +# +# # TODO(PZH): Temporarily remove post-processing. Decide later! +# batch.append(transformed) +# +# # out_path = os.path.join(output_folder, "{}.pkl".format(cnt)) +# # with open(out_path, "wb") as f: +# # pickle.dump(transformed, f) +# # print("File is saved at: ", out_path) +# # cnt += 1 +# +# internal_data = process_data_to_internal_format(transformed) +# data = internal_data[0] +# data = extend_batch_dim(data) +# model.place_vehicles_for_single_scenario(data, index=index, vis=True, vis_dir=vis_dir) diff --git a/scenestreamer/eval/test_nuscenes_eval.py b/scenestreamer/eval/test_nuscenes_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..821878a64d9137d4f410503106db4f6cb7bff8c4 --- /dev/null +++ b/scenestreamer/eval/test_nuscenes_eval.py @@ -0,0 +1,37 @@ +from tqdm import tqdm + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.utils import debug_tools + + +def toy_test(): + cfg_file = "cfgs/motion_debug.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + + config.PREPROCESSING["keep_all_data"] = True + config.EVALUATION.PREDICT_ALL_AGENTS = False + + config.DATA.TRAINING_DATA_DIR = 'data/nuscenes_debug' + config.DATA.TEST_DATA_DIR = 'data/nuscenes_debug' + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=10, + train_num_workers=0, + val_batch_size=8, + val_num_workers=2, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + for data_dict in tqdm(dataloader): + data_dict["decoder/output_action"] = data_dict["decoder/target_action"] + ground_truth_trajectory = data_dict["decoder/future_agent_position"][..., :2] + + # TODO: Call the eval function from nuscenes? + + +if __name__ == '__main__': + toy_test() diff --git a/scenestreamer/eval/test_trafficgen_eval.py b/scenestreamer/eval/test_trafficgen_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..a46a07a46a6625a5345f8dcc34e678e2d612e8d3 --- /dev/null +++ b/scenestreamer/eval/test_trafficgen_eval.py @@ -0,0 +1,1247 @@ +import copy + +import hydra +import omegaconf +import torch +import torchmetrics +import tqdm +import numpy as np +from numpy.core.defchararray import center + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils + +from shapely.geometry import Polygon + +import copy +import datetime +import pickle +import time +from enum import Enum + +import numpy as np +import torch +from shapely.geometry import Polygon +from torch import Tensor + + +class RoadEdgeType(Enum): + UNKNOWN = 0 + # Physical road boundary that doesn't have traffic on the other side (e.g., a curb or the k-rail on the right side of a freeway). + BOUNDARY = 1 + # Physical road boundary that separates the car from other traffic (e.g. a k-rail or an island). + MEDIAN = 2 + + @staticmethod + def is_road_edge(edge): + return True if edge.__class__ == RoadEdgeType else False + + @staticmethod + def is_sidewalk(edge): + return True if edge == RoadEdgeType.BOUNDARY else False + + +class RoadLineType(Enum): + UNKNOWN = 0 + BROKEN_SINGLE_WHITE = 1 + SOLID_SINGLE_WHITE = 2 + SOLID_DOUBLE_WHITE = 3 + BROKEN_SINGLE_YELLOW = 4 + BROKEN_DOUBLE_YELLOW = 5 + SOLID_SINGLE_YELLOW = 6 + SOLID_DOUBLE_YELLOW = 7 + PASSING_DOUBLE_YELLOW = 8 + + @staticmethod + def is_road_line(line): + return True if line.__class__ == RoadLineType else False + + @staticmethod + def is_yellow(line): + return True if line in [ + RoadLineType.SOLID_DOUBLE_YELLOW, RoadLineType.PASSING_DOUBLE_YELLOW, RoadLineType.SOLID_SINGLE_YELLOW, + RoadLineType.BROKEN_DOUBLE_YELLOW, RoadLineType.BROKEN_SINGLE_YELLOW + ] else False + + @staticmethod + def is_broken(line): + return True if line in [ + RoadLineType.BROKEN_DOUBLE_YELLOW, RoadLineType.BROKEN_SINGLE_YELLOW, RoadLineType.BROKEN_SINGLE_WHITE + ] else False + + +class AgentType(Enum): + UNSET = 0 + VEHICLE = 1 + PEDESTRIAN = 2 + CYCLIST = 3 + OTHER = 4 + + +def time_me(fn): + def _wrapper(*args, **kwargs): + start = time.clock() + ret = fn(*args, **kwargs) + return ret, time.clock() - start + + return _wrapper + + +def MDdata_to_initdata(MDdata): + ret = {} + tracks = MDdata['tracks'] + + ret['context_num'] = 1 + all_agent = np.zeros([128, 7]) + agent_mask = np.zeros(128) + + sdc = tracks[MDdata['sdc_index']]['state'] + all_agent[0, :2] = sdc[0, :2] + all_agent[0, 2:4] = sdc[0, 7:9] + all_agent[0, 4] = sdc[0, 6] + all_agent[0, 5:7] = sdc[0, 3:5] + + cnt = 1 + for id, track in tracks.items(): + if id == MDdata['sdc_index']: + continue + if not track['type'] == AgentType.VEHICLE: + continue + if track['state'][0, -1] == 0: + continue + state = track['state'] + all_agent[cnt, :2] = state[0, :2] + all_agent[cnt, 2:4] = state[0, 7:9] + all_agent[cnt, 4] = state[0, 6] + all_agent[cnt, 5:7] = state[0, 3:5] + cnt += 1 + + all_agent = all_agent[:32] + agent_num = min(32, cnt) + agent_mask[:agent_num] = 1 + agent_mask = agent_mask.astype(bool) + + lanes = [] + for k, lane in input['map'].items(): + a_lane = np.zeros([20, 4]) + tp = 0 + try: + lane_type = lane['type'] + except: + lane_type = lane['sign'] + poly_line = lane['polygon'] + if lane_type == 'cross_walk': + tp = 18 + elif lane_type == 'speed_bump': + tp = 19 + + if lane_type == 'center_lane': + poly_line = lane['polyline'] + tp = 1 + + elif lane_type == RoadEdgeType.BOUNDARY or lane_type == RoadEdgeType.MEDIAN: + tp = 15 if lane_type == RoadEdgeType.BOUNDARY else 16 + poly_line = lane['polyline'] + elif 'polyline' in lane: + tp = 7 + poly_line = lane['polyline'] + if tp == 0: + continue + + a_lane[:, 2] = tp + a_lane[:, :2] = poly_line + + lanes.append(a_lane) + lanes = np.stack(lanes) + + return + + +def get_polygon(center, yaw, L, W): + l, w = L / 2, W / 2 + yaw += torch.pi / 2 + theta = torch.atan(w / l) + s1 = torch.sqrt(l**2 + w**2) + x1 = abs(torch.cos(theta + yaw) * s1) + y1 = abs(torch.sin(theta + yaw) * s1) + x2 = abs(torch.cos(theta - yaw) * s1) + y2 = abs(torch.sin(theta - yaw) * s1) + + p1 = [center[0] + x1, center[1] + y1] + p2 = [center[0] + x2, center[1] - y2] + p3 = [center[0] - x1, center[1] - y1] + p4 = [center[0] - x2, center[1] + y2] + return Polygon([p1, p3, p2, p4]) + + +def get_agent_coord_from_vec(vec, long_lat): + vec = torch.tensor(vec) + x1, y1, x2, y2 = vec[:, 0], vec[:, 1], vec[:, 2], vec[:, 3] + x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2 + + vec_len = ((x1 - x2)**2 + (y1 - y2)**2)**0.5 + + vec_dir = torch.atan2(y2 - y1, x2 - x1) + + long_pos = vec_len * long_lat[..., 0] + lat_pos = vec_len * long_lat[..., 1] + + coord = rotate(lat_pos, long_pos, -np.pi / 2 + vec_dir) + + coord[:, 0] += x_center + coord[:, 1] += y_center + + return coord + + +def get_agent_pos_from_vec(vec, long_lat, speed, vel_heading, heading, bbox, use_rel_heading=True): + x1, y1, x2, y2 = vec[:, 0], vec[:, 1], vec[:, 2], vec[:, 3] + x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2 + + vec_len = ((x1 - x2)**2 + (y1 - y2)**2)**0.5 + + vec_dir = torch.atan2(y2 - y1, x2 - x1) + + long_pos = vec_len * long_lat[..., 0] + lat_pos = vec_len * long_lat[..., 1] + + coord = rotate(lat_pos, long_pos, -np.pi / 2 + vec_dir) + + coord[:, 0] += x_center + coord[:, 1] += y_center + + if use_rel_heading: + agent_dir = vec_dir + heading + else: + agent_dir = heading + + v_dir = vel_heading + agent_dir + + vel = torch.stack([torch.cos(v_dir) * speed, torch.sin(v_dir) * speed], axis=-1) + agent_num, _ = vel.shape + + type = Tensor([[1]]).repeat(agent_num, 1).to(coord.device) + agent = torch.cat([coord, vel, agent_dir.unsqueeze(1), bbox, type], dim=-1).detach().cpu().numpy() + + vec_based_rep = torch.cat( + [long_lat, speed.unsqueeze(-1), + vel_heading.unsqueeze(-1), + heading.unsqueeze(-1), vec], dim=-1 + ).detach().cpu().numpy() + + agent = WaymoAgent(agent, vec_based_rep) + + return agent + + +def process_lane(lane, max_vec, lane_range, offset=-40): + # dist = lane[..., 0]**2+lane[..., 1]**2 + # idx = np.argsort(dist) + # lane = lane[idx] + + vec_dim = 6 + + lane_point_mask = (abs(lane[..., 0] + offset) < lane_range) * (abs(lane[..., 1]) < lane_range) + + lane_id = np.unique(lane[..., -2]).astype(int) + + vec_list = [] + vec_mask_list = [] + vec_id_list = [] + b_s, _, lane_dim = lane.shape + + for id in lane_id: + id_set = lane[..., -2] == id + points = lane[id_set].reshape(b_s, -1, lane_dim) + masks = lane_point_mask[id_set].reshape(b_s, -1) + + vec_ids = np.ones([b_s, points.shape[1] - 1, 1]) * id + vector = np.zeros([b_s, points.shape[1] - 1, vec_dim]) + vector[..., 0:2] = points[:, :-1, :2] + vector[..., 2:4] = points[:, 1:, :2] + # id + # vector[..., 4] = points[:,1:, 3] + # type + vector[..., 4] = points[:, 1:, 2] + # traffic light + vector[..., 5] = points[:, 1:, 4] + vec_mask = masks[:, :-1] * masks[:, 1:] + vector[vec_mask == 0] = 0 + vec_list.append(vector) + vec_mask_list.append(vec_mask) + vec_id_list.append(vec_ids) + + vector = np.concatenate(vec_list, axis=1) if vec_list else np.zeros([b_s, 0, vec_dim]) + vector_mask = np.concatenate(vec_mask_list, axis=1) if vec_mask_list else np.zeros([b_s, 0], dtype=bool) + vec_id = np.concatenate(vec_id_list, axis=1) if vec_id_list else np.zeros([b_s, 0, 1]) + + all_vec = np.zeros([b_s, max_vec, vec_dim]) + all_mask = np.zeros([b_s, max_vec]) + all_id = np.zeros([b_s, max_vec, 1]) + + for t in range(b_s): + mask_t = vector_mask[t] + vector_t = vector[t][mask_t] + vec_id_t = vec_id[t][mask_t] + + dist = vector_t[..., 0]**2 + vector_t[..., 1]**2 + idx = np.argsort(dist) + vector_t = vector_t[idx] + mask_t = np.ones(vector_t.shape[0]) + vec_id_t = vec_id_t[idx] + + vector_t = vector_t[:max_vec] + mask_t = mask_t[:max_vec] + vec_id_t = vec_id_t[:max_vec] + + vector_t = np.pad(vector_t, ([0, max_vec - vector_t.shape[0]], [0, 0])) + mask_t = np.pad(mask_t, ([0, max_vec - mask_t.shape[0]])) + vec_id_t = np.pad(vec_id_t, ([0, max_vec - vec_id_t.shape[0]], [0, 0])) + + all_vec[t] = vector_t + all_mask[t] = mask_t + all_id[t] = vec_id_t + + return all_vec, all_mask.astype(bool), all_id.astype(int) + + +def process_map(lane, traf=None, center_num=384, edge_num=128, lane_range=60, offest=-40, rest_num=192): + lane_with_traf = np.zeros([*lane.shape[:-1], 5]) + lane_with_traf[..., :4] = lane + + lane_id = lane[..., -1] + b_s = lane_id.shape[0] + + # print(traf) + if traf is not None: + for i in range(b_s): + traf_t = traf[i] + lane_id_t = lane_id[i] + # print(traf_t) + for a_traf in traf_t: + # print(a_traf) + control_lane_id = a_traf[0] + state = a_traf[-2] + lane_idx = np.where(lane_id_t == control_lane_id) + lane_with_traf[i, lane_idx, -1] = state + lane = lane_with_traf + + # lane = np.delete(lane_with_traf,-2,axis=-1) + lane_type = lane[0, :, 2] + center_1 = lane_type == 1 + center_2 = lane_type == 2 + center_3 = lane_type == 3 + center_ind = center_1 + center_2 + center_3 + + boundary_1 = lane_type == 15 + boundary_2 = lane_type == 16 + bound_ind = boundary_1 + boundary_2 + + cross_walk = lane_type == 18 + speed_bump = lane_type == 19 + cross_ind = cross_walk + speed_bump + + rest = ~(center_ind + bound_ind + cross_walk + speed_bump + cross_ind) + + cent, cent_mask, cent_id = process_lane(lane[:, center_ind], center_num, lane_range, offest) + bound, bound_mask, _ = process_lane(lane[:, bound_ind], edge_num, lane_range, offest) + cross, cross_mask, _ = process_lane(lane[:, cross_ind], 32, lane_range, offest) + rest, rest_mask, _ = process_lane(lane[:, rest], rest_num, lane_range, offest) + + return cent, cent_mask, cent_id, bound, bound_mask, cross, cross_mask, rest, rest_mask + + +def get_time_str(): + return datetime.datetime.now().strftime("%y_%m_%d-%H_%M_%S") + + +def normalize_angle(angle): + if isinstance(angle, torch.Tensor): + while not torch.all(angle >= 0): + angle[angle < 0] += np.pi * 2 + while not torch.all(angle < np.pi * 2): + angle[angle >= np.pi * 2] -= np.pi * 2 + return angle + + else: + while not np.all(angle >= 0): + angle[angle < 0] += np.pi * 2 + while not np.all(angle < np.pi * 2): + angle[angle >= np.pi * 2] -= np.pi * 2 + + return angle + + +def cal_rel_dir(dir1, dir2): + dist = dir1 - dir2 + + while not np.all(dist >= 0): + dist[dist < 0] += np.pi * 2 + while not np.all(dist < np.pi * 2): + dist[dist >= np.pi * 2] -= np.pi * 2 + + dist[dist > np.pi] -= np.pi * 2 + return dist + + +def rotate(x, y, angle): + if isinstance(x, torch.Tensor): + other_x_trans = torch.cos(angle) * x - torch.sin(angle) * y + other_y_trans = torch.cos(angle) * y + torch.sin(angle) * x + output_coords = torch.stack((other_x_trans, other_y_trans), axis=-1) + + else: + other_x_trans = np.cos(angle) * x - np.sin(angle) * y + other_y_trans = np.cos(angle) * y + np.sin(angle) * x + output_coords = np.stack((other_x_trans, other_y_trans), axis=-1) + return output_coords + + +def from_list_to_batch(inp_list): + keys = inp_list[0].keys() + + batch = {} + for key in keys: + one_item = [item[key] for item in inp_list] + batch[key] = Tensor(np.stack(one_item)) + + return batch + + +def get_type_class(line_type): + if line_type in range(1, 4): + return 'center_lane' + elif line_type == 6: + return RoadLineType.BROKEN_SINGLE_WHITE + elif line_type == 7: + return RoadLineType.SOLID_SINGLE_WHITE + elif line_type == 8: + return RoadLineType.SOLID_DOUBLE_WHITE + elif line_type == 9: + return RoadLineType.BROKEN_SINGLE_YELLOW + elif line_type == 10: + return RoadLineType.BROKEN_DOUBLE_YELLOW + elif line_type == 11: + return RoadLineType.SOLID_SINGLE_YELLOW + elif line_type == 12: + return RoadLineType.SOLID_DOUBLE_YELLOW + elif line_type == 13: + return RoadLineType.PASSING_DOUBLE_YELLOW + elif line_type == 15: + return RoadEdgeType.BOUNDARY + elif line_type == 16: + return RoadEdgeType.MEDIAN + else: + return 'other' + + +def transform_to_metadrive_data(pred_i, other): + output_temp = {} + output_temp['id'] = 'fake' + output_temp['ts'] = [x / 10 for x in range(190)] + output_temp['dynamic_map_states'] = [{}] + output_temp['sdc_index'] = 0 + cnt = 0 + + center_info = other['center_info'] + output = copy.deepcopy(output_temp) + output['tracks'] = {} + output['map'] = {} + # extract agents + agent = pred_i + + for i in range(agent.shape[1]): + track = {} + agent_i = agent[:, i] + track['type'] = AgentType.VEHICLE + state = np.zeros([agent_i.shape[0], 10]) + state[:, :2] = agent_i[:, :2] + state[:, 3] = 5.286 + state[:, 4] = 2.332 + state[:, 7:9] = agent_i[:, 2:4] + state[:, -1] = 1 + state[:, 6] = agent_i[:, 4] # + np.pi / 2 + track['state'] = state + output['tracks'][i] = track + + # extract maps + lane = other['unsampled_lane'] + lane_id = np.unique(lane[..., -1]).astype(int) + for id in lane_id: + + a_lane = {} + id_set = lane[..., -1] == id + points = lane[id_set] + polyline = np.zeros([points.shape[0], 3]) + line_type = points[0, -2] + polyline[:, :2] = points[:, :2] + a_lane['type'] = get_type_class(line_type) + a_lane['polyline'] = polyline + if id in center_info.keys(): + a_lane.update(center_info[id]) + output['map'][id] = a_lane + + return output + + +def save_as_metadrive_data(pred_i, other, save_path): + output = transform_to_metadrive_data(pred_i, other) + + with open(save_path, 'wb') as f: + pickle.dump(output, f) + + +def rotate(x, y, angle): + if isinstance(x, torch.Tensor): + other_x_trans = torch.cos(angle) * x - torch.sin(angle) * y + other_y_trans = torch.cos(angle) * y + torch.sin(angle) * x + output_coords = torch.stack((other_x_trans, other_y_trans), axis=-1) + + else: + other_x_trans = np.cos(angle) * x - np.sin(angle) * y + other_y_trans = np.cos(angle) * y + np.sin(angle) * x + output_coords = np.stack((other_x_trans, other_y_trans), axis=-1) + return output_coords + + +def cal_rel_dir(dir1, dir2): + dist = dir1 - dir2 + + while not np.all(dist >= 0): + dist[dist < 0] += np.pi * 2 + while not np.all(dist < np.pi * 2): + dist[dist >= np.pi * 2] -= np.pi * 2 + + dist[dist > np.pi] -= np.pi * 2 + return dist + + +def angle_to_vector(angles): + return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1) + + +def gaussian_kernel(x, y, kernel_mul=2.0, kernel_num=5, fix_sigma=None): + total = torch.cat([x, y], dim=0) + n_samples = total.size(0) + + total0 = total.unsqueeze(0).expand(n_samples, n_samples, -1) + total1 = total.unsqueeze(1).expand(n_samples, n_samples, -1) + l2_distance = ((total0 - total1)**2).sum(2) + + if fix_sigma: + bandwidth = fix_sigma + else: + bandwidth = torch.sum(l2_distance.data) / (n_samples**2 - n_samples) + bandwidth /= kernel_mul**(kernel_num // 2) + + bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] + kernels = [torch.exp(-l2_distance / bw) for bw in bandwidth_list] + return sum(kernels) + + +class WaymoAgent: + def __init__(self, feature, vec_based_info=None, range=50, max_speed=30, from_inp=False): + # index of xy,v,lw,yaw,type,valid + + self.RANGE = range + self.MAX_SPEED = max_speed + + if from_inp: + + self.position = feature[..., :2] * self.RANGE + self.velocity = feature[..., 2:4] * self.MAX_SPEED + self.heading = np.arctan2(feature[..., 5], feature[..., 4])[..., np.newaxis] + self.length_width = feature[..., 6:8] + type = np.ones_like(self.heading) + self.feature = np.concatenate( + [self.position, self.velocity, self.heading, self.length_width, type], axis=-1 + ) + if vec_based_info is not None: + vec_based_rep = copy.deepcopy(vec_based_info) + vec_based_rep[..., 5:9] *= self.RANGE + vec_based_rep[..., 2] *= self.MAX_SPEED + self.vec_based_info = vec_based_rep + + else: + self.feature = feature + self.position = feature[..., :2] + self.velocity = feature[..., 2:4] + self.heading = feature[..., [4]] + self.length_width = feature[..., 5:7] + self.type = feature[..., [7]] + self.vec_based_info = vec_based_info + + @staticmethod + def from_list_to_array(inp_list): + MAX_AGENT = 32 + agent = np.concatenate([x.get_inp(act=True) for x in inp_list], axis=0) + agent = agent[:MAX_AGENT] + agent_num = agent.shape[0] + agent = np.pad(agent, ([0, MAX_AGENT - agent_num], [0, 0])) + agent_mask = np.zeros([agent_num]) + agent_mask = np.pad(agent_mask, ([0, MAX_AGENT - agent_num])) + agent_mask[:agent_num] = 1 + agent_mask = agent_mask.astype(bool) + return agent, agent_mask + + def get_agent(self, index): + return WaymoAgent(self.feature[[index]], self.vec_based_info[[index]]) + + def get_feature(self): + self.feature[..., :2] = self.position + self.feature[..., 2:4] = self.velocity + self.feature[..., [4]] = self.heading + + return self.feature + + def get_list(self): + # bs, agent_num, feature_dim = self.feature.shape + agent_num, feature_dim = self.feature.shape + vec_dim = self.vec_based_info.shape[-1] + feature = self.feature.reshape([-1, feature_dim]) + vec_rep = self.vec_based_info.reshape([-1, vec_dim]) + agent_num = feature.shape[0] + lis = [] + for i in range(agent_num): + # lis.append(WaymoAgent(feature[[i]], vec_rep[[i]])) + lis.append(WaymoAgent(feature[[i]], vec_rep[[i]])) + return lis + + def get_inp(self, act=False, act_inp=False): + + if act: + return np.concatenate([self.position, self.velocity, self.heading, self.length_width], axis=-1) + + pos = self.position / self.RANGE + velo = self.velocity / self.MAX_SPEED + cos_head = np.cos(self.heading) + sin_head = np.sin(self.heading) + + if act_inp: + return np.concatenate([pos, velo, cos_head, sin_head, self.length_width], axis=-1) + + vec_based_rep = copy.deepcopy(self.vec_based_info) + vec_based_rep[..., 5:9] /= self.RANGE + vec_based_rep[..., 2] /= self.MAX_SPEED + agent_feat = np.concatenate([pos, velo, cos_head, sin_head, self.length_width, vec_based_rep], axis=-1) + return agent_feat + + def get_rect(self, pad=0): + + l, w = (self.length_width[..., 0] + pad) / 2, (self.length_width[..., 1] + pad) / 2 + x1, y1 = l, w + x2, y2 = l, -w + + point1 = rotate(x1, y1, self.heading[..., 0]) + point2 = rotate(x2, y2, self.heading[..., 0]) + center = self.position + + x1, y1 = point1[..., [0]], point1[..., [1]] + x2, y2 = point2[..., [0]], point2[..., [1]] + + p1 = np.concatenate([center[..., [0]] + x1, center[..., [1]] + y1], axis=-1) + p2 = np.concatenate([center[..., [0]] + x2, center[..., [1]] + y2], axis=-1) + p3 = np.concatenate([center[..., [0]] - x1, center[..., [1]] - y1], axis=-1) + p4 = np.concatenate([center[..., [0]] - x2, center[..., [1]] - y2], axis=-1) + + p1 = p1.reshape(-1, p1.shape[-1]) + p2 = p2.reshape(-1, p1.shape[-1]) + p3 = p3.reshape(-1, p1.shape[-1]) + p4 = p4.reshape(-1, p1.shape[-1]) + + agent_num, dim = p1.shape + + rect_list = [] + for i in range(agent_num): + rect = np.stack([p1[i], p2[i], p3[i], p4[i]]) + rect_list.append(rect) + return rect_list + + def get_polygon(self): + rect_list = self.get_rect(pad=0.25) + + poly_list = [] + for i in range(len(rect_list)): + a = rect_list[i][0] + b = rect_list[i][1] + c = rect_list[i][2] + d = rect_list[i][3] + poly_list.append(Polygon([a, b, c, d])) + + return poly_list + + +def compute_mmd_different_sizes(x, y, kernel='gaussian', kernel_mul=2.0, kernel_num=5, fix_sigma=None): + assert x.ndim == 2 + assert y.ndim == 2 + if kernel == 'gaussian': + kernels = gaussian_kernel(x, y, kernel_mul, kernel_num, fix_sigma) + else: + raise ValueError("Currently, only Gaussian kernel is supported for different sizes.") + + n_x = x.size(0) + n_y = y.size(0) + + XX = kernels[:n_x, :n_x] # Kernel matrix for x vs x + YY = kernels[n_x:, n_x:] # Kernel matrix for y vs y + XY = kernels[:n_x, n_x:] # Kernel matrix for x vs y + YX = kernels[n_x:, :n_x] # Kernel matrix for y vs x + + # Normalize expectations to account for different sizes + # mmd = (XX.sum() / (n_x * n_x) + YY.sum() / (n_y * n_y) - XY.sum() / (n_x * n_y) - YX.sum() / (n_y * n_x)) + mmd = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) - torch.mean(YX) + mmd = torch.clamp(mmd, min=0.0) + return mmd + + +def normalize_angle(angle): + """ + From: https://github.com/metadriverse/trafficgen/blob/28b109e8e640d820192d5485bf9a28128b38ca21/trafficgen/utils/utils.py#L20 + """ + if isinstance(angle, torch.Tensor): + while not torch.all(angle >= 0): + angle[angle < 0] += np.pi * 2 + while not torch.all(angle < np.pi * 2): + angle[angle >= np.pi * 2] -= np.pi * 2 + return angle + + else: + while not np.all(angle >= 0): + angle[angle < 0] += np.pi * 2 + while not np.all(angle < np.pi * 2): + angle[angle >= np.pi * 2] -= np.pi * 2 + + return angle + + +class TrafficGenEvaluator: + def __init__(self, config, device=None): + self.use_tg_as_gt = config.EVALUATION.USE_TG_AS_GT + assert self.use_tg_as_gt == 1111, "no need to set USE_TG_AS_GT" + + def _transform_coordinate_map(self, data): + """ + Every frame is different + """ + timestep = data['all_agent'].shape[0] + + ego = data['all_agent'][:, 0] + pos = ego[:, [0, 1]][:, np.newaxis] + + lane = data['lane'][np.newaxis] + lane = np.repeat(lane, timestep, axis=0) + lane[..., :2] -= pos + + x = lane[..., 0] + y = lane[..., 1] + ego_heading = ego[:, [4]] + lane[..., :2] = rotate(x, y, -ego_heading) + + unsampled_lane = data['unsampled_lane'][np.newaxis] + unsampled_lane = np.repeat(unsampled_lane, timestep, axis=0) + unsampled_lane[..., :2] -= pos + + x = unsampled_lane[..., 0] + y = unsampled_lane[..., 1] + ego_heading = ego[:, [4]] + unsampled_lane[..., :2] = rotate(x, y, -ego_heading) + return lane, unsampled_lane[0] + + def _get_trafficgen_data(self, data_dict, current_t): + """ + PZH: + I don't want to waste time to read through the LCTGen code, + which essentially is from the TrafficGen code base. + I've read the TrafficGen code base and I really really don't want + to look into it for the second time. + Just copy the code here and modify it to fit the current code base. + """ + + from scenestreamer.eval.scenarionet_to_trafficgen import metadrive_scenario_to_init_data + + data = metadrive_scenario_to_init_data(data_dict["raw_scenario_description"][0]) + PZH_TRACK_NAMES = data["PZH_TRACK_NAMES"] + case_info = {} + other = {} + + # agent = copy.deepcopy(data['all_agent']) + other['traf'] = copy.deepcopy(data['traffic_light']) + + max_time_step = 190 + gap = 190 + index = -1 + RANGE = 50 + + if index == -1: + data['all_agent'] = data['all_agent'][current_t:max_time_step:gap] + data['traffic_light'] = data['traffic_light'][current_t:max_time_step:gap] + else: + raise ValueError + data['lane'], other['unsampled_lane'] = self._transform_coordinate_map(data) + other['lane'] = data['lane'] + + def _process_agent(agent, sort_agent): + + ego = agent[:, 0] + + # transform every frame into ego coordinate in the first frame + ego_pos = copy.deepcopy(ego[[0], :2])[:, np.newaxis] + ego_heading = ego[[0], [4]] + + agent[..., :2] -= ego_pos + agent[..., :2] = rotate(agent[..., 0], agent[..., 1], -ego_heading) + agent[..., 2:4] = rotate(agent[..., 2], agent[..., 3], -ego_heading) + agent[..., 4] -= ego_heading + + agent_mask = agent[..., -1] + agent_type_mask = agent[..., -2] + agent_range_mask = (abs(agent[..., 0]) < RANGE) * (abs(agent[..., 1]) < RANGE) + + mask = agent_mask * agent_type_mask + # use agent range mask only for the first frame + # allow agent to be out of range in the future frames + mask[0, :] *= agent_range_mask[0, :] + + return agent, mask.astype(bool) + + case_info["agent"], case_info["agent_mask"] = _process_agent(data['all_agent'], False) + case_info['center'], case_info['center_mask'], case_info['center_id'], case_info['bound'], case_info[ + 'bound_mask'], \ + case_info['cross'], case_info['cross_mask'], case_info['rest'], case_info['rest_mask'] = process_map( + data['lane'], data['traffic_light'], lane_range=RANGE, offest=0) + + # get vector-based representatiomn + def _get_vec_based_rep(case_info, PZH_TRACK_NAMES): + THRES = 5 + thres = THRES + # max_agent_num = 32 + # _process future agent + + agent = case_info['agent'] + vectors = case_info["center"] + + agent_mask = case_info['agent_mask'] + + vec_x = ((vectors[..., 0] + vectors[..., 2]) / 2) + vec_y = ((vectors[..., 1] + vectors[..., 3]) / 2) + + agent_x = agent[..., 0] + agent_y = agent[..., 1] + + b, vec_num = vec_y.shape + _, agent_num = agent_x.shape + + vec_x = np.repeat(vec_x[:, np.newaxis], axis=1, repeats=agent_num) + vec_y = np.repeat(vec_y[:, np.newaxis], axis=1, repeats=agent_num) + + agent_x = np.repeat(agent_x[:, :, np.newaxis], axis=-1, repeats=vec_num) + agent_y = np.repeat(agent_y[:, :, np.newaxis], axis=-1, repeats=vec_num) + + dist = np.sqrt((vec_x - agent_x)**2 + (vec_y - agent_y)**2) + + cent_mask = np.repeat(case_info['center_mask'][:, np.newaxis], axis=1, repeats=agent_num) + dist[cent_mask == 0] = 10e5 + vec_index = np.argmin(dist, -1) + min_dist_to_lane = np.min(dist, -1) + min_dist_mask = min_dist_to_lane < thres + + selected_vec = np.take_along_axis(vectors, vec_index[..., np.newaxis], axis=1) + + vx, vy = agent[..., 2], agent[..., 3] + v_value = np.sqrt(vx**2 + vy**2) + low_vel = v_value < 0.1 + + dir_v = np.arctan2(vy, vx) + x1, y1, x2, y2 = selected_vec[..., 0], selected_vec[..., 1], selected_vec[..., 2], selected_vec[..., 3] + dir = np.arctan2(y2 - y1, x2 - x1) + agent_dir = agent[..., 4] + + v_relative_dir = cal_rel_dir(dir_v, agent_dir) + relative_dir = cal_rel_dir(agent_dir, dir) + + v_relative_dir[low_vel] = 0 + + v_dir_mask = abs(v_relative_dir) < np.pi / 6 + dir_mask = abs(relative_dir) < np.pi / 4 + + agent_x = agent[..., 0] + agent_y = agent[..., 1] + vec_x = (x1 + x2) / 2 + vec_y = (y1 + y2) / 2 + + cent_to_agent_x = agent_x - vec_x + cent_to_agent_y = agent_y - vec_y + + coord = rotate(cent_to_agent_x, cent_to_agent_y, np.pi / 2 - dir) + + vec_len = np.clip(np.sqrt(np.square(y2 - y1) + np.square(x1 - x2)), a_min=4.5, a_max=5.5) + + lat_perc = np.clip(coord[..., 0], a_min=-vec_len / 2, a_max=vec_len / 2) / vec_len + long_perc = np.clip(coord[..., 1], a_min=-vec_len / 2, a_max=vec_len / 2) / vec_len + + # ignore other masks for future agents (to support out-of-range agent prediction) + total_mask = agent_mask + # for the first frame, use all masks to filter out off-road agents + total_mask[0, :] = (min_dist_mask * agent_mask * v_dir_mask * dir_mask)[0, :] + + total_mask[:, 0] = 1 + total_mask = total_mask.astype(bool) + + b_s, agent_num, agent_dim = agent.shape + agent_ = np.zeros([b_s, agent_num, agent_dim]) + agent_mask_ = np.zeros([b_s, agent_num]).astype(bool) + + the_vec = np.take_along_axis(vectors, vec_index[..., np.newaxis], 1) + # 0: vec_index + # 1-2 long and lat percent + # 3-5 velocity and direction + # 6-9 lane vector + # 10-11 lane type and traff state + info = np.concatenate( + [ + vec_index[..., np.newaxis], long_perc[..., np.newaxis], lat_perc[..., np.newaxis], + v_value[..., np.newaxis], v_relative_dir[..., np.newaxis], relative_dir[..., np.newaxis], the_vec + ], -1 + ) + + info_ = np.zeros([b_s, agent_num, info.shape[-1]]) + + start_mask = total_mask[0] + for i in range(agent.shape[0]): + agent_i = agent[i][start_mask] + info_i = info[i][start_mask] + + step_mask = total_mask[i] + valid_mask = step_mask[start_mask] + + agent_i = agent_i[:agent_num] + info_i = info_i[:agent_num] + + valid_num = agent_i.shape[0] + agent_i = np.pad(agent_i, [[0, agent_num - agent_i.shape[0]], [0, 0]]) + info_i = np.pad(info_i, [[0, agent_num - info_i.shape[0]], [0, 0]]) + + agent_[i] = agent_i + info_[i] = info_i + agent_mask_[i, :valid_num] = valid_mask[:valid_num] + + PZH_TRACK_NAMES_new = np.array(list(PZH_TRACK_NAMES[start_mask]) + [None] * (agent_num - start_mask.sum())) + + case_info['vec_based_rep'] = info_[..., 1:] + case_info['agent_vec_index'] = info_[..., 0].astype(int) + case_info['agent_mask'] = agent_mask_ + case_info["agent"] = agent_ + + return case_info, PZH_TRACK_NAMES_new + + case_info, PZH_TRACK_NAMES = _get_vec_based_rep(case_info, PZH_TRACK_NAMES) + + case_num = case_info['agent'].shape[0] + case_list = [] + for i in range(case_num): + dic = {} + for k, v in case_info.items(): + dic[k] = v[i] + case_list.append(dic) + + # PZH: Obviously, you only pick T=0 from the data. + ret = case_list[0] + ret["PZH_TRACK_NAMES"] = PZH_TRACK_NAMES + return ret + + def validation_step(self, data_dict, stat, log_func, **kwargs): + B = data_dict["decoder/modeled_agent_position_for_trafficgen"].shape[0] + assert B == 1 + + current_t = data_dict["metadata/current_time_index"].item() + + agent_pos = data_dict["decoder/modeled_agent_position_for_trafficgen"] # (N, 2) + agent_heading = data_dict["decoder/modeled_agent_heading_for_trafficgen"] # (N, 1) + agent_velocity = data_dict["decoder/modeled_agent_velocity_for_trafficgen"] # (N, 2) + agent_shape = data_dict["decoder/current_agent_shape_for_trafficgen"] # (N, 3) + agent_mask = data_dict["decoder/input_action_valid_mask_for_trafficgen"] # (N,) + agent_type = data_dict["decoder/agent_type_for_trafficgen"] # (N,) + + trafficgen_data = self._get_trafficgen_data(data_dict, current_t) + trafficgen_select_track_names = trafficgen_data["PZH_TRACK_NAMES"][trafficgen_data['agent_mask']] + decoder_track_name = list(data_dict["decoder/track_name"][0]) + trafficgen_select_index = [] + for name in trafficgen_select_track_names: + if name in decoder_track_name: + trafficgen_select_index.append(decoder_track_name.index(name)) + else: + # print(11) + pass + + all_select_index = data_dict["decoder/agent_valid_mask"][0, current_t].nonzero()[:, 0] + + for i in range(B): + pos_target = data_dict["decoder/agent_position"][i, current_t, trafficgen_select_index, :2] + vel_target = data_dict["decoder/agent_velocity"][i, current_t, trafficgen_select_index, :2] + head_target = data_dict["decoder/agent_heading"][i, current_t, trafficgen_select_index] + size_target = data_dict["decoder/agent_shape"][i, current_t, trafficgen_select_index] + actor_type = data_dict["decoder/agent_type"][i][trafficgen_select_index] + num_target = len(trafficgen_select_index) + + pos_pred = agent_pos[i, agent_mask[i]] + head_pred = agent_heading[i, agent_mask[i]] + vel_pred = agent_velocity[i, agent_mask[i]] + size_pred = agent_shape[i, agent_mask[i]] + type_pred = agent_type[i, agent_mask[i]] + num_pred = len(pos_pred) + + from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour, detect_collision + poly = cal_polygon_contour( + x=pos_pred[:, 0].cpu().numpy(), + y=pos_pred[:, 1].cpu().numpy(), + theta=head_pred.cpu().numpy(), + width=size_pred[:, 1].cpu().numpy(), + length=size_pred[:, 0].cpu().numpy() + ) + collision_detected = np.zeros(len(poly), dtype=bool) + for i in range(len(poly) - 1): + agent_collision_detected = [] + for j in range(i + 1, len(poly)): + poly1 = Polygon(poly[i]) + poly2 = Polygon(poly[j]) + if poly1.intersects(poly2): + collision_detected[i] = True + collision_detected[j] = True + static_collision_rate = collision_detected.mean() + log_func("static_collision_rate", static_collision_rate) + + # Compute the position matching here + + for suffix, (gt_mask, pred_mask) in { + "_all": (pos_pred.new_ones(num_target, dtype=bool), pos_pred.new_ones(num_pred, dtype=bool)), + # "_vehicle": (actor_type == 1, type_pred == 0), + # "_pedestrian": (actor_type == 2, type_pred == 1), + # "_cyclist": (actor_type == 3, type_pred == 2), + }.items(): + if not gt_mask.any(): + continue + if not pred_mask.any(): + continue + log_func(f"num_gt_samples{suffix}", gt_mask.sum().item()) + log_func(f"num_pred_samples{suffix}", pred_mask.sum().item()) + + # Follow: https://github.com/metadriverse/trafficgen/blob/28b109e8e640d820192d5485bf9a28128b38ca21/trafficgen/test_init.py#L44C5-L44C16 + kernel_mul = 1.0 + kernel_num = 1 + + center_head = head_target[0] + mmd_pos = compute_mmd_different_sizes( + x=pos_pred[pred_mask], + y=pos_target[gt_mask], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + log_func(f"mmd_pos{suffix}", mmd_pos) + + log_func( + f"mmd_vel{suffix}", + compute_mmd_different_sizes( + x=vel_pred[pred_mask], + y=vel_target[gt_mask], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + ) + + after_sdc_center = compute_mmd_different_sizes( + x=normalize_angle(head_pred[pred_mask] - center_head)[:, None], + y=normalize_angle(head_target[gt_mask] - center_head)[:, None], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + no_sdc_center = compute_mmd_different_sizes( + x=normalize_angle(head_pred[pred_mask])[:, None], + y=normalize_angle(head_target[gt_mask])[:, None], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + + log_func(f"mmd_head{suffix}", no_sdc_center) + log_func(f"mmd_head_center{suffix}", after_sdc_center) + + transformed_pred = angle_to_vector(normalize_angle(head_pred[pred_mask])) + transformed_target = angle_to_vector(normalize_angle(head_target[gt_mask])) + log_func( + f"mmd_head_transformed{suffix}", + compute_mmd_different_sizes( + x=transformed_pred, + y=transformed_target, + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + ) + log_func( + f"mmd_size{suffix}", + compute_mmd_different_sizes( + x=size_pred[pred_mask][..., :2], + y=size_target[gt_mask][..., :2], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + ) + + if "num_collisions" in stat: + log_func("num_collisions", stat["num_collisions"]) + if "num_violations" in stat: + log_func("num_violations", stat["num_violations"]) + + for i in range(B): + pos_target = data_dict["decoder/agent_position"][i, current_t, all_select_index, :2] + vel_target = data_dict["decoder/agent_velocity"][i, current_t, all_select_index, :2] + head_target = data_dict["decoder/agent_heading"][i, current_t, all_select_index] + size_target = data_dict["decoder/agent_shape"][i, current_t, all_select_index] + actor_type = data_dict["decoder/agent_type"][i][all_select_index] + num_target = len(all_select_index) + + pos_pred = agent_pos[i, agent_mask[i]] + head_pred = agent_heading[i, agent_mask[i]] + vel_pred = agent_velocity[i, agent_mask[i]] + size_pred = agent_shape[i, agent_mask[i]] + type_pred = agent_type[i, agent_mask[i]] + num_pred = len(pos_pred) + + # Compute the position matching here + for suffix, (gt_mask, pred_mask) in { + "_all": (pos_pred.new_ones(num_target, dtype=bool), pos_pred.new_ones(num_pred, dtype=bool)), + # "_vehicle": (actor_type == 1, type_pred == 0), + # "_pedestrian": (actor_type == 2, type_pred == 1), + # "_cyclist": (actor_type == 3, type_pred == 2), + }.items(): + if not gt_mask.any(): + continue + if not pred_mask.any(): + continue + log_func(f"ALL_AGENT_num_gt_samples{suffix}", gt_mask.sum().item()) + log_func(f"ALL_AGENT_num_pred_samples{suffix}", pred_mask.sum().item()) + + # Follow: https://github.com/metadriverse/trafficgen/blob/28b109e8e640d820192d5485bf9a28128b38ca21/trafficgen/test_init.py#L44C5-L44C16 + kernel_mul = 1.0 + kernel_num = 1 + + center_head = head_target[0] + mmd_pos = compute_mmd_different_sizes( + x=pos_pred[pred_mask], + y=pos_target[gt_mask], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + log_func(f"ALL_AGENT_mmd_pos{suffix}", mmd_pos) + + log_func( + f"ALL_AGENT_mmd_vel{suffix}", + compute_mmd_different_sizes( + x=vel_pred[pred_mask], + y=vel_target[gt_mask], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + ) + + after_sdc_center = compute_mmd_different_sizes( + x=normalize_angle(head_pred[pred_mask] - center_head)[:, None], + y=normalize_angle(head_target[gt_mask] - center_head)[:, None], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + no_sdc_center = compute_mmd_different_sizes( + x=normalize_angle(head_pred[pred_mask])[:, None], + y=normalize_angle(head_target[gt_mask])[:, None], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + + log_func(f"ALL_AGENT_mmd_head{suffix}", no_sdc_center) + log_func(f"ALL_AGENT_mmd_head_center{suffix}", after_sdc_center) + + transformed_pred = angle_to_vector(normalize_angle(head_pred[pred_mask])) + transformed_target = angle_to_vector(normalize_angle(head_target[gt_mask])) + log_func( + f"ALL_AGENT_mmd_head_transformed{suffix}", + compute_mmd_different_sizes( + x=transformed_pred, + y=transformed_target, + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + ) + log_func( + f"ALL_AGENT_mmd_size{suffix}", + compute_mmd_different_sizes( + x=size_pred[pred_mask][..., :2], + y=size_target[gt_mask][..., :2], + kernel_mul=kernel_mul, + kernel_num=kernel_num, + ) + ) + + def on_validation_epoch_end(self, *args, trainer, global_rank, **kwargs): + print("Rank", global_rank) + trainer.strategy.barrier() + + # for k in self.mmd_metrics: + # print_func(f'eval/{k}', self.mmd_metrics[k].compute()) + # log_func(f'eval/{k}', self.mmd_metrics[k].compute()) + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1214_midgpt_v14.yaml") +def main(config): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + path = "/Users/pengzhenghao/PycharmProjects/scenestreamer/lightning_logs/1231_MidGPT_V17_Bicy_WTrafficGen_2024-12-31/" + model = utils.get_model(checkpoint_path=path, device=device) + + # model = utils.get_model(config, device=device) + + evaluator = TrafficGenEvaluator(config) + + config = model.config + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING["keep_all_data"] = True + config.DATA.TRAINING_DATA_DIR = "data/20scenarios" + config.DATA.TEST_DATA_DIR = "data/20scenarios" + + assert config.USE_TRAFFICGEN is True + + test_dataset = SceneStreamerDataset(config, "training") + # ddd = iter(test_dataset) + + START_ACTION = config.PREPROCESSING.MAX_MAP_FEATURES + END_ACTION = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + + for count, raw_data_dict in enumerate(tqdm.tqdm(test_dataset)): + data_dict = raw_data_dict + data_dict = utils.numpy_to_torch(data_dict, device=device) + data_dict = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} + data_dict = copy.deepcopy(data_dict) + + data_dict = model.model.encode_scene(data_dict) + output_dict, stat = model.model.trafficgen_decoder.autoregressive_rollout_trafficgen(data_dict) + + # output_dict = data_dict + # output_dict = {k: v[0] if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} + # output_dict = utils.torch_to_numpy(output_dict) + + # suffix = "gt" if use_gt else "pred" + # save_path = pathlib.Path("0107_trafficgen") / f"trafficgen_{sid}_{suffix}.png" + # save_path.parent.mkdir(exist_ok=True) + # print(f"Saving to {save_path}") + # plot_trafficgen(output_dict, show=False, save_path=save_path) + + # Call evaluator + evaluator.validation_step(output_dict, stat, log_func=print) + + if count > 3: + break + + evaluator.on_validation_epoch_end(print_func=print) + print("End") + + +if __name__ == '__main__': + main() diff --git a/scenestreamer/eval/test_waymo_eval.py b/scenestreamer/eval/test_waymo_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3acd85ca270f9d6facb6cf2d1403c1ea32752e2d --- /dev/null +++ b/scenestreamer/eval/test_waymo_eval.py @@ -0,0 +1,179 @@ +import numpy as np +import torch +from tqdm import tqdm + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.eval.waymo_eval import waymo_evaluation_optimized +from scenestreamer.eval.waymo_motion_prediction_evaluator import generate_submission, transform_to_global_coordinate +from scenestreamer.utils import debug_tools + + +def _unbatch_to_numpy(tensor_dict, index=0): + ret = {} + for k, v in tensor_dict.items(): + ret[k] = v[index].numpy() + return ret + + +def _batch_to_tensor(array_list): + return torch.from_numpy(np.array(array_list)) + + +def run_waymo_eval(data_dict, pred_trajs): + # scores = data_dict["decoder/output_score"] + # pred_trajs = data_dict["decoder/reconstructed_position"] + + num_modes_nms = 32 + num_modes_for_eval = 6 + + # Let's test the GT trajectory first. + + B, T, N, _ = pred_trajs.shape + scores = pred_trajs.new_ones(size=(B, N)) + + scores_of_interested_agents = [] + pred_trajs_of_interested_agents = [] + for batch_index, track_indices in enumerate(data_dict["eval/modeled_agent_id"]): + for mode_index in range(num_modes_nms): + sc = torch.stack( + [ + scores[batch_index][agent_index] # .detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + dim=0 + ) + traj = torch.stack( + [ + pred_trajs[batch_index][:, agent_index] # .detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + dim=1 + ) + pred_trajs_of_interested_agents.append(traj) + scores_of_interested_agents.append(sc) + + from scenestreamer.eval.nms import batch_nms + pred_trajs_of_interested_agents, scores_of_interested_agents = batch_nms( + pred_trajs_of_interested_agents, + scores_of_interested_agents, + pred_to_scenario_id=np.repeat(data_dict["scenario_id"], num_modes_nms, axis=0), + dist_thresh=2.5, # Follow MTR + num_ret_modes=num_modes_for_eval, # TODO + num_original_modes=num_modes_nms, # TODO + ) + + prediction_dict = { + "pred_trajs": pred_trajs_of_interested_agents, + "pred_scores": scores_of_interested_agents, + "pred_to_scenario_id": np.repeat(data_dict["scenario_id"], num_modes_for_eval, axis=0), + "expanded_map_center": np.repeat(data_dict["metadata/map_center"], num_modes_for_eval, axis=0), + "expanded_map_heading": np.repeat(data_dict["metadata/map_heading"], num_modes_for_eval, axis=0) + } + for k, v in data_dict.items(): + if k.startswith("eval/") or k.startswith("metadata/"): + prediction_dict[k] = v + prediction_dict = { + k: (v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in prediction_dict.items() + } + prediction_dict = transform_to_global_coordinate(prediction_dict) + + return prediction_dict + + +def toy_test(): + cfg_file = "cfgs/motion_debug_2_local_train.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + config.PREPROCESSING["keep_all_data"] = True + config.EVALUATION.PREDICT_ALL_AGENTS = True + datamodule = SceneStreamerDataModule( + config, + train_batch_size=10, + train_num_workers=0, + val_batch_size=8, + val_num_workers=2, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + validation_list = [] + for data_dict in tqdm(dataloader): + pred_trajs = data_dict["decoder/future_agent_position"][..., :2] + validation_list.append(run_waymo_eval(data_dict, pred_trajs=pred_trajs)) + result_dict, result_str = waymo_evaluation_optimized(validation_list) + print(result_str) + + +def test_tokenizer_with_waymo_eval(): + cfg_file = "cfgs/motion_debug_2_local_train.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + + config.PREPROCESSING["keep_all_data"] = True + config.EVALUATION.PREDICT_ALL_AGENTS = False + + # config.DATA.TRAINING_DATA_DIR = "data/waymo_debug_f9d25ee7375ca381" + # config.DATA.TEST_DATA_DIR = "data/waymo_debug_f9d25ee7375ca381" + + config.DATA.TRAINING_DATA_DIR = 'data/metadrive_processed_waymo/validation' + config.DATA.TEST_DATA_DIR = 'data/metadrive_processed_waymo/validation' + + config.TOKENIZATION.TOKENIZATION_METHOD = "delta_delta" + config.TOKENIZATION.X_MAX = 4 + config.TOKENIZATION.X_MIN = -4 + config.TOKENIZATION.Y_MAX = 3 + config.TOKENIZATION.Y_MIN = -3 + + # config.TOKENIZATION.TOKENIZATION_METHOD = "delta" + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=10, + train_num_workers=0, + val_batch_size=8, + val_num_workers=2, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + from scenestreamer.tokenization import motion_tokenizers + tokenizer = motion_tokenizers.get_tokenizer(config) + + validation_list = [] + for data_dict in tqdm(dataloader): + data_dict["decoder/output_action"] = data_dict["decoder/target_action"] + with torch.no_grad(): + data_dict = tokenizer.detokenize(data_dict) + # pred_trajs = data_dict["decoder/reconstructed_position"] + + pred_trajs = data_dict["decoder/future_agent_position"][..., :2] + + validation_list.append(run_waymo_eval(data_dict, pred_trajs=pred_trajs)) + result_dict, result_str, submission_dict = waymo_evaluation_optimized(validation_list, generate_submission=True) + + print("\n\n", result_str) + + submission_prefix = "test" + account_name = "test" + unique_method_name = "test" + output_dir = "." + path, duplicated_scenarios, done_scenarios = generate_submission( + prefix=submission_prefix, + account_name=account_name, + unique_method_name=unique_method_name, + output_dir=output_dir, + **submission_dict + ) + print( + "Submission created at: {}. Finished {} scenarios. Duplicated scenarios: {}.".format( + path, len(done_scenarios), duplicated_scenarios + ) + ) + + +if __name__ == '__main__': + # toy_test() + test_tokenizer_with_waymo_eval() diff --git a/scenestreamer/eval/test_wosac_eval.py b/scenestreamer/eval/test_wosac_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9461a0cfd3ade0037e6de48b7e935e15879ad8 --- /dev/null +++ b/scenestreamer/eval/test_wosac_eval.py @@ -0,0 +1,193 @@ +import numpy as np +import torch +from tqdm import tqdm + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.eval.waymo_motion_prediction_evaluator import transform_to_global_coordinate +from scenestreamer.eval.wosac_eval import wosac_evaluation +from scenestreamer.utils import debug_tools + +# from metadrive.scenario.scenario_description import ScenarioDescription as SD, MetaDriveType +# from waymo_open_dataset.protos import sim_agents_submission_pb2 + +# from waymo_open_dataset.utils.sim_agents import test_utils as sim_agents_test_utils +# from waymo_open_dataset.utils.sim_agents import visualizations +# from waymo_open_dataset.utils import trajectory_utils + + +def data_dict_to_sim_agents_prediction(data_dict, pred_trajs, pred_headings): + """ + Transforms data_dict (the global dictionary that contains all information) to the format used in the sim agents prediction challenge evaluation pipeline. + pipeline. + + Args: + data_dict: the global dictionary that contains all information. Important keys: + encoder/ + decoder/ + decoder/ + eval/ + in_evaluation + metadata/ + decoder/ + pred_trajs: the predicted trajectories. Shape: (B, T, N, 2). + pred_headings: the predicted headings. Shape: (B, T, N). + + Returns: + prediction_dict: a dictionary that contains all information needed for the sim agents prediction pipeline. + pred_trajs: Trajectories of agents to evaluate. Shape: (B * num_modes, T, N, 2). + pred_scores: Scores of the predicted trajectories. Shape: (B * num_modes, N). + pred_to_scenario_id: Scenario ID of each prediction. Shape: (B * num_modes,). + Anything with the prefix eval/ from data_dict. + """ + # scores = data_dict["decoder/output_score"] + # pred_trajs = data_dict["decoder/reconstructed_position"] + + num_modes_for_eval = 32 + + # Let's test the GT trajectory first. + + B, T, N, _ = pred_trajs.shape + scores = pred_trajs.new_ones(size=(B, N)) + + scores_of_interested_agents = [] + pred_trajs_of_interested_agents = [] + pred_headings_of_interested_agents = [] + for batch_index, track_indices in enumerate(data_dict["eval/modeled_agent_id"]): + for mode_index in range(num_modes_for_eval): + sc = np.stack( + [ + scores[batch_index][agent_index].detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + axis=0 + ) + traj = np.stack( + [ + pred_trajs[batch_index][:, agent_index].detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + axis=1 + ) + heading = np.stack( + [ + pred_headings[batch_index][:, agent_index].detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + axis=1 + ) + pred_trajs_of_interested_agents.append(traj) + scores_of_interested_agents.append(sc) + pred_headings_of_interested_agents.append(heading) + + prediction_dict = { + "pred_trajs": pred_trajs_of_interested_agents, + "pred_scores": scores_of_interested_agents, + "pred_headings": pred_headings_of_interested_agents, + "pred_to_scenario_id": np.repeat(data_dict["scenario_id"], num_modes_for_eval, axis=0), + "expanded_map_center": data_dict["metadata/map_center"][:, None].repeat(1, num_modes_for_eval, 1).flatten(0, 1), + "expanded_map_heading": data_dict["metadata/map_heading"][:, None].repeat(1, num_modes_for_eval, + 1).flatten(0, 1), + } + + # Copy over all eval/ keys from data_dict. + for k, v in data_dict.items(): + if k.startswith("eval/") or k.startswith("metadata/"): + prediction_dict[k] = v + + # Convert all torch.Tensor to numpy. + prediction_dict = { + k: (v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in prediction_dict.items() + } + return prediction_dict + + +def test_tokenizer_with_wosac_eval(): + """ + Run the tokenizer and evaluate the result with the Waymo evaluation pipeline. + """ + cfg_file = "cfgs/motion_debug_2_local_train.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + + config.PREPROCESSING["keep_all_data"] = True + config.EVALUATION.PREDICT_ALL_AGENTS = True + config.DATA.TRAINING_DATA_DIR = "data/waymo_8s_debug" + config.DATA.TEST_DATA_DIR = "data/waymo_8s_debug" + + config.DATA.SAMPLE_INTERVAL = {'training': 1, 'test': 1} + + config.TOKENIZATION.TOKENIZATION_METHOD = "delta_delta" + config.TOKENIZATION.X_MAX = 4 + config.TOKENIZATION.X_MIN = -4 + config.TOKENIZATION.Y_MAX = 3 + config.TOKENIZATION.Y_MIN = -3 + + # config.TOKENIZATION.TOKENIZATION_METHOD = "delta" + + config.DATA["SD_PASSTHROUGH"] = True + # Note: The datamodules, when iterated over, return dictionaries (data_dicts). + datamodule = SceneStreamerDataModule( + config, + train_batch_size=10, + train_num_workers=0, + val_batch_size=8, + val_num_workers=0, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") # "fit" here doesn't mean anything (yet!) + dataloader = datamodule.val_dataloader() + + from scenestreamer.tokenization import motion_tokenizers + tokenizer = motion_tokenizers.get_tokenizer(config) + + validation_list = [] + for data_dict in tqdm(dataloader): + # We can check our discretization error by detokenizing the tokenized ground truth. + data_dict["decoder/output_action"] = data_dict["decoder/target_action"] + with torch.no_grad(): + data_dict = tokenizer.detokenize(data_dict) + # pred_trajs = data_dict["decoder/reconstructed_position"] + pred_headings = data_dict["decoder/reconstructed_heading"] + pred_trajs = data_dict["decoder/future_agent_position"][..., :2] + + new_prediction_dict = data_dict_to_sim_agents_prediction( + data_dict, pred_trajs=pred_trajs, pred_headings=pred_headings + ) + + new_prediction_dict = transform_to_global_coordinate(new_prediction_dict) + + validation_list.append(new_prediction_dict) + # Validation list: A list of pred_dicts. + print("Evaluating...") + scenario_metrics, aggregate_metrics = wosac_evaluation(validation_list) + print(scenario_metrics) + print("\n\n\n") + print(aggregate_metrics) + + +# def run_wosac_submission( +# test_file='/scratch/metadrive/data/uncompressed_scenario_validation_validation.tfrecord-00000-of-00150'): +# # Read the dataset from the .tfrecord file. +# filename = tf.io.matching_files(test_file) +# +# dataset = tf.data.TFRecordDataset(filename) +# dataset_iterator = dataset.as_numpy_iterator() +# +# bytes_example = next(dataset_iterator) +# scenario = scenario_pb2.Scenario.FromString(bytes_example) +# print(f'Checking type: {type(scenario)}') +# print(f'Loaded scenario with ID: {scenario.scenario_id}') +# +# print(f'Simulation length, in steps: {submission_specs.N_SIMULATION_STEPS}') +# print( +# f'Duration of a step, in seconds: {submission_specs.STEP_DURATION_SECONDS}s (frequency: {1 / submission_specs.STEP_DURATION_SECONDS}Hz)') +# print(f'Number of parallel simulations per Scenario: {submission_specs.N_ROLLOUTS}') +# +# logged_trajectories, simulated_states = simulate_with_extrapolation( +# scenario, print_verbose_comments=True) +# generate_submission([scenario], [simulated_states], [logged_trajectories]) + +if __name__ == "__main__": + test_tokenizer_with_wosac_eval() + # run_wosac_submission() diff --git a/scenestreamer/eval/waymo_eval.py b/scenestreamer/eval/waymo_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..16b5a1baec883dcbb9a59f278223c0d43fe29240 --- /dev/null +++ b/scenestreamer/eval/waymo_eval.py @@ -0,0 +1,674 @@ +""" +This file provides some utility functions for preparing waymo submission, and evaluate on waymo validation dataset. + +To install: + +conda install python=3.10 +pip install waymo-open-dataset-tf-2-12-0==1.6.4 + +https://github.com/waymo-research/waymo-open-dataset.git +""" +from collections import defaultdict + +import numpy as np + +try: + import tensorflow as tf + from google.protobuf import text_format + + all_gpus = tf.config.experimental.list_physical_devices('GPU') + if all_gpus: + try: + for cur_gpu in all_gpus: + tf.config.experimental.set_memory_growth(cur_gpu, True) + except RuntimeError as e: + print(e) + + # from trafficgen_v2.datasets.waymo.mtr_metadrive_dataset import get_type_string + from waymo_open_dataset.metrics.ops import py_metrics_ops + from waymo_open_dataset.metrics.python import config_util_py as config_util + from waymo_open_dataset.protos import motion_metrics_pb2 + from scenestreamer.dataset.constants import object_int_to_type +except ModuleNotFoundError: + pass + +import logging + +logger = logging.getLogger(__file__) +object_type_to_id = {'TYPE_UNSET': 0, 'TYPE_VEHICLE': 1, 'TYPE_PEDESTRIAN': 2, 'TYPE_CYCLIST': 3, 'TYPE_OTHER': 4} + + +def _default_metrics_config(eval_second, num_modes_for_eval=6): + assert eval_second in [3, 5, 8] + config = motion_metrics_pb2.MotionMetricsConfig() + config_text = """ + track_steps_per_second: 10 + prediction_steps_per_second: 2 + track_history_samples: 10 + speed_lower_bound: 1.4 + speed_upper_bound: 11.0 + speed_scale_lower: 0.5 + speed_scale_upper: 1.0 + step_configurations { + measurement_step: 5 + lateral_miss_threshold: 1.0 + longitudinal_miss_threshold: 2.0 + } + """ + config_text += f""" + max_predictions: {num_modes_for_eval} + """ + if eval_second == 3: + config_text += """ + track_future_samples: 30 + """ + elif eval_second == 5: + config_text += """ + track_future_samples: 50 + step_configurations { + measurement_step: 9 + lateral_miss_threshold: 1.8 + longitudinal_miss_threshold: 3.6 + } + """ + else: + config_text += """ + track_future_samples: 80 + step_configurations { + measurement_step: 9 + lateral_miss_threshold: 1.8 + longitudinal_miss_threshold: 3.6 + } + step_configurations { + measurement_step: 15 + lateral_miss_threshold: 3.0 + longitudinal_miss_threshold: 6.0 + } + """ + + text_format.Parse(config_text, config) + return config + + +# def transform_preds_to_waymo_format(pred_dicts, top_k_for_eval=-1, eval_second=8): +# print(f'Total number for evaluation (intput): {len(pred_dicts)}') +# temp_pred_dicts = [] +# for k in range(len(pred_dicts)): +# if isinstance(pred_dicts[k], list): +# temp_pred_dicts.extend(pred_dicts[k]) +# else: +# temp_pred_dicts.append(pred_dicts[k]) +# pred_dicts = temp_pred_dicts +# print(f'Total number for evaluation (after processed): {len(pred_dicts)}') +# +# scene2preds = {} +# num_max_objs_per_scene = 0 +# for k in range(len(pred_dicts)): +# cur_scenario_id_list = pred_dicts[k]["scenario_id"] +# +# for batch_index, cur_scenario_id in enumerate(cur_scenario_id_list): +# +# if cur_scenario_id not in scene2preds: +# scene2preds[cur_scenario_id] = [] +# +# # PZH NOTE: A little workaround here to deal with the name mismatch +# # pred_dicts[k]["object_type"] = get_type_string(pred_dicts[k][batch_index]["object_type"]) +# +# per_scenario_pred_dicts = {k: v[batch_index] for k, v in pred_dicts[k].items()} +# +# scene2preds[cur_scenario_id].append(per_scenario_pred_dicts) +# # num_max_objs_per_scene = max(num_max_objs_per_scene, len(scene2preds[cur_scenario_id])) +# +# num_scenario = len(scene2preds) +# +# # try: +# topK, num_future_frames, _ = per_scenario_pred_dicts["pred_trajs"].shape +# # except ValueError as e: +# # print(pred_dicts[0]['pred_trajs'].shape) +# # raise e +# +# if top_k_for_eval != -1: +# topK = min(top_k_for_eval, topK) +# +# if num_future_frames in [30, 50, 80]: +# sampled_interval = 5 +# assert num_future_frames % sampled_interval == 0, f'num_future_frames={num_future_frames}' +# num_frame_to_eval = num_future_frames // sampled_interval +# +# if eval_second == 3: +# num_frames_in_total = 41 +# num_frame_to_eval = 6 +# elif eval_second == 5: +# num_frames_in_total = 61 +# num_frame_to_eval = 10 +# else: +# num_frames_in_total = 91 +# num_frame_to_eval = 16 +# +# batch_pred_trajs = np.zeros((num_scenario, num_max_objs_per_scene, topK, 1, num_frame_to_eval, 2)) +# batch_pred_scores = np.zeros((num_scenario, num_max_objs_per_scene, topK)) +# gt_trajs = np.zeros((num_scenario, num_max_objs_per_scene, num_frames_in_total, 7)) +# gt_is_valid = np.zeros((num_scenario, num_max_objs_per_scene, num_frames_in_total), dtype=int) +# pred_gt_idxs = np.zeros((num_scenario, num_max_objs_per_scene, 1)) +# pred_gt_idx_valid_mask = np.zeros((num_scenario, num_max_objs_per_scene, 1), dtype=int) +# object_type = np.zeros((num_scenario, num_max_objs_per_scene), dtype=object) +# object_id = np.zeros((num_scenario, num_max_objs_per_scene), dtype=int) +# scenario_id = np.zeros((num_scenario), dtype=object) +# +# object_type_cnt_dict = {} +# for key in object_type_to_id.keys(): +# object_type_cnt_dict[key] = 0 +# +# for scene_idx, val in enumerate(scene2preds.items()): +# cur_scenario_id, preds_per_scene = val +# scenario_id[scene_idx] = cur_scenario_id +# for obj_idx, cur_pred in enumerate(preds_per_scene): +# sort_idxs = cur_pred['pred_scores'].argsort()[::-1] +# cur_pred['pred_scores'] = cur_pred['pred_scores'][sort_idxs] +# cur_pred['pred_trajs'] = cur_pred['pred_trajs'][sort_idxs] +# +# cur_pred['pred_scores'] = cur_pred['pred_scores'] / cur_pred['pred_scores'].sum() +# +# batch_pred_trajs[scene_idx, +# obj_idx] = cur_pred['pred_trajs'][:topK, np.newaxis, +# 4::sampled_interval, :][:, :, :num_frame_to_eval, :] +# batch_pred_scores[scene_idx, obj_idx] = cur_pred['pred_scores'][:topK] +# gt_trajs[scene_idx, obj_idx] = cur_pred['gt_trajs'][:num_frames_in_total, [ +# 0, 1, 3, 4, 6, 7, 8 +# ]] # (num_timestamps_in_total, 10), [cx, cy, cz, dx, dy, dz, heading, vel_x, vel_y, valid] +# gt_is_valid[scene_idx, obj_idx] = cur_pred['gt_trajs'][:num_frames_in_total, -1] +# pred_gt_idxs[scene_idx, obj_idx, 0] = obj_idx +# pred_gt_idx_valid_mask[scene_idx, obj_idx, 0] = 1 +# object_type[scene_idx, obj_idx] = object_type_to_id[cur_pred['object_type']] +# object_id[scene_idx, obj_idx] = cur_pred['object_id'] +# +# object_type_cnt_dict[cur_pred['object_type']] += 1 +# +# gt_infos = { +# 'scenario_id': scenario_id.tolist(), +# 'object_id': object_id.tolist(), +# 'object_type': object_type.tolist(), +# 'gt_is_valid': gt_is_valid, +# 'gt_trajectory': gt_trajs, +# 'pred_gt_indices': pred_gt_idxs, +# 'pred_gt_indices_mask': pred_gt_idx_valid_mask +# } +# return batch_pred_scores, batch_pred_trajs, gt_infos, object_type_cnt_dict + + +def waymo_evaluation_optimized( + pred_dicts, eval_second=8, num_modes_for_eval=6, verbose=True, generate_submission=False, predict_all_agents=False +): + # Split all data based on scenario + split_data = defaultdict(list) + + scenario_id_list = [] + # Split the prediction for each scenario, also flatten the data + for d in pred_dicts: + for sid in np.unique(d["pred_to_scenario_id"]): + for k, v in d.items(): + if k in ["pred_trajs", "pred_scores"]: + assert len(d["pred_to_scenario_id"]) == len(v), (len(d["pred_to_scenario_id"]), len(v), k) + entry_in_same_scenario = [v[idx] for idx in range(len(v)) if d["pred_to_scenario_id"][idx] == sid] + assert entry_in_same_scenario + entry_in_same_scenario = np.stack(entry_in_same_scenario, axis=0) + split_data[k].append(entry_in_same_scenario) + elif k in ["eval/agent_position", "eval/agent_velocity", "eval/agent_heading", "eval/agent_valid_mask", + "eval/agent_shape", "eval/agent_type", "decoder/object_of_interest_id", + "encoder/object_of_interest_name", "decoder/object_of_interest_name", "decoder/track_name"]: + entry_in_same_scenario = [v[idx] for idx in range(len(v)) if d["scenario_id"][idx] == sid] + assert entry_in_same_scenario + assert len(entry_in_same_scenario) == 1 + entry_in_same_scenario = entry_in_same_scenario[0] + split_data[k].append(entry_in_same_scenario) + scenario_id_list.append(sid) + split_data = dict(split_data) + + num_scenario = len(split_data["pred_trajs"]) + + trajectory_in_single_scenario = split_data["pred_trajs"][0] + num_modes = len(trajectory_in_single_scenario) + num_future_frames = trajectory_in_single_scenario[0].shape[0] + num_max_objs_per_scene = max([v[0].shape[1] for v in split_data["pred_trajs"]]) + + if num_future_frames in [30, 50, 80]: + sampled_interval = 5 + else: + raise ValueError("Unknown prediction with future steps: {}".format(num_future_frames)) + + if eval_second == 3: + num_frames_in_total = 41 + num_frame_to_eval = 6 + elif eval_second == 5: + num_frames_in_total = 61 + num_frame_to_eval = 10 + else: + num_frames_in_total = 91 + num_frame_to_eval = 16 + + # ===== Process each scenario's prediction ===== + batch_pred_trajs = np.zeros((num_scenario, num_max_objs_per_scene, num_modes, 1, num_frame_to_eval, 2)) + batch_pred_scores = np.zeros((num_scenario, num_max_objs_per_scene, num_modes)) + pred_gt_indices = np.zeros((num_scenario, num_max_objs_per_scene, 1), dtype=int) + pred_gt_indices_mask = np.zeros((num_scenario, num_max_objs_per_scene, 1), dtype=int) + for scenario_count in range(num_scenario): + + if predict_all_agents: + num_objs = (split_data["eval/agent_type"][scenario_count] >= 0).sum() + else: + num_objs = (split_data["decoder/object_of_interest_id"][scenario_count] >= 0).sum() + + pred_trajs = split_data["pred_trajs"][scenario_count] + pred_trajs = pred_trajs[:, :, :num_objs] + assert pred_trajs.shape == (num_modes_for_eval, num_future_frames, num_objs, 2) + + # prev: scores in shape [num objects, num modes] + scores = split_data["pred_scores"][scenario_count] # (#modes, N) + # assert scores.ndim == 2, scores.shape + # scores = scores[:, :num_objs] + + # assert scores.min() >= 0.0 + + # top_k_index = np.argsort(scores, axis=0)[::-1][:num_modes_for_eval] # (#modes, N) + # top_k_index = top_k_index[:, :num_objs] + # assert top_k_index.shape == (num_modes_for_eval, num_objs) + # top_k_scores = np.take_along_axis(scores, top_k_index, axis=0) # (#modes, N) + + # pred_trajs = np.take_along_axis(pred_trajs, top_k_index.reshape(num_modes_for_eval, 1, num_objs, 1), axis=0) + pred_trajs_processed = pred_trajs[:, (sampled_interval - 1)::sampled_interval] + pred_trajs_processed = pred_trajs_processed[:, :num_frame_to_eval, :num_objs] + + # Till now, pred_trajs_processed.shape == (#modes, #steps, N, 2), need to change shape to + # (N, #modes, 1, #steps, 2) + assert pred_trajs_processed.ndim == 4 + pred_trajs_processed = pred_trajs_processed.swapaxes(0, 2) # (#modes, #steps, N, 2) -> (N, #steps, #modes, 2) + pred_trajs_processed = pred_trajs_processed.swapaxes(1, 2) # -> (N, #modes, #steps, 2) + pred_trajs_processed = pred_trajs_processed.reshape(num_objs, num_modes_for_eval, 1, num_frame_to_eval, 2) + + # batch_pred_trajs.shape == (#scenarios, #maxobjs, #modes, 1, #steps, 2) + batch_pred_trajs[scenario_count, :num_objs] = pred_trajs_processed + + # scores in shape [num objects, num modes] = [7, 6] + # normalize the scores for all modes of one object in this scene. + # assert top_k_scores.ndim == 2 + # top_k_scores = top_k_scores.swapaxes(0, 1) # (#modes, N) -> (N, #modes) + # top_k_scores = top_k_scores / (top_k_scores.sum(axis=-1, keepdims=True) + 1e-6) + + if scores.ndim == 1: + scores = scores.reshape(1, -1) + batch_pred_scores[scenario_count, :num_objs] = scores + else: + batch_pred_scores[scenario_count, :num_objs] = scores.swapaxes(0, 1) + + pred_gt_indices[scenario_count, :num_objs, 0] = np.arange(num_objs, dtype=int) + pred_gt_indices_mask[scenario_count, :num_objs, 0] = 1 + + # ===== Process GT data directly ===== + in_testing_set = split_data["eval/agent_position"][0].shape[0] == 11 + + if in_testing_set: + if generate_submission: + object_id = np.zeros((num_scenario, num_max_objs_per_scene), dtype=int) + object_id.fill(-1) + for scenario_count in range(num_scenario): + if predict_all_agents: + num_objs = (split_data["eval/agent_type"][scenario_count] >= 0).sum() + else: + num_objs = (split_data["decoder/object_of_interest_id"][scenario_count] >= 0).sum() + + assert (split_data["eval/agent_type"][scenario_count][num_objs:] == -1).all() + if not predict_all_agents: + object_id[ + scenario_count, :num_objs] = split_data["decoder/object_of_interest_name"][scenario_count][ + :num_objs] + else: + object_id[scenario_count, :num_objs] = split_data["decoder/track_name"][scenario_count][:num_objs] + batch_pred_trajs = tf.convert_to_tensor(batch_pred_trajs, tf.float32) + batch_pred_scores = tf.convert_to_tensor(batch_pred_scores, tf.float32) + submission_data = dict( + prediction_trajectory_list=batch_pred_trajs, + prediction_score_list=batch_pred_scores, + object_id_list=object_id, + scenario_id_list=scenario_id_list, + ) + else: + submission_data = dict() + return {}, "", submission_data + + + object_id = np.zeros((num_scenario, num_max_objs_per_scene), dtype=int) + object_id.fill(-1) + + object_type = np.zeros((num_scenario, num_max_objs_per_scene), dtype=int) + gt_trajs = np.zeros((num_scenario, num_max_objs_per_scene, num_frames_in_total, 7)) + gt_is_valid = np.zeros((num_scenario, num_max_objs_per_scene, num_frames_in_total), dtype=int) + for scenario_count in range(num_scenario): + + if predict_all_agents: + num_objs = (split_data["eval/agent_type"][scenario_count] >= 0).sum() + else: + num_objs = (split_data["decoder/object_of_interest_id"][scenario_count] >= 0).sum() + + assert (split_data["eval/agent_type"][scenario_count][num_objs:] == -1).all() + + object_type[scenario_count, :num_objs] = split_data["eval/agent_type"][scenario_count][:num_objs] + + if not predict_all_agents: + object_id[ + scenario_count, :num_objs] = split_data["decoder/object_of_interest_name"][scenario_count][:num_objs] + else: + object_id[scenario_count, :num_objs] = split_data["decoder/track_name"][scenario_count][:num_objs] + + heading_in_0_2pi = (split_data["eval/agent_heading"][scenario_count][..., None]) % (2 * np.pi) + + gt_trajs_per_scenario = np.concatenate( + [ + split_data["eval/agent_position"][scenario_count][..., :2], + split_data["eval/agent_shape"][scenario_count][..., :2], + heading_in_0_2pi, + split_data["eval/agent_velocity"][scenario_count], + ], + axis=-1 + ) + gt_trajs_per_scenario = gt_trajs_per_scenario.swapaxes(0, 1) + gt_trajs[scenario_count, :num_objs] = gt_trajs_per_scenario[:num_objs] + + gt_is_valid_per_scenario = split_data["eval/agent_valid_mask"][scenario_count] + gt_is_valid_per_scenario = gt_is_valid_per_scenario.swapaxes(0, 1) + gt_is_valid[scenario_count, :num_objs] = gt_is_valid_per_scenario[:num_objs] + + eval_config = _default_metrics_config(eval_second=eval_second, num_modes_for_eval=num_modes_for_eval) + + # DEBUG: + pred = batch_pred_trajs[:, :, :, 0] # (B, N, M, 16, 2) + gt = gt_trajs[:, :, :, :2] # (B, N, 91, 2) + vl = gt_is_valid[:, :, 15::5] + gt = gt[:, :, 15::5] + gt = gt.reshape(gt.shape[0], gt.shape[1], 1, -1, 2).repeat(num_modes_for_eval, axis=2) + vl = vl.reshape(vl.shape[0], vl.shape[1], 1, -1).repeat(num_modes_for_eval, axis=2) + # gt shape = (B, N, M, 16, 2) + diff = (np.linalg.norm(pred - gt, axis=-1)) * vl + + # Avg over time dim + valid = vl.sum(-1) + error = diff.sum(-1) / np.maximum(valid, 1) + + last_valid_ind = (vl != 0).cumsum(axis=-1).argmax(axis=-1) + fde = np.take_along_axis(diff, last_valid_ind[..., None], axis=-1).squeeze(-1) + + # Avg over agent dim + valid = (valid > 0).sum(1) # Num valid agents + error = error.sum(1) / np.maximum(valid, 1) + fde = fde.sum(1) / np.maximum(valid, 1) + # Now error.shape = (B, M=6) + avg_error = error.mean(-1).mean(0) + min_error = error.min(-1).mean(0) + + avg_fde = fde.mean(-1).mean(0) + min_fde = fde.min(-1).mean(0) + + batch_pred_scores = tf.convert_to_tensor(batch_pred_scores, tf.float32) + batch_pred_trajs = tf.convert_to_tensor(batch_pred_trajs, tf.float32) + gt_trajs = tf.convert_to_tensor(gt_trajs, tf.float32) + gt_is_valid = tf.convert_to_tensor(gt_is_valid, bool) + pred_gt_indices = tf.convert_to_tensor(pred_gt_indices, tf.int64) + pred_gt_indices_mask = tf.convert_to_tensor(pred_gt_indices_mask, bool) + object_type = tf.convert_to_tensor(object_type, tf.int64) + + input_dict = dict( + prediction_trajectory=batch_pred_trajs, + # (batch_size, num_pred_groups, top_k, num_agents_per_group, num_pred_steps, 2) + prediction_score=batch_pred_scores, # (batch_size, num_pred_groups, top_k) + ground_truth_trajectory=gt_trajs, # (batch_size, num_total_agents, num_gt_steps, 7) + ground_truth_is_valid=gt_is_valid, # (batch_size, num_total_agents, num_gt_steps) + prediction_ground_truth_indices=pred_gt_indices, # (batch_size, num_pred_groups, num_agents_per_group) + prediction_ground_truth_indices_mask=pred_gt_indices_mask, + # (batch_size, num_pred_groups, num_agents_per_group) + object_type=object_type # (batch_size, num_total_agents) + ) + + metric_results = py_metrics_ops.motion_metrics(config=eval_config.SerializeToString(), **input_dict) + + metric_names = config_util.get_breakdown_names_from_motion_config(eval_config) + + result_dict = {} + avg_results = {} + for i, m in enumerate(['minADE', 'minFDE', 'MissRate', 'mAP', 'OverlapRate']): + avg_results.update({f'{m}-VEHICLE': [0.0, 0], f'{m}-PEDESTRIAN': [0.0, 0], f'{m}-CYCLIST': [0.0, 0]}) + for j, n in enumerate(metric_names): + cur_name = n.split('_')[1] + avg_results[f'{m}-{cur_name}'][0] += float(metric_results[i][j]) + avg_results[f'{m}-{cur_name}'][1] += 1 + result_dict[f'{m}-{n}'] = float(metric_results[i][j]) + + for key in avg_results: + avg_results[key] = avg_results[key][0] / avg_results[key][1] + + if verbose: + result_dict['-------------------------------------------------------------'] = 0 + + result_dict.update(avg_results) + + object_type_cnt_dict = {k: 0 for k in object_int_to_type.values()} + for type_int_list in split_data["eval/agent_type"]: + for type_int in np.unique(type_int_list): + object_type_cnt_dict[object_int_to_type[type_int]] += (type_int_list == type_int).sum() + result_dict.update(object_type_cnt_dict) + + final_avg_results = {} + result_format_list = [ + [ + 'Waymo', 'Count', 'mAP', 'minADE', 'minFDE', 'MissRate', 'OverlapR', 'mJADE', 'avgJADE', 'mJFDE', 'avgJFDE', + '\n' + ], + ['VEH', None, None, None, None, None, None, None, None, None, None, '\n'], + ['PED', None, None, None, None, None, None, None, None, None, None, '\n'], + ['CYC', None, None, None, None, None, None, None, None, None, None, '\n'], + ['Avg', None, None, None, None, None, None, None, None, None, None, '\n'], + ] + name_to_row = {'VEHICLE': 1, 'PEDESTRIAN': 2, 'CYCLIST': 3, 'Avg': 4} + name_to_col = { + 'Count': 1, + 'mAP': 2, + 'minADE': 3, + 'minFDE': 4, + 'MissRate': 5, + 'OverlapRate': 6, + 'mJADE': 7, + 'avgJADE': 8, + 'mJFDE': 9, + 'avgJFDE': 10, + } + + for cur_metric_name in ['minADE', 'minFDE', 'MissRate', 'mAP', 'OverlapRate']: + final_avg_results[cur_metric_name] = 0 + for cur_name in ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']: + final_avg_results[cur_metric_name] += avg_results[f'{cur_metric_name}-{cur_name}'] + + result_format_list[name_to_row[cur_name]][name_to_col[cur_metric_name]] = \ + '%.4f,' % avg_results[f'{cur_metric_name}-{cur_name}'] + + final_avg_results[cur_metric_name] /= 3 + result_format_list[4][name_to_col[cur_metric_name]] = '%.4f,' % final_avg_results[cur_metric_name] + + for object_type in ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']: + result_format_list[name_to_row[object_type]][name_to_col["Count"]] = str(object_type_cnt_dict[object_type]) + object_count_sum = ( + object_type_cnt_dict["VEHICLE"] + object_type_cnt_dict["PEDESTRIAN"] + object_type_cnt_dict["CYCLIST"] + ) + result_format_list[name_to_row['Avg']][name_to_col["Count"]] = "{}".format(object_count_sum) + final_avg_results["Count"] = object_count_sum + + result_format_list[name_to_row['Avg']][name_to_col["mJADE"]] = '%.4f,' % min_error.mean() + result_format_list[name_to_row['Avg']][name_to_col["avgJADE"]] = '%.4f,' % avg_error.mean() + result_format_list[name_to_row['Avg']][name_to_col["mJFDE"]] = '%.4f,' % min_fde.mean() + result_format_list[name_to_row['Avg']][name_to_col["avgJFDE"]] = '%.4f,' % avg_fde.mean() + final_avg_results["mJADE"] = min_error.mean() + final_avg_results["avgJADE"] = avg_error.mean() + final_avg_results["mJFDE"] = min_fde.mean() + final_avg_results["avgJFDE"] = avg_fde.mean() + + result_format_str = ' '.join( + [x.rjust(9) if x is not None else " N/A" for items in result_format_list for x in items] + ) + + if verbose: + result_dict['--------------------------------------------------------------'] = 0 + + result_dict.update(final_avg_results) + + if verbose: + result_dict['---------------------------------------------------------------'] = 0 + + if verbose: + result_dict[ + '-----Note that this evaluation may have marginal differences with the official Waymo evaluation server-----' + ] = 0 + + if generate_submission: + submission_data = dict( + prediction_trajectory_list=batch_pred_trajs, + prediction_score_list=batch_pred_scores, + object_id_list=object_id, + scenario_id_list=scenario_id_list, + ) + else: + submission_data = dict() + + return result_dict, result_format_str, submission_data + + +# def waymo_evaluation(pred_dicts, top_k=-1, eval_second=8, num_modes_for_eval=6, verbose=True, generate_proto=False): +# # TODO FIXME: This part can be optimized???? +# # Our output things is tensor. Why bother moving things to numpy??? +# +# pred_score, pred_trajectory, gt_infos, object_type_cnt_dict = transform_preds_to_waymo_format( +# pred_dicts, +# top_k_for_eval=top_k, +# eval_second=eval_second, +# ) +# eval_config = _default_metrics_config(eval_second=eval_second, num_modes_for_eval=num_modes_for_eval) +# +# pred_score = tf.convert_to_tensor(pred_score, np.float32) +# pred_trajs = tf.convert_to_tensor(pred_trajectory, np.float32) +# gt_trajs = tf.convert_to_tensor(gt_infos['gt_trajectory'], np.float32) +# gt_is_valid = tf.convert_to_tensor(gt_infos['gt_is_valid'], bool) +# pred_gt_indices = tf.convert_to_tensor(gt_infos['pred_gt_indices'], tf.int64) +# pred_gt_indices_mask = tf.convert_to_tensor(gt_infos['pred_gt_indices_mask'], bool) +# object_type = tf.convert_to_tensor(gt_infos['object_type'], tf.int64) +# +# metric_results = py_metrics_ops.motion_metrics( +# config=eval_config.SerializeToString(), +# prediction_trajectory=pred_trajs, +# # (batch_size, num_pred_groups, top_k, num_agents_per_group, num_pred_steps, ) +# prediction_score=pred_score, # (batch_size, num_pred_groups, top_k) +# ground_truth_trajectory=gt_trajs, # (batch_size, num_total_agents, num_gt_steps, 7) +# ground_truth_is_valid=gt_is_valid, # (batch_size, num_total_agents, num_gt_steps) +# prediction_ground_truth_indices=pred_gt_indices, # (batch_size, num_pred_groups, num_agents_per_group) +# prediction_ground_truth_indices_mask=pred_gt_indices_mask, +# # (batch_size, num_pred_groups, num_agents_per_group) +# object_type=object_type # (batch_size, num_total_agents) +# ) +# +# # Generate Proto for Waymo Motion submission +# if generate_proto: +# generate_submissition( +# pred_trajs, pred_score, gt_trajs, gt_is_valid, object_type, gt_infos['scenario_id'], gt_infos['object_id'] +# ) +# metric_names = config_util.get_breakdown_names_from_motion_config(eval_config) +# +# result_dict = {} +# avg_results = {} +# for i, m in enumerate(['minADE', 'minFDE', 'MissRate', 'mAP', 'OverlapRate']): +# avg_results.update({f'{m}-VEHICLE': [0.0, 0], f'{m}-PEDESTRIAN': [0.0, 0], f'{m}-CYCLIST': [0.0, 0]}) +# for j, n in enumerate(metric_names): +# cur_name = n.split('_')[1] +# avg_results[f'{m}-{cur_name}'][0] += float(metric_results[i][j]) +# avg_results[f'{m}-{cur_name}'][1] += 1 +# result_dict[f'{m}-{n}'] = float(metric_results[i][j]) +# +# for key in avg_results: +# avg_results[key] = avg_results[key][0] / avg_results[key][1] +# +# if verbose: +# result_dict['-------------------------------------------------------------'] = 0 +# +# result_dict.update(avg_results) +# +# final_avg_results = {} +# result_format_list = [ +# ['Waymo', 'mAP', 'minADE', 'minFDE', 'MissRate', 'OverlapRate', '\n'], +# ['VEHICLE', None, None, None, None, None, '\n'], +# ['PEDESTRIAN', None, None, None, None, None, '\n'], +# ['CYCLIST', None, None, None, None, None, '\n'], +# ['Avg', None, None, None, None, None, '\n'], +# ] +# name_to_row = {'VEHICLE': 1, 'PEDESTRIAN': 2, 'CYCLIST': 3, 'Avg': 4} +# name_to_col = {'mAP': 1, 'minADE': 2, 'minFDE': 3, 'MissRate': 4, 'OverlapRate': 5} +# +# for cur_metric_name in ['minADE', 'minFDE', 'MissRate', 'mAP', 'OverlapRate']: +# final_avg_results[cur_metric_name] = 0 +# for cur_name in ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']: +# final_avg_results[cur_metric_name] += avg_results[f'{cur_metric_name}-{cur_name}'] +# +# result_format_list[name_to_row[cur_name]][name_to_col[cur_metric_name] +# ] = '%.4f,' % avg_results[f'{cur_metric_name}-{cur_name}'] +# +# final_avg_results[cur_metric_name] /= 3 +# result_format_list[4][name_to_col[cur_metric_name]] = '%.4f,' % final_avg_results[cur_metric_name] +# +# result_format_str = ' '.join([x.rjust(12) for items in result_format_list for x in items]) +# +# if verbose: +# result_dict['--------------------------------------------------------------'] = 0 +# +# result_dict.update(final_avg_results) +# +# if verbose: +# result_dict['---------------------------------------------------------------'] = 0 +# +# result_dict.update(object_type_cnt_dict) +# +# if verbose: +# result_dict[ +# '-----Note that this evaluation may have marginal differences with the official Waymo evaluation server-----' +# ] = 0 +# +# return result_dict, result_format_str + +# def main(): +# parser = argparse.ArgumentParser(description='arg parser') +# parser.add_argument('--pred_infos', type=str, default=None, help='pickle file') +# parser.add_argument('--top_k', type=int, default=-1, help='') +# parser.add_argument('--eval_second', type=int, default=8, help='') +# parser.add_argument('--num_modes_for_eval', type=int, default=6, help='') +# parser.add_argument('--generate_proto', type=bool, default=False, help='Generate proto file for Waymo challenge') +# +# args = parser.parse_args() +# print(args) +# +# assert args.eval_second in [3, 5, 8] +# with open(args.pred_infos, 'rb') as f: +# pred_infos = pickle.load(f) +# +# print('Start to evaluate the waymo format results...') +# +# metric_results, result_format_str = waymo_evaluation( +# pred_dicts=pred_infos, +# top_k=args.top_k, +# eval_second=args.eval_second, +# num_modes_for_eval=args.num_modes_for_eval, +# generate_proto=args.generate_proto, +# ) +# +# print(metric_results) +# metric_result_str = '\n' +# for key in metric_results: +# metric_results[key] = metric_results[key] +# metric_result_str += '%s: %.4f \n' % (key, metric_results[key]) +# print(metric_result_str) +# print(result_format_str) +# +# +# if __name__ == '__main__': +# main() diff --git a/scenestreamer/eval/waymo_motion_prediction_evaluator.py b/scenestreamer/eval/waymo_motion_prediction_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..5820085308f9acf81a5c7d7eaa01c21256f36a9b --- /dev/null +++ b/scenestreamer/eval/waymo_motion_prediction_evaluator.py @@ -0,0 +1,603 @@ +# Referenced from https://github.com/Tsinghua-MARS-Lab/InterSim/blob/main/simulator/proto.py +import copy +import os +import pathlib +import pickle +import shutil +import time +from datetime import datetime + +import numpy as np +import torch +from tqdm import tqdm + +try: + from waymo_open_dataset.protos import motion_submission_pb2 +except ModuleNotFoundError: + motion_submission_pb2 = None +import uuid +from scenestreamer.dataset.preprocessor import centralize_to_map_center +from scenestreamer.eval.waymo_eval import waymo_evaluation_optimized +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import wrap_to_pi, rotate + + +def joint_trajectory_nms( + pred_trajs, # [NUM_MODES, 80, N, 2] + mode_scores, # [NUM_MODES, 80, N] + threshold=2.5, + num_ret_modes=6, + global_rank=None, +): + sorted_scores, sorted_indices = torch.sort(mode_scores, descending=True) + sorted_trajs = pred_trajs[sorted_indices] + + num_modes = sorted_trajs.shape[0] + suppressed = torch.zeros(num_modes, dtype=torch.bool).to(sorted_trajs.device) + keep = [] + + # Precompute pairwise similarities (e.g., goal distance or ADE) + goal_points = sorted_trajs[:, -1, :, :2] # [NUM_MODES, N, 2] + assert goal_points.shape[1] == 2 + + goal_distances = [] + for i in range(goal_points.shape[1]): + goal_distances.append((goal_points[:, i][:, None] - goal_points[:, i][None, :]).norm(dim=-1)) + goal_distances = torch.stack(goal_distances, dim=0).mean(dim=0) # [NUM_MODES, NUM_MODES] + + while (not suppressed.all()): + # Find the next highest-score unsuppressed mode + active_modes = (~suppressed).nonzero(as_tuple=True)[0] + if len(active_modes) == 0: + break # Edge case: no modes left (unlikely with num_ret_modes=6) + best_idx = active_modes[0] + keep.append(sorted_indices[best_idx]) + suppressed[best_idx] = True + + # Suppress overlapping modes + overlapping = (goal_distances[best_idx] < threshold) & (~suppressed) + suppressed[overlapping] = True + + print("RANK {}, Keep modes: {} out of {}.".format(global_rank, len(keep), num_modes)) + if len(keep) > num_ret_modes: + keep = torch.stack(keep) + # Just randomized keep + random_indices = torch.randperm(len(keep)) + keep = keep[random_indices[:num_ret_modes]] + else: + for v in sorted_indices: + if v not in keep: + keep.append(v) + if len(keep) == num_ret_modes: + break + keep = torch.stack(keep) + assert len(keep) == num_ret_modes + kept_preds = pred_trajs[keep[:num_ret_modes]] + kept_scores = mode_scores[keep[:num_ret_modes]] + return kept_preds, kept_scores, None + + +def transform_to_global_coordinate(data_dict): + map_center = data_dict["metadata/map_center"].reshape(-1, 1, 1, 3) + map_heading = data_dict["metadata/map_heading"].reshape(-1, 1, 1) + if "eval/agent_position" not in data_dict: + print("Have you set EVALUATION.PREDICT_ALL_AGENTS to False?") + data_dict["eval/agent_position"] = data_dict["decoder/agent_position"] + data_dict["eval/agent_heading"] = data_dict["decoder/agent_heading"] + data_dict["eval/agent_velocity"] = data_dict["decoder/agent_velocity"] + data_dict["eval/agent_type"] = data_dict["decoder/agent_type"] + data_dict["eval/agent_shape"] = data_dict["decoder/agent_shape"] + data_dict["eval/agent_valid_mask"] = data_dict["decoder/agent_valid_mask"] + + B, T, N, _ = data_dict["eval/agent_position"].shape + map_heading = map_heading.repeat(T, axis=1).repeat(N, axis=2) + assert map_heading.shape == (B, T, N) + data_dict["eval/agent_position"] = rotate( + x=data_dict["eval/agent_position"][..., 0], + y=data_dict["eval/agent_position"][..., 1], + angle=map_heading, + z=data_dict["eval/agent_position"][..., 2] + ) + assert data_dict["eval/agent_position"].ndim == 4 + data_dict["eval/agent_position"] += map_center + + data_dict["eval/agent_heading"] = wrap_to_pi(data_dict["eval/agent_heading"] + map_heading) + + data_dict["eval/agent_velocity"] = rotate( + x=data_dict["eval/agent_velocity"][..., 0], + y=data_dict["eval/agent_velocity"][..., 1], + angle=map_heading, + ) + + data_dict["pred_trajs"] = [ + centralize_to_map_center( + traj, map_center=-data_dict["expanded_map_center"][b], map_heading=-data_dict["expanded_map_heading"][b] + ) for b, traj in enumerate(data_dict["pred_trajs"]) + ] + + return data_dict + + +# ===== Preprocessing to expand the all data from bs=B to bs=B*num_modes ===== +def _repeat_for_modes(v, num_modes): + if isinstance(v, list): + return v + d = v.ndim + if d > 1: + v = v[:, None] + if isinstance(v, np.ndarray): + shape = v.shape + v = v.repeat(num_modes, axis=1) + v = v.reshape(-1, *(shape[2:])) + else: + v = v.repeat(1, num_modes, *((1, ) * (d - 1))) + v = v.flatten(0, 1) + else: + v = v.reshape(-1, 1) + if isinstance(v, np.ndarray): + v = v.repeat(num_modes, axis=1) + else: + v = v.repeat(1, num_modes) + v = v.reshape(-1) + return v + + +def generate_submission( + prediction_trajectory_list, + prediction_score_list, + scenario_id_list, + object_id_list, + num_model_parameters, + prefix="submission", + account_name="peng", + unique_method_name="peng", + output_dir=".", +): + submission = motion_submission_pb2.MotionChallengeSubmission( + submission_type=motion_submission_pb2.MotionChallengeSubmission.SubmissionType.INTERACTION_PREDICTION, + account_name=account_name, + unique_method_name=unique_method_name, + uses_lidar_data=False, + uses_camera_data=False, + uses_public_model_pretraining=False, + num_model_parameters=num_model_parameters, + ) + MODE_NUM = 6 + + + done_scenarios = set() + duplicated_scenarios = set() + + # for prediction_trajectory, prediction_score, \ + # ground_truth_trajectory, ground_truth_is_valid, object_type, scenario_id, object_id in \ + # tqdm(zip(prediction_trajectory_list, prediction_score_list, ground_truth_trajectory_list, + # ground_truth_is_valid_list, object_type_list, scenario_id_list, object_id_list), + # total=len(prediction_trajectory_list)): + for prediction_trajectory, prediction_score, scenario_id, object_id in \ + tqdm( + zip(prediction_trajectory_list, prediction_score_list, scenario_id_list, object_id_list), + total=len(prediction_trajectory_list), + desc="Generating submission" + ): + scenario_id = str(scenario_id) + + if scenario_id in done_scenarios: + duplicated_scenarios.add(scenario_id) + continue + done_scenarios.add(scenario_id) + + # predict_num = len(prediction_trajectory) + predict_num = (object_id != -1).sum() + assert (object_id[:predict_num] != -1).all() + assert (object_id[predict_num:] == -1).all() + + scenario_prediction = submission.scenario_predictions.add() + + # NOTE: This is for single_predictions + # prediction_set = scenario_prediction.single_predictions + # scenario_prediction.scenario_id = str(scenario_id) + # for i in range(predict_num): + # # SingleObjectPrediction + # prediction = prediction_set.predictions.add() + # prediction.object_id = object_id[i] + # for k in range(MODE_NUM): + # # ScoredTrajectory + # scored_trajectory = prediction.trajectories.add() + # scored_trajectory.confidence = float(prediction_score[i, k]) + # trajectory = scored_trajectory.trajectory + # traj = prediction_trajectory[i, k, :, :] + # assert traj.shape[0] == 1 + # trajectory.center_x[:] = traj[0, :, 0].numpy().tolist() + # trajectory.center_y[:] = traj[0, :, 1].numpy().tolist() + + # NOTE: This is for joint_predictions + joint_pred = scenario_prediction.joint_prediction + scenario_prediction.scenario_id = str(scenario_id) + for k in range(MODE_NUM): + joint_traj = joint_pred.joint_trajectories.add() + + if prediction_score.ndim == 2: + joint_conf = sum(prediction_score[:, k]) # (2, 6) -> (6,) + else: + joint_conf = prediction_score[k] + joint_traj.confidence = joint_conf + assert predict_num == 2 + for i in range(predict_num): + obj_traj = joint_traj.trajectories.add() + obj_traj.object_id = object_id[i] + trajectory = obj_traj.trajectory + traj = prediction_trajectory[i, k, :, :] + assert traj.shape[0] == 1 + trajectory.center_x[:] = traj[0, :, 0].numpy().tolist() + trajectory.center_y[:] = traj[0, :, 1].numpy().tolist() + + file_name = '{}_motion_val_submission_{:%Y_%m_%d_%H_%M_%S}'.format(prefix, datetime.now()) + path = pathlib.Path(output_dir) / file_name + path = path.resolve() + with open(path, "wb") as f: + f.write(submission.SerializeToString()) + + os.system(f'tar -zcvf {path}.tar.gz {path}') + print("Submission is saved at: {}.tar.gz".format(path)) + return f"{path}.tar.gz", duplicated_scenarios, done_scenarios + + +class WaymoMotionPredictionEvaluator: + def __init__(self, config): + self.config = config + self.validation_outputs = [] + + print("[Prediction] SAMPLING CONFIG IS: ", self.config.SAMPLING) + + def _call_model(self, data_dict, model): + """We might want to create mini batches to call model in case the of OOM...""" + + use_cache = self.config.EVALUATION.USE_CACHE + + temperature = self.config.SAMPLING.TEMPERATURE + sampling_method = self.config.SAMPLING.SAMPLING_METHOD + topp = self.config.SAMPLING.TOPP + + # ===== Autoregressive Decoding ===== + if self.config.MODEL.NAME == "scenestreamer": + if not hasattr(self, "scenestreamer_generator"): + from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator + self.scenestreamer_generator = SceneStreamerGenerator( + model=model, + device=data_dict["encoder/agent_feature"].device, + ) + with torch.no_grad(): + self.scenestreamer_generator.reset(new_data_dict=data_dict) + expanded_data_dict = self.scenestreamer_generator.generate_scenestreamer_motion() + + elif self.config.MODEL.NAME == "gpt": + from scenestreamer.infer.motion import generate_motion + with torch.no_grad(): + expanded_data_dict = generate_motion( + model=model, + data_dict=data_dict, + autoregressive_start_step=2, + # num_decode_steps=num_decode_steps, + ) + + else: + raise NotImplementedError(f"Model {self.config.MODEL.NAME} not implemented.") + + # ===== Postprocessing to extract predictions for the modeled agents ===== + scores = expanded_data_dict["decoder/output_score"] + pred_trajs = expanded_data_dict["decoder/reconstructed_position"] + + if self.config.eval_backward_model: + pred_trajs = pred_trajs[:, 16:] # For Backward evaluation + assert pred_trajs.shape[1] == 80 + else: + if pred_trajs.shape[1] == 96: + pred_trajs = pred_trajs[:, :-5] + assert pred_trajs.shape[1] == 91, pred_trajs.shape + pred_trajs = pred_trajs[:, 11:] + + # If training to predict all agents, but asking for eval on modeled agents, + # need to pick the prediction for the modeled agents only. + if self.config.TRAINING.PREDICT_ALL_AGENTS: + scores_of_interested_agents = [] + pred_trajs_of_interested_agents = [] + + if self.config.EVALUATION.PREDICT_ALL_AGENTS: + pred_ids = expanded_data_dict["decoder/agent_id"] + else: + pred_ids = expanded_data_dict["decoder/object_of_interest_id"] + for batch_index, track_indices in enumerate(pred_ids): + scores_of_interested_agents.append( + torch.stack( + [ + scores[batch_index][agent_index] # .detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + dim=0 + ) + ) + pred_trajs_of_interested_agents.append( + torch.stack( + [ + pred_trajs[batch_index][:, agent_index] # .detach().cpu().numpy() + for agent_index in track_indices if agent_index != -1 + ], + dim=1 + ) + ) + else: + assert self.config.EVALUATION.PREDICT_ALL_AGENTS is False + scores_of_interested_agents = [] + pred_trajs_of_interested_agents = [] + for batch_index in range(expanded_data_dict["decoder/object_of_interest_id"].shape[0]): + num_eval_objs = (expanded_data_dict["decoder/object_of_interest_id"][batch_index] != -1).sum() + scores_of_interested_agents.append(scores[batch_index][:num_eval_objs].detach().cpu().numpy()) + pred_trajs_of_interested_agents.append( + pred_trajs[batch_index][:, :num_eval_objs].detach().cpu().numpy() + ) + + return pred_trajs_of_interested_agents, scores_of_interested_agents, expanded_data_dict + + def validation_step(self, data_dict, batch_idx, model, global_rank, **kwargs): + # TODO: Pass this from config. + num_decode_steps = 16 + + num_modes_for_eval = self.config.EVALUATION.NUM_MODES + maximum_batch_size = self.config.EVALUATION.MAXIMUM_BATCH_SIZE + + if num_modes_for_eval <= maximum_batch_size: + num_repeat_calls = 1 + else: + assert num_modes_for_eval % maximum_batch_size == 0 + num_repeat_calls = num_modes_for_eval // maximum_batch_size + + NUM_MODES_WAYMO_MOTION_PREDICTION = 6 + + B = data_dict["encoder/agent_feature"].shape[0] + data_dict["batch_idx"] = torch.arange(B) + + if num_repeat_calls == 1: + expanded_data_dict = { + k: _repeat_for_modes(data_dict[k], num_modes=num_modes_for_eval) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("eval/") or k.startswith("decoder/") or k == "batch_idx" or k == "in_evaluation" + or k == "in_backward_prediction" + + # DEBUG: + # or k.startswith("decoder/") + or k.startswith("raw") + ) + } + pred_trajs_of_interested_agents, scores_of_interested_agents, output_data_dict = self._call_model( + expanded_data_dict, model + ) + + else: + assert B == 1, B + num_modes_per_call = num_modes_for_eval // num_repeat_calls + assert num_modes_per_call * num_repeat_calls == num_modes_for_eval + expanded_data_dict = { + k: _repeat_for_modes(data_dict[k], num_modes=num_modes_per_call) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("eval/") or k.startswith("decoder/") or k == "batch_idx" or k == "in_evaluation" + or k == "in_backward_prediction" + ) + } + + pred_trajs_of_interested_agents = [] + scores_of_interested_agents = [] + for call in range(num_repeat_calls): + traj, score, output_data_dict = self._call_model(copy.deepcopy(expanded_data_dict), model) + pred_trajs_of_interested_agents.append(traj) + scores_of_interested_agents.append(score) + pred_trajs_of_interested_agents = [vv for v in pred_trajs_of_interested_agents for vv in v] + scores_of_interested_agents = [vv for v in scores_of_interested_agents for vv in v] + + pred_to_scenario_id = _repeat_for_modes( + np.asarray(data_dict["scenario_id"]), num_modes=NUM_MODES_WAYMO_MOTION_PREDICTION + ) + expanded_map_center = _repeat_for_modes( + data_dict["metadata/map_center"], num_modes=NUM_MODES_WAYMO_MOTION_PREDICTION + ) + expanded_map_heading = _repeat_for_modes( + data_dict["metadata/map_heading"], num_modes=NUM_MODES_WAYMO_MOTION_PREDICTION + ) + + # Conduct non-maximum suppression (NMS) to reduce the number of modes + if num_modes_for_eval > NUM_MODES_WAYMO_MOTION_PREDICTION: + # from scenestreamer.eval.nms import batch_nms + # pred_trajs_of_interested_agents, scores_of_interested_agents = batch_nms( + # pred_trajs_of_interested_agents, + # scores_of_interested_agents, + # pred_to_scenario_id=np.repeat(data_dict["scenario_id"], num_modes_for_eval, axis=0), + # dist_thresh=2.5, # Follow MTR + # num_ret_modes=NUM_MODES_WAYMO_MOTION_PREDICTION, + # num_original_modes=num_modes_for_eval, + # ) + + + # Assume: + # pred_trajs = [100 modes, 80 steps, 2 agents, 2D] + # scores = [100, 80, 2] + pred_trajs_of_interested_agents = torch.stack(pred_trajs_of_interested_agents, 0) + mode_scores = torch.stack(scores_of_interested_agents, 0).sum(-1) # [NUM_MODES, 80] + + pred_trajs_of_interested_agents, scores_of_interested_agents, _ = joint_trajectory_nms( + pred_trajs_of_interested_agents, + mode_scores, + global_rank=global_rank + ) + pred_trajs_of_interested_agents = list(pred_trajs_of_interested_agents) + scores_of_interested_agents = list(scores_of_interested_agents) + + # print(f"Kept {len(kept_preds)} predictions after NMS.") + + # sort_scores = torch.stack(scores_of_interested_agents, 0).sum(-1).sort(descending=True) + # selected_indices = sort_scores.indices[:NUM_MODES_WAYMO_MOTION_PREDICTION] + # selected_scores = sort_scores.values[:NUM_MODES_WAYMO_MOTION_PREDICTION] + # + # pred_trajs_of_interested_agents = [ + # pred_trajs_of_interested_agents[i] for i in selected_indices + # ] + # scores_of_interested_agents = [ + # scores_of_interested_agents[i] for i in selected_indices + # ] + # print(f"Selected scores: {selected_scores}") + + # ===== Cache the prediction results ===== + prediction_dict = { + "pred_trajs": pred_trajs_of_interested_agents, + "pred_scores": scores_of_interested_agents, + "pred_to_scenario_id": pred_to_scenario_id, + "expanded_map_center": expanded_map_center, + "expanded_map_heading": expanded_map_heading, + } + for k, v in data_dict.items(): + if k.startswith("decoder/") or k.startswith("eval/") or k.startswith("metadata/") or k in ["scenario_id"]: + prediction_dict[k] = v + + new_prediction_dict = {} + for k, v in prediction_dict.items(): + if isinstance(v, torch.Tensor): + new_prediction_dict[k] = v.detach().cpu().numpy() + elif isinstance(v, list): + new_list = [] + for vv in v: + if isinstance(vv, torch.Tensor): + new_list.append(vv.detach().cpu().numpy()) + else: + new_list.append(vv) + new_prediction_dict[k] = new_list + else: + new_prediction_dict[k] = v + # prediction_dict = copy.deepcopy(new_prediction_dict) # Avoid memory issue + + # Transform back to global coordinate + new_prediction_dict = transform_to_global_coordinate(new_prediction_dict) + self.validation_outputs.append(new_prediction_dict) + + # print(debug_tools.using(f"val step start {batch_idx} DONE")) + + # DEBUG: + # waymo_evaluation_optimized( + # [new_prediction_dict], + # generate_submission=False, + # ) + + def on_validation_epoch_end( + self, trainer, logger, global_rank, log_dict_func, log_func, print_func, exp_name, **kwargs + ): + """ + This function gathers intermediate evaluation result and pass them to the Waymo + evaluation pipeline together and log the final results. + """ + st = time.time() + + # print(debug_tools.using(f"val epoch end start")) + + # https://lightning.ai/docs/pytorch/latest/accelerators/accelerator_prepare.html?highlight=hardware + # torch.cuda.empty_cache() + # PZH NOTE: Hack to implement our own all_gather across ranks. + trainer.strategy.barrier() + + # Collect the intermediate evaluation results from each call to on_validation_step in this particular rank. + self.validation_outputs = [ + {k: (v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in final_pred_dicts.items()} for final_pred_dicts in self.validation_outputs + ] + + # Dump all results in this rank to a local file so that later the rank0 process can read them. + tmpdir = self.config.ROOT_DIR / self.config.TMP_DIR / "validation_tmpdir_{}".format(exp_name) + print(f"Rank {global_rank} saving validation results to {tmpdir}.") + + os.makedirs(tmpdir, exist_ok=True) + with open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(global_rank)), 'wb') as f: + pickle.dump(self.validation_outputs, f) + self.validation_outputs.clear() + + # print(debug_tools.using(f"val epoch saved file.")) + + # If this is the main process (rank0), read all results in local filesystem and call evaluation pipeline. + torch.cuda.empty_cache() + trainer.strategy.barrier() + if trainer.is_global_zero: + print_func(f"===== Start evaluation: {time.time() - st:.3f} =====") + + # Gather results from different ranks + validation_list = [] + for i in range(trainer.world_size): + file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i)) + success = False + for sleep in range(10): + if not os.path.isfile(file): + time.sleep(1) + print(f"Can't find file: {file}. Sleep {sleep}/{10} seconds.") + else: + success = True + break + if not success: + print(f"[WARNING] Can't find file: {file}. Skip this rank.") + continue + with open(file, "rb") as f: + val_outputs = pickle.load(f) + validation_list.extend(val_outputs) + if self.config.EVALUATION.DELETE_EVAL_RESULT: + shutil.rmtree(tmpdir) + + if not validation_list: + print_func("No evaluation results found. Skip evaluation.") + return + + # print(debug_tools.using(f"going to eval")) + + # Call evaluation pipeline + torch.cuda.empty_cache() + result_dict, result_str, submission_dict = waymo_evaluation_optimized( + validation_list, + + # TODO: This flag + generate_submission=self.config.SUBMISSION.GENERATE_SUBMISSION, + predict_all_agents=self.config.EVALUATION.PREDICT_ALL_AGENTS, + ) + torch.cuda.empty_cache() + validation_list.clear() + + # Log result + result_dict = {f"eval/{k}": float(v) for k, v in result_dict.items()} + log_dict_func(result_dict, rank_zero_only=True) + for k in ['eval/minADE', 'eval/minFDE', 'eval/MissRate', 'eval/mAP', "eval/mJADE", "eval/avgJADE", + "eval/mJFDE", "eval/avgJFDE"]: + if k not in result_dict: + continue + log_func(name=k.split("/")[1], value=result_dict[k], rank_zero_only=True) + print_func(result_str) + print_func(f"===== Finish evaluation: {time.time() - st:.3f} =====") + + print_func(f"Rank {global_rank} finished evaluation!") + torch.cuda.empty_cache() + trainer.strategy.barrier() + + # TODO This flag + if trainer.is_global_zero and self.config.SUBMISSION.GENERATE_SUBMISSION: + account_name = self.config.SUBMISSION.ACCOUNT + unique_method_name = self.config.SUBMISSION.METHOD_NAME + num_model_parameters = self.config.SUBMISSION.num_model_parameters + + output_dir = logger.log_dir + submission_prefix = logger.name + + path, duplicated_scenarios, done_scenarios = generate_submission( + prefix=submission_prefix, + account_name=account_name, + unique_method_name=unique_method_name, + output_dir=output_dir, + num_model_parameters=num_model_parameters, + **submission_dict + ) + print_func( + "Submission created at: {}. Finished {} scenarios. Duplicated scenarios: {}.".format( + path, len(done_scenarios), duplicated_scenarios + ) + ) diff --git a/scenestreamer/eval/waymo_sim_agent_evaluator.py b/scenestreamer/eval/waymo_sim_agent_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..450c381c9de95946a09c26522c087797075667af --- /dev/null +++ b/scenestreamer/eval/waymo_sim_agent_evaluator.py @@ -0,0 +1,590 @@ +""" +Script to generate submission files for Waymo SimAgent Challenge. +Please check out the end of this file where we provide a script to merge submission files. +""" +import copy +import os +import pathlib +import uuid + +import numpy as np +import torch + +from scenestreamer.dataset.preprocessor import centralize_to_map_center +from scenestreamer.eval.waymo_motion_prediction_evaluator import _repeat_for_modes +from scenestreamer.eval.wosac_eval import wosac_evaluation, load_metrics_config_from_file_name +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import wrap_to_pi, rotate +from scenestreamer.infer.motion import generate_motion + + +def transform_to_global_coordinate(data_dict): + map_center = data_dict["metadata/map_center"].reshape(-1, 1, 1, 3) + map_heading = data_dict["metadata/map_heading"].reshape(-1, 1, 1) + B, T, N, _ = data_dict["decoder/agent_position"].shape + map_heading = map_heading.repeat(T, axis=1).repeat(N, axis=2) + assert map_heading.shape == (B, T, N) + data_dict["decoder/agent_position"] = rotate( + x=data_dict["decoder/agent_position"][..., 0], + y=data_dict["decoder/agent_position"][..., 1], + angle=map_heading, + z=data_dict["decoder/agent_position"][..., 2] + ) + assert data_dict["decoder/agent_position"].ndim == 4 + data_dict["decoder/agent_position"] += map_center + + data_dict["decoder/agent_heading"] = wrap_to_pi(data_dict["decoder/agent_heading"] + map_heading) + + data_dict["decoder/agent_velocity"] = rotate( + x=data_dict["decoder/agent_velocity"][..., 0], + y=data_dict["decoder/agent_velocity"][..., 1], + angle=map_heading, + ) + + data_dict["decoder/agent_position"][~data_dict["decoder/agent_valid_mask"]] = 0 + data_dict["decoder/agent_heading"][~data_dict["decoder/agent_valid_mask"]] = 0 + data_dict["decoder/agent_velocity"][~data_dict["decoder/agent_valid_mask"]] = 0 + + data_dict["pred_trajs"] = [ + centralize_to_map_center( + traj, map_center=-data_dict["expanded_map_center"][b], map_heading=-data_dict["expanded_map_heading"][b] + ) for b, traj in enumerate(data_dict["pred_trajs"]) + ] + + return data_dict + + +scenario_metrics_keys = [ + # 'scenario_id', + 'metametric', + 'average_displacement_error', + 'min_average_displacement_error', + 'linear_speed_likelihood', + 'linear_acceleration_likelihood', + 'angular_speed_likelihood', + 'angular_acceleration_likelihood', + 'distance_to_nearest_object_likelihood', + 'collision_indication_likelihood', + 'time_to_collision_likelihood', + 'distance_to_road_edge_likelihood', + 'offroad_indication_likelihood' +] + +aggregate_metrics_keys = [ + 'realism_meta_metric', 'kinematic_metrics', 'interactive_metrics', 'map_based_metrics', 'min_ade' +] + + +def scenario_metrics_to_dict(scenario_metrics): + return {k: getattr(scenario_metrics, k) for k in scenario_metrics_keys} + + +def aggregate_metrics_to_dict(aggregate_metrics): + return {k: getattr(aggregate_metrics, k) for k in aggregate_metrics_keys} + +def joint_trajectory_nms( + pred_trajs, # [NUM_MODES, 80, N, 2] + pred_headings, # [NUM_MODES, 80, N] + mode_scores, # [NUM_MODES, 80, N] + ooisdc, + threshold=2.5, + num_ret_modes=32, + global_rank=None, +): + sorted_scores, sorted_indices = torch.sort(mode_scores, descending=True) + sorted_trajs = pred_trajs[sorted_indices] + + num_modes = sorted_trajs.shape[0] + suppressed = torch.zeros(num_modes, dtype=torch.bool).to(sorted_trajs.device) + keep = [] + + # Precompute pairwise similarities (e.g., goal distance or ADE) + goal_points = sorted_trajs[:, -1, :, :2] # [NUM_MODES, N, 2] + # assert goal_points.shape[1] == 2 + + assert ooisdc.shape[0] == 1, "ooisdc should be 1" + ooisdc = ooisdc[0].tolist() + + goal_distances = [] + for i in ooisdc: + goal_distances.append((goal_points[:, i][:, None] - goal_points[:, i][None, :]).norm(dim=-1)) + goal_distances = torch.stack(goal_distances, dim=0).mean(dim=0) # [NUM_MODES, NUM_MODES] + + while (not suppressed.all()): + # Find the next highest-score unsuppressed mode + active_modes = (~suppressed).nonzero(as_tuple=True)[0] + if len(active_modes) == 0: + break # Edge case: no modes left (unlikely with num_ret_modes=6) + best_idx = active_modes[0] + keep.append(sorted_indices[best_idx]) + suppressed[best_idx] = True + + # Suppress overlapping modes + overlapping = (goal_distances[best_idx] < threshold) & (~suppressed) + suppressed[overlapping] = True + + print("RANK {}, Keep modes: {} out of {}.".format(global_rank, len(keep), num_modes)) + if len(keep) > num_ret_modes: + keep = torch.stack(keep) + # Just randomized keep + random_indices = torch.randperm(len(keep)) + keep = keep[random_indices[:num_ret_modes]] + else: + randomized = torch.randperm(len(sorted_indices)).tolist() + for i in randomized: + if len(keep) == num_ret_modes: + break + v = sorted_indices[i] + if v not in keep: + keep.append(v) + keep = torch.stack(keep) + assert len(keep) == num_ret_modes + kept_preds = pred_trajs[keep[:num_ret_modes]] + kept_headings = pred_headings[keep[:num_ret_modes]] + kept_scores = mode_scores[keep[:num_ret_modes]] + return kept_preds, kept_headings, kept_scores, None + + + +class WaymoSimAgentEvaluator: + def __init__(self, config): + self.config = config + + self.metrics = [] + self.scenario_rollouts_list = [] + self.scenario_rollouts_list_91steps = [] + self.scenario_pb_list = [] + self.shard_count = 0 + self.scenario_count = 0 + + self.shard_count_91steps = 0 + self.scenario_count_91steps = 0 + + self.num_scenarios_per_shard = 10 + + self.scenario_generation_challenge = (self.config.EVALUATION.NAME == "sgen") + + if self.config.EVALUATION.NAME in ["wosac2024", "sgen"]: + self.use_2024 = True + elif self.config.EVALUATION.NAME == "wosac2023": + self.use_2024 = False + else: + raise ValueError() + + print("[Sim Agent] SAMPLING CONFIG IS: ", self.config.SAMPLING) + + def _call_model(self, data_dict, model): + """We might want to create mini batches to call model in case the of OOM...""" + + # ===== Autoregressive Decoding ===== + if model.config.MODEL.NAME == "scenestreamer": + if not hasattr(self, "scenestreamer_generator"): + from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator + self.scenestreamer_generator = SceneStreamerGenerator( + model=model, + device=data_dict["decoder/agent_position"].device, + ) + with torch.no_grad(): + self.scenestreamer_generator.reset(new_data_dict=data_dict) + + if self.scenario_generation_challenge: + expanded_data_dict = self.scenestreamer_generator.generate_scenestreamer_initial_state_and_motion() + + else: + expanded_data_dict = self.scenestreamer_generator.generate_scenestreamer_motion() + + elif model.config.MODEL.NAME == "gpt": + from scenestreamer.infer.motion import generate_motion + with torch.no_grad(): + expanded_data_dict = generate_motion( + model=model, + data_dict=data_dict, + autoregressive_start_step=2, + # num_decode_steps=num_decode_steps, + ) + + # ===== Postprocessing to extract predictions for the modeled agents ===== + scores = expanded_data_dict["decoder/output_score"] + pred_trajs = expanded_data_dict["decoder/reconstructed_position"] + pred_heading = expanded_data_dict["decoder/reconstructed_heading"] + + if "decoder/reconstructed_shape" in expanded_data_dict: + pred_shapes = expanded_data_dict["decoder/reconstructed_shape"] + else: + pred_shapes = None + + # If training to predict all agents, but asking for eval on modeled agents, + # need to pick the prediction for the modeled agents only. + assert self.config.TRAINING.PREDICT_ALL_AGENTS + assert self.config.EVALUATION.PREDICT_ALL_AGENTS + + return pred_trajs, pred_heading, scores, expanded_data_dict, pred_shapes + + def validation_step(self, data_dict, batch_idx, model, log_dict_func, global_rank, logger, **kwargs): + + save_91steps_together = self.scenario_generation_challenge + save_80steps_together = not self.scenario_generation_challenge + + disable_eval = True + + num_modes_for_eval = self.config.EVALUATION.NUM_MODES + maximum_batch_size = self.config.EVALUATION.MAXIMUM_BATCH_SIZE + + if num_modes_for_eval <= maximum_batch_size: + num_repeat_calls = 1 + else: + assert num_modes_for_eval % maximum_batch_size == 0 + num_repeat_calls = num_modes_for_eval // maximum_batch_size + + NUM_MODES_WAYMO_SIM_AGENTS = 32 + + B = data_dict["encoder/agent_feature"].shape[0] + data_dict["batch_idx"] = torch.arange(B) + + # DEBUG: + # print("RAW SCENARIO DESCPTION: (BEFORE) ", data_dict["raw_scenario_description"][0]['id']) + + if num_repeat_calls == 1: + expanded_data_dict = { + k: _repeat_for_modes(data_dict[k], num_modes=num_modes_for_eval) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("decoder/") or k in ["batch_idx", "in_evaluation", "scenario_id"] + ) + } + pred_trajs_of_interested_agents, pred_heading_of_interested_agents, scores_of_interested_agents, output_data_dict, pred_shapes = self._call_model( + expanded_data_dict, model + ) + + else: + assert B == 1 + num_modes_per_call = num_modes_for_eval // num_repeat_calls + assert num_modes_per_call * num_repeat_calls == num_modes_for_eval + expanded_data_dict = { + k: _repeat_for_modes(data_dict[k], num_modes=num_modes_per_call) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("decoder/") + or k in ["batch_idx", "in_evaluation", "scenario_id", "in_backward_prediction"] + ) + } + + pred_trajs_of_interested_agents = [] + scores_of_interested_agents = [] + pred_heading_of_interested_agents = [] + pred_shapes = [] + for call in range(num_repeat_calls): + traj, head, score, output_data_dict, pred_s = self._call_model(copy.deepcopy(expanded_data_dict), model) + pred_trajs_of_interested_agents.append(traj) + scores_of_interested_agents.append(score) + pred_heading_of_interested_agents.append(head) + pred_shapes.append(pred_s) + pred_trajs_of_interested_agents = [vv for v in pred_trajs_of_interested_agents for vv in v] + scores_of_interested_agents = [vv for v in scores_of_interested_agents for vv in v] + pred_heading_of_interested_agents = [vv for v in pred_heading_of_interested_agents for vv in v] + if pred_shapes[0] is not None: + pred_shapes = [vv for v in pred_shapes for vv in v] + else: + pred_shapes = None + + # ===== Postprocessing to extract predictions for the modeled agents ===== + if num_modes_for_eval > NUM_MODES_WAYMO_SIM_AGENTS: + if self.scenario_generation_challenge: + pred_trajs_of_interested_agents = pred_trajs_of_interested_agents[:NUM_MODES_WAYMO_SIM_AGENTS] + pred_heading_of_interested_agents = pred_heading_of_interested_agents[:NUM_MODES_WAYMO_SIM_AGENTS] + scores_of_interested_agents = scores_of_interested_agents[:NUM_MODES_WAYMO_SIM_AGENTS] + pred_shapes = pred_shapes[:NUM_MODES_WAYMO_SIM_AGENTS] + + else: + pred_trajs_of_interested_agents = torch.stack(pred_trajs_of_interested_agents, 0) + pred_heading_of_interested_agents = torch.stack(pred_heading_of_interested_agents, 0) + scores_of_interested_agents = torch.stack(scores_of_interested_agents, 0).sum(-1) + ooisdc = data_dict["decoder/labeled_agent_id"].clone() + pred_trajs_of_interested_agents, pred_heading_of_interested_agents, scores_of_interested_agents, _ = \ + joint_trajectory_nms( + pred_trajs_of_interested_agents, + pred_heading_of_interested_agents, + scores_of_interested_agents, + num_ret_modes=NUM_MODES_WAYMO_SIM_AGENTS, + ooisdc=ooisdc, + global_rank=global_rank, + ) + pred_trajs_of_interested_agents = list(pred_trajs_of_interested_agents) + pred_heading_of_interested_agents = list(pred_heading_of_interested_agents) + scores_of_interested_agents = list(scores_of_interested_agents) + + pred_to_scenario_id = _repeat_for_modes(data_dict["scenario_id"], num_modes=NUM_MODES_WAYMO_SIM_AGENTS) + expanded_map_center = _repeat_for_modes(data_dict["metadata/map_center"], num_modes=NUM_MODES_WAYMO_SIM_AGENTS) + expanded_map_heading = _repeat_for_modes( + data_dict["metadata/map_heading"], num_modes=NUM_MODES_WAYMO_SIM_AGENTS + ) + + # ===== Cache the prediction results ===== + prediction_dict = { + "pred_trajs": pred_trajs_of_interested_agents, + "pred_headings": pred_heading_of_interested_agents, + "pred_scores": scores_of_interested_agents, + + "pred_shape": pred_shapes, + + "pred_to_scenario_id": pred_to_scenario_id, + "expanded_map_center": expanded_map_center, + "expanded_map_heading": expanded_map_heading, + } + for k, v in data_dict.items(): + if k.startswith("decoder/") or k.startswith("decoder/") or k.startswith("metadata/") or k in [ + "raw_scenario_description", "scenario_id" + ]: + prediction_dict[k] = v + + new_prediction_dict = {} + for k, v in prediction_dict.items(): + if isinstance(v, torch.Tensor): + new_prediction_dict[k] = v.detach().cpu().numpy() + elif isinstance(v, list): + new_prediction_dict[k] = [vv.detach().cpu().numpy() if isinstance(vv, torch.Tensor) else v for vv in v] + else: + new_prediction_dict[k] = v + # prediction_dict = copy.deepcopy(new_prediction_dict) # Avoid memory issue + + # Transform back to global coordinate + new_prediction_dict = transform_to_global_coordinate(new_prediction_dict) + # self.validation_outputs.append(new_prediction_dict) + scenario_metrics, aggregate_metrics, scenario_rollouts_list_80steps, scenario_rollouts_list_91steps, scenario_pb_list = wosac_evaluation( + [new_prediction_dict], disable_eval=disable_eval, use_2024=self.use_2024, + save_91steps_together=save_91steps_together, + save_80steps_together=save_80steps_together, + ) + + # import matplotlib.pyplot as plt + # plt.figure() + # plt.gca().set_aspect('equal', adjustable='box') + # from scenestreamer.gradio_ui.plot import _plot_map + # from scenestreamer.utils import utils + # np_d = utils.torch_to_numpy(data_dict) + # np_d = {k: v[0] for k, v in np_d.items()} + # _plot_map(np_d, ax=plt.gca()) + # AID = 2 + # for mode in pred_trajs_of_interested_agents: + # mode = utils.torch_to_numpy(mode) + # plt.plot(mode[:, AID, 0], mode[:, AID, 1]) + # plt.show() + + # TODO: Some assertions here to avoid WOSAC error ...... + # https://github.com/waymo-research/waymo-open-dataset/issues/807 + # Scenario 891805f154b4f0dd: Sim agents {1178} are missing from the simulation. + # Scenario de8c427e65487b93: Sim agents {432, 569, 554} are missing from the simulation. + # Scenario 386f0b2faebe74af: Sim agents {3361, 3332, 3399, 3371, 3375, 3311, 3345, 3346, 3378, 3319, 3352, 3388} are missing from the simulation. + # Scenario cd861218ceb2dc1e: Sim agents {2145, 4805, 2150, 2123, 2162, 2131, 4730, 2139} are missing from the simulation. + # Scenario 46a12cf2da1fdda8: Sim agents {1540, 1544, 4507, 4514, 4545, 1608, 1616, 1623, 1626, 1627, 1500, 1630, 1631, 1632, 1634, 4450, 1636, 4599, 1638, 1639, 4455, 4583, 4585, 4586, 1515, 4461, 1649, 4469, 4597, 4598, 4477} are missing from the simulation. + watching = { + "891805f154b4f0dd": [1178], + "de8c427e65487b93": [432, 569, 554], + "386f0b2faebe74af": [3361, 3332, 3399, 3371, 3375, 3311, 3345, 3346, 3378, 3319, 3352, 3388], + "cd861218ceb2dc1e": [2145, 4805, 2150, 2123, 2162, 2131, 4730, 2139], + "46a12cf2da1fdda8": [ + 1540, 1544, 4507, 4514, 4545, 1608, 1616, 1623, 1626, 1627, 1500, 1630, 1631, 1632, 1634, 4450, 1636, + 4599, 1638, 1639, 4455, 4583, 4585, 4586, 1515, 4461, 1649, 4469, 4597, 4598, 4477 + ], + } + for r in scenario_rollouts_list_80steps: + sid = r.scenario_id + if sid in watching: + obj_ids = {j.object_id for j in r.joint_scenes[10].simulated_trajectories} + for oid in watching[sid]: + assert oid in obj_ids + # # TODO: Some assertions here to avoid WOSAC error ...... + + if len(scenario_rollouts_list_80steps) > 0: + assert data_dict["raw_scenario_description"][0]['id'] == scenario_rollouts_list_80steps[0].scenario_id + if len(scenario_rollouts_list_91steps) > 0: + assert data_dict["raw_scenario_description"][0]['id'] == scenario_rollouts_list_91steps[0].scenario_id + + if not disable_eval: + scenario_id = list(scenario_metrics.keys()) + + scenario_metrics = {k: scenario_metrics_to_dict(scenario_metrics[k]) for k in scenario_metrics} + aggregate_metrics = {k: aggregate_metrics_to_dict(aggregate_metrics[k]) for k in aggregate_metrics} + + stat = {} + for k in scenario_metrics_keys: + stat[f"scenario_metrics/{k}"] = np.mean([d[k] for d in scenario_metrics.values()]) + for k in aggregate_metrics_keys: + stat[f"aggregate_metrics/{k}"] = np.mean([d[k] for d in aggregate_metrics.values()]) + + log_dict_func( + stat, + batch_size=data_dict["encoder/agent_feature"].shape[0], + on_epoch=True, + prog_bar=True, + ) + + self.metrics.append(stat) + + print( + "\n=============== RANK {} FINISHED {} SCENARIOS =============".format(global_rank, len(self.metrics)) + ) + print("Latest scenario ID: ", scenario_id) + for k in self.metrics[0].keys(): + print(f"{k}: {np.mean([m[k] for m in self.metrics]):.4f}") + print("===========================================================".format(len(self.metrics))) + + if save_80steps_together: + self.scenario_rollouts_list.extend(scenario_rollouts_list_80steps) + self.scenario_pb_list.extend(scenario_pb_list) + if len(self.scenario_rollouts_list) >= self.num_scenarios_per_shard: + output_dir = pathlib.Path(logger.log_dir) / "80steps" + self.generate_submission_shard(output_dir, global_rank) + + if save_91steps_together: + self.scenario_rollouts_list_91steps.extend(scenario_rollouts_list_91steps) + if len(self.scenario_rollouts_list_91steps) >= self.num_scenarios_per_shard: + output_dir = pathlib.Path(logger.log_dir) / "91steps" + self.generate_submission_shard_91steps(output_dir, global_rank) + + def on_validation_epoch_end(self, *args, global_rank, logger, trainer, **kwargs): + if self.metrics: + print("======== FINAL RESULT RANK {} WITH {} SCENARIOS ==========".format(global_rank, len(self.metrics))) + for k in self.metrics[0].keys(): + print(f"{k}: {np.mean([m[k] for m in self.metrics]):.4f}") + print("===========================================================".format(len(self.metrics))) + + output_dir = pathlib.Path(logger.log_dir) / "80steps" + print( + f"RANK {global_rank} Storing the final submission files with {len(self.scenario_rollouts_list)} rollouts..." + ) + import time + sleep = np.random.randint(1, 5) + print(f"RANK {global_rank} sleep {sleep} seconds.") + time.sleep(sleep) + + if not self.scenario_generation_challenge: + self.generate_submission_shard(output_dir, global_rank) + + if self.scenario_generation_challenge: + output_dir = pathlib.Path(logger.log_dir) / "91steps" + self.generate_submission_shard_91steps(output_dir, global_rank) + + print(f"RANK {global_rank} finished. Entering barrier...") + # trainer.strategy.barrier() + print(f"RANK {global_rank} left barrier...") + # if global_rank == 0: + print("RANK {} Generated {} shards total.".format(global_rank, self.shard_count)) + print("RANK {} Generated {} scenarios total.".format(global_rank, self.scenario_count)) + print("RANK {} ========== Please manually merge the submission files!!! ==========".format(global_rank)) + output_dir = pathlib.Path(output_dir).resolve() + print("===============================================================================================\n") + print("RANK {} Shard submission is saved at: {}".format(global_rank, output_dir)) + print("\n===============================================================================================") + print("RANK {} Exit.".format(global_rank)) + + def generate_submission_shard(self, output_dir, this_rank): + from waymo_open_dataset.protos import sim_agents_submission_pb2 + account_name = self.config.SUBMISSION.ACCOUNT + unique_method_name = self.config.SUBMISSION.METHOD_NAME + num_model_parameters = self.config.SUBMISSION.num_model_parameters + shard_submission = sim_agents_submission_pb2.SimAgentsChallengeSubmission( + scenario_rollouts=self.scenario_rollouts_list, + submission_type=sim_agents_submission_pb2.SimAgentsChallengeSubmission.SIM_AGENTS_SUBMISSION, + account_name=account_name, + unique_method_name=unique_method_name, + + authors=['scenestreamer_authors'], + + # New fields, need changed. + uses_lidar_data=False, + uses_camera_data=False, + uses_public_model_pretraining=False, + num_model_parameters=num_model_parameters, + acknowledge_complies_with_closed_loop_requirement=True + ) + + # output_filename = f'submission.binproto-{global_rank:05d}-of-{total_ranks:05d}' + output_filename = f'submission.binproto-tmp{uuid.uuid4()}' + + scenario_id_list = [s.scenario_id for s in self.scenario_rollouts_list] + print("Scenario ID to be saved in shard: ", scenario_id_list, output_filename) + + output_dir = pathlib.Path(output_dir).absolute() + + output_dir.mkdir(parents=True, exist_ok=True) + + file_path = pathlib.Path(output_dir) / output_filename + file_path = file_path.resolve() + with open(file_path, 'wb') as f: + f.write(shard_submission.SerializeToString()) + + if self.config.SUBMISSION.SAVE_EVAL_DATA and (not self.config.SUBMISSION.GENERATE_SUBMISSION): + for s in self.scenario_pb_list: + print("Scenario ID to be saved together apart from shard: ", s.scenario_id) + file_path = pathlib.Path(output_dir) / "scenario_pb" + file_path.mkdir(parents=True, exist_ok=True) + file_path = file_path / f"{s.scenario_id}.binproto" + file_path = file_path.resolve() + with open(file_path, 'wb') as f: + f.write(s.SerializeToString()) + + print("=====================================================================================================\n") + print("RANK {} Shard submission is saved at: {}".format(this_rank, file_path)) + print("To generate final submission, please manually run:") + print("\npython -m scenestreamer.merge_shards --output_dir={}".format(output_dir)) + print("\n\nTo see evaluation results, please manually run: (please make sure SUBMISSION.SAVE_EVAL_DATA=True)") + print("\npython -m scenestreamer.wosac_eval_async --output_dir={}".format(output_dir)) + print("\npython -m scenestreamer.wosac_eval --output_dir={}".format(output_dir)) + print("\n=====================================================================================================") + # self.output_filenames.append(output_filename) + self.scenario_rollouts_list = [] + self.scenario_pb_list = [] + self.shard_count += 1 + self.scenario_count += len(scenario_id_list) + def generate_submission_shard_91steps(self, output_dir, this_rank): + from waymo_open_dataset.protos import sim_agents_submission_pb2 + account_name = self.config.SUBMISSION.ACCOUNT + unique_method_name = self.config.SUBMISSION.METHOD_NAME + num_model_parameters = self.config.SUBMISSION.num_model_parameters + shard_submission = sim_agents_submission_pb2.SimAgentsChallengeSubmission( + scenario_rollouts=self.scenario_rollouts_list_91steps, + submission_type=sim_agents_submission_pb2.SimAgentsChallengeSubmission.SIM_AGENTS_SUBMISSION, + account_name=account_name, + unique_method_name=unique_method_name, + + # New fields, need changed. + uses_lidar_data=False, + uses_camera_data=False, + uses_public_model_pretraining=False, + num_model_parameters=num_model_parameters, + acknowledge_complies_with_closed_loop_requirement=True + ) + + # output_filename = f'submission.binproto-{global_rank:05d}-of-{total_ranks:05d}' + output_filename = f'submission.binproto-tmp{uuid.uuid4()}' + + scenario_id_list = [s.scenario_id for s in self.scenario_rollouts_list_91steps] + print("Scenario ID to be saved in shard: ", scenario_id_list, output_filename) + + output_dir = pathlib.Path(output_dir).absolute() + + output_dir.mkdir(parents=True, exist_ok=True) + + file_path = pathlib.Path(output_dir) / output_filename + file_path = file_path.resolve() + with open(file_path, 'wb') as f: + f.write(shard_submission.SerializeToString()) + + if self.config.SUBMISSION.SAVE_EVAL_DATA and (not self.config.SUBMISSION.GENERATE_SUBMISSION): + raise ValueError + + print("=====================================================================================================\n") + print("RANK {} Shard submission is saved at: {}".format(this_rank, file_path)) + print("To generate final submission, please manually run:") + print("\npython -m scenestreamer.merge_shards --output_dir={}".format(output_dir)) + print("\n\nTo see evaluation results, please manually run: (please make sure SUBMISSION.SAVE_EVAL_DATA=True)") + print("\npython -m scenestreamer.wosac_eval_async --output_dir={}".format(output_dir)) + print("\npython -m scenestreamer.wosac_eval --output_dir={}".format(output_dir)) + print("\n=====================================================================================================") + # self.output_filenames.append(output_filename) + self.scenario_rollouts_list_91steps = [] + assert not self.scenario_pb_list + self.scenario_pb_list = [] + self.shard_count_91steps += 1 + self.scenario_count_91steps += len(scenario_id_list) diff --git a/scenestreamer/eval/waymo_submission.py b/scenestreamer/eval/waymo_submission.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/eval/wosac_eval.py b/scenestreamer/eval/wosac_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3f198066956aa6d9eaa3689c949a55a84d15a671 --- /dev/null +++ b/scenestreamer/eval/wosac_eval.py @@ -0,0 +1,475 @@ +""" +This file provides some utility functions for locally validating on the Waymo Open Sim Agents Challenge metrics. + +Installation: + +conda install python=3.10 +pip install waymo-open-dataset-tf-2-12-0==1.6.4 + +https://github.com/waymo-research/waymo-open-dataset.git +""" + +# os.chdir("waymo-open-dataset/src") +# Load Scenario Description for passthrough +import pathlib +from collections import defaultdict +from scenestreamer.utils import utils +import numpy as np +import tensorflow as tf +from google.protobuf import text_format + +try: + from waymo_open_dataset.protos import scenario_pb2 + from waymo_open_dataset.protos import sim_agents_metrics_pb2 + from waymo_open_dataset.protos import sim_agents_submission_pb2 + from waymo_open_dataset.utils.sim_agents import submission_specs + from waymo_open_dataset.wdl_limited.sim_agents_metrics import metrics +except ModuleNotFoundError: + scenario_pb2 = None + sim_agents_metrics_pb2 = None + sim_agents_submission_pb2 = None + submission_specs = None + metrics = None + +from scenestreamer.utils import wrap_to_pi + +# Set memory growth on all gpus. +# all_gpus = tf.config.experimental.list_physical_devices('GPU') +# if all_gpus: +# try: +# for cur_gpu in all_gpus: +# tf.config.experimental.set_memory_growth(cur_gpu, True) +# except RuntimeError as e: +# print(e) + +FOLDER = pathlib.Path(__file__).resolve().parent + +# Disable all GPUS +tf.config.set_visible_devices([], 'GPU') +visible_devices = tf.config.get_visible_devices() +for device in visible_devices: + assert device.device_type != 'GPU', f"Expected device type to be CPU, got {device.device_type}." + + +def joint_scene_from_states(states: np.ndarray, object_ids: tf.Tensor, + sgen_challenge) -> sim_agents_submission_pb2.JointScene: + # States shape: (num_objects, num_steps, 4). + # Objects IDs shape: (num_objects,). + # states = states.numpy() + simulated_trajectories = [] + for i_object in range(object_ids.shape[0]): + + if sgen_challenge: + assert states.shape[-1] == 7 + traj = sim_agents_submission_pb2.SimulatedTrajectory( + center_x=states[i_object, :, 0], + center_y=states[i_object, :, 1], + center_z=states[i_object, :, 2], + heading=states[i_object, :, 3], + object_id=object_ids[i_object], + length=states[i_object, :, 4], + width=states[i_object, :, 5], + height=states[i_object, :, 6], + ) + + else: + assert states.shape[-1] == 4 + traj = sim_agents_submission_pb2.SimulatedTrajectory( + center_x=states[i_object, :, 0], + center_y=states[i_object, :, 1], + center_z=states[i_object, :, 2], + heading=states[i_object, :, 3], + object_id=object_ids[i_object] + ) + simulated_trajectories.append(traj) + return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories) + + +def scenario_rollouts_from_states( + scenario_id, states: tf.Tensor, object_ids: tf.Tensor, sgen_challenge=False +) -> sim_agents_submission_pb2.ScenarioRollouts: + """ + Aggregate agent states into a ScenarioRollouts proto message. + """ + # States shape: (num_rollouts, num_objects, num_steps, 4). + # Objects IDs shape: (num_objects,). + joint_scenes = [] + for i_rollout in range(states.shape[0]): + joint_scenes.append(joint_scene_from_states(states[i_rollout], object_ids, sgen_challenge=sgen_challenge)) + return sim_agents_submission_pb2.ScenarioRollouts( + # Note: remember to include the Scenario ID in the proto message. + joint_scenes=joint_scenes, + scenario_id=scenario_id + ) + + +from google.protobuf import json_format + + +def load_protobuf_from_dict(scenario_dict, scenario_type=scenario_pb2.Scenario): + """ + Load a Scenario protobuf message from a dictionary. + + :param scenario_dict: A dictionary representing the Scenario. + :return: A Scenario protobuf message. + """ + scenario = scenario_type() + json_format.ParseDict(scenario_dict, scenario) + return scenario + + +def scenario_description_to_scenario_pb2(sd: dict) -> scenario_pb2.Scenario: + """ + Converts a scenario description dict to a scenario_pb2.Scenario proto. + """ + + # 1. Parse tracks into the format expected by the scenario proto. + tracks_to_predict = [] + scenario_tracks = [] + for track_count, (track_name, track) in enumerate(sd["tracks"].items()): + type_mapping = { + "VEHICLE": scenario_pb2.Track.TYPE_VEHICLE, + "PEDESTRIAN": scenario_pb2.Track.TYPE_PEDESTRIAN, + "CYCLIST": scenario_pb2.Track.TYPE_CYCLIST, + "UNSET": scenario_pb2.Track.TYPE_UNSET, + "OTHER": scenario_pb2.Track.TYPE_OTHER + } + + if str(track_name) in sd["metadata"]["tracks_to_predict"]: + tracks_to_predict.append( + { + "track_index": track_count, + "difficulty": sd["metadata"]["tracks_to_predict"][str(track_name)]["difficulty"], + } + ) + + # track["state"] is formatted as a dict of arrays of shape (timesteps, dim). We want to convert it into a list of dictionaries sharing the same keys. + timesteps = track["state"]["position"].shape[0] + center_x, center_y, center_z = track["state"]["position"].T + + one_d_keys = ["length", "width", "height", "heading", "valid"] + for key in one_d_keys: + assert track["state"][key].shape == ( + timesteps, + ), f"Expected shape (timesteps,), got {track['state'][key].shape}." + length, width, height, heading, valid = [track["state"][key] for key in one_d_keys] + velocity_x, velocity_y = track["state"]["velocity"].T + scenario_tracks.append( + { + "id": int(track_name), + "object_type": type_mapping[track["type"]], + "states": [ + { + "center_x": center_x[i].tolist(), + "center_y": center_y[i].tolist(), + "center_z": center_z[i].tolist(), + "length": length[i].tolist(), + "width": width[i].tolist(), + "height": height[i].tolist(), + "heading": heading[i].tolist(), + "velocity_x": velocity_x[i].tolist(), + "velocity_y": velocity_y[i].tolist(), + "valid": valid[i].tolist() + } for i in range(timesteps) + ] + } + ) + + # 2. Build up scenario map features. The protobuf expects a list of map_pb2.MapFeature, each with a type and built up with polylines. + scenario_map_features = [] + for map_feature_id, map_feature in sd["map_features"].items(): + map_feature_mapping = { + "LANE": "lane", + "ROAD_LINE": "road_line", + "ROAD_EDGE": "road_edge", + "STOP_SIGN": "stop_sign", + "CROSSWALK": "crosswalk", + "SPEED_BUMP": "speed_bump", + "DRIVEWAY": "driveway", + "UNKNOWN": "unknown" + } + for key in map_feature_mapping: # E.g if "LANE" in map_feature["type"], then map_feature_type = map_pb2.MapFeature.lane + + if key == "UNKNOWN": + continue # TODO: Deal with this in future + + if key in map_feature["type"]: + map_feature_type = map_feature_mapping[key] + break + else: + map_feature_type = -1 + + if map_feature_type == -1: + continue # TODO: Deal with this in future + assert map_feature_type != -1, f"Map feature type {map_feature['type']} not recognized." + + # MAIN DICT FOR MAP FEATURE + map_feature_dict = {"id": int(map_feature_id), map_feature_type: {}} + + if map_feature_type in ["road_line", "road_edge"]: + # Polyline features only exist for road_line and road_edge map features. + # The polylines in the dict are of shape (N, 3). We want to convert them into a list of MapFeature.Polyline, each with an x, y, and z field. + if map_feature_type == "road_line": + map_feature_dict[map_feature_type]["type"] = "TYPE_" + '_'.join(map_feature["type"].split("_")[2:]) + else: + map_feature_dict[map_feature_type]["type"] = "TYPE_" + map_feature["type"] + polylines = map_feature["polyline"] + formatted_polylines = [{"x": polyline[0], "y": polyline[1], "z": polyline[2]} for polyline in polylines] + map_feature_dict[map_feature_type]["polyline"] = formatted_polylines + + elif map_feature_type == "lane": + # Add lane-specific fields. These can be pass-through fields from the ScenarioDescription. + passthrough_keys = [ + "speed_limit_mph", "interpolating", "entry_lanes", "exit_lanes", "left_boundaries", "right_boundaries" + ] + for key in passthrough_keys: + try: + if type(map_feature_dict[map_feature_type][key]) == list: + if len(map_feature_dict[map_feature_type][key]) == 0: + continue + map_feature_dict[map_feature_type][key] = map_feature[key] + except: + pass + + elif map_feature_type == "stop_sign": + map_feature_dict[map_feature_type]["lane"] = map_feature["lane"] + map_feature_dict[map_feature_type]["position"] = { + "x": map_feature["position"][0], + "y": map_feature["position"][1], + "z": map_feature["position"][2] + } + + elif map_feature_type in ["driveway", "crosswalk", "speed_bump"]: + # Driveways, crosswalks, and speedbumps are represented by a singular polygon instead of multiple polylines. Each polygon is an array with shape (4, 3) + polygon = map_feature["polygon"] + map_feature_dict[map_feature_type]["polygon"] = [ + { + "x": vertex[0], + "y": vertex[1], + "z": vertex[2] + } for vertex in polygon + ] + + elif map_feature_type in ["unknown"]: + # polylines = map_feature["polyline"] + # map_feature_dict[map_feature_type]["polyline"] = [{"x": polyline[0], "y": polyline[1], "z": polyline[2]} for polyline in polylines] + pass + # TODO: Deal with this in future + + else: + raise ValueError(f"Map feature type {map_feature_type} not recognized.") + + scenario_map_features.append(map_feature_dict) + + assert len(tracks_to_predict) > 0 + + # print(f"In scenario {sd['metadata']['scenario_id']}, number of tracks to predict: {len(tracks_to_predict)}") + + scenario_parsedict = { + "compressed_frame_laser_data": [], # Not filled + "current_time_index": sd["metadata"]["current_time_index"], + "dynamic_map_states": [], # TODO: Traffic light is not filled + "map_features": scenario_map_features, + "objects_of_interest": [], # Not filled + "scenario_id": sd['metadata']["scenario_id"], + "sdc_track_index": sd["metadata"]["sdc_track_index"], + "timestamps_seconds": sd["metadata"]["ts"].tolist(), + "tracks": scenario_tracks, + "tracks_to_predict": tracks_to_predict, + } + + scenario = load_protobuf_from_dict(scenario_parsedict) + return scenario + + +def load_metrics_config(use_2024) -> sim_agents_metrics_pb2.SimAgentMetricsConfig: + """Loads the `SimAgentMetricsConfig` used for the challenge.""" + # pylint: disable=line-too-long + # pyformat: disable + + # As noted in: https://github.com/waymo-research/waymo-open-dataset/issues/817 + # The config have changed. So we need to switch between them. + if use_2024: + config_path = FOLDER / 'challenge_2024_config.textproto' + else: + config_path = FOLDER / 'challenge_2023_config.textproto' + + with open(config_path, 'r') as f: + config = sim_agents_metrics_pb2.SimAgentMetricsConfig() + text_format.Parse(f.read(), config) + return config + + +def load_metrics_config_from_file_name(file_name): + """Loads the `SimAgentMetricsConfig` used for the challenge.""" + # pylint: disable=line-too-long + # pyformat: disable + + # As noted in: https://github.com/waymo-research/waymo-open-dataset/issues/817 + # The config have changed. So we need to switch between them. + config_path = FOLDER / file_name + + with open(config_path, 'r') as f: + config = sim_agents_metrics_pb2.SimAgentMetricsConfig() + text_format.Parse(f.read(), config) + return config + + +def wosac_evaluation(pred_dicts: list, disable_eval, use_2024, save_91steps_together, save_80steps_together): + """ + pred_dicts: A list of dictionaries with the data for evaluation. For more, see data_dict_to_motion_prediction in test_waymo_eval.py. + + Returns: + scenario_metrics: sim_agents_submission_pb2.SimAgentMetrics -> The metrics for the scenario. + aggregate_metrics: sim_agents_submission_pb2.SimAgentsBucketedMetrics -> The aggregated metrics for the scenario. + """ + """ + scenario: scenario_pb2.Scenario -> The scenario to evaluate. WOSAC uses it for the map data. + simulated_states: tf.Tensor -> The simulated states of the agents. Shape: (num_rollouts (by default, 32), num_agents, num_steps (80), 4). + logged_trajectories: waymo_open_dataset.utils.trajectory_utils.ObjectTrajectories + """ + # Split all data based on scenario + split_data = defaultdict(list) + scenario_id_list = [] + # Split the prediction for each scenario, also flatten the data + for d in pred_dicts: + for sid in np.unique(d["pred_to_scenario_id"]): + # For every unique scenario id: + for k, v in d.items(): + if k in ["pred_trajs", "pred_headings", "pred_shape"]: + if v is None: + continue + # Filter out the data that corresponds to the particular scenario id. + entry_in_same_scenario = [v[idx] for idx in range(len(v)) if d["pred_to_scenario_id"][idx] == sid] + assert entry_in_same_scenario # is not None + entry_in_same_scenario = np.stack(entry_in_same_scenario, axis=0) + split_data[k].append(entry_in_same_scenario) + elif k in ["decoder/agent_position", "decoder/agent_velocity", "decoder/agent_heading", + "decoder/agent_valid_mask", "decoder/agent_shape", "decoder/agent_type", + "raw_scenario_description", "decoder/track_name"]: + entry_in_same_scenario = [v[idx] for idx in range(len(v)) if d["scenario_id"][idx] == sid] + assert entry_in_same_scenario + assert len( + entry_in_same_scenario + ) == 1 # Assert there's only one scenario data instance for each object in the list. + entry_in_same_scenario = entry_in_same_scenario[0] + + # little workaround + if k == "raw_scenario_description" and isinstance(entry_in_same_scenario, list): + assert len(entry_in_same_scenario) == 1 + entry_in_same_scenario = entry_in_same_scenario[0] + + split_data[k].append(entry_in_same_scenario) + scenario_id_list.append(sid) + + assert len(split_data["raw_scenario_description"]) > 0, "No scenario description found in the data." + + scenario_metrics_result = {} + aggregate_metrics_result = {} + + scenario_rollouts_list_80steps = [] + scenario_rollouts_list_91steps = [] + scenario_pb_list = [] + + current_time = pred_dicts[0]["metadata/current_time_index"].item() + + # print("Creating scenario rollouts...") + for scenario_index, scenario_dict in enumerate(split_data["raw_scenario_description"]): + # scenario: scenario_pb2.Scenario + scenario_id = scenario_dict["metadata"]["scenario_id"] + # simulated states: tf.Tensor with shape (num_modes, num_agents, num_steps, 4) + # The 4 dimensions are: center_x, center_y, center_z, heading + + states = split_data["pred_trajs"][ + scenario_index] # torch.Tensor with shape: (num_modes, num_steps, num_agents, 2 -> (center_x, center_y)) + num_modes = states.shape[0] + headings = wrap_to_pi(split_data["pred_headings"][scenario_index]) # shape: (num_modes, num_steps, num_agents) + + # Change headings to (-pi, pi) + headings = utils.wrap_to_pi(headings) + + + # PZH: Fill in Z here at current step. + z_values_for_simagent = split_data["decoder/agent_position"][scenario_index][None, + current_time:current_time + 1, :, 2] + z_values_for_simagent = np.repeat(z_values_for_simagent, num_modes, axis=0) + Modes, T, N, _ = states.shape + z_values_for_simagent = np.repeat(z_values_for_simagent, T, axis=1) + + # NOTE: Write these if you want to see GT's WOSAC scores. + # headings[:, :91] = split_data["decoder/agent_heading"][scenario_index][None].repeat(32, axis=0) + # z_values[:, :91] = split_data["decoder/agent_position"][scenario_index][..., 2][None].repeat(32, axis=0) + # states[:, :91] = split_data["decoder/agent_position"][scenario_index][..., :2][None].repeat(32, axis=0) + + # Assume sdc z is at step=0, sdc is agent=0. + sdc_z = split_data["decoder/agent_position"][scenario_index][0][0][-1] + z_values_for_sgen = np.full_like(z_values_for_simagent, sdc_z) # shape: (num_modes, num_steps, num_agents, 1) + + states_for_simagent = np.concatenate([states, z_values_for_simagent[..., None], headings[..., None]], axis=-1) + states_for_simagent = states_for_simagent.transpose(0, 2, 1, + 3) # shape: (num_modes, num_agents, num_steps, 4 -> (center_x, center_y)) + + if "pred_shape" in split_data: + shape = split_data['pred_shape'][scenario_index] + states_for_scenariogen = np.concatenate([states, z_values_for_sgen[..., None], headings[..., None], shape], + axis=-1) + states_for_scenariogen = states_for_scenariogen.transpose(0, 2, 1, 3) + + if states_for_scenariogen.shape[2] == 96: + states_for_scenariogen = states_for_scenariogen[:, :, :-5, :] + assert states_for_scenariogen.shape[2] == 91 + + else: + assert save_91steps_together is False + + # states = tf.convert_to_tensor(states, dtype=tf.float32) + + # Get the trajectory ids for each prediction. + trajectory_ids = split_data["decoder/track_name"][scenario_index] + + trajectory_ids = np.array([[int(vv) for vv in v] for v in trajectory_ids]) + assert trajectory_ids.shape[0] == 1 + trajectory_ids = trajectory_ids[0] + trajectory_ids = tf.convert_to_tensor(trajectory_ids, dtype=tf.int32) + + if states_for_simagent.shape[2] == 96: + states_for_simagent = states_for_simagent[:, :, :-5, :] + + + assert states_for_simagent.shape[2] == 91 + + if save_80steps_together: + states_80 = states_for_simagent[:, :, 11:, :] + scenario_rollouts_80 = scenario_rollouts_from_states(scenario_id, states_80, trajectory_ids) + scenario_rollouts_list_80steps.append(scenario_rollouts_80) + + + if save_91steps_together: + scenario_rollouts_91 = scenario_rollouts_from_states(scenario_id, states_for_scenariogen, trajectory_ids, + sgen_challenge=True) + scenario_rollouts_list_91steps.append(scenario_rollouts_91) + + + scenario = scenario_description_to_scenario_pb2(scenario_dict) + + scenario_pb_list.append(scenario) + + if disable_eval: + continue + + if save_91steps_together: + # Load the test configuration. + config = load_metrics_config_from_file_name("challenge_2025_scenario_gen_config.textproto") + # Compute the metrics for the scenario. + scenario_metrics = metrics.compute_scenario_metrics_for_bundle( + config, scenario, scenario_rollouts_91, submission_specs.ChallengeType.SCENARIO_GEN + ) + print(scenario_metrics) + + # Load the test configuration. + # scenario_metrics = metrics.compute_scenario_metrics_for_bundle(config, scenario, scenario_rollouts_list_80steps) + # aggregate_metrics = metrics.aggregate_metrics_to_buckets(config, scenario_metrics) + scenario_metrics_result[scenario_dict["metadata"]["scenario_id"]] = scenario_metrics + # aggregate_metrics_result[scenario_dict["metadata"]["scenario_id"]] = aggregate_metrics + + return scenario_metrics_result, aggregate_metrics_result, scenario_rollouts_list_80steps, scenario_rollouts_list_91steps, scenario_pb_list diff --git a/scenestreamer/gradio_ui/__init__.py b/scenestreamer/gradio_ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/gradio_ui/artifact.py b/scenestreamer/gradio_ui/artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..12348f04fa383903e9a601282dfc89541a11c4c8 --- /dev/null +++ b/scenestreamer/gradio_ui/artifact.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import contextlib +import io +from pathlib import Path + +import numpy as np + + +def npz_to_dict(npz: np.lib.npyio.NpzFile) -> dict: + out: dict = {} + for k in npz.files: + v = npz[k] + if isinstance(v, np.ndarray) and v.shape == (): + v = v.item() + out[k] = v + return out + + +def load_asset(asset_path: str | Path) -> dict: + with np.load(asset_path, allow_pickle=False) as npz: + return npz_to_dict(npz) + + +def render_asset(asset_path: str | Path, out_path: str | Path, verbose: bool = False) -> Path: + from scenestreamer.gradio_ui.plot import plot_pred + + asset_path = Path(asset_path) + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + data_dict = load_asset(asset_path) + + if verbose: + plot_pred(data_dict, show=False, save_path=str(out_path)) + else: + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + plot_pred(data_dict, show=False, save_path=str(out_path)) + return out_path + diff --git a/scenestreamer/gradio_ui/demo_app.py b/scenestreamer/gradio_ui/demo_app.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1433039e40930db21e672052ddb04dfe1ebc3f --- /dev/null +++ b/scenestreamer/gradio_ui/demo_app.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from pathlib import Path + +import gradio as gr +import torch + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.gradio_ui.plot import plot_gt, plot_pred +from scenestreamer.infer.initial_state import convert_initial_states_as_motion_data, generate_initial_state +from scenestreamer.infer.motion import generate_motion +from scenestreamer.utils import utils + +DEFAULT_HF_REPO = "pengzhenghao97/scenestreamer_0301" +DEFAULT_HF_FILE = "0228_MidGPT_V19_WTG_addstep_2025-02-28_epoch=14-step=426133.ckpt" + + +def choose_device(device_arg: str) -> torch.device: + if device_arg != "auto": + return torch.device(device_arg) + if torch.cuda.is_available(): + return torch.device("cuda") + if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +def build_demo( + *, + dataset_dir: str = "data/20scenarios", + hf_repo: str = DEFAULT_HF_REPO, + hf_file: str = DEFAULT_HF_FILE, + ckpt: str | None = None, + device: str = "auto", +) -> gr.Blocks: + device_obj = choose_device(device) + if ckpt: + pl_model = utils.get_model(checkpoint_path=ckpt, device=device_obj) + else: + pl_model = utils.get_model(huggingface_repo=hf_repo, huggingface_file=hf_file, device=device_obj) + + config = pl_model.config + config.DATA.TRAINING_DATA_DIR = dataset_dir + config.DATA.TEST_DATA_DIR = dataset_dir + config.DATA.USE_CACHE = True + + dataset = SceneStreamerDataset(config, "test") + + def load_ground_truth(scenario_index: int): + raw = dataset[int(scenario_index)] + gt_img = plot_gt(raw) + scenario_id = raw.get("scenario_id", scenario_index) + status = ( + f"Loaded scenario `{scenario_id}` on `{pl_model.device}`.\n\n" + "Ground truth is fixed. Click `Generate` to refresh only the prediction panel." + ) + return status, gt_img, None + + def run_demo(scenario_index: int, mode: str): + raw = dataset[int(scenario_index)] + + batched = utils.batch_data(utils.numpy_to_torch(raw, device=pl_model.device)) + if mode == "motion_only": + output = generate_motion( + data_dict=batched, + model=pl_model.model, + autoregressive_start_step=0, + teacher_forcing_sdc=True, + num_decode_steps=19, + ) + else: + densified, _ = generate_initial_state(data_dict=batched, model=pl_model.model) + densified_motion_input = convert_initial_states_as_motion_data(densified) + output = generate_motion( + data_dict=densified_motion_input, + model=pl_model.model, + autoregressive_start_step=0, + teacher_forcing_sdc=False, + num_decode_steps=19, + ) + + pred = utils.unbatch_data(utils.torch_to_numpy(output)) + pred_img = plot_pred(pred) + scenario_id = raw.get("scenario_id", scenario_index) + status = ( + f"Generated prediction for scenario `{scenario_id}` on `{pl_model.device}`.\n\n" + f"Mode: `{mode}`\n" + f"Dataset size: `{len(dataset)}`\n" + "Ground truth stays fixed while the prediction is regenerated." + ) + return status, pred_img + + max_index = max(0, len(dataset) - 1) + + with gr.Blocks(title="SceneStreamer Interactive Demo") as demo: + gr.Markdown("## SceneStreamer Interactive Demo") + gr.Markdown("Pick a scenario, choose a generation mode, and inspect ground-truth vs generated results in the browser.") + gr.Markdown(f"Loaded dataset: `{dataset_dir}` on device `{pl_model.device}`") + + with gr.Row(): + scenario_index = gr.Slider( + minimum=0, + maximum=max_index, + value=0, + step=1, + label="Scenario Index", + ) + mode = gr.Radio( + choices=[ + ("Motion Only", "motion_only"), + ("Densified Agents", "densified_agents"), + ], + value="motion_only", + label="Generation Mode", + ) + run_button = gr.Button("Generate") + + status = gr.Markdown() + with gr.Row(): + gt_image = gr.Image(label="Ground Truth") + pred_image = gr.Image(label="Generated Prediction") + + scenario_index.change(load_ground_truth, inputs=[scenario_index], outputs=[status, gt_image, pred_image]) + run_button.click(run_demo, inputs=[scenario_index, mode], outputs=[status, pred_image]) + demo.load(load_ground_truth, inputs=[scenario_index], outputs=[status, gt_image, pred_image]) + + return demo + diff --git a/scenestreamer/gradio_ui/gradio_ui.py b/scenestreamer/gradio_ui/gradio_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf79ed4c60a249a77a8f2d7ec09c885580afdaf --- /dev/null +++ b/scenestreamer/gradio_ui/gradio_ui.py @@ -0,0 +1,83 @@ +import pathlib + +import gradio as gr +import requests + +agent_1_id, agent_1_turn, agent_2_id, agent_2_turn = None, None, None, None + + +def fetch_next_scene(): + resp = requests.get("http://127.0.0.1:5001/next_scene") + if resp.status_code == 200: + print(resp.json()) + return pathlib.Path(f"./{resp.json()['original_image']}") + # return pathlib.Path(f"./{resp.json()['original_gif']}") + else: + raise LookupError() + + +# Mock function to simulate plotting the new modes based on the form input. +def plot_modes(agent_1_id, agent_1_turn, agent_2_id, agent_2_turn): + # Replace this with actual logic to generate images. + value_map = {"STOP": 0, "Go Straight": 1, "Turn Left": 2, "Turn Right": 3, "U-TURN": 4} + resp = requests.post( + "http://127.0.0.1:5001/plot", + json={ + "agent_1_id": agent_1_id, + "agent_1_turn": value_map[agent_1_turn], + "agent_2_id": agent_2_id, + "agent_2_turn": value_map[agent_2_turn] + } + ) + if resp.status_code == 200: + print(resp.json()) + return resp.json() + raise LookupError() + + +def update_and_visualize(agent_1_id, agent_1_turn, agent_2_id, agent_2_turn): + mode_images = plot_modes(agent_1_id, agent_1_turn, agent_2_id, agent_2_turn) + return [pathlib.Path(f"./{image}") for image in mode_images] + + +# Creating the Gradio interface +with gr.Blocks() as demo: + + gr.Markdown("# Action-conditioned MotionLM Visualization with Collision and Turn Injection") + + with gr.Row(): + # Button to fetch next scene + next_scene_btn = gr.Button("Next Scene") + + with gr.Row(): + # Display original scenario image + original_image = gr.Image(label="Original Scenario") + + with gr.Row(): + # Input fields for agent 1 and agent 2 IDs and turn actions + agent_1_id = gr.Textbox(label="Enter Agent 1 ID", placeholder="Agent 1 ID") + agent_1_turn = gr.Dropdown( + choices=["STOP", "Go Straight", "Turn Left", "Turn Right", "U-TURN"], label="Agent 1 Turn Action" + ) + agent_2_id = gr.Textbox(label="Enter Agent 2 ID", placeholder="Agent 2 ID") + agent_2_turn = gr.Dropdown( + choices=["STOP", "Go Straight", "Turn Left", "Turn Right", "U-TURN"], label="Agent 2 Turn Action" + ) + + with gr.Row(): + # Button to update and visualize decoded modes + update_btn = gr.Button("Update and Visualize Decoded Modes") + + with gr.Column(): + # Display decoded mode images + mode_images = [gr.Image(label=f"Mode {i}") for i in range(6)] + + # Define button actions + next_scene_btn.click(fn=fetch_next_scene, outputs=[original_image]) + + update_btn.click( + fn=update_and_visualize, inputs=[agent_1_id, agent_1_turn, agent_2_id, agent_2_turn], outputs=mode_images + ) + +# Launch the Gradio app +demo.launch(share=False, server_port=7860) diff --git a/scenestreamer/gradio_ui/metadrive_render.py b/scenestreamer/gradio_ui/metadrive_render.py new file mode 100644 index 0000000000000000000000000000000000000000..54b370e5b670b7a9ed7e10f54b919a3a82e6c580 --- /dev/null +++ b/scenestreamer/gradio_ui/metadrive_render.py @@ -0,0 +1,47 @@ +import random +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.envs import ScenarioEnv +import argparse +import mediapy +import numpy as np +from scenestreamer.utils import REPO_ROOT + +extra_args = dict(film_size=(900, 600), screen_size=(900, 600)) + + +def render(input_dir, output_path): + try: + env = ScenarioEnv( + { + "manual_control": False, + "reactive_traffic": False, + "use_render": False, + "agent_policy": ReplayEgoCarPolicy, + "data_directory": REPO_ROOT / input_dir, + "num_scenarios": 1 + } + ) + o, _ = env.reset() + frames = [] + for i in range(1, 100000): + o, r, tm, tc, info = env.step([1.0, 0.]) + frame = env.render(mode="top_down", **extra_args) + frames.append(frame) + if tm or tc: + break + + except Exception as e: + raise e + finally: + env.close() + + imgs = np.stack([frame for frame in frames], axis=0) + mediapy.write_video(REPO_ROOT / output_path, imgs, fps=20) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + render(args.input_dir, args.output_path) diff --git a/scenestreamer/gradio_ui/new_ui.py b/scenestreamer/gradio_ui/new_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bfd0dbcac2e13a7fcbd4afc99d98ef96b9b1a0 --- /dev/null +++ b/scenestreamer/gradio_ui/new_ui.py @@ -0,0 +1,680 @@ +import argparse +import copy +import functools +import os +import pathlib +import pickle +import uuid +from pathlib import Path + +import gradio as gr +import numpy as np +import torch +from hydra import compose +from hydra import initialize_config_dir +from omegaconf import OmegaConf + +from scenestreamer.dataset.preprocessor import preprocess_scenario_description_for_motionlm +from scenestreamer.gradio_ui.plot import plot_gt, plot_pred, create_animation_from_pred, create_animation_from_gt +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils + +os.environ["GRADIO_TEMP_DIR"] = str(REPO_ROOT / "gradio_tmp") + +dpi = 100 +fps = 20 + +parser = argparse.ArgumentParser() +parser.add_argument("--share", action="store_true", help="Enable sharing") +parser.add_argument("--default_ckpt", "--default_model", "--ckpt", type=str, default="debug") +parser.add_argument("--default_data", "--data", type=str, default="data/20scenarios") +parser.add_argument("--port", type=int, default=7860) +parser.add_argument("--title", type=str, default="Motion Generation") +parser.add_argument("--default_config_path", type=str, default="1214_midgpt_v14.yaml") +parser.add_argument( + "--display_video", "--video", action="store_true", help="by default display static images, can display mp4 video" +) +args = parser.parse_args() + +default_config_path = args.default_config_path +# default_config = OmegaConf.load(REPO_ROOT / "cfgs" / default_config_path) + +config_dir = Path("/absolute/path/to/conf").resolve() +with initialize_config_dir(config_dir=str(REPO_ROOT / "cfgs")): + default_config = compose(config_name=default_config_path) + +OmegaConf.set_struct(default_config, False) +default_config.PREPROCESSING.keep_all_data = True +default_config.TOKENIZATION.TOKENIZATION_METHOD = "bicycle_interpolated" +default_config.ROOT_DIR = str(REPO_ROOT.resolve()) + +DEFAULT_DATA_PATH = args.default_data or "/bigdata/datasets/scenarionet/waymo/validation_interactive/validation_interactive_0" +DEFAULT_MODEL = args.default_ckpt or None + +NUM_OF_MODES = 6 +# os.environ["CUDA_VISIBLE_DEVICES"] = "3" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +LENGTH = 1000 + + +class State: + model = None + model_path = None + config: dict = default_config + dataset_path: pathlib.Path = REPO_ROOT / DEFAULT_DATA_PATH + + scenario = None + + raw_data_files = None + data_files = None + + raw_data_dict = None + data_dict = None + + default_config: dict = default_config + + # click support + sel_point = None + xlim, ylim = None, None + fig_width, fig_height = None, None # in inches + fig_dpi = None + bbox_x0, bbox_y0, bbox_w, bbox_h = None, None, None, None + original_img = None + modified_img = None + + +state = State() + + +def ckpt_callback(ckpt_path): + from scenestreamer.models.motionlm_lightning import MotionLMLightning + + msg = "Failed!" + temperature = 1.0 + safe_agents = "" + turn_agents = "" + main_vis = None + sampling_method = "topp" + + if ckpt_path.lower() == "debug": + try: + config = copy.deepcopy(state.default_config) + OmegaConf.set_struct(config, False) + config.MODEL.D_MODEL = 32 + config.MODEL.NUM_DECODER_LAYERS = 1 + config.MODEL.NUM_ATTN_LAYERS = 1 + # config.ACTION_LABEL.USE_SAFETY_LABEL = True + # config.ACTION_LABEL.USE_ACTION_LABEL = True + OmegaConf.set_struct(config, True) + model = MotionLMLightning(config) + model = model.to(device) + msg = "DEBUG MODEL LOADED!" + config = model.config + temperature = config.SAMPLING.TEMPERATURE + state.model = model + state.config = config + sampling_method = config.SAMPLING.SAMPLING_METHOD + except Exception as e: + print("Error: ", e) + msg = "Failed to load DEBUG model!" + + return [msg, sampling_method, temperature, main_vis] + [""] * 7 + [0.0] * 5 + + ckpt_path = ckpt_path.replace("\\", "") + path = pathlib.Path(ckpt_path) + path = REPO_ROOT / path + + if path.is_dir(): + path = path / "last.ckpt" + + print("Loading model from: ", path.absolute()) + if not path.exists(): + msg = "{} does not exist!".format(path) + return [msg, sampling_method, temperature, main_vis] + [""] * 7 + [0.0] * 5 + + try: + model = utils.get_model( + config=None, checkpoint_path=path, device=device, default_config=default_config_path + ).eval() + msg = "Model loaded successfully!" + config = model.config + temperature = config.SAMPLING.TEMPERATURE + state.model = model + state.config = config + sampling_method = config.SAMPLING.SAMPLING_METHOD + + infer_heading_params = ( + config.TOKENIZATION.MIN_SPEED, #or 0.0, + config.TOKENIZATION.MAX_HEADING_DIFF, + config.TOKENIZATION.MIN_DISPLACEMENT_INIT, # or 0.0, + config.TOKENIZATION.MIN_DISPLACEMENT, #or 0.0, + config.TOKENIZATION.SMOOTH_FACTOR, #or 1.0, + ) + + except Exception as e: + print("Error: ", e) + raise e + msg = "Failed to load model!" + infer_heading_params = [0.0] * 5 + + return [msg, sampling_method, temperature, main_vis] + [""] * 7 + list(infer_heading_params) + + +def on_dataset_path_submit(path): + + print(state, type(state)) + + FAILED_MSG = "Failed!" + + path = pathlib.Path(path) + path = REPO_ROOT / path + + if not path.exists(): + return FAILED_MSG + + if not path.is_dir(): + return FAILED_MSG + + state.dataset_path = path + print(state.dataset_path) + + files = os.listdir(path) + files = [f for f in files if f.endswith(".pkl")] + print("Files: ", files) + + if not hasattr(state, "count"): + state.count = 0 + state.count += 1 + + return [ + "Dataset with {} Scenarios Listed!".format(len(files)), + gr.FileExplorer( + file_count="single", + root_dir=state.dataset_path, + # root_dir=path, + glob="**/*.pkl", + scale=1 + # label="UPDATED={}".format(state.count), + # interactive=True + ) + ] + + +def on_data_file_name_search(search): + return gr.FileExplorer(file_count="single", root_dir=state.dataset_path, glob=f"**/*{search}*.pkl", scale=1) + + +def on_data_file_select(file_path): + if not file_path: + return (None, ) + ("", ) * 2 + + file_path = pathlib.Path(file_path) + assert state.dataset_path is not None + file_path = state.dataset_path / file_path + + with open(file_path, "rb") as f: + data = pickle.load(f) + + state.scenario = data + scenario_data_dict = preprocess_scenario_description_for_motionlm( + scenario=data, + config=state.config, + in_evaluation=True, + keep_all_data=True, + # cache=None, + tokenizer=get_tokenizer(config=state.config) + ) + state.raw_data_dict = scenario_data_dict + + if args.display_video: + video_path = str(REPO_ROOT / "gradio_tmp" / "gt_animation_{}.mp4".format(uuid.uuid4())) + gt_gif_path = create_animation_from_gt( + scenario_data_dict, + save_path=video_path, + dpi=dpi, + fps=fps, + ) + return ( + gr.Video(value=gt_gif_path, label="Ground Truth Trajectories", + autoplay=True), # Use gr.Video instead of gr.Image + "", + "", + ) + else: + img, info_dict = plot_gt(scenario_data_dict, get_info=True) + state.original_img = img + if info_dict: + state.xlim = info_dict["xlim"] + state.ylim = info_dict["ylim"] + state.fig_width, state.fig_height = info_dict["fig_size"] + state.fig_dpi = info_dict["fig_dpi"] + state.bbox_x0 = info_dict["bbox_x0"] + state.bbox_y0 = info_dict["bbox_y0"] + state.bbox_w = info_dict["bbox_w"] + state.bbox_h = info_dict["bbox_h"] + return (gr.Image(value=img, label=data["id"]), ) + ("", ) * 2 + + +def on_generate_button_click( + # MIN_SPEED, MAX_HEADING_DIFF, MIN_DISPLACEMENT_INIT, MIN_DISPLACEMENT, smooth_factor, + sampling_method, + temperature, + seed, + agents_safe_0, + agents_safe_1, + # agents_turn_stop, agents_turn_straight, agents_turn_left, agents_turn_right, + # agents_turn_uturn, + only_draw_gt, + only_draw_detokenized, + draw_backward_prediction=False +): + + # TODO: Seed is not respect! + # TODO: Seed is not respect! + # TODO: Seed is not respect! + + assert sampling_method in ["softmax", "topp"], "Invalid sampling method! {}".format(sampling_method) + + if state.scenario is None: + return ( + None, + "Data is not loaded!", + ) + (None, ) * 2 + + if (not state.config.get("BACKWARD_PREDICTION", False)) and draw_backward_prediction: + print("BACKWARD_PREDICTION is not enabled in the config!") + return ( + None, + "BACKWARD_PREDICTION is not enabled in the config!", + ) + (None, ) * 2 + + model = state.model + if model is None and not only_draw_gt: + return ( + None, + "Model is not loaded!", + ) + (None, ) * 2 + + # if smooth_factor: + config = copy.deepcopy(state.config) + OmegaConf.set_struct(config, False) + # state.config.TOKENIZATION.SMOOTH_FACTOR = None + # state.config.TOKENIZATION.MIN_DISPLACEMENT = MIN_DISPLACEMENT + # state.config.TOKENIZATION.MIN_DISPLACEMENT_INIT = MIN_DISPLACEMENT_INIT + # state.config.TOKENIZATION.MAX_HEADING_DIFF = MAX_HEADING_DIFF + # state.config.TOKENIZATION.MIN_SPEED = MIN_SPEED + OmegaConf.set_struct(config, True) + + data_dict = preprocess_scenario_description_for_motionlm( + scenario=copy.deepcopy(state.scenario), + config=config, + in_evaluation=True, + keep_all_data=True, + # cache=None, + backward_prediction=draw_backward_prediction, + tokenizer=get_tokenizer(config=config) + ) + + # ===== Overwrite the labels ===== + def _parse_agents(agents_str): + is_raw = False + if agents_str: + if agents_str.startswith("[RAW]"): + agents_str = agents_str[5:] + is_raw = True + if not agents_str: + return [], is_raw + return [int(agent.strip()) for agent in agents_str.split(",")], is_raw + else: + is_raw = True + return [], is_raw + + def _fill_label(data_dict, label_name, agents, label_value): + data_dict[label_name] = -1 + if agents: + label = data_dict[label_name] + for aid in agents: + assert 0 <= aid < label.shape[0], (aid, label.shape) + data_dict[label_name][aid] = label_value + + if 'decoder/label_safety' in data_dict: + agents_safe_0, agents_safe_0_is_raw = _parse_agents(agents_safe_0) + _fill_label(data_dict, 'decoder/label_safety', agents_safe_0, 0) + agents_safe_1, agents_safe_1_is_raw = _parse_agents(agents_safe_1) + _fill_label(data_dict, 'decoder/label_safety', agents_safe_1, 1) + print(f"the input safety label for inference:", data_dict['decoder/label_safety']) + else: + print("decoder/label_safety is not in the data_dict!") + + if only_draw_gt: + output_dict = data_dict + + if args.display_video: + video_path = str(REPO_ROOT / "gradio_tmp" / "gt_animation_{}.mp4".format(uuid.uuid4())) + video_path = create_animation_from_gt(output_dict, save_path=video_path, dpi=dpi, fps=fps) + print("gt_mp4_path", video_path) + else: + img = plot_gt(output_dict) + + else: + input_data_dict = utils.numpy_to_torch(data_dict, device) + + # Extend the batch dim: + input_data_dict = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in input_data_dict.items()} + input_data_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + input_data_dict["in_backward_prediction"] = torch.tensor([draw_backward_prediction], dtype=bool).to(device) + + if not only_draw_detokenized: + with torch.no_grad(): + output_dict = model.model.autoregressive_rollout( + input_data_dict, + num_decode_steps=None, + sampling_method=sampling_method, + temperature=temperature, + backward_prediction=draw_backward_prediction + ) + + output_dict = get_tokenizer(config).detokenize( + output_dict, detokenizing_gt=False, backward_prediction=draw_backward_prediction + ) + else: + # ===== DEBUG ===== + output_dict = input_data_dict + output_dict["decoder/output_action"] = output_dict["decoder/target_action"] + fill_zero = ~output_dict["decoder/target_action_valid_mask"] + output_dict["decoder/input_action_valid_mask"][fill_zero] = False + + output_dict = get_tokenizer(config).detokenize(output_dict, detokenizing_gt=True) + + output_dict = { + k: (v.squeeze(0).cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in output_dict.items() + } + + if args.display_video: + video_path = str(REPO_ROOT / "gradio_tmp" / "pred_animation_{}.mp4".format(uuid.uuid4())) + video_path = create_animation_from_pred(output_dict, save_path=video_path, dpi=dpi, fps=fps) + print("predict gif path:", video_path) + + else: + img = plot_pred(output_dict) + + # Postprocess + if "decoder/label_safety" in output_dict: + safety_label = output_dict["decoder/label_safety"] + if agents_safe_0_is_raw: + aid = [str(v) for v in (safety_label == 0).nonzero()[0]] + agents_safe_0 = "[RAW]" + ",".join(aid) + if agents_safe_1_is_raw: + aid = [str(v) for v in (safety_label == 1).nonzero()[0]] + agents_safe_1 = "[RAW]" + ",".join(aid) + + if args.display_video: + return ( + # gr.Image(value=img, label=state.scenario["id"]), + gr.Video( + value=video_path, + label="Generated Trajectory Prediction", + show_download_button=True, + width=LENGTH, + height=LENGTH, + interactive=False, + format="mp4", + autoplay=True, + ), # Use gr.Video instead of gr.Image for GIF + "Scenario Generated!", + ", ".join([str(v) for v in agents_safe_0]) if isinstance(agents_safe_0, list) else agents_safe_0, + ", ".join([str(v) for v in agents_safe_1]) if isinstance(agents_safe_1, list) else agents_safe_1, + ) + else: + return ( + gr.Image(value=img, label=state.scenario["id"]), + "Scenario Generated!", + ", ".join([str(v) for v in agents_safe_0]) if isinstance(agents_safe_0, list) else agents_safe_0, + ", ".join([str(v) for v in agents_safe_1]) if isinstance(agents_safe_1, list) else agents_safe_1, + ) + + +import cv2 +from PIL import Image + + +def on_scenario_vis_click(event: gr.SelectData): + # Get axes relative coordinates (0 to 1 within the axes) + width, height = state.fig_width * state.fig_dpi, state.fig_height * state.fig_dpi + if state.fig_dpi is not None: + # Convert to relative figure coordinates (0 to 1 within figure) + x_rel_fig = (event.index[0] - state.bbox_x0 * width) / (state.bbox_w * width) + y_rel_fig = (event.index[1] - state.bbox_y0 * height) / (state.bbox_h * height) # note that y-axis is inverted + print( + f"Data: ({event.index[0]:.2f}, {event.index[1]:.2f}), Relative to Figure: ({x_rel_fig:.2f}, {y_rel_fig:.2f})" + ) + print(f"xlim: {state.xlim}, ylim: {state.ylim}") + state.sel_point = ( + x_rel_fig * (state.xlim[1] - state.xlim[0]) + state.xlim[0], + y_rel_fig * (state.ylim[0] - state.ylim[1]) + state.ylim[1] + ) + print(state.sel_point) + cv2_original = cv2.cvtColor(np.array(state.original_img), cv2.COLOR_BGR2RGB) + cv2.circle(cv2_original, event.index, 10, (0, 0, 255), -1) + state.modified_img = Image.fromarray(cv2.cvtColor(cv2_original, cv2.COLOR_BGR2RGB)) + return state.modified_img, state.sel_point + else: + print("Select a scenario first") + return gr.update() + + +def on_clear_click(): + raise ValueError() + state.sel_point = None + state.modified_img = None + return state.original_img, state.sel_point + + +# ============================================================ +# ======================== GRADIO UI ========================= +# ============================================================ +with gr.Blocks(theme=gr.themes.Soft(text_size="lg"), title=args.title) as demo: + with gr.Group(): + gr.Markdown(" ## Data") + with gr.Row(): + with gr.Column(scale=3): + inp = gr.Textbox(label="Path to Dataset Folder", value=DEFAULT_DATA_PATH) + + with gr.Column(scale=1): + out = gr.Textbox(label="Status", placeholder="Enter to submit...") + + # gr.Markdown("## Visualization") + with gr.Row(equal_height=True): # Future release fix: https://github.com/gradio-app/gradio/pull/9577 + with gr.Column(scale=1): + with gr.Group(): + file_name_input = gr.Textbox(label="Search Scenario ID", max_lines=1) + file_explorer = gr.FileExplorer( + root_dir=state.dataset_path, + glob="**/*.pkl", + file_count="single", + interactive=True, + container=True, + max_height=900 + ) + with gr.Column(scale=2): + with gr.Row(): + sel_point = gr.Textbox( + label="Selected Point", + interactive=False, + placeholder="Click on scenario map to select a point" + ) + clear = gr.Button(value="Clear Selection", interactive=True, visible=True) + + if args.display_video: + gt_vis = gr.Video( + label="Ground Truth Trajectories", + show_download_button=True, + width=LENGTH, + height=LENGTH, + interactive=False, + format="mp4" + ) + else: + gt_vis = gr.Image( + label="Original Scenario", + show_download_button=True, + width=LENGTH, + height=LENGTH, + interactive=False + ) + + with gr.Group(): + gr.Markdown("## Model") + with gr.Row(): + with gr.Column(scale=3): + if DEFAULT_MODEL: + ckpt_input = gr.Textbox( + label="Path to model checkpoint", value=DEFAULT_MODEL, placeholder="/home/.../last.ckpt" + ) + else: + ckpt_input = gr.Textbox( + label="Path to model checkpoint", + placeholder="/home/.../last.ckpt (Type 'debug' for debug model!)" + ) + with gr.Column(scale=1): + ckpt_output = gr.Textbox(label="Status", placeholder="Enter to load...") + + gr.Markdown("## Visualization") + with gr.Row(): + with gr.Column(scale=1): + sampling_method = gr.Radio(label="Sampling Method", choices=["softmax", "topp"], value="topp") + temperature = gr.Slider( + label="Sampling Temperature", minimum=0.0, maximum=2.0, step=0.1, value=1.0, interactive=True + ) + seed = gr.Number(label="Seed (TODO: not used)", value=42, precision=0) + + gr.Markdown( + "### Agents ID to assign labels:\n1. Split by comma ','\n2. If empty, original labels are used and printed as `[RAW]`" + ) + agents_safe_0 = gr.Textbox(label="label_safety = 0 (NO COLL)", interactive=True, placeholder="0, 1") + agents_safe_1 = gr.Textbox(label="label_safety = 1 (W/ COLL)", interactive=True, placeholder="0, 1") + + generate_button = gr.Button(value="Generate") + generate_backward_prediction_button = gr.Button(value="Generate Backward Prediction") + draw_gt_button = gr.Button(value="Draw Original Scenario") + draw_detok_button = gr.Button(value="Draw Raw Detokenized Scenario") + + with gr.Column(scale=2): + main_vis_text = gr.Textbox(label="Status", placeholder="", interactive=False) + if args.display_video: + main_vis = gr.Video( + label="Generated Scenario", + show_download_button=True, + width=LENGTH, + height=LENGTH, + format="mp4" + ) + else: + main_vis = gr.Image( + label="Generated Scenario", show_download_button=True, width=LENGTH, height=LENGTH + ) + + inp.submit(on_dataset_path_submit, inputs=inp, outputs=[out, file_explorer]) + file_name_input.change(on_data_file_name_search, inputs=file_name_input, outputs=file_explorer) + file_explorer.change( + on_data_file_select, + inputs=file_explorer, + outputs=[ + gt_vis, + agents_safe_0, + agents_safe_1, + ], + ) + + ckpt_input.submit( + ckpt_callback, + inputs=ckpt_input, + outputs=[ + ckpt_output, + sampling_method, + temperature, + main_vis, + agents_safe_0, + agents_safe_1, + ], + ) + + # TODO: Interesting feature to allow user click the scenario map to select a point + # gt_vis.select(on_scenario_vis_click, outputs=[gt_vis, sel_point]) + # clear.click(on_clear_click, outputs=[gt_vis, sel_point]) + + generate_button.click( + functools.partial(on_generate_button_click, only_draw_gt=False, only_draw_detokenized=False), + inputs=[ + sampling_method, + temperature, + seed, + agents_safe_0, + agents_safe_1, + ], + outputs=[ + main_vis, + main_vis_text, + agents_safe_0, + agents_safe_1, + ], + ) + generate_backward_prediction_button.click( + functools.partial( + on_generate_button_click, only_draw_gt=False, only_draw_detokenized=False, draw_backward_prediction=True + ), + inputs=[ + sampling_method, + temperature, + seed, + agents_safe_0, + agents_safe_1, + ], + outputs=[ + main_vis, + main_vis_text, + agents_safe_0, + agents_safe_1, + ], + ) + draw_gt_button.click( + functools.partial(on_generate_button_click, only_draw_gt=True, only_draw_detokenized=False), + inputs=[ + sampling_method, + temperature, + seed, + agents_safe_0, + agents_safe_1, + ], + outputs=[ + main_vis, + main_vis_text, + agents_safe_0, + agents_safe_1, + ], + ) + draw_detok_button.click( + functools.partial(on_generate_button_click, only_draw_gt=False, only_draw_detokenized=True), + inputs=[ + sampling_method, + temperature, + seed, + agents_safe_0, + agents_safe_1, + ], + outputs=[ + main_vis, + main_vis_text, + agents_safe_0, + agents_safe_1, + ], + ) + +if DEFAULT_MODEL: + print("Loading default model from: ", DEFAULT_MODEL) + ckpt_callback(DEFAULT_MODEL) + +demo.queue().launch(server_port=args.port, share=args.share) diff --git a/scenestreamer/gradio_ui/plot.py b/scenestreamer/gradio_ui/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..f208d23e98bac350ec9993ffef055aafa9b412f5 --- /dev/null +++ b/scenestreamer/gradio_ui/plot.py @@ -0,0 +1,1076 @@ +import PIL +import hydra +import matplotlib.pyplot as plt +import numpy as np +import omegaconf +import seaborn as sns +from matplotlib.animation import FFMpegWriter +from matplotlib.patches import Polygon, Circle, Rectangle + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.utils import REPO_ROOT +import torch + +BOUNDARY = 10 + +EGO_FONT_SIZE = 15 +MODELED_FONT_SIZE = 15 +NON_EGO_FONT_SIZE = 12 + +from scenestreamer.utils.safety_critical_generation_utils import ( + _overwrite_data_given_agents_not_ooi, + get_ego_edge_points, + get_ego_edge_points_old, + post_process_adv_traj, + _overwrite_data_given_agents_ooi, + _overwrite_data_given_agents, + set_adv, +) + + +def get_limit(agent_pos, map_pos): + assert agent_pos.shape[-1] == 2 + assert map_pos.shape[-1] == 2 + agent_pos = agent_pos.reshape(-1, 2) + map_pos = map_pos.reshape(-1, 2) + axmin, aymin = tuple(agent_pos.min(0)) + axmax, aymax = tuple(agent_pos.max(0)) + mxmin, mymin = tuple(map_pos.min(0)) + mxmax, mymax = tuple(map_pos.max(0)) + xmin = max(axmin, mxmin) - BOUNDARY + ymin = max(aymin, mymin) - BOUNDARY + xmax = min(axmax, mxmax) + BOUNDARY + ymax = min(aymax, mymax) + BOUNDARY + return {"xmin": xmin, "xmax": xmax, "ymin": ymin, "ymax": ymax} + + +def cal_polygon_contour(x, y, theta, width, length): + left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_front = np.column_stack((left_front_x, left_front_y)) + + right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_front = np.column_stack((right_front_x, right_front_y)) + + right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_back = np.column_stack((right_back_x, right_back_y)) + + left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_back = np.column_stack((left_back_x, left_back_y)) + + polygon_contour = np.concatenate( + (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1 + ) + + return polygon_contour + + +def draw_2d(pos, mask=None, **kwargs): # for trajectory + # pos: (-1, 2) + # mask: (-1, + if mask is None: + return plt.plot(pos[..., 0], pos[..., 1], **kwargs) + return plt.plot(pos[..., 0][mask], pos[..., 1][mask], **kwargs) + + +def draw_position( + ax, pos, heading, width, length, fill_color, text=None, fontsize=20, position_kwargs=None, contour_kwargs=None +): + position_kwargs = position_kwargs or {} + contour_kwargs = contour_kwargs or {} + + position_kwargs["color"] = fill_color + + contour = cal_polygon_contour( + x=np.array([pos[0]]), y=np.array([pos[1]]), theta=np.array([heading]), width=width, length=length + ) + + ax.fill(contour[0][:, 0], contour[0][:, 1], **position_kwargs) + + contour_closed = np.concatenate([contour[0], contour[0][:1]], axis=0) + + ax.plot(contour_closed[:, 0], contour_closed[:, 1], color='black', linewidth=1, **contour_kwargs) + + ax.plot(contour[0][:, 0], contour[0][:, 1], color='black', linewidth=1, **contour_kwargs) + + if text is not None: + ax.text(pos[0], pos[1], text, color=fill_color, fontsize=fontsize) + + +def draw_trajectory( + *, + ax, + pos, + heading, + width, + length, + fill_color, + mask=None, + text=None, + fontsize=20, + traj_kwargs=None, + contour_kwargs=None, + draw_line=False, + draw_text=True +): + traj_kwargs = traj_kwargs or {} + contour_kwargs = contour_kwargs or {} + + traj_kwargs["color"] = fill_color + + assert heading.shape == pos.shape[:-1] + assert isinstance(width, float) or heading.shape == width.shape + assert isinstance(length, float) or heading.shape == length.shape + contour = cal_polygon_contour(x=pos[..., 0], y=pos[..., 1], theta=heading, width=width, length=length) + + selected_contour = contour[::5] + + if mask is not None: + contour = contour[mask] + selected_contour = selected_contour[mask[::5]] + + assert (contour != 0.0).all() + assert (contour != 0.0).all() + + if len(contour) == 0: + # Nothing to draw + print("Agent {} has no valid data to draw".format(text)) + return + + if mask is not None and not mask[0]: + selected_contour = np.concatenate([contour[0][None], selected_contour]) + if mask is not None and not mask[-1]: + selected_contour = np.concatenate([selected_contour, contour[-1][None]]) + + reverse = contour[::-1] + if draw_line: + ax.plot(contour.mean(1)[:, 0], contour.mean(1)[:, 1], **traj_kwargs) + + else: + for i, ct in enumerate(reverse): + # Calculate alpha based on the position in the sequence + # for i=0, it's the last element, we want its alpha = 1 + alpha = 1.0 - min(0.9, i / len(contour)) + + # Plot the contour with the calculated alpha + ax.fill(ct[:, 0], ct[:, 1], alpha=alpha, **traj_kwargs) + + # Plot the original contours + reverse = selected_contour[::-1] + for poly in reverse: + ax.fill(poly[:, 0], poly[:, 1], **contour_kwargs) + + if text is not None and draw_text: + c = selected_contour[0] + c_final = selected_contour[-1] + c = (c - c_final) * 0.05 + c + rand = np.random.randint(4) + col = [min(v * 0.8, 1.0) for v in fill_color] + ax.text(c[rand][0], c[rand][1], text, color=col, fontsize=fontsize) + + +def draw_crosswalk(ax, polygon, mask, alpha=0.1): + polygon = polygon[mask] + polygon = Polygon(polygon, closed=True, facecolor='gray', alpha=alpha, linewidth=0.5) + ax.add_patch(polygon) + + +def draw_stop_sign(ax, x, y, r=3): + print("Draw stop sign triggered") + # Number of sides for the stop sign (octagon) + sides = 8 + + # Angle between each vertex in radians + angle = 2 * np.pi / sides + + # Calculate the vertices of the octagon + vertices = [(x + r * np.cos((i + 0.5) * angle), y + r * np.sin((i + 0.5) * angle)) for i in range(sides)] + + # Create a polygon patch for the octagon (stop sign) + octagon = Polygon(vertices, closed=True, edgecolor='black', facecolor='red') + + # Add the polygon patch to the axis + ax.add_patch(octagon) + + +def draw_traffic_light(ax, center, fill, radius=1.5, alpha=0.4): + circle = Circle(center, radius, edgecolor='black', facecolor=fill, alpha=alpha) + ax.add_patch(circle) + return circle + + +def _plot_map(data_dict, ax, dont_draw_lane=False): + if "vis/map_feature" in data_dict: + map_pos = data_dict["vis/map_feature"][:, :, :2] # (num map, num vec, 2) + map_feat = data_dict["vis/map_feature"] + else: + map_pos = data_dict["raw/map_feature"][:, :, :2] # (num map, num vec, 2) + map_feat = data_dict["raw/map_feature"] + + map_mask = data_dict["encoder/map_feature_valid_mask"] # (num map. num vec) + is_lane = map_feat[..., 0, 13] + is_crosswalk = map_feat[..., 0, 22] + is_stop_sign = map_feat[..., 0, 24] + for map_feat_ind in range(map_pos.shape[0]): + use_special_color = False + # if 'decoder/dest_map_index' in data_dict: + # # decoder/dest_map_index is temporal. So only take the dest at t=0. + # for i in [0, 30, 60, 90]: + # if map_feat_ind in data_dict['decoder/dest_map_index'][i]: + # use_special_color = True + + if is_crosswalk[map_feat_ind]: + draw_crosswalk(ax, polygon=map_pos[map_feat_ind], mask=map_mask[map_feat_ind]) + elif is_stop_sign[map_feat_ind]: + draw_stop_sign(ax, map_pos[map_feat_ind][0][0], map_pos[map_feat_ind][0][1]) + elif is_lane[map_feat_ind] and (not dont_draw_lane): + if use_special_color: + draw_2d(map_pos[map_feat_ind], map_mask[map_feat_ind], c=(1, 0, 0, 0.2), linewidth=0.5) + else: + draw_2d(map_pos[map_feat_ind], map_mask[map_feat_ind], c=(0.5, 0.5, 0.5, 0.2), linewidth=0.5) + else: + if use_special_color: + draw_2d(map_pos[map_feat_ind], map_mask[map_feat_ind], c=(1, 0, 0, 0.5), label="map", linewidth=1.0) + else: + draw_2d(map_pos[map_feat_ind], map_mask[map_feat_ind], c=(0.5, 0.5, 0.5, 0.5), label="map", linewidth=1.0) + + +def _plot_traffic_light(data_dict, ax, step=None): + if "encoder/traffic_light_feature" not in data_dict: + if "model/traffic_light_state" in data_dict: + tl_state = data_dict["model/traffic_light_state"] + tl_pos = data_dict["encoder/traffic_light_position"][:, :2] # NT, 3 + tl_mask = data_dict["encoder/traffic_light_valid_mask"] # T, NT + patches = [] + if step is None: + step = 0 + for tl_ind in range(tl_state.shape[1]): + if not tl_mask[step, tl_ind]: + continue + tl_pos_t = tl_pos[tl_ind] + tl_state_t = tl_state[step, tl_ind].item() + if tl_state_t == 1: + color = 'green' + elif tl_state_t == 2: + color = 'yellow' + elif tl_state_t == 3: + color = 'red' + elif tl_state_t == 0: + color = 'gray' + else: + raise ValueError + patches.append(draw_traffic_light(ax, tl_pos_t, color)) + return patches + elif "encoder/traffic_light_state" in data_dict: + tl_state = data_dict["encoder/traffic_light_state"] + tl_pos = data_dict["encoder/traffic_light_position"][:, :2] # NT, 3 + tl_mask = data_dict["encoder/traffic_light_valid_mask"] # T, NT + patches = [] + if step is None: + step = 0 + for tl_ind in range(tl_state.shape[1]): + if not tl_mask[step, tl_ind]: + continue + tl_pos_t = tl_pos[tl_ind] + tl_state_t = tl_state[step, tl_ind].item() + if tl_state_t == 1: + color = 'green' + elif tl_state_t == 2: + color = 'yellow' + elif tl_state_t == 3: + color = 'red' + elif tl_state_t == 0: + color = 'gray' + else: + raise ValueError + patches.append(draw_traffic_light(ax, tl_pos_t, color)) + return patches + else: + raise ValueError("No traffic light data found in the data_dict") + + + tl_state = data_dict["encoder/traffic_light_feature"] # T, NT, 7 + tl_pos = data_dict["encoder/traffic_light_position"][:, :2] # NT, 3 + tl_mask = data_dict["encoder/traffic_light_valid_mask"] # T, NT + if tl_mask.ndim == 1: + tl_mask = tl_mask[None] + step = 0 + if tl_state.ndim == 2: + tl_state = tl_state[None] + step = 0 + if step is None: + step = 0 + patches = [] + for tl_ind in range(tl_state.shape[1]): + if not tl_mask[0, tl_ind]: + continue + tl_pos_t = tl_pos[tl_ind] + tl_state_t = tl_state[step, tl_ind] + if tl_state_t[3] == 1: + color = 'green' + elif tl_state_t[4] == 1: + color = 'yellow' + elif tl_state_t[5] == 1: + color = 'red' + elif tl_state_t[6] == 1: + color = 'gray' + else: + continue + patches.append(draw_traffic_light(ax, tl_pos_t, color)) + return patches + + +def _plot_gt(data_dict, ax, draw_line=False, draw_text=True): + agent_pos = data_dict["decoder/agent_position"][:, :, :2] # (91, N, 2) + agent_heading = data_dict["decoder/agent_heading"] # (91, N, 2) + agent_velocity = data_dict["decoder/agent_velocity"] # (91, N, 2) + agent_shape = data_dict["decoder/agent_shape"] # (91, N, 2) + agent_mask = data_dict["decoder/agent_valid_mask"] + ego_agent_id = data_dict['decoder/sdc_index'] + + _plot_map(data_dict, ax) + + T, N, _ = agent_pos.shape + + modeled_agents_indicies = np.concatenate([data_dict["decoder/object_of_interest_id"], np.atleast_1d(ego_agent_id)]) + + cmap = sns.color_palette("colorblind", n_colors=N) + plotted_count = 0 + draw_trajectory( + ax=ax, + pos=agent_pos[:, ego_agent_id], + heading=agent_heading[:, ego_agent_id], + width=agent_shape[:, ego_agent_id, 1], + length=agent_shape[:, ego_agent_id, 0], + mask=agent_mask[:, ego_agent_id], + fill_color=cmap[0], + traj_kwargs=dict(), + contour_kwargs=dict( + edgecolor="k", + linewidth=0.1, + fill=False, + ), + text="{}-SDC".format(str(ego_agent_id)), + fontsize=EGO_FONT_SIZE, + draw_line=draw_line, + draw_text=draw_text, + ) + plotted_count += 1 + + for agent_ind in range(N): + if agent_ind == ego_agent_id: + continue + if agent_ind in modeled_agents_indicies: + text = "{}-OOI".format(str(agent_ind)) + fontsize = MODELED_FONT_SIZE + else: + text = str(agent_ind) + fontsize = NON_EGO_FONT_SIZE + draw_trajectory( + ax=ax, + pos=agent_pos[:, agent_ind], + heading=agent_heading[:, agent_ind], + width=agent_shape[:, agent_ind, 1], + length=agent_shape[:, agent_ind, 0], + mask=agent_mask[:, agent_ind], + fill_color=cmap[plotted_count], + traj_kwargs=dict(), + contour_kwargs=dict( + edgecolor="k", + linewidth=0.1, + fill=False, + ), + text=text, + fontsize=fontsize, + draw_line=draw_line, + draw_text=draw_text + ) + plotted_count += 1 + + _plot_traffic_light(data_dict, ax) + + if "vis/map_feature" in data_dict: + map_pos = data_dict["vis/map_feature"][:, :, :2][data_dict["encoder/map_feature_valid_mask"]] + else: + map_pos = data_dict["encoder/map_position"][..., :2][data_dict["encoder/map_valid_mask"]] + return get_limit(agent_pos=agent_pos[agent_mask], map_pos=map_pos) + + +def plot_gt(data_dict, get_info=False, save_path=None): + fig = plt.figure(figsize=(10, 10), dpi=300) + ax = fig.add_subplot(111) + ax.set_aspect(1) + + ret = _plot_gt(data_dict, ax) + + xmin, xmax, ymin, ymax = ret["xmin"], ret["xmax"], ret["ymin"], ret["ymax"] + + ax.set_xlim(xmin - BOUNDARY, xmax + BOUNDARY) + ax.set_ylim(ymin - BOUNDARY, ymax + BOUNDARY) + ax.set_aspect(1) + fig.tight_layout(pad=0.05) + fig.canvas.draw() + + # plt.show() + canvas = fig.canvas + w, h = canvas.get_width_height() + if hasattr(canvas, "tostring_rgb"): + ret = PIL.Image.frombytes("RGB", (w, h), canvas.tostring_rgb()) + else: + # Matplotlib>=3.10 removed tostring_rgb on Agg; use the RGBA buffer. + rgba = np.asarray(canvas.buffer_rgba()) + ret = PIL.Image.fromarray(rgba[..., :3], mode="RGB") + + if save_path: + fig.savefig(save_path) + + plt.close(fig) + + if get_info: + # TODO: What is info dict??? + return ret, {} + return ret + + +def plot_pred(data_dict, show=False, save_path=None): + + fig = plt.figure(figsize=(10, 10), dpi=300) + ax = fig.add_subplot(111) + ax.set_aspect(1) + + _plot_map(data_dict, ax) + + agent_pos = data_dict["decoder/reconstructed_position"][:, :, :2] # (91, N, 2) + agent_heading = data_dict["decoder/reconstructed_heading"] # (91, N, 2) + # agent_velocity = data_dict["decoder/agent_velocity"] # (91, N, 2) + # agent_shape = data_dict["decoder/agent_shape"][10] # TODO hardcoded + + agent_mask = data_dict["decoder/reconstructed_valid_mask"] + if 'decoder/sdc_index' in data_dict: + ego_agent_id = data_dict['decoder/sdc_index'] + else: + ego_agent_id = 0 + + T, N, _ = agent_pos.shape + + agent_shape = data_dict["decoder/current_agent_shape"] + agent_shape = np.tile(agent_shape[None], (T, 1, 1)) + + if "decoder/object_of_interest_id" in data_dict: + modeled_agents_indicies = np.concatenate( + [data_dict["decoder/object_of_interest_id"], + np.atleast_1d(ego_agent_id)] + ) + else: + modeled_agents_indicies = [] + + cmap = sns.color_palette("colorblind", n_colors=N) + plotted_count = 0 + draw_trajectory( + ax=ax, + pos=agent_pos[:, ego_agent_id], + heading=agent_heading[:, ego_agent_id], + width=agent_shape[:, ego_agent_id, 1], + length=agent_shape[:, ego_agent_id, 0], + mask=agent_mask[:, ego_agent_id], + fill_color=cmap[0], + traj_kwargs=dict(), + contour_kwargs=dict( + edgecolor="k", + linewidth=0.1, + fill=False, + ), + text="{}-SDC".format(str(ego_agent_id)), + fontsize=EGO_FONT_SIZE, + ) + plotted_count += 1 + + for agent_ind in range(N): + if agent_ind == ego_agent_id: + continue + if agent_ind in modeled_agents_indicies: + text = "{}-OOI".format(str(agent_ind)) + fontsize = MODELED_FONT_SIZE + else: + text = str(agent_ind) + fontsize = NON_EGO_FONT_SIZE + draw_trajectory( + ax=ax, + pos=agent_pos[:, agent_ind], + heading=agent_heading[:, agent_ind], + width=agent_shape[:, agent_ind, 1], + length=agent_shape[:, agent_ind, 0], + mask=agent_mask[:, agent_ind], + fill_color=cmap[plotted_count], + traj_kwargs=dict(), + contour_kwargs=dict( + edgecolor="k", + linewidth=0.1, + fill=False, + ), + text=text, + fontsize=fontsize, + ) + plotted_count += 1 + + _plot_traffic_light(data_dict, ax) + + p = agent_pos[agent_mask] + xmax, ymax = p.max(0) + xmin, ymin = p.min(0) + ax.set_xlim(xmin - BOUNDARY, xmax + BOUNDARY) + ax.set_ylim(ymin - BOUNDARY, ymax + BOUNDARY) + ax.set_aspect(1) + fig.tight_layout(pad=0.05) + fig.canvas.draw() + + if show: + plt.show() + canvas = fig.canvas + w, h = canvas.get_width_height() + if hasattr(canvas, "tostring_rgb"): + ret = PIL.Image.frombytes("RGB", (w, h), canvas.tostring_rgb()) + else: + # Matplotlib>=3.10 removed tostring_rgb on Agg; use the RGBA buffer. + rgba = np.asarray(canvas.buffer_rgba()) + ret = PIL.Image.fromarray(rgba[..., :3], mode="RGB") + if not show: + plt.close(fig) + + if save_path: + # save figure + fig.savefig(save_path) + return ret + + +def _animate( + save_path, agent_pos, agent_heading, agent_mask, agent_shape, data_dict, fps=10, dpi=300, draw_traffic=True +): + # TODO: Agent mask is not considered. + + # all_agent_pos = data_dict["decoder/agent_position"][:91, :, :2] + # all_agent_heading = data_dict["decoder/agent_heading"] + # all_agent_shape = data_dict["decoder/agent_shape"][10] + if "decoder/labeled_agent_id" in data_dict: + ooi = data_dict["decoder/labeled_agent_id"] + else: + ooi = [] + + if 'decoder/sdc_index' in data_dict: + ego_agent_id = int(data_dict['decoder/sdc_index']) + else: + ego_agent_id = 0 + + assert agent_pos.ndim == 3 + T = agent_pos.shape[0] # Number of timesteps + N = agent_pos.shape[1] # Number of agents + + cmap = sns.color_palette("colorblind", n_colors=N) # Color for each agent + + all_agent_positions = agent_pos[:, :, ...].reshape(-1, 2) + xmin, ymin = all_agent_positions.min(axis=0) + xmax, ymax = all_agent_positions.max(axis=0) + xlim, ylim = (xmin - 10, xmax + 10), (ymin - 10, ymax + 10) # Adjust `BOUNDARY` as needed + + writer = FFMpegWriter(fps=fps, codec='libx264', extra_args=['-preset', 'ultrafast', '-crf', '23', '-threads', '4']) + fig, ax = plt.subplots(figsize=(10, 10), dpi=dpi) + ax.set_aspect(1) + ax.set_xlim(xlim) + ax.set_ylim(ylim) + + _plot_map(data_dict, ax, dont_draw_lane=True) + _plot_traffic_light(data_dict, ax) + + agent_patches = [] + agent_texts = [] + for agent_ind in range(N): + if not draw_traffic and agent_ind not in ooi: + continue + face_color = cmap[0] if agent_ind == ego_agent_id else cmap[agent_ind] + label = "{}-SDC".format(ego_agent_id) if agent_ind == ego_agent_id else \ + "{}-OOI".format(agent_ind) if agent_ind in ooi else str(agent_ind) + + # Create a rectangular patch for each agent with black edge + rect = Rectangle( + (0, 0), + agent_shape[agent_ind, 0], + agent_shape[agent_ind, 1], + facecolor=face_color, + edgecolor='black', + linewidth=0.6, + zorder=10 + ) + agent_patches.append(rect) + ax.add_patch(rect) + + text = ax.text(0, 0, label, color=face_color, fontsize=11, ha='center', va='center', zorder=15) + agent_texts.append(text) + + with writer.saving(fig, save_path, dpi=dpi): + for t in range(T): + pos = agent_pos[t] # update agent positions and labels for each frame + heading = agent_heading[t] + + for agent_ind, (rect, text) in enumerate(zip(agent_patches, agent_texts)): + x, y = pos[agent_ind] + + if not agent_mask[t, agent_ind]: + rect.set_visible(False) + text.set_visible(False) + x = -10000 + y = -10000 + else: + rect.set_visible(True) + text.set_visible(True) + + raise ValueError("wrong animation") + rect.set_xy((x - agent_shape[agent_ind, 0] / 2, y - agent_shape[agent_ind, 1] / 2)) + rect.angle = np.degrees(heading[agent_ind]) + + rect.set_edgecolor('black') + rect.set_linewidth(0.8) + + text.set_position((x, y)) + text.set_text(text.get_text()) # forces the text to render + + writer.grab_frame() + + +def _animate_trafficgen( + save_path, + agent_pos, + agent_heading, + agent_mask, + agent_shape, + labeled_agent_id, + ego_agent_id, + data_dict, + fps=10, + dpi=300, + draw_traffic=True +): + # TODO: Agent mask is not considered. + + # all_agent_pos = data_dict["decoder/agent_position"][:91, :, :2] + # all_agent_heading = data_dict["decoder/agent_heading"] + # all_agent_shape = data_dict["decoder/agent_shape"][10] + # ooi = data_dict["decoder/labeled_agent_id"] + # ego_agent_id = int(data_dict['decoder/sdc_index']) + + assert agent_pos.ndim == 2 + # T = agent_pos.shape[0] # Number of timesteps + N = agent_pos.shape[0] # Number of agents + + cmap = sns.color_palette("colorblind", n_colors=N) # Color for each agent + + all_agent_positions = agent_pos.reshape(-1, 2) + xmin, ymin = all_agent_positions.min(axis=0) + xmax, ymax = all_agent_positions.max(axis=0) + xlim, ylim = (xmin - 10, xmax + 10), (ymin - 10, ymax + 10) # Adjust `BOUNDARY` as needed + + writer = FFMpegWriter(fps=fps, codec='libx264', extra_args=['-preset', 'ultrafast', '-crf', '23', '-threads', '4']) + fig, ax = plt.subplots(figsize=(10, 10), dpi=dpi) + ax.set_aspect(1) + ax.set_xlim(xlim) + ax.set_ylim(ylim) + + _plot_map(data_dict, ax, dont_draw_lane=True) + _plot_traffic_light(data_dict, ax) + + agent_patches = [] + agent_texts = [] + + with writer.saving(fig, save_path, dpi=dpi): + + for agent_ind in range(N): + if not draw_traffic and agent_ind not in labeled_agent_id: + continue + face_color = cmap[0] if agent_ind == ego_agent_id else cmap[agent_ind] + label = "{}-SDC".format(ego_agent_id) if agent_ind == ego_agent_id else \ + "{}-OOI".format(agent_ind) if agent_ind in labeled_agent_id else str(agent_ind) + + # Create a rectangular patch for each agent with black edge + rect = Rectangle( + (0, 0), + agent_shape[agent_ind, 0], + agent_shape[agent_ind, 1], + facecolor=face_color, + edgecolor='black', + linewidth=0.6, + zorder=10 + ) + agent_patches.append(rect) + ax.add_patch(rect) + + text = ax.text(0, 0, label, color=face_color, fontsize=11, ha='center', va='center', zorder=15) + agent_texts.append(text) + + x, y = agent_pos[agent_ind] + + raise ValueError("wrong animation") + rect.set_xy((x - agent_shape[agent_ind, 0] / 2, y - agent_shape[agent_ind, 1] / 2)) + rect.angle = np.degrees(agent_heading[agent_ind]) + + rect.set_edgecolor('black') + rect.set_linewidth(0.8) + + text.set_position((x, y)) + text.set_text(text.get_text()) # forces the text to render + + for _ in range(5): + writer.grab_frame() + + # with writer.saving(fig, save_path, dpi=dpi): + # for t in range(T): + # pos = agent_pos[t] # update agent positions and labels for each frame + # heading = agent_heading[t] + # + # for agent_ind, (rect, text) in enumerate(zip(agent_patches, agent_texts)): + # x, y = pos[agent_ind] + # + # if not agent_mask[t, agent_ind]: + # rect.set_visible(False) + # text.set_visible(False) + # x = -10000 + # y = -10000 + # else: + # rect.set_visible(True) + # text.set_visible(True) + #raise ValueError("wrong animation") + # rect.set_xy((x - agent_shape[agent_ind, 0] / 2, y - agent_shape[agent_ind, 1] / 2)) + # rect.angle = np.degrees(heading[agent_ind]) + # + # rect.set_edgecolor('black') + # rect.set_linewidth(0.8) + # + # text.set_position((x, y)) + # text.set_text(text.get_text()) # forces the text to render + # + # writer.grab_frame() + + +def create_animation_from_gt(data_dict, save_path='gt_animation.mp4', fps=10, dpi=300, draw_traffic=True): + _animate( + save_path=save_path, + agent_pos=data_dict["decoder/agent_position"][:91, :, :2], + agent_mask=data_dict["decoder/agent_valid_mask"], + agent_heading=data_dict["decoder/agent_heading"], + agent_shape=data_dict["decoder/current_agent_shape"], + data_dict=data_dict, + dpi=dpi, + draw_traffic=draw_traffic, + fps=fps, + ) + print(f"MP4 video saved at {save_path}") + plt.close() + return save_path + + +def create_animation_from_pred(data_dict, save_path='pred_animation.mp4', fps=10, dpi=300, draw_traffic=True): + all_agent_shape = data_dict["decoder/current_agent_shape"] + _animate( + save_path=save_path, + agent_pos=data_dict["decoder/reconstructed_position"], + agent_mask=data_dict["decoder/reconstructed_valid_mask"], + agent_heading=data_dict["decoder/reconstructed_heading"], + agent_shape=all_agent_shape, + data_dict=data_dict, + dpi=dpi, + draw_traffic=draw_traffic, + fps=fps, + ) + print(f"MP4 video saved at {save_path}") + plt.close() + return save_path + + +def create_animation_from_trafficgen(data_dict, save_path='pred_animation.mp4', fps=10, dpi=300, draw_traffic=True): + all_agent_shape = data_dict["decoder/current_agent_shape"] + + # The first agent is in TrafficGen so we delete it. + num_tg_agents = data_dict["decoder/modeled_agent_position_for_trafficgen"].shape[0] - 1 + + _animate_trafficgen( + save_path=save_path, + agent_pos=data_dict["decoder/modeled_agent_position_for_trafficgen"][1:], + agent_mask=np.ones(num_tg_agents, dtype=bool), + agent_heading=data_dict["decoder/modeled_agent_heading_for_trafficgen"][1:], + labeled_agent_id=[], # TODO: ???? + ego_agent_id=0, # TODO: ???? + agent_shape=all_agent_shape, + data_dict=data_dict, + dpi=dpi, + draw_traffic=draw_traffic, + fps=fps, + ) + print(f"MP4 video saved at {save_path}") + plt.close() + return save_path + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +def debug(config): + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + omegaconf.OmegaConf.set_struct(config, True) + test_dataset = SceneStreamerDataset(config, "test") + ddd = iter(test_dataset) + while True: + try: + raw_data = data = next(ddd) + + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config) + data, _ = tokenizer.tokenize_numpy_array(data) + data["decoder/output_action"] = data["decoder/target_action"] + fill_zero = ~data["decoder/target_action_valid_mask"] + data["decoder/input_action_valid_mask"][fill_zero] = False + + data = tokenizer.detokenize_numpy_array(data, detokenizing_gt=True) + raw_data.update(data) + # plot_pred(raw_data) + plot_pred(raw_data, show=True) + + # break + except StopIteration: + break + print("End") + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +def debug_backward_prediction(config): + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + omegaconf.OmegaConf.set_struct(config, True) + test_dataset = SceneStreamerDataset(config, "test") + ddd = iter(test_dataset) + while True: + try: + raw_data = data = next(ddd) + + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config) + data, _ = tokenizer.tokenize_numpy_array(data, backward_prediction=True) + data["decoder/output_action"] = data["decoder/target_action"] + fill_zero = ~data["decoder/target_action_valid_mask"] + data["decoder/input_action_valid_mask"][fill_zero] = False + + data = tokenizer.detokenize_numpy_array(data, detokenizing_gt=True, backward_prediction=True) + raw_data.update(data) + # plot_pred(raw_data) + plot_pred(raw_data, show=True) + + # break + except StopIteration: + break + print("End") + + +def run_backward_prediction_with_teacher_forcing( + model, config, backward_input_dict, tokenizer, not_teacher_forcing_ids +): + device = backward_input_dict["decoder/agent_position"].device + + # Force to run backward prediction first to make sure the data is tokenized correctly. + tok_data_dict, _ = tokenizer.tokenize(backward_input_dict, backward_prediction=True) + backward_input_dict.update(tok_data_dict) + + backward_input_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + backward_input_dict["in_backward_prediction"] = torch.tensor([1], dtype=bool).to(device) + with torch.no_grad(): + ar_func = model.model.autoregressive_rollout_backward_prediction_with_replay + # ar_func = model.model.autoregressive_rollout_backward_prediction + backward_output_dict = ar_func( + backward_input_dict, + num_decode_steps=None, + sampling_method=config.SAMPLING.SAMPLING_METHOD, + temperature=config.SAMPLING.TEMPERATURE, + not_teacher_forcing_ids=not_teacher_forcing_ids, + ) + backward_output_dict = tokenizer.detokenize( + backward_output_dict, + detokenizing_gt=False, + backward_prediction=True, + flip_wrong_heading=True, + ) + return backward_output_dict + + +def run_forward_prediction_with_teacher_forcing(model, config, forward_input_dict, tokenizer, teacher_forcing_ids): + device = forward_input_dict["decoder/agent_position"].device + + # Force to run backward prediction first to make sure the data is tokenized correctly. + f_tok_data_dict, _ = tokenizer.tokenize(forward_input_dict, backward_prediction=False) + forward_input_dict.update(f_tok_data_dict) + + forward_input_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + forward_input_dict["in_backward_prediction"] = torch.tensor([0], dtype=bool).to(device) + with torch.no_grad(): + forward_output_dict = model.model.autoregressive_rollout_with_replay( + forward_input_dict, + sampling_method=config.SAMPLING.SAMPLING_METHOD, + temperature=config.SAMPLING.TEMPERATURE, + backward_prediction=False, + teacher_forcing_ids=teacher_forcing_ids, + ) + forward_output_dict = tokenizer.detokenize( + forward_output_dict, detokenizing_gt=False, backward_prediction=False, flip_wrong_heading=True + ) + return forward_output_dict + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1031_midgpt.yaml") +def debug_run_model(config): + # Load model + from scenestreamer.utils import utils + import copy + path = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + # "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1107_MidGPT_Bicycle_V3_WBackward_2024-11-09_2215/checkpoints/last.ckpt" + + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + config.pretrain = "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + # "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1107_MidGPT_Bicycle_V3_WBackward_2024-11-09_2215/checkpoints/last.ckpt" + # "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1104_MidGPT_NoAgnt_WTLSgl_WContRel_WBackward_FixedStepAgentID_2024-11-04_2208/checkpoints/last.ckpt" + config.BACKWARD_PREDICTION = True # <<< + config.ADD_CONTOUR_RELATION = True + config.DATA.TRAINING_DATA_DIR = "/bigdata/yuxin/waymo_validation_interactive_500" #"data/20scenarios" + + # config.TOKENIZATION.TOKENIZATION_METHOD="bicycle" + # config.TOKENIZATION.NUM_BINS=33 + # config.DELTA_POS_IS_VELOCITY=True + # config.TOKENIZATION.ADD_NOISE=False + omegaconf.OmegaConf.set_struct(config, True) + + # model = utils.get_model(config=config, ) + model = utils.get_model(checkpoint_path=path) + + import torch + model = model.to("cuda") + + model = utils.get_model(config, device="cuda") + device = model.device + + test_dataset = SceneStreamerDataset(config, "training") + ddd = iter(test_dataset) + + flip_heading_accordingly = True + backward_prediction = True + + while True: + try: + raw_data_dict = data_dict = next(ddd) + + # Create a new ADV in the data so backward prediction will help us generate it. + + # TODO: If we also want to TF ego, then we should not overwrite ego data. + sdc_id = data_dict["decoder/sdc_index"] + + data_dict, adv_id = set_adv(data_dict) + # data_dict = create_new_adv(data_dict) + + data_dict = utils.numpy_to_torch(data_dict, device=device) + # data_dict = { + # k: torch.from_numpy(v).to(device) if isinstance(v, np.ndarray) else v + # for k, v in data_dict.items() + # } + data_dict = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} + data_dict = copy.deepcopy(data_dict) + + all_agents = data_dict["decoder/agent_id"][0] + not_tf_ids = all_agents[all_agents != 0] + + for iteration in range(1): + print("====================================") + print("Iteration: ", iteration) + print("====================================") + + if iteration > 0: + backward_input_dict = _overwrite_data_given_agents_not_ooi( + data_dict, forward_output_dict, [0, int(adv_id)] + #int(backward_output_dict_numpy['decoder/agent_id'][-1].item())] + ) + + else: + backward_input_dict = copy.deepcopy(data_dict) + + # import pdb; pdb.set_trace() + + backward_output_dict = run_backward_prediction_with_teacher_forcing( + model=model, + config=config, + backward_input_dict=backward_input_dict, + tokenizer=tokenizer, + + # TODO: Which to TF? + not_teacher_forcing_ids=not_tf_ids + ) + + scenario_id = backward_output_dict['metadata/scenario_id'] + + # ===== Only used for vis ===== + backward_output_dict_numpy = { + k: (v.squeeze(0).cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in backward_output_dict.items() + } + + # backward_output_dict_numpy, valid = post_process_adv_traj(backward_output_dict_numpy, adv_id, sdc_id) + # if not valid: + # continue + + backward_video_path = '/bigdata/yuxin/backward_validation_interactive_500/backward_TF_no_ego_{}.mp4'.format( + scenario_id + ) + create_animation_from_pred(backward_output_dict_numpy, save_path=backward_video_path, dpi=100, fps=20) + + continue + # plot_pred(backward_output_dict_numpy, show=True) + # ===== Only used for vis ===== + + forward_input_dict = _overwrite_data_given_agents_ooi( + data_dict, + backward_output_dict, + + # TODO: Which to TF? + [0, int(backward_output_dict_numpy['decoder/agent_id'][-1].item())] + # [int(backward_output_dict_numpy['decoder/agent_id'][-1].item())] + ) + forward_output_dict = run_forward_prediction_with_teacher_forcing( + model=model, + config=config, + forward_input_dict=forward_input_dict, + tokenizer=tokenizer, + + # TODO: Which to TF? + teacher_forcing_ids=[adv_id] + # teacher_forcing_ids=[-1] + ) + + # ===== Only used for vis ===== + forward_output_dict_numpy = { + k: (v.squeeze(0).cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in forward_output_dict.items() + } + forward_video_path = 'data/vis_20scenarios_nearestADV/forward_adv_animation_{}.mp4'.format(scenario_id) + create_animation_from_pred(forward_output_dict_numpy, save_path=forward_video_path, dpi=100, fps=20) + # plot_pred(forward_output_dict_numpy, show=True) + print("====================================") + print("Please launch:") + print("google-chrome {}".format(pathlib.Path(backward_video_path).absolute())) + print("google-chrome {}".format(pathlib.Path(forward_video_path).absolute())) + print("====================================") + # ===== Only used for vis ===== + + # break + except StopIteration: + break + print("End") + + +if __name__ == '__main__': + debug() + # debug_backward_prediction() + # debug_run_model() diff --git a/scenestreamer/gradio_ui/plot_video.py b/scenestreamer/gradio_ui/plot_video.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4c799e55b8262655428e20168d0f04dd48ed8a --- /dev/null +++ b/scenestreamer/gradio_ui/plot_video.py @@ -0,0 +1,62 @@ +import copy +import tempfile +import pickle +import os +from scenestreamer.utils import REPO_ROOT +from scenestreamer.gradio_ui.metadrive_render import render +import subprocess + + +def plot_gt_video(scenario): + video_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False, dir=os.environ['GRADIO_TEMP_DIR']) + + with tempfile.TemporaryDirectory(dir=os.environ['GRADIO_TEMP_DIR']) as in_dir: + in_pickle = tempfile.NamedTemporaryFile(suffix='.pkl', prefix='sd_', delete=True, dir=in_dir) + pickle.dump(scenario, in_pickle) + subprocess.run( + [ + 'python', REPO_ROOT / 'scenestreamer/gradio_ui/metadrive_render.py', '--input_dir', in_dir, '--output_path', + video_file.name + ] + ) + # render(in_dir, video_file.name) + return video_file + + +def plot_pred_video( + scenario, output_dict, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, agents_turn_left, + agents_turn_right, agents_turn_uturn +): + scenario_description_copy = copy.deepcopy(scenario) + agent_id_union = agents_safe_0 + agents_safe_1 + agents_turn_stop + agents_turn_straight + agents_turn_left + agents_turn_right + agents_turn_uturn + agent_id_union = set(agent_id_union) + id_track_map = { + agent_id: track_id + for agent_id, track_id in + zip(output_dict["encoder/agent_id"].tolist(), output_dict["encoder/track_name"].tolist()) + } + agent_track_map = {k: v for k, v in id_track_map.items() if k in agent_id_union} + + for agent_id, track in agent_track_map.items(): + # reconstructed position only has dim 2 + track_state = scenario_description_copy["tracks"][str(track)]['state'] + track_state['position'][11:, :2] = output_dict["decoder/reconstructed_position"][:, agent_id, :] + output_dict[ + "metadata/map_center"][:2] + track_state['velocity'][11:] = output_dict["decoder/reconstructed_velocity"][:, agent_id, :] + track_state['heading'][11:] = output_dict["decoder/reconstructed_heading"][:, agent_id] + track_state['valid'][11:] = output_dict["decoder/interpolated_target_action_valid_mask"][:, agent_id] + + video_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False, dir=os.environ['GRADIO_TEMP_DIR']) + + with tempfile.TemporaryDirectory(dir=os.environ['GRADIO_TEMP_DIR']) as in_dir: + in_pickle = tempfile.NamedTemporaryFile(suffix='.pkl', prefix='sd_', delete=True, dir=in_dir) + pickle.dump(scenario_description_copy, in_pickle) + subprocess.run( + [ + 'python', REPO_ROOT / 'scenestreamer/gradio_ui/metadrive_render.py', '--input_dir', in_dir, '--output_path', + video_file.name + ] + ) + # render(in_dir, video_file.name) + + return video_file diff --git a/scenestreamer/gradio_ui/scenestreamer_ui.py b/scenestreamer/gradio_ui/scenestreamer_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..e807021f21eec2680bf3bb334c4f4b543946b45e --- /dev/null +++ b/scenestreamer/gradio_ui/scenestreamer_ui.py @@ -0,0 +1,309 @@ +import os +import pathlib +import pickle +import gradio as gr + +from scenestreamer.dataset.preprocessor import preprocess_scenario_description_for_motionlm +from scenestreamer.tokenization import get_tokenizer + +from hydra import initialize_config_dir, compose +from omegaconf import OmegaConf +from pathlib import Path +import uuid +from scenestreamer.utils import REPO_ROOT + +# --- Configuration --- +DEFAULT_DATA_PATH = "data/20scenarios" # Adjust as needed +DEFAULT_CONFIG_NAME = "0220_midgpt.yaml" +DEFAULT_MODEL = "/home/zhenghao/scenestreamer/lightning_logs/scenestreamer/0226_MidGPT_V19_WTG_2025-02-26/" + +# Load config with Hydra +config_path = REPO_ROOT / "cfgs" +with initialize_config_dir(config_dir=str(config_path), version_base=None): + config_dict = compose(config_name=DEFAULT_CONFIG_NAME) +DEFAULT_CONFIG = config_dict + + +# --- Utility Functions --- +def list_map_files(dataset_path: str): + """List all .pkl files in the dataset folder.""" + p = pathlib.Path(dataset_path) + return sorted([str(f.relative_to(p)) for f in p.glob("*.pkl")]) + + +def load_first_map(dataset_path: str): + """Return the first map file in the folder, if available.""" + files = list_map_files(dataset_path) + return files[0] if files else None + + +def load_and_visualize_map(selected_file, dataset_path, state): + """ + Loads a selected map file (or auto-selects the first one if none is selected), + updates the state, and visualizes the map. + """ + # If no file is selected, auto-select the first one. + if not selected_file: + selected_file = load_first_map(dataset_path) + if not selected_file: + return "No map files found", state, None + + # Build the absolute file path. + ds_path = pathlib.Path(dataset_path).resolve() + file_path = ds_path / selected_file + if not file_path.exists(): + return "File not found", state, None + + # Load the map data (assumed to be a pickle file). + try: + with open(file_path, "rb") as f: + scenario = pickle.load(f) + except Exception as e: + return f"Failed to load file: {e}", state, None + + data_dict = preprocess_scenario_description_for_motionlm( + scenario=scenario, + config=DEFAULT_CONFIG, + in_evaluation=True, + keep_all_data=True, + tokenizer=get_tokenizer(config=DEFAULT_CONFIG) + ) + + # Update the state with the loaded map. + state["selected_map"] = {"path": str(file_path), "scenario": scenario} + + # Visualize the map using your existing plotting routine. + from scenestreamer.gradio_ui.plot import plot_gt + result = plot_gt(data_dict) + img = result[0] if isinstance(result, tuple) else result + + return f"Loaded map: {file_path.name}", state, img + + +# --- Generation Functions --- +def generate_initial_states(state): + """ + Generate initial states and return a status message along with the + visualization image obtained via plot_gt (with get_info=True). + """ + if "selected_map" not in state or not state["selected_map"]: + return "No map loaded. Please select a map first.", state, None + + from scenestreamer.infer.initial_state import generate_initial_state, convert_initial_states_as_motion_data + from scenestreamer.utils import utils + + force_add = False + scenario = state["selected_map"]["scenario"] + config = state["config"] + pl_model = state["model"] + + data_dict = preprocess_scenario_description_for_motionlm( + scenario=scenario, + config=config, + in_evaluation=True, + keep_all_data=False, + tokenizer=get_tokenizer(config=config) + ) + data_dict = utils.batch_data(utils.numpy_to_torch(data_dict, device=pl_model.device)) + + data_dict, _ = generate_initial_state(data_dict=data_dict, model=pl_model.model, force_add=force_add) + + data_dict = convert_initial_states_as_motion_data(data_dict) + + state["initial_state_output_data_dict"] = data_dict + + unbatched_data = utils.unbatch_data(utils.torch_to_numpy(data_dict)) + + # Draw the image using your plot_gt function with get_info=True. + from scenestreamer.gradio_ui.plot import plot_gt + img, info_dict = plot_gt(unbatched_data, get_info=True) + return f"Initial states generated for map {state['selected_map']['path']}", state, img + + +def generate_motions(state, num_decode_steps): + """ + Generate motion predictions and return a status message along with a video. + Here we use create_animation_from_pred as a placeholder for video generation. + """ + if "selected_map" not in state or not state["selected_map"]: + return "No map loaded. Please select a map first.", state, None + + from scenestreamer.infer.motion import generate_motion + from scenestreamer.utils import utils + + pl_model = state["model"] + config = state["config"] + data_dict = state["initial_state_output_data_dict"] + + generated_data_dict = generate_motion( + data_dict=data_dict, + model=pl_model.model, + autoregressive_start_step=0, + num_decode_steps=num_decode_steps, + remove_out_of_map_agent=True + ) + + unbatched_data = utils.unbatch_data(utils.torch_to_numpy(generated_data_dict)) + + from scenestreamer.gradio_ui.plot import create_animation_from_pred + video_path = str(REPO_ROOT / "gradio_tmp" / "gt_animation_{}.mp4".format(uuid.uuid4())) + video_path = create_animation_from_pred(unbatched_data, save_path=video_path, dpi=100, fps=10) + # TODO: 0.5 hardcoded. + return f"Motions generated for map {state['selected_map']['path']} with {num_decode_steps * 0.5}s", state, video_path + + +def generate_scenestreamer_motions(state, num_decode_steps): + if "selected_map" not in state or not state["selected_map"]: + return "No map loaded. Please select a map first.", state, None + + from scenestreamer.infer.infinite import generate_scenestreamer_motion + from scenestreamer.utils import utils + + pl_model = state["model"] + config = state["config"] + data_dict = state["initial_state_output_data_dict"] + + generated_data_dict = generate_scenestreamer_motion( + data_dict=data_dict, + model=pl_model.model, + autoregressive_start_step=0, + num_decode_steps=num_decode_steps, + remove_out_of_map_agent=True + ) + + unbatched_data = utils.unbatch_data(utils.torch_to_numpy(generated_data_dict)) + + from scenestreamer.gradio_ui.plot import create_animation_from_pred + video_path = str(REPO_ROOT / "gradio_tmp" / "scenestreamer_animation_{}.mp4".format(uuid.uuid4())) + video_path = create_animation_from_pred(unbatched_data, save_path=video_path, dpi=100, fps=10) + return state, video_path + + +# --- Model Checkpoint Loading Function --- +def load_checkpoint(ckpt_path, state): + from scenestreamer.models.motionlm_lightning import MotionLMLightning + from scenestreamer.utils import utils + import torch + import copy + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + ckpt_path = ckpt_path.replace("\\", "") + path = pathlib.Path(ckpt_path) + path = REPO_ROOT / path + + if path.is_dir(): + path = path / "last.ckpt" + + print("Loading model from: ", path.absolute()) + if not path.exists(): + msg = f"{path} does not exist!" + return msg + + try: + model = utils.get_model(config=None, checkpoint_path=path, device=device).eval() + msg = "Model loaded successfully!" + config = model.config + state["model"] = model + state["config"] = config + except Exception as e: + print("Error: ", e) + msg = "Failed to load model!" + return msg + + +# --- Build the Gradio UI --- +with gr.Blocks(title="Map & Motion Generator") as demo: + # Use Gradio state (a dictionary) to store our data. + state = gr.State(value={}) + + with gr.Group("Map Selection") as map_group: + gr.Markdown("### Map Selection and Visualization") + with gr.Row(): + with gr.Column(): + dataset_path_input = gr.Textbox(label="Dataset Folder", value=DEFAULT_DATA_PATH, interactive=True) + with gr.Column(): + file_status = gr.Textbox(label="Status", interactive=False) + with gr.Row(): + with gr.Column(): + file_explorer = gr.FileExplorer( + label="Choose a map (.pkl)", + file_count="single", + root_dir=DEFAULT_DATA_PATH, + glob="**/*.pkl", + interactive=True + ) + with gr.Column(): + map_image = gr.Image(label="Map Visualization") + # Callback to load and visualize map when a file is selected. + file_explorer.change( + load_and_visualize_map, + inputs=[file_explorer, dataset_path_input, state], + outputs=[file_status, state, map_image] + ) + # Automatically load the first map on startup. + demo.load( + load_and_visualize_map, + inputs=[file_explorer, dataset_path_input, state], + outputs=[file_status, state, map_image] + ) + + with gr.Group("Model Checkpoint") as ckpt_group: + gr.Markdown("### Model Checkpoint Loading") + with gr.Row(): + with gr.Column(): + ckpt_input = gr.Textbox(label="Checkpoint Path", value=DEFAULT_MODEL, interactive=True) + with gr.Column(): + ckpt_status = gr.Textbox(label="Checkpoint Status", interactive=False) + with gr.Row(): + ckpt_button = gr.Button("Load Checkpoint") + ckpt_button.click(load_checkpoint, inputs=[ckpt_input, state], outputs=ckpt_status) + # Automatically load the default checkpoint on startup. + demo.load(load_checkpoint, inputs=[ckpt_input, state], outputs=ckpt_status) + + with gr.Group("Generation Controls") as gen_group: + gr.Markdown("### Generation Controls") + with gr.Row(): + gen_initial_button = gr.Button("Generate Initial States") + gen_motions_button = gr.Button("Generate Motions") + with gr.Row(): + # New numeric input for number of decoded steps. + num_decode_steps = gr.Slider(label="Number of Decoded Steps", value=19, interactive=True) + + gen_status = gr.Textbox(label="Generation Status", interactive=False) + + with gr.Group("Output Visualization Canvas"): + gr.Markdown("### Output Visualization") + with gr.Row(): + canvas_image = gr.Image(label="Initial State Image", interactive=False, height=400, width=400) + canvas_video = gr.Video( + label="Motion Video", interactive=False, height=400, width=400, autoplay=True, loop=True + ) + + # Connect generation outputs to the visualization canvases. + # For initial states, update the image canvas. + gen_initial_button.click(generate_initial_states, inputs=state, outputs=[gen_status, state, canvas_image]) + # For motions, update the video canvas. + gen_motions_button.click( + generate_motions, inputs=[state, num_decode_steps], outputs=[gen_status, state, canvas_video] + ) + + with gr.Group("SceneStreamer") as gen_group: + gr.Markdown("### SceneStreamer") + with gr.Row(): + scenestreamer_button = gr.Button("Kickoff Continuous Generation") + with gr.Row(): + # New numeric input for number of decoded steps. + scenestreamer_num_decode_steps = gr.Slider(label="Number of Decoded Steps", value=100, interactive=True) + + gr.Markdown("### SceneStreamer Visualization") + with gr.Row(): + scenestreamer_canvas_video = gr.Video( + label="SceneStreamer Video", interactive=False, height=800, width=800, autoplay=True, loop=True + ) + scenestreamer_button.click( + generate_scenestreamer_motions, + inputs=[state, scenestreamer_num_decode_steps], + outputs=[state, scenestreamer_canvas_video], + ) + +demo.launch() diff --git a/scenestreamer/gradio_ui/tmp.py b/scenestreamer/gradio_ui/tmp.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ea49736173ce15eeba570e95fec9f28f7f0ecd --- /dev/null +++ b/scenestreamer/gradio_ui/tmp.py @@ -0,0 +1,115 @@ +import argparse +import functools +import os +import pathlib +import pickle + +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf + +from scenestreamer.dataset.preprocess_action_label import TurnAction +from scenestreamer.dataset.preprocessor import preprocess_scenario_description_for_motionlm +from scenestreamer.gradio_ui.plot import plot_gt, plot_pred +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils + +os.environ['GRADIO_TEMP_DIR'] = str(REPO_ROOT / "gradio_tmp") + +default_config = OmegaConf.load(REPO_ROOT / "cfgs/motion_default.yaml") + +OmegaConf.set_struct(default_config, False) +default_config.MODEL.D_MODEL = 32 +default_config.MODEL.NUM_DECODER_LAYERS = 1 +default_config.MODEL.NUM_ATTN_LAYERS = 1 +default_config.ACTION_LABEL.USE_SAFETY_LABEL = True +default_config.ACTION_LABEL.USE_ACTION_LABEL = True +default_config.ROOT_DIR = REPO_ROOT +OmegaConf.set_struct(default_config, True) + +DEFAULT_DATA_PATH = "data/20scenarios" + +NUM_OF_MODES = 6 +# os.environ["CUDA_VISIBLE_DEVICES"] = "3" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +LENGTH = 1000 + + +class State: + model = None + model_path = None + config: dict = default_config + dataset_path: pathlib.Path = REPO_ROOT / DEFAULT_DATA_PATH + + scenario = None + + raw_data_files = None + data_files = None + + raw_data_dict = None + data_dict = None + + default_config: dict = default_config + + +state = State() + +from scenestreamer.models.motionlm_lightning import MotionLMLightning + +ckpt_path = "/home/zhenghao/scenestreamer/lightning_logs/scenestreamer/1012_motionlm_joint_condition_2024-10-12_1149/last.ckpt" + +msg = "Failed!" +temperature = 1.0 +safe_agents = "" +turn_agents = "" +main_vis = None +sampling_method = "topp" + +if ckpt_path.lower() == "debug": + try: + config = state.default_config + OmegaConf.set_struct(config, False) + config.MODEL.D_MODEL = 32 + config.MODEL.NUM_DECODER_LAYERS = 1 + config.MODEL.NUM_ATTN_LAYERS = 1 + config.ACTION_LABEL.USE_SAFETY_LABEL = True + config.ACTION_LABEL.USE_ACTION_LABEL = True + OmegaConf.set_struct(config, True) + model = MotionLMLightning(config) + model = model.to(device) + msg = "DEBUG MODEL LOADED!" + config = model.config + temperature = config.SAMPLING.TEMPERATURE + state.model = model + state.config = config + sampling_method = config.SAMPLING.SAMPLING_METHOD + except Exception as e: + # print("Error: ", e) + raise e + msg = "Failed to load DEBUG model!" + +path = pathlib.Path(ckpt_path) +path = REPO_ROOT / path + +print("Loading model from: ", path.absolute()) +if not path.exists(): + msg = "{} does not exist!".format(path) + +try: + model = utils.load_from_checkpoint( + checkpoint_path=path, cls=MotionLMLightning, config=None, default_config=default_config + ) + model = model.to(device) + msg = "Model loaded successfully!" + config = model.config + temperature = config.SAMPLING.TEMPERATURE + state.model = model + state.config = config + sampling_method = config.SAMPLING.SAMPLING_METHOD +except Exception as e: + print("Error: ", e) + raise e + msg = "Failed to load model!" diff --git a/scenestreamer/gradio_ui/video_ui.py b/scenestreamer/gradio_ui/video_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..42dd71a1a9b8d94930e4edf7ad777471d8b4a97a --- /dev/null +++ b/scenestreamer/gradio_ui/video_ui.py @@ -0,0 +1,447 @@ +import argparse +import functools +import os +import pathlib +import pickle + +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf + +from scenestreamer.dataset.preprocess_action_label import TurnAction +from scenestreamer.dataset.preprocessor import preprocess_scenario_description_for_motionlm +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils +from scenestreamer.gradio_ui.plot_video import plot_gt_video, plot_pred_video + +os.environ['GRADIO_TEMP_DIR'] = str(REPO_ROOT / "gradio_tmp") + +parser = argparse.ArgumentParser() +parser.add_argument("--share", action="store_true", help="Enable sharing") +parser.add_argument("--default_ckpt", "--default_model", type=str, default="data/20scenarios") +parser.add_argument("--default_data", type=str, default="") +args = parser.parse_args() + +default_config = OmegaConf.load(REPO_ROOT / "cfgs/motion_default.yaml") + +OmegaConf.set_struct(default_config, False) +default_config.MODEL.D_MODEL = 32 +default_config.MODEL.NUM_DECODER_LAYERS = 1 +default_config.MODEL.NUM_ATTN_LAYERS = 1 +default_config.ACTION_LABEL.USE_SAFETY_LABEL = True +default_config.ACTION_LABEL.USE_ACTION_LABEL = True +default_config.ROOT_DIR = REPO_ROOT +OmegaConf.set_struct(default_config, True) + +DEFAULT_DATA_PATH = args.default_data or "data/20scenarios" +DEFAULT_MODEL = args.default_ckpt or None + +NUM_OF_MODES = 6 +# os.environ["CUDA_VISIBLE_DEVICES"] = "3" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +LENGTH = 1000 + + +class State: + model = None + model_path = None + config: dict = default_config + dataset_path: pathlib.Path = REPO_ROOT / DEFAULT_DATA_PATH + + scenario = None + + raw_data_files = None + data_files = None + + raw_data_dict = None + data_dict = None + + default_config: dict = default_config + + +state = State() + + +def ckpt_callback(ckpt_path): + from scenestreamer.models.motionlm_lightning import MotionLMLightning + + msg = "Failed!" + temperature = 1.0 + safe_agents = "" + turn_agents = "" + main_vis = None + sampling_method = "topp" + + if ckpt_path.lower() == "debug": + try: + config = state.default_config + OmegaConf.set_struct(config, False) + config.MODEL.D_MODEL = 32 + config.MODEL.NUM_DECODER_LAYERS = 1 + config.MODEL.NUM_ATTN_LAYERS = 1 + config.ACTION_LABEL.USE_SAFETY_LABEL = True + config.ACTION_LABEL.USE_ACTION_LABEL = True + OmegaConf.set_struct(config, True) + model = MotionLMLightning(config) + model = model.to(device) + msg = "DEBUG MODEL LOADED!" + config = model.config + temperature = config.SAMPLING.TEMPERATURE + state.model = model + state.config = config + sampling_method = config.SAMPLING.SAMPLING_METHOD + except Exception as e: + print("Error: ", e) + msg = "Failed to load DEBUG model!" + + return [msg, sampling_method, temperature, main_vis] + [""] * 7 + + path = pathlib.Path(ckpt_path) + path = REPO_ROOT / path + + print("Loading model from: ", path.absolute()) + if not path.exists(): + msg = "{} does not exist!".format(path) + return [msg, sampling_method, temperature, main_vis] + [""] * 7 + + try: + model = utils.load_from_checkpoint( + checkpoint_path=path, cls=MotionLMLightning, config=None, default_config=default_config, strict=False + ) + model = model.to(device) + msg = "Model loaded successfully!" + config = model.config + temperature = config.SAMPLING.TEMPERATURE + state.model = model + state.config = config + sampling_method = config.SAMPLING.SAMPLING_METHOD + except Exception as e: + print("Error: ", e) + msg = "Failed to load model!" + + return [msg, sampling_method, temperature, main_vis] + [""] * 7 + + +def on_dataset_path_submit(path): + + print(state, type(state)) + + FAILED_MSG = "Failed!" + + path = pathlib.Path(path) + path = REPO_ROOT / path + + if not path.exists(): + return FAILED_MSG + + if not path.is_dir(): + return FAILED_MSG + + state.dataset_path = path + print(state.dataset_path) + + files = os.listdir(path) + files = [f for f in files if f.endswith(".pkl")] + print("Files: ", files) + + if not hasattr(state, "count"): + state.count = 0 + state.count += 1 + + return [ + "Dataset with {} Scenarios Listed!".format(len(files)), + gr.FileExplorer( + file_count="single", + root_dir=state.dataset_path, + # root_dir=path, + glob="**/*.pkl", + scale=1 + # label="UPDATED={}".format(state.count), + # interactive=True + ) + ] + + +def on_data_file_name_search(search): + return gr.FileExplorer(file_count="single", root_dir=state.dataset_path, glob=f"**/*{search}*.pkl", scale=1) + + +def on_data_file_select(file_path): + if not file_path: + return (None, ) + ("", ) * 7 + + file_path = pathlib.Path(file_path) + assert state.dataset_path is not None + file_path = state.dataset_path / file_path + + with open(file_path, "rb") as f: + data = pickle.load(f) + + state.scenario = data + scenario_data_dict = preprocess_scenario_description_for_motionlm( + scenario=data, config=state.config, in_evaluation=True, keep_all_data=True + ) + state.raw_data_dict = scenario_data_dict + video_file = plot_gt_video(data) + return (gr.Video(value=video_file.name, label=data["id"]), ) + ("", ) * 7 + + +def on_generate_button_click( + sampling_method, temperature, seed, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, + agents_turn_left, agents_turn_right, agents_turn_uturn, only_draw_gt +): + + # TODO: Seed is not respect! + # TODO: Seed is not respect! + # TODO: Seed is not respect! + + assert sampling_method in ["softmax", "topp"] + + if state.scenario is None: + return ( + None, + "Data is not loaded!", + ) + (None, ) * 7 + + model = state.model + if model is None and not only_draw_gt: + return ( + None, + "Model is not loaded!", + ) + (None, ) * 7 + + data_dict = preprocess_scenario_description_for_motionlm( + scenario=state.scenario, config=state.config, in_evaluation=True, keep_all_data=True + ) + + # ===== Overwrite the labels ===== + def _parse_agents(agents_str): + is_raw = False + if agents_str: + if agents_str.startswith("[RAW]"): + agents_str = agents_str[5:] + is_raw = True + if not agents_str: + return [], is_raw + return [int(agent.strip()) for agent in agents_str.split(",")], is_raw + else: + is_raw = True + return [], is_raw + + def _fill_label(data_dict, label_name, agents, label_value): + if agents: + label = data_dict['decoder/' + label_name] + for aid in agents: + assert 0 <= aid < label.shape[0], (aid, label.shape) + label[aid] = label_value + + agents_safe_0, agents_safe_0_is_raw = _parse_agents(agents_safe_0) + _fill_label(data_dict, 'label_safety', agents_safe_0, 0) + agents_safe_1, agents_safe_1_is_raw = _parse_agents(agents_safe_1) + _fill_label(data_dict, 'label_safety', agents_safe_1, 1) + agents_turn_stop, agents_turn_stop_is_raw = _parse_agents(agents_turn_stop) + _fill_label(data_dict, 'label_turning', agents_turn_stop, TurnAction.STOP) + agents_turn_straight, agents_turn_straight_is_raw = _parse_agents(agents_turn_straight) + _fill_label(data_dict, 'label_turning', agents_turn_straight, TurnAction.KEEP_STRAIGHT) + agents_turn_left, agents_turn_left_is_raw = _parse_agents(agents_turn_left) + _fill_label(data_dict, 'label_turning', agents_turn_left, TurnAction.TURN_LEFT) + agents_turn_right, agents_turn_right_is_raw = _parse_agents(agents_turn_right) + _fill_label(data_dict, 'label_turning', agents_turn_right, TurnAction.TURN_RIGHT) + agents_turn_uturn, agents_turn_uturn_is_raw = _parse_agents(agents_turn_uturn) + _fill_label(data_dict, 'label_turning', agents_turn_uturn, TurnAction.U_TURN) + + if only_draw_gt: + output_dict = data_dict + video_file = plot_gt_video(state.scenario) + + else: + input_data_dict = { + k: torch.from_numpy(v).to(device) if isinstance(v, np.ndarray) else v + for k, v in data_dict.items() + } + # Extend the batch dim: + input_data_dict = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in input_data_dict.items()} + input_data_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + + with torch.no_grad(): + output_dict = model.model.autoregressive_rollout( + input_data_dict, num_decode_steps=16, sampling_method=sampling_method, temperature=temperature + ) + output_dict = get_tokenizer(model.config).detokenize(output_dict) + output_dict = { + k: (v.squeeze(0).cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in output_dict.items() + } + + video_file = plot_pred_video( + state.scenario, output_dict, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, + agents_turn_left, agents_turn_right, agents_turn_uturn + ) + + # Postprocess + if "decoder/label_safety" in output_dict: + safety_label = output_dict["decoder/label_safety"] + if agents_safe_0_is_raw: + aid = [str(v) for v in (safety_label == 0).nonzero()[0]] + agents_safe_0 = "[RAW]" + ",".join(aid) + if agents_safe_1_is_raw: + aid = [str(v) for v in (safety_label == 1).nonzero()[0]] + agents_safe_1 = "[RAW]" + ",".join(aid) + if "decoder/label_turning" in output_dict: + turning_label = output_dict["decoder/label_turning"] + if agents_turn_stop_is_raw: + aid = [str(v) for v in (turning_label == TurnAction.STOP).nonzero()[0]] + agents_turn_stop = "[RAW]" + ",".join(aid) + if agents_turn_straight_is_raw: + aid = [str(v) for v in (turning_label == TurnAction.KEEP_STRAIGHT).nonzero()[0]] + agents_turn_straight = "[RAW]" + ",".join(aid) + if agents_turn_left_is_raw: + aid = [str(v) for v in (turning_label == TurnAction.TURN_LEFT).nonzero()[0]] + agents_turn_left = "[RAW]" + ",".join(aid) + if agents_turn_right_is_raw: + aid = [str(v) for v in (turning_label == TurnAction.TURN_RIGHT).nonzero()[0]] + agents_turn_right = "[RAW]" + ",".join(aid) + if agents_turn_uturn_is_raw: + aid = [str(v) for v in (turning_label == TurnAction.U_TURN).nonzero()[0]] + agents_turn_uturn = "[RAW]" + ",".join(aid) + + return ( + gr.Video(value=video_file.name, label=state.scenario["id"]), "Scenario Generated!", + ", ".join([str(v) for v in agents_safe_0]) if isinstance(agents_safe_0, list) else agents_safe_0, + ", ".join([str(v) for v in agents_safe_1]) if isinstance(agents_safe_1, list) else agents_safe_1, + ", ".join([str(v) for v in agents_turn_stop]) if isinstance(agents_turn_stop, list) else agents_turn_stop, + ", ".join([str(v) + for v in agents_turn_straight]) if isinstance(agents_turn_straight, list) else agents_turn_straight, + ", ".join([str(v) for v in agents_turn_left]) if isinstance(agents_turn_left, list) else agents_turn_left, + ", ".join([str(v) for v in agents_turn_right]) if isinstance(agents_turn_right, list) else agents_turn_right, + ", ".join([str(v) for v in agents_turn_uturn]) if isinstance(agents_turn_uturn, list) else agents_turn_uturn + ) + + +# ============================================================ +# ======================== GRADIO UI ========================= +# ============================================================ +with gr.Blocks(theme=gr.themes.Soft(text_size="lg")) as demo: + with gr.Group(): + gr.Markdown(" ## Data") + with gr.Row(): + with gr.Column(scale=3): + inp = gr.Textbox(label="Path to Dataset Folder", value=DEFAULT_DATA_PATH) + + with gr.Column(scale=1): + out = gr.Textbox(label="Status", placeholder="Enter to submit...") + + # gr.Markdown("## Visualization") + with gr.Row(equal_height=True): # Future release fix: https://github.com/gradio-app/gradio/pull/9577 + with gr.Column(scale=1): + with gr.Group(): + file_name_input = gr.Textbox(label="Search Scenario ID", max_lines=1) + file_explorer = gr.FileExplorer( + root_dir=state.dataset_path, + glob="**/*.pkl", + file_count="single", + interactive=True, + container=True, + max_height=900 + ) + + with gr.Column(scale=2): + gt_vis = gr.Video( + label="Original Scenario", + show_download_button=False, + width=LENGTH, + height=LENGTH, + interactive=False + ) + + with gr.Group(): + gr.Markdown("## Model") + with gr.Row(): + with gr.Column(scale=3): + if DEFAULT_MODEL: + ckpt_input = gr.Textbox( + label="Path to model checkpoint", value=DEFAULT_MODEL, placeholder="/home/.../last.ckpt" + ) + else: + ckpt_input = gr.Textbox( + label="Path to model checkpoint", + placeholder="/home/.../last.ckpt (Type 'debug' for debug model!)" + ) + with gr.Column(scale=1): + ckpt_output = gr.Textbox(label="Status", placeholder="Enter to load...") + + gr.Markdown("## Visualization") + with gr.Row(): + with gr.Column(scale=1): + sampling_method = gr.Radio(label="Sampling Method", choices=["softmax", "topp"], value="topp") + temperature = gr.Slider( + label="Sampling Temperature", minimum=0.0, maximum=2.0, step=0.1, value=1.0, interactive=True + ) + seed = gr.Number(label="Seed", value=42, precision=0) + gr.Markdown( + "### Agents ID to assign labels:\n1. Split by comma ','\n2. If empty, original labels are used and printed as `[RAW]`" + ) + agents_safe_0 = gr.Textbox(label="label_safety = 0 (NO COLL)", interactive=True, placeholder="0, 1") + agents_safe_1 = gr.Textbox(label="label_safety = 1 (W/ COLL)", interactive=True, placeholder="0, 1") + agents_turn_stop = gr.Textbox(label="label_turning = STOP", interactive=True) + agents_turn_straight = gr.Textbox(label="label_turning = KEEP_STRAIGHT", interactive=True) + agents_turn_left = gr.Textbox(label="label_turning = LEFT", interactive=True) + agents_turn_right = gr.Textbox(label="label_turning = RIGHT", interactive=True) + agents_turn_uturn = gr.Textbox(label="label_turning = U_TURN", interactive=True) + generate_button = gr.Button(value="Generate") + draw_gt_button = gr.Button(value="Draw Original Scenario") + + with gr.Column(scale=2): + main_vis_text = gr.Textbox(label="Status", placeholder="", interactive=False) + main_vis = gr.Video(label="Generated Scenario", show_download_button=False, width=LENGTH, height=LENGTH) + + inp.submit(on_dataset_path_submit, inputs=inp, outputs=[out, file_explorer]) + file_name_input.change(on_data_file_name_search, inputs=file_name_input, outputs=file_explorer) + file_explorer.change( + on_data_file_select, + inputs=file_explorer, + outputs=[ + gt_vis, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, agents_turn_left, + agents_turn_right, agents_turn_uturn + ], + ) + + ckpt_input.submit( + ckpt_callback, + inputs=ckpt_input, + outputs=[ + ckpt_output, sampling_method, temperature, main_vis, agents_safe_0, agents_safe_1, agents_turn_stop, + agents_turn_straight, agents_turn_left, agents_turn_right, agents_turn_uturn + ], + ) + generate_button.click( + functools.partial(on_generate_button_click, only_draw_gt=False), + inputs=[ + sampling_method, temperature, seed, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, + agents_turn_left, agents_turn_right, agents_turn_uturn + ], + outputs=[ + main_vis, main_vis_text, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, + agents_turn_left, agents_turn_right, agents_turn_uturn + ], + ) + draw_gt_button.click( + functools.partial(on_generate_button_click, only_draw_gt=True), + inputs=[ + sampling_method, temperature, seed, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, + agents_turn_left, agents_turn_right, agents_turn_uturn + ], + outputs=[ + main_vis, main_vis_text, agents_safe_0, agents_safe_1, agents_turn_stop, agents_turn_straight, + agents_turn_left, agents_turn_right, agents_turn_uturn + ], + ) + +if DEFAULT_MODEL: + print("Loading default model from: ", DEFAULT_MODEL) + ckpt_callback(DEFAULT_MODEL) + +demo.queue().launch(server_port=7860, share=args.share) diff --git a/scenestreamer/infer/infinite.py b/scenestreamer/infer/infinite.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b35ce98fdcbcf89fb56f7a9d989870f9dc1521 --- /dev/null +++ b/scenestreamer/infer/infinite.py @@ -0,0 +1,791 @@ +""" +This module reimplements the autoregressive motion generation process. +""" + +import copy +import numpy as np + +import torch + +from scenestreamer.tokenization.motion_tokenizers import interpolate, interpolate_heading +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils +from scenestreamer.infer.motion import encode_scene +from scenestreamer.tokenization.trafficgen_tokenizers import TrafficGenTokenizer +from scenestreamer.tokenization.motion_tokenizers import START_ACTION as MOTION_START_ACTION +from scenestreamer.tokenization.biycle_tokenizer import get_relative_velocity + +from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour, detect_collision +import torch + +from scenestreamer.infer.initial_state import ( + generate_initial_state, + convert_initial_states_as_motion_data, + decode_one_step_initial_state, +) +from scenestreamer.infer.motion import ( + decode_one_step, + randomize_agent_id, + encode_scene, + interpolate_autoregressive_output, +) + + +@torch.no_grad() +def generate_densified_scenario( + *, + data_dict, + model, + force_add=False, + num_decode_steps=None, + discard_low_speed_agent=False, + remove_static_agent=False, + exclude_sdc_neighborhood=False +): + # Motion generation + agent_unique_id = data_dict['decoder/agent_id'][0] # (N,) + interpolation = True + remove_out_of_map_agent = True + tokenizer = model.tokenizer + temperature = model.config.SAMPLING.TEMPERATURE + topp = model.config.SAMPLING.TOPP + sampling_method = model.config.SAMPLING.SAMPLING_METHOD + B, T_input, N = data_dict["decoder/input_action"].shape + agent_pos = data_dict["decoder/agent_position"][:, ::tokenizer.num_skipped_steps] + agent_heading = data_dict["decoder/agent_heading"][:, ::tokenizer.num_skipped_steps] + agent_valid_mask = data_dict["decoder/agent_valid_mask"][:, ::tokenizer.num_skipped_steps] + agent_velocity = data_dict["decoder/agent_velocity"][:, ::tokenizer.num_skipped_steps] + B, T_full, N, _ = agent_pos.shape + gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + assert agent_pos.ndim == 4 + gt_input_action = data_dict["decoder/input_action"].clone() + data_dict, _ = randomize_agent_id(data_dict=data_dict, model=model) + step_data_dict = dict( + input_step=torch.arange(1).to(gt_input_action.device), + input_action=gt_input_action[:, :1].clone(), + input_action_valid_mask=data_dict["decoder/input_action_valid_mask"][:, :1].clone(), + agent_position=data_dict["decoder/modeled_agent_position"][:, :1].clone(), + agent_heading=data_dict["decoder/modeled_agent_heading"][:, :1].clone(), + agent_velocity=data_dict["decoder/modeled_agent_velocity"][:, :1].clone(), # TODO: Remove this? + agent_valid_mask=data_dict["decoder/input_action_valid_mask"][:, :1].clone(), + agent_delta=data_dict["decoder/modeled_agent_delta"][:, :1].clone(), + cache=None, + agent_id=data_dict["decoder/randomized_modeled_agent_id"], + agent_type=data_dict["decoder/agent_type"], + agent_shape=data_dict["decoder/current_agent_shape"], + ) + max_unique_id = agent_unique_id.max().item() + data_dict, _ = encode_scene(data_dict=data_dict, model=model) + + # Densify the scenario + num_agents_being_added = 128 - (max_unique_id + 1) + step_data_dict, _, agent_unique_id, max_unique_id = scenestreamer_step( + data_dict=data_dict, + motion_step_data_dict=step_data_dict, + motion_decode_one_step_info={}, + evicted_agent_mask=None, + model=model, + agent_unique_id=agent_unique_id, + num_agents_being_added=num_agents_being_added, + discard_low_speed_agent=discard_low_speed_agent, + exclude_sdc_neighborhood=exclude_sdc_neighborhood, + max_unique_id=max_unique_id, + ) + + output_logit_list = [] + output_action_list = [] + agent_unique_id_list = [agent_unique_id.clone()] + input_action_valid_mask_list = [step_data_dict["input_action_valid_mask"]] + pos = [step_data_dict["agent_position"]] + head = [step_data_dict["agent_heading"]] + vel = [step_data_dict["agent_velocity"]] + agent_type = [step_data_dict["agent_type"]] + agent_shape = [step_data_dict["agent_shape"]] + for decode_step in range(num_decode_steps): + teacher_forcing_valid_mask = None + teacher_forcing_action = None + next_state_data_dict, decode_one_step_info = decode_one_step( + data_dict=data_dict, + model=model, + sampling_method=sampling_method, + temperature=temperature, + topp=topp, + teacher_forcing_valid_mask=teacher_forcing_valid_mask, + teacher_forcing_action=teacher_forcing_action, + remove_out_of_map_agent=remove_out_of_map_agent, + **step_data_dict, + remove_static_agent=remove_static_agent, + ) + + # if "evicted_agent_mask" in decode_one_step_info: + if force_add: + num_agents_being_added = 128 - next_state_data_dict["agent_valid_mask"].sum().item() + assert num_agents_being_added >= 0 + else: + num_agents_being_added = None + + next_state_data_dict, decode_one_step_info, agent_unique_id, max_unique_id = scenestreamer_step( + data_dict=data_dict, + motion_step_data_dict=next_state_data_dict, + motion_decode_one_step_info=decode_one_step_info, + evicted_agent_mask=decode_one_step_info["evicted_agent_mask"], + model=model, + agent_unique_id=agent_unique_id, + num_agents_being_added=num_agents_being_added, + discard_low_speed_agent=discard_low_speed_agent, + exclude_sdc_neighborhood=exclude_sdc_neighborhood, + max_unique_id=max_unique_id, + ) + # There is a very tricky bug. At step T, say agent A is evicted. + # We know the agent A is valid at T (input_valid_mask is valid for A at T), + # so we will generate the position at T+1. + # However, in scenestreamer_step, all evicted agents are removed from the motion_step_data_dict. + # Therefore the later interpolation process can't access to A's position at T+1. + # So there will be an issue in the interpolation process. + # To solve this, a workaround below is to remove agent A at T. (we might lost some information though) + if "evicted_agent_mask" in decode_one_step_info and decode_one_step_info["evicted_agent_mask"] is not None: + input_action_valid_mask_list[ + -1] = input_action_valid_mask_list[-1] * (~decode_one_step_info["evicted_agent_mask"].unsqueeze(1)) + + pos.append(next_state_data_dict["agent_position"].clone()) + head.append(next_state_data_dict["agent_heading"].clone()) + vel.append(next_state_data_dict["agent_velocity"].clone()) + agent_type.append(next_state_data_dict["agent_type"].clone()) + agent_shape.append(next_state_data_dict["agent_shape"].clone()) + output_logit_list.append(decode_one_step_info["output_token"].clone()) + output_action_list.append(next_state_data_dict["input_action"].clone()) + agent_unique_id_list.append(agent_unique_id.clone()) + input_action_valid_mask_list.append(next_state_data_dict["input_action_valid_mask"].clone()) + step_data_dict = next_state_data_dict + + # ===== Post-process the data ===== + num_total_agents = agent_unique_id_list[-1].max().item() + 1 + + def _scatter(data_list): + ret = [] + for i, (unique_id, data) in enumerate(zip(agent_unique_id_list[1:], data_list)): + s = list(data.shape) + s[2] = num_total_agents + new_data = data.new_zeros(s) + assert len(unique_id) == data.shape[2] + assert unique_id.max() < num_total_agents, (unique_id.max(), num_total_agents) + new_data[:, :, unique_id] = data + ret.append(new_data) + return torch.cat(ret, dim=1) + + def _scatter_with_first_step(data_list): + """This function is used for the data where the first step is the initial states.""" + ret = [] + for i, (unique_id, data) in enumerate(zip(agent_unique_id_list, data_list)): + s = list(data.shape) + s[2] = num_total_agents + new_data = data.new_zeros(s) + assert len(unique_id) == data.shape[2], (len(unique_id), data.shape[2]) + assert unique_id.max() < num_total_agents, (unique_id.max(), num_total_agents) + new_data[:, :, unique_id] = data + ret.append(new_data) + return torch.cat(ret, dim=1) + + def _scatter_for_non_temporal(data_list): + s = list(data_list[0].shape)[2:] + ret = data_list[0].new_zeros([ + B, + num_total_agents, + ] + s) + for i, (unique_id, data) in enumerate(zip(agent_unique_id_list, data_list)): + assert len(unique_id) == data.shape[1], (len(unique_id), data.shape[1]) + assert unique_id.max() < num_total_agents, (unique_id.max(), num_total_agents) + ret[:, unique_id] = data + return ret + + output_action_list = _scatter(output_action_list) + assert output_action_list.shape == (B, num_decode_steps, num_total_agents) + assert len(input_action_valid_mask_list) == num_decode_steps + 1 + input_action_valid_mask = _scatter_with_first_step(input_action_valid_mask_list) + # Evict the last step's input_action_valid_mask_list as it is not used. + input_action_valid_mask = input_action_valid_mask[:, :-1] + + output_logit_list = _scatter(output_logit_list) + traj_log_prob, traj_prob = utils.calculate_trajectory_probabilities_new( + output_logit_list, output_action_list, mask=input_action_valid_mask + ) # (B, N) + pos = _scatter_with_first_step(pos) + head = _scatter_with_first_step(head) + vel = _scatter_with_first_step(vel) + agent_type = _scatter_for_non_temporal(agent_type) + agent_shape = _scatter_for_non_temporal(agent_shape) + + # ===== Interpolate the output ===== + if interpolation: + data_dict, _ = interpolate_autoregressive_output( + data_dict=data_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + input_valid_mask=input_action_valid_mask, + num_skipped_steps=tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + ) + + # ===== Save the data ===== + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + data_dict["decoder/output_score"] = traj_log_prob + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + # data_dict["decoder/debug_ar_pos"] = pos + # data_dict["decoder/debug_ar_head"] = head + # data_dict["decoder/debug_ar_vel"] = vel + + data_dict["decoder/agent_type"] = agent_type + data_dict["decoder/current_agent_shape"] = agent_shape + data_dict.pop("decoder/agent_shape") + data_dict.pop("decoder/object_of_interest_id") + + valid_output_action = output_action_list[input_action_valid_mask] + assert valid_output_action.max() < tokenizer.num_actions + assert valid_output_action.min() >= 0 + + return data_dict + + +@torch.no_grad() +def generate_scenestreamer_motion( + *, + data_dict, + model, + force_add=False, + num_decode_steps=None, + discard_low_speed_agent=False, + remove_static_agent=False, + exclude_sdc_neighborhood=False +): + + # Initial state generation + data_dict, initial_state_info = generate_initial_state( + data_dict=data_dict, model=model, force_add=force_add, discard_low_speed_agent=discard_low_speed_agent + ) + data_dict = convert_initial_states_as_motion_data(data_dict) + + # Motion generation + agent_unique_id = data_dict['decoder/agent_id'][0] # (N,) + interpolation = True + remove_out_of_map_agent = True + tokenizer = model.tokenizer + temperature = model.config.SAMPLING.TEMPERATURE + topp = model.config.SAMPLING.TOPP + sampling_method = model.config.SAMPLING.SAMPLING_METHOD + B, T_input, N = data_dict["decoder/input_action"].shape + agent_pos = data_dict["decoder/agent_position"][:, ::tokenizer.num_skipped_steps] + agent_heading = data_dict["decoder/agent_heading"][:, ::tokenizer.num_skipped_steps] + agent_valid_mask = data_dict["decoder/agent_valid_mask"][:, ::tokenizer.num_skipped_steps] + agent_velocity = data_dict["decoder/agent_velocity"][:, ::tokenizer.num_skipped_steps] + B, T_full, N, _ = agent_pos.shape + gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + assert agent_pos.ndim == 4 + gt_input_action = data_dict["decoder/input_action"].clone() + data_dict, _ = randomize_agent_id(data_dict=data_dict, model=model) + step_data_dict = dict( + input_step=torch.arange(1).to(gt_input_action.device), + input_action=gt_input_action[:, :1].clone(), + input_action_valid_mask=data_dict["decoder/input_action_valid_mask"][:, :1].clone(), + agent_position=data_dict["decoder/modeled_agent_position"][:, :1].clone(), + agent_heading=data_dict["decoder/modeled_agent_heading"][:, :1].clone(), + agent_velocity=data_dict["decoder/modeled_agent_velocity"][:, :1].clone(), # TODO: Remove this? + agent_valid_mask=data_dict["decoder/input_action_valid_mask"][:, :1].clone(), + agent_delta=data_dict["decoder/modeled_agent_delta"][:, :1].clone(), + cache=None, + agent_id=data_dict["decoder/randomized_modeled_agent_id"], + agent_type=data_dict["decoder/agent_type"], + agent_shape=data_dict["decoder/current_agent_shape"], + ) + max_unique_id = agent_unique_id.max().item() + agent_unique_id_list = [agent_unique_id.clone()] + output_logit_list = [] + output_action_list = [] + input_action_valid_mask_list = [step_data_dict["input_action_valid_mask"]] + pos = [step_data_dict["agent_position"]] + head = [step_data_dict["agent_heading"]] + vel = [step_data_dict["agent_velocity"]] + agent_type = [step_data_dict["agent_type"]] + agent_shape = [step_data_dict["agent_shape"]] + data_dict, _ = encode_scene(data_dict=data_dict, model=model) + for decode_step in range(num_decode_steps): + teacher_forcing_valid_mask = None + teacher_forcing_action = None + next_state_data_dict, decode_one_step_info = decode_one_step( + data_dict=data_dict, + model=model, + sampling_method=sampling_method, + temperature=temperature, + topp=topp, + teacher_forcing_valid_mask=teacher_forcing_valid_mask, + teacher_forcing_action=teacher_forcing_action, + remove_out_of_map_agent=remove_out_of_map_agent, + **step_data_dict, + remove_static_agent=remove_static_agent, + ) + + # if "evicted_agent_mask" in decode_one_step_info: + if force_add: + num_agents_being_added = 128 - next_state_data_dict["agent_valid_mask"].sum().item() + assert num_agents_being_added >= 0 + else: + num_agents_being_added = None + + next_state_data_dict, decode_one_step_info, agent_unique_id, max_unique_id = scenestreamer_step( + data_dict=data_dict, + motion_step_data_dict=next_state_data_dict, + motion_decode_one_step_info=decode_one_step_info, + evicted_agent_mask=decode_one_step_info["evicted_agent_mask"], + model=model, + agent_unique_id=agent_unique_id, + num_agents_being_added=num_agents_being_added, + discard_low_speed_agent=discard_low_speed_agent, + exclude_sdc_neighborhood=exclude_sdc_neighborhood, + max_unique_id=max_unique_id, + ) + # There is a very tricky bug. At step T, say agent A is evicted. + # We know the agent A is valid at T (input_valid_mask is valid for A at T), + # so we will generate the position at T+1. + # However, in scenestreamer_step, all evicted agents are removed from the motion_step_data_dict. + # Therefore the later interpolation process can't access to A's position at T+1. + # So there will be an issue in the interpolation process. + # To solve this, a workaround below is to remove agent A at T. (we might lost some information though) + if "evicted_agent_mask" in decode_one_step_info and decode_one_step_info["evicted_agent_mask"] is not None: + input_action_valid_mask_list[ + -1] = input_action_valid_mask_list[-1] * (~decode_one_step_info["evicted_agent_mask"].unsqueeze(1)) + + pos.append(next_state_data_dict["agent_position"].clone()) + head.append(next_state_data_dict["agent_heading"].clone()) + vel.append(next_state_data_dict["agent_velocity"].clone()) + agent_type.append(next_state_data_dict["agent_type"].clone()) + agent_shape.append(next_state_data_dict["agent_shape"].clone()) + output_logit_list.append(decode_one_step_info["output_token"].clone()) + output_action_list.append(next_state_data_dict["input_action"].clone()) + agent_unique_id_list.append(agent_unique_id.clone()) + input_action_valid_mask_list.append(next_state_data_dict["input_action_valid_mask"].clone()) + step_data_dict = next_state_data_dict + + # ===== Post-process the data ===== + num_total_agents = agent_unique_id_list[-1].max().item() + 1 + + def _scatter(data_list): + ret = [] + for i, (unique_id, data) in enumerate(zip(agent_unique_id_list[1:], data_list)): + s = list(data.shape) + s[2] = num_total_agents + new_data = data.new_zeros(s) + assert len(unique_id) == data.shape[2] + assert unique_id.max() < num_total_agents + new_data[:, :, unique_id] = data + ret.append(new_data) + return torch.cat(ret, dim=1) + + def _scatter_with_first_step(data_list): + """This function is used for the data where the first step is the initial states.""" + ret = [] + for i, (unique_id, data) in enumerate(zip(agent_unique_id_list, data_list)): + s = list(data.shape) + s[2] = num_total_agents + new_data = data.new_zeros(s) + assert len(unique_id) == data.shape[2] + assert unique_id.max() < num_total_agents + new_data[:, :, unique_id] = data + ret.append(new_data) + return torch.cat(ret, dim=1) + + def _scatter_for_non_temporal(data_list): + s = list(data_list[0].shape)[2:] + ret = data_list[0].new_zeros([ + B, + num_total_agents, + ] + s) + for i, (unique_id, data) in enumerate(zip(agent_unique_id_list, data_list)): + assert len(unique_id) == data.shape[1] + assert unique_id.max() < num_total_agents + ret[:, unique_id] = data + return ret + + output_action_list = _scatter(output_action_list) + assert output_action_list.shape == (B, num_decode_steps, num_total_agents) + assert len(input_action_valid_mask_list) == num_decode_steps + 1 + input_action_valid_mask = _scatter_with_first_step(input_action_valid_mask_list) + # Evict the last step's input_action_valid_mask_list as it is not used. + input_action_valid_mask = input_action_valid_mask[:, :-1] + + output_logit_list = _scatter(output_logit_list) + traj_log_prob, traj_prob = utils.calculate_trajectory_probabilities_new( + output_logit_list, output_action_list, mask=input_action_valid_mask + ) # (B, N) + pos = _scatter_with_first_step(pos) + head = _scatter_with_first_step(head) + vel = _scatter_with_first_step(vel) + agent_type = _scatter_for_non_temporal(agent_type) + agent_shape = _scatter_for_non_temporal(agent_shape) + + # ===== Interpolate the output ===== + if interpolation: + data_dict, _ = interpolate_autoregressive_output( + data_dict=data_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + input_valid_mask=input_action_valid_mask, + num_skipped_steps=tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + ) + + # ===== Save the data ===== + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + data_dict["decoder/output_score"] = traj_log_prob + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + # data_dict["decoder/debug_ar_pos"] = pos + # data_dict["decoder/debug_ar_head"] = head + # data_dict["decoder/debug_ar_vel"] = vel + + data_dict["decoder/agent_type"] = agent_type + data_dict["decoder/current_agent_shape"] = agent_shape + data_dict.pop("decoder/agent_shape") + data_dict.pop("decoder/object_of_interest_id") + + valid_output_action = output_action_list[input_action_valid_mask] + assert valid_output_action.max() < tokenizer.num_actions + assert valid_output_action.min() >= 0 + + return data_dict + + +def scenestreamer_step( + *, + data_dict, + motion_step_data_dict, + motion_decode_one_step_info, + evicted_agent_mask, + model, + agent_unique_id, + max_unique_id, + num_agents_being_added=None, + discard_low_speed_agent=False, + exclude_sdc_neighborhood=False +): + scenestreamer_step_info = {} + + B = data_dict["encoder/scenario_token"].shape[0] + + print("Entering scenestreamer_step...") + + # Step 1: Remove the evicted agent info from motion_data_dict. + # Just do nothing here as we don't want to mess up the KV cache. + if evicted_agent_mask is not None: + print("{} agents are evicted.".format(evicted_agent_mask.sum().item())) + + # Step 2: Run initial state generation. + init_step_data_dict = build_initial_data_dict( + data_dict=data_dict, + agent_position=motion_step_data_dict["agent_position"][:, -1], + agent_heading=motion_step_data_dict["agent_heading"][:, -1], + agent_velocity=motion_step_data_dict["agent_velocity"][:, -1], + agent_type=motion_step_data_dict["agent_type"], + agent_shape=motion_step_data_dict["agent_shape"], + agent_valid_mask=motion_step_data_dict["agent_valid_mask"][:, -1], + evicted_agent_mask=evicted_agent_mask, + start_action_id=model.trafficgen_decoder.trafficgen_tokenizer.start_action_id, + ) + num_agents_before_adding = init_step_data_dict["input_action"].shape[-1] - 1 + num_agents_after_adding = num_agents_before_adding + + if num_agents_being_added is None: + num_decode_steps = 128 - num_agents_before_adding + force_add = False + else: + print("We will force add {} agents.".format(num_agents_being_added)) + num_decode_steps = num_agents_being_added + force_add = True + # for _ in tqdm.tqdm(range(num_decode_steps), desc="Generating initial state"): + for _ in range(num_decode_steps): + next_init_step_data_dict, step_info = decode_one_step_initial_state( + data_dict=data_dict, + model=model, + force_add=force_add, + **init_step_data_dict, + discard_low_speed_agent=discard_low_speed_agent, + exclude_sdc_neighborhood=exclude_sdc_neighborhood, + ) + if step_info["terminated"]: + break + init_step_data_dict = next_init_step_data_dict + num_agents_after_adding = init_step_data_dict["input_action"].shape[-1] - 1 + + if num_agents_after_adding > model.config.PREPROCESSING.MAX_AGENTS: + break + + # Step 4: Translate the initial state to motion_data_dict. + # Note that we should fill in the "MOTION_START_TOKEN" here. + # The idea is to just concat new agents' data into existing motion data dict. + num_newly_added_agents = num_agents_after_adding - num_agents_before_adding + print("{} agents are newly added.".format(num_newly_added_agents)) + # The tokens after init input tokens are the newly added agents' tokens. + # init_input_token_num = motion_step_data_dict['input_action'].shape[-1] + 1 + + new_pos = init_step_data_dict['agent_position'][:, num_agents_before_adding + 1:] + new_head = init_step_data_dict['agent_heading'][:, num_agents_before_adding + 1:] + new_vel = init_step_data_dict['agent_velocity'][:, num_agents_before_adding + 1:] + new_shape = init_step_data_dict['agent_shape'][:, num_agents_before_adding + 1:] + new_type = init_step_data_dict['agent_type'][:, num_agents_before_adding + 1:] + new_valid_mask = init_step_data_dict['input_action_valid_mask'][:, num_agents_before_adding + 1:] + assert new_valid_mask.shape == (B, num_newly_added_agents) + + if evicted_agent_mask is not None: + leftover_indices = (~evicted_agent_mask)[0].nonzero()[:, 0] + else: + leftover_indices = torch.arange(num_agents_before_adding).to(new_pos.device) + + old_agent_id = motion_step_data_dict["agent_id"][:, leftover_indices] + assert motion_step_data_dict["agent_id"].shape[0] == 1 + agent_id_candidates = list(range(model.config.PREPROCESSING.MAX_AGENTS)) + agent_id_candidates = [x for x in agent_id_candidates if x not in old_agent_id[0].tolist()] + if len(agent_id_candidates) < num_newly_added_agents: + print( + "Not enough agent ids to assign! We have {} agents already and {} agents are newly added.".format( + num_agents_before_adding, num_newly_added_agents + ) + ) + return motion_step_data_dict, motion_decode_one_step_info, agent_unique_id, max_unique_id + + agent_id_candidates = np.random.choice(agent_id_candidates, num_newly_added_agents, replace=False) + new_id = torch.tensor(agent_id_candidates).to(old_agent_id.device).long() + motion_step_data_dict["agent_id"] = torch.cat([old_agent_id, new_id.unsqueeze(0)], dim=-1).long() + + motion_step_data_dict["input_action"] = torch.cat( + [ + motion_step_data_dict["input_action"][:, :, leftover_indices], + motion_step_data_dict["input_action"].new_full([B, 1, num_newly_added_agents], MOTION_START_ACTION) + ], + dim=-1 + ) + motion_step_data_dict["input_action_valid_mask"] = torch.cat( + [motion_step_data_dict["input_action_valid_mask"][:, :, leftover_indices], + new_valid_mask.unsqueeze(1)], dim=-1 + ) + motion_step_data_dict["agent_position"] = torch.cat( + [motion_step_data_dict["agent_position"][:, :, leftover_indices], + new_pos.unsqueeze(1)], dim=-2 + ) + motion_step_data_dict["agent_heading"] = torch.cat( + [motion_step_data_dict["agent_heading"][:, :, leftover_indices], + new_head.unsqueeze(1)], dim=-1 + ) + motion_step_data_dict["agent_velocity"] = torch.cat( + [motion_step_data_dict["agent_velocity"][:, :, leftover_indices], + new_vel.unsqueeze(1)], dim=-2 + ) + motion_step_data_dict["agent_shape"] = torch.cat( + [motion_step_data_dict["agent_shape"][:, leftover_indices], new_shape], dim=-2 + ) + motion_step_data_dict["agent_type"] = torch.cat( + [motion_step_data_dict["agent_type"][:, leftover_indices], new_type], dim=-1 + ) + motion_step_data_dict["agent_valid_mask"] = torch.cat( + [motion_step_data_dict["agent_valid_mask"][:, :, leftover_indices], + new_valid_mask.unsqueeze(1)], dim=-1 + ) + + assert motion_step_data_dict["input_action"].shape == motion_step_data_dict["input_action_valid_mask"].shape + + new_delta = get_relative_velocity(new_vel.unsqueeze(1), new_head.unsqueeze(1)) + motion_step_data_dict["agent_delta"] = torch.cat( + [motion_step_data_dict["agent_delta"][:, :, leftover_indices], new_delta], dim=-2 + ) + + print("Totally {} agents exist.".format(motion_step_data_dict["agent_id"].shape[-1])) + + # The cache and agent history should be updated as well... + if "output_token" in motion_decode_one_step_info: + motion_decode_one_step_info["output_token"] = torch.cat( + [ + motion_decode_one_step_info["output_token"][:, :, leftover_indices], + motion_decode_one_step_info["output_token"].new_zeros( + [B, 1, num_newly_added_agents, motion_decode_one_step_info["output_token"].shape[-1]] + ) + ], + dim=-2 + ) + + if "agent_position_history" in motion_step_data_dict: + _, T_history, num_existing_agents, _ = motion_step_data_dict["agent_position_history"].shape + motion_step_data_dict["agent_position_history"] = torch.cat( + [ + motion_step_data_dict["agent_position_history"][:, :, leftover_indices], + motion_step_data_dict["agent_position_history"].new_zeros([B, T_history, num_newly_added_agents, 2]) + ], + dim=-2 + ) + + motion_step_data_dict["agent_heading_history"] = torch.cat( + [ + motion_step_data_dict["agent_heading_history"][:, :, leftover_indices], + motion_step_data_dict["agent_heading_history"].new_zeros([B, T_history, num_newly_added_agents]) + ], + dim=-1 + ) + motion_step_data_dict["agent_valid_mask_history"] = torch.cat( + [ + motion_step_data_dict["agent_valid_mask_history"][:, :, leftover_indices], + motion_step_data_dict["agent_valid_mask_history"].new_zeros([B, T_history, num_newly_added_agents]) + ], + dim=-1 + ) + # No need to update this: + # motion_step_data_dict["agent_step_history"] = torch.cat( + + if "cache" in motion_step_data_dict and motion_step_data_dict["cache"] is not None: + # Need to update cache: + new_cache = [] + old_cache = motion_step_data_dict["cache"] + for layer_cache in old_cache: + k, v, (batch_size, seq_len) = layer_cache + assert seq_len == T_history + assert batch_size == B * num_existing_agents + k = k.reshape(B * num_existing_agents, seq_len, -1)[leftover_indices] + v = v.reshape(B * num_existing_agents, seq_len, -1)[leftover_indices] + k = torch.cat([k, k.new_zeros([B * num_newly_added_agents, seq_len, k.shape[-1]])], dim=0) + v = torch.cat([v, v.new_zeros([B * num_newly_added_agents, seq_len, v.shape[-1]])], dim=0) + k = k.reshape(B * (len(leftover_indices) + num_newly_added_agents), seq_len, -1) + v = v.reshape(B * (len(leftover_indices) + num_newly_added_agents), seq_len, -1) + new_cache.append((k, v, (B * (len(leftover_indices) + num_newly_added_agents), seq_len))) + motion_step_data_dict["cache"] = new_cache + + agent_unique_id = torch.cat( + [ + agent_unique_id[leftover_indices], + max_unique_id + 1 + torch.arange(num_newly_added_agents).to(agent_unique_id.device) + ] + ).long() + max_unique_id = agent_unique_id.max().item() + + print("agent_unique_id: ", list(agent_unique_id.cpu().numpy())) + + return motion_step_data_dict, motion_decode_one_step_info, agent_unique_id, max_unique_id + + +def build_initial_data_dict( + *, data_dict, start_action_id, agent_position, agent_heading, agent_velocity, agent_type, agent_shape, + agent_valid_mask, evicted_agent_mask +): + + # TODO + only_lane = True + + if evicted_agent_mask is not None: + agent_valid_mask = agent_valid_mask & (~evicted_agent_mask) + + leftover_indices = agent_valid_mask[0].nonzero()[:, 0] + + agent_position = agent_position[:, leftover_indices] + agent_heading = agent_heading[:, leftover_indices] + agent_velocity = agent_velocity[:, leftover_indices] + agent_type = agent_type[:, leftover_indices] + agent_shape = agent_shape[:, leftover_indices] + agent_valid_mask = agent_valid_mask[:, leftover_indices] + + B, N, _ = agent_position.shape + + map_pos = data_dict["encoder/map_position"][..., :2] + map_heading = data_dict["encoder/map_heading"] + + # Get map feature valid mask + valid_map_feat = data_dict["encoder/map_valid_mask"] + heading_diff = utils.wrap_to_pi(agent_heading[:, :, None] - map_heading[:, None]) + valid_heading = torch.abs(heading_diff) < np.deg2rad(90) + valid_map_feat = valid_map_feat & valid_heading + + if only_lane: + map_feature = data_dict["encoder/map_feature"] + is_lane = map_feature[:, :, 0, 13] == 1 + is_lane = is_lane[:, None] + valid_map_feat = is_lane & valid_map_feat + + # Find the closest map feature + dist = torch.cdist(agent_position, map_pos) + dist[~valid_map_feat] = torch.inf + closest_map_feat = torch.argmin(dist, dim=-1) + + # Get the selected map feature + selected_map_pos = torch.gather(map_pos, dim=1, index=closest_map_feat[:, :, None].expand(-1, -1, 2)) + selected_map_heading = torch.gather(map_heading, dim=1, index=closest_map_feat) + + # Get relative information + relative_pos = agent_position - selected_map_pos + relative_pos = utils.rotate(x=relative_pos[..., 0], y=relative_pos[..., 1], angle=-selected_map_heading) + relative_heading = utils.wrap_to_pi(agent_heading - selected_map_heading) + relative_vel = utils.rotate(x=agent_velocity[..., 0], y=agent_velocity[..., 1], angle=-selected_map_heading) + + # Get the discretized relative position + gt_position_x = TrafficGenTokenizer.bucketize(relative_pos[..., 0], "position_x") + gt_position_y = TrafficGenTokenizer.bucketize(relative_pos[..., 1], "position_y") + recon_pos_x = TrafficGenTokenizer.de_bucketize(gt_position_x, "position_x") + recon_pos_y = TrafficGenTokenizer.de_bucketize(gt_position_y, "position_y") + + # Reconstruct the position with the bucketized value + recon_pos = torch.stack([recon_pos_x, recon_pos_y], dim=-1) + recon_pos_abs = utils.rotate(x=recon_pos_x, y=recon_pos_y, angle=selected_map_heading) + selected_map_pos + # pad + recon_pos_abs = torch.cat([recon_pos_abs.new_zeros([B, 1, 2]), recon_pos_abs], dim=1) + + # Reconstruct the heading and velocity + gt_heading = TrafficGenTokenizer.bucketize(relative_heading, "heading") + recon_heading = TrafficGenTokenizer.de_bucketize(gt_heading, "heading") + recon_heading_abs = utils.wrap_to_pi(recon_heading + selected_map_heading) + # pad + recon_heading_abs = torch.cat([recon_heading_abs.new_zeros([B, 1]), recon_heading_abs], dim=1) + + # Reconstruct the velocity + gt_vel_x = TrafficGenTokenizer.bucketize(relative_vel[..., 0], "velocity_x") + gt_vel_y = TrafficGenTokenizer.bucketize(relative_vel[..., 1], "velocity_y") + recon_vel_x = TrafficGenTokenizer.de_bucketize(gt_vel_x, "velocity_x") + recon_vel_y = TrafficGenTokenizer.de_bucketize(gt_vel_y, "velocity_y") + recon_vel = torch.stack([recon_vel_x, recon_vel_y], dim=-1) + recon_vel_abs = utils.rotate(x=recon_vel_x, y=recon_vel_y, angle=selected_map_heading) + # pad + recon_vel_abs = torch.cat([recon_vel_abs.new_zeros([B, 1, 2]), recon_vel_abs], dim=1) + + # Reconstruct shape + gt_shape_l = TrafficGenTokenizer.bucketize(agent_shape[..., 0], "length") + gt_shape_w = TrafficGenTokenizer.bucketize(agent_shape[..., 1], "width") + gt_shape_h = TrafficGenTokenizer.bucketize(agent_shape[..., 2], "height") + recon_shape = torch.stack( + [ + TrafficGenTokenizer.de_bucketize(gt_shape_l, "length"), + TrafficGenTokenizer.de_bucketize(gt_shape_w, "width"), + TrafficGenTokenizer.de_bucketize(gt_shape_h, "height") + ], + dim=-1 + ) + # Pad recon shape in 1st dimension + recon_shape = torch.cat([recon_shape.new_zeros([B, 1, 3]), recon_shape], dim=1) + + feat = recon_shape.new_zeros((B, recon_pos.shape[1] + 1, 5)) + feat[:, 1:, :2] = recon_pos + feat[:, 1:, 2] = recon_heading + feat[:, 1:, 3:5] = recon_vel + + input_action = torch.cat([closest_map_feat.new_full([B, 1], start_action_id), closest_map_feat], dim=1) + + input_action_valid_mask = torch.cat([agent_valid_mask.new_ones([B, 1]), agent_valid_mask], dim=1) + + # We are not building only the first input step, but instead the whole input sequence. + step_data_dict = dict( + input_step=torch.zeros( + [B, recon_pos.shape[1] + 1], dtype=torch.long, device=data_dict["encoder/scenario_token"].device + ), + input_action=input_action, + input_action_valid_mask=input_action_valid_mask, + agent_position=recon_pos_abs, + agent_heading=recon_heading_abs, + agent_velocity=recon_vel_abs, + agent_type=torch.cat([agent_type.new_zeros([B, 1]), agent_type], dim=1), + agent_shape=recon_shape, + agent_feature=feat, + ) + return step_data_dict diff --git a/scenestreamer/infer/initial_state.py b/scenestreamer/infer/initial_state.py new file mode 100644 index 0000000000000000000000000000000000000000..830638d6fd56090165c305306b62dedc8a37edc6 --- /dev/null +++ b/scenestreamer/infer/initial_state.py @@ -0,0 +1,337 @@ +""" +This module reimplements the autoregressive motion generation process. +""" + +import copy + +import torch + +from scenestreamer.tokenization.motion_tokenizers import interpolate, interpolate_heading +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils +from scenestreamer.infer.motion import encode_scene + +from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour, detect_collision +import torch + + +@torch.no_grad() +def generate_initial_state( + *, data_dict, model, force_add=False, discard_low_speed_agent=False, exclude_sdc_neighborhood=False +): + num_collisions = 0 + num_violations = 0 + N = data_dict["decoder/modeled_agent_position"].shape[2] + + # ===== Call Model ===== + data_dict, _ = encode_scene(data_dict=data_dict, model=model) + B = data_dict["encoder/scenario_token"].shape[0] + step_data_dict = dict( + input_step=torch.zeros([B, 1], dtype=torch.long, device=data_dict["encoder/scenario_token"].device), + input_action=data_dict["decoder/input_action_for_trafficgen"][:, :1].clone(), + input_action_valid_mask=data_dict["decoder/input_action_valid_mask_for_trafficgen"][:, :1].clone(), + agent_position=data_dict["decoder/modeled_agent_position_for_trafficgen"][:, :1].clone(), + agent_heading=data_dict["decoder/modeled_agent_heading_for_trafficgen"][:, :1].clone(), + agent_velocity=data_dict["decoder/modeled_agent_velocity_for_trafficgen"][:, :1].clone(), + agent_type=data_dict["decoder/agent_type_for_trafficgen"][:, :1].clone(), + agent_shape=data_dict["decoder/current_agent_shape_for_trafficgen"][:, :1].clone(), + agent_feature=data_dict["decoder/input_action_feature_for_trafficgen"][:, :1].clone(), + ) + + num_decode_steps = min(N, 128) + + # num_decode_steps = 128 + # for _ in tqdm.tqdm(range(num_decode_steps), desc="Generating initial state"): + for decode_step in range(num_decode_steps): + next_state_data_dict, step_info = decode_one_step_initial_state( + data_dict=data_dict, + model=model, + force_add=force_add, + discard_low_speed_agent=discard_low_speed_agent, + **step_data_dict, + exclude_sdc_neighborhood=exclude_sdc_neighborhood, + decode_step=decode_step, + ) + if step_info["terminated"]: + break + step_data_dict = next_state_data_dict + + # Filter out the first step because it's (0,0) the START ACTION. + data_dict.update( + { + "decoder/modeled_agent_position_for_trafficgen": step_data_dict["agent_position"].clone()[:, 1:], + "decoder/modeled_agent_heading_for_trafficgen": step_data_dict["agent_heading"].clone()[:, 1:], + "decoder/modeled_agent_velocity_for_trafficgen": step_data_dict["agent_velocity"].clone()[:, 1:], + "decoder/current_agent_shape_for_trafficgen": step_data_dict["agent_shape"].clone()[:, 1:], + "decoder/agent_type_for_trafficgen": step_data_dict["agent_type"].clone()[:, 1:], + "decoder/input_action_valid_mask_for_trafficgen": step_data_dict["input_action_valid_mask"].clone()[:, 1:], + } + ) + return data_dict, { + # "num_collisions": step_info["num_collisions"], + # "num_violations": num_violations, + # "num_low_speed": step_info["num_low_speed"] + } + + +def decode_one_step_initial_state( + *, + data_dict, + model, + input_action, + input_action_valid_mask, + input_step, + agent_position, + agent_heading, + agent_velocity, + agent_type, + agent_shape, + agent_feature, + decode_step, + num_collisions=None, + force_add=False, + discard_low_speed_agent=False, + num_low_speed=None, + exclude_sdc_neighborhood=False, +): + if num_collisions is None: + num_collisions = 0 + if num_low_speed is None: + num_low_speed = 0 + B = data_dict["encoder/scenario_token"].shape[0] + raw_input_dict = { + # Static features + "encoder/scenario_token": data_dict["encoder/scenario_token"], + "encoder/scenario_heading": data_dict["encoder/scenario_heading"], + "encoder/scenario_position": data_dict["encoder/scenario_position"], + "encoder/scenario_valid_mask": data_dict["encoder/scenario_valid_mask"], + "encoder/map_position": data_dict["encoder/map_position"], + "encoder/map_feature": data_dict["encoder/map_feature"], + "encoder/map_valid_mask": data_dict["encoder/map_valid_mask"], + "in_evaluation": torch.ones([B], dtype=torch.bool, device=data_dict["encoder/scenario_token"].device), + + # Actions + "decoder/input_step_for_trafficgen": input_step, + "decoder/input_action_for_trafficgen": input_action, + "decoder/input_action_valid_mask_for_trafficgen": input_action_valid_mask, + + # Agent features + "decoder/modeled_agent_position_for_trafficgen": agent_position, + "decoder/modeled_agent_heading_for_trafficgen": agent_heading, + "decoder/modeled_agent_velocity_for_trafficgen": agent_velocity, + "decoder/agent_type_for_trafficgen": agent_type, + "decoder/current_agent_shape_for_trafficgen": agent_shape, + "decoder/input_action_feature_for_trafficgen": agent_feature, + } + while True: + raw_output_dict = model.trafficgen_decoder(copy.deepcopy(raw_input_dict)) + output_dict = {} + for k, v in raw_output_dict.items(): + if "encoder" in k or k == "in_evaluation": + output_dict[k] = v + elif "for_trafficgen" in k and "input_step" not in k: + output_dict[k] = v[:, -1:] + + if exclude_sdc_neighborhood: + # Do surgery here to mask out SDC neighborhood in output logits. + output_logit = output_dict['decoder/output_logit_for_trafficgen'] + vocab_size = output_logit.shape[-1] + # We assume that the SDC is the first agent (the 2nd tokens). + assert agent_position.shape[1] >= 2 + sdc_pos = agent_position[0, 1] + map_pos = data_dict["encoder/map_position"][0, :, :2] + dist = torch.cdist(map_pos, sdc_pos[None]) + invalid_map_mask = dist < 50.0 + invalid_map_mask_full = invalid_map_mask.new_zeros((1, 1, vocab_size)) + invalid_map_mask_full[0, 0, :invalid_map_mask.shape[0]] = invalid_map_mask[:, 0] + output_logit = torch.where(invalid_map_mask_full, -1e9 * torch.ones_like(output_logit), output_logit) + output_dict['decoder/output_logit_for_trafficgen'] = output_logit + + else: + assert data_dict["encoder/map_position"].shape[0] == 1, "Batch size should be 1" + + # Do surgery here to mask IN SDC neighborhood in output logits. + output_logit = output_dict['decoder/output_logit_for_trafficgen'] + vocab_size = output_logit.shape[-1] + # We assume that the SDC is the first agent (the 2nd tokens). + if agent_position.shape[1] >= 2: + sdc_pos = agent_position[0, 1] + map_pos = data_dict["encoder/map_position"][0, :, :2] + dist = torch.cdist(map_pos, sdc_pos[None]) + valid_map_mask = dist < 50.0 + valid_map_mask_full = valid_map_mask.new_zeros((1, 1, vocab_size)) + valid_map_mask_full[0, 0, :valid_map_mask.shape[0]] = valid_map_mask[:, 0] + output_logit = torch.where(valid_map_mask_full, output_logit, -1e9 * torch.ones_like(output_logit)) + output_dict['decoder/output_logit_for_trafficgen'] = output_logit + + temperature = 1.0 + sampled_action = model.trafficgen_decoder.sample_action( + output_dict, force_no_end=force_add, temperature=temperature + ) + is_end = sampled_action == model.trafficgen_decoder.trafficgen_tokenizer.INIT_END_ACTION + new_agent_type_output = model.trafficgen_decoder.forward_agent_type(output_dict, action=sampled_action) + new_agent_type = model.trafficgen_decoder.sample_agent_type(new_agent_type_output, temperature=temperature) + new_offset_output = model.trafficgen_decoder.forward_offset( + output_dict, action=sampled_action, agent_type=new_agent_type + ) + new_offset_action = model.trafficgen_decoder.sample_offset( + offset_output=new_offset_output, temperature=temperature + ) + predicted_values = model.trafficgen_decoder.trafficgen_tokenizer.detokenize( + data_dict, sampled_action, agent_type=new_agent_type, offset_action=new_offset_action + ) + new_pos = predicted_values["position"] + new_head = predicted_values["heading"] + new_vel = predicted_values["velocity"] + new_type = predicted_values["agent_type"] # in 0,1,2 + new_shape = predicted_values["shape"] + new_feature = predicted_values["feature"] + + # sdc_index = data_dict["decoder/sdc_index"].item() + # sdc_speed = data_dict["decoder/agent_velocity"][0, :, sdc_index].norm(dim=1).max() + # speed = new_vel.norm(dim=-1).item() + # SPEED_THRESHOLD = max(sdc_speed / 2, 1.0) + # + # if speed < SPEED_THRESHOLD: + # num_low_speed += 1 + # if discard_low_speed_agent: + # continue + + if decode_step == 0 and model.config.FORCE_SDC_FOR_TRAFFICGEN: + if data_dict["decoder/agent_position"].shape[1] > 150: + current_t = 0 + else: + current_t = 10 + assert B == 1 + sdc_index = data_dict["decoder/sdc_index"][0].item() + sdc_center = data_dict["decoder/agent_position"][:, current_t, sdc_index] + map_to_sdc_dist = (data_dict["encoder/map_position"][0][..., :2] - sdc_center[0, :2]).norm(dim=-1) + map_to_sdc_dist_valid_mask = data_dict["encoder/map_valid_mask"].clone() + map_to_sdc_dist_valid_mask = ( + map_to_sdc_dist_valid_mask & (data_dict["encoder/map_feature"][:, :, 0, 13] == 1) + ) + map_to_sdc_dist[~map_to_sdc_dist_valid_mask[0]] = 1e6 + map_argmin = map_to_sdc_dist.argmin() + map_min = map_to_sdc_dist.min() + sampled_action = map_argmin.unsqueeze(0).unsqueeze(-1) + + new_pos = sdc_center[:, :2].reshape(B, 1, 2) + new_head = data_dict["decoder/agent_heading"][:, current_t, sdc_index].unsqueeze(1) + new_vel = data_dict["decoder/agent_velocity"][:, current_t, sdc_index].unsqueeze(1) + new_type = data_dict["decoder/agent_type"][:, sdc_index].unsqueeze(1) + new_shape = data_dict["decoder/current_agent_shape"][:, sdc_index].unsqueeze(1) + new_feature = data_dict["decoder/input_action_feature_for_trafficgen"][:, sdc_index].unsqueeze(1) + + # print("SPEED: {}, Threshold: {}".format(speed, SPEED_THRESHOLD)) + no_coll = detect_collision_for_new_agent( + agent_position=agent_position, + agent_shape=agent_shape, + agent_heading=agent_heading, + new_pos=new_pos, + new_head=new_head, + new_shape=new_shape, + input_action_valid_mask=input_action_valid_mask, + is_end=is_end, + ) + if not no_coll: + num_collisions += 1 + + if no_coll: + step_data_dict = dict( + input_action=torch.cat([input_action, sampled_action], dim=1), + input_action_valid_mask=torch.cat([input_action_valid_mask, ~is_end], dim=1), + agent_position=torch.cat([agent_position, new_pos], dim=1), + agent_heading=torch.cat([agent_heading, new_head], dim=1), + agent_velocity=torch.cat([agent_velocity, new_vel], dim=1), + agent_type=torch.cat([agent_type, new_type], dim=1), + agent_shape=torch.cat([agent_shape, new_shape], dim=1), + agent_feature=torch.cat([agent_feature, new_feature], dim=1), + # num_collisions=num_collisions, + # num_low_speed=num_low_speed, + input_step=torch.cat([input_step, input_step[:, -1:] + 1], dim=1), + ) + step_info = dict( + # num_collisions=num_collisions, + terminated=not step_data_dict["input_action_valid_mask"][0, -1].item(), + # num_low_speed=num_low_speed, + ) + return step_data_dict, step_info + + +def detect_collision_for_new_agent( + *, agent_position, agent_heading, agent_shape, new_pos, new_head, new_shape, input_action_valid_mask, is_end +): + assert agent_position.ndim == 3 + assert agent_position.shape[0] == 1 + # Check if collision happens: + existing_contours = cal_polygon_contour( + x=agent_position[0, :, 0].cpu().numpy(), + y=agent_position[0, :, 1].cpu().numpy(), + theta=agent_heading[0, :].cpu().numpy(), + width=agent_shape[0, :, 1].cpu().numpy(), + length=agent_shape[0, :, 0].cpu().numpy() + ) # (N, 4, 2) + new_contour = cal_polygon_contour( + x=new_pos[0, :, 0].cpu().numpy(), + y=new_pos[0, :, 1].cpu().numpy(), + theta=new_head[0, :].cpu().numpy(), + width=new_shape[0, :, 1].cpu().numpy(), + length=new_shape[0, :, 0].cpu().numpy() + ) + if existing_contours.shape[0] == 1: + no_coll = True # Skip first one (it's the START_ACTION) + else: + no_coll = True + for existing_id in range(1, existing_contours.shape[0]): + collision_detected = detect_collision( + [existing_contours[existing_id]], # (N, 4, 2) + [input_action_valid_mask[0][existing_id]], # (N,) + new_contour, + ~is_end[0], + ) + if collision_detected[0]: + no_coll = False + break + return no_coll + + +def convert_initial_states_as_motion_data(data_dict): + num_tg_agents = data_dict["decoder/modeled_agent_position_for_trafficgen"].shape[1] - 1 + from scenestreamer.tokenization.motion_tokenizers import START_ACTION as MOTION_START_ACTION, get_relative_velocity + B = data_dict["decoder/agent_position"].shape[0] + device = data_dict["decoder/agent_position"].device + data_dict.update( + { + # Agent features + "decoder/agent_position": data_dict["decoder/modeled_agent_position_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents, 2), + "decoder/modeled_agent_position": data_dict["decoder/modeled_agent_position_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents, 2), + "decoder/agent_heading": data_dict["decoder/modeled_agent_heading_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents), + "decoder/modeled_agent_heading": data_dict["decoder/modeled_agent_heading_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents), + "decoder/agent_velocity": data_dict["decoder/modeled_agent_velocity_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents, 2), + "decoder/modeled_agent_velocity": data_dict["decoder/modeled_agent_velocity_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents, 2), + "decoder/agent_valid_mask": data_dict["decoder/input_action_valid_mask_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents), + "decoder/current_agent_shape": data_dict["decoder/current_agent_shape_for_trafficgen"] + [:, 1:].reshape(B, num_tg_agents, 3), + "decoder/modeled_agent_delta": get_relative_velocity( + vel=data_dict["decoder/modeled_agent_velocity_for_trafficgen"][:, 1:].reshape(B, 1, num_tg_agents, 2), + heading=data_dict["decoder/modeled_agent_heading_for_trafficgen"][:, 1:].reshape(B, 1, num_tg_agents), + ), + "decoder/agent_type": data_dict["decoder/agent_type_for_trafficgen"][:, 1:].reshape(B, num_tg_agents), + "decoder/agent_id": torch.arange(num_tg_agents, dtype=torch.long).unsqueeze(0).repeat(B, 1).to(device), + "decoder/agent_shape": data_dict["decoder/current_agent_shape_for_trafficgen"] + [:, 1:].reshape(B, 1, num_tg_agents, 3), # This data has temporal information + + # Action + "decoder/input_action": torch.full([B, 1, num_tg_agents], MOTION_START_ACTION, dtype=torch.long).to(device), + "decoder/current_agent_valid_mask": torch.full([B, num_tg_agents], True, dtype=torch.bool).to(device), + "decoder/input_action_valid_mask": torch.full([B, 1, num_tg_agents], True, dtype=torch.bool).to(device), + } + ) + return data_dict diff --git a/scenestreamer/infer/motion.py b/scenestreamer/infer/motion.py new file mode 100644 index 0000000000000000000000000000000000000000..b01f89147b5811c4857500d47f464b9c02f7cad1 --- /dev/null +++ b/scenestreamer/infer/motion.py @@ -0,0 +1,669 @@ +""" +This module reimplements the autoregressive motion generation process. +""" + +import copy + +import torch +from scenestreamer.tokenization.motion_tokenizers import END_ACTION +from scenestreamer.tokenization.motion_tokenizers import interpolate, interpolate_heading +from scenestreamer.utils import utils + +import numpy as np + +def generate_motion( + *, + data_dict, + model, + autoregressive_start_step, + allow_newly_added_agent_step=None, + temperature=None, + topp=None, + num_decode_steps=None, + sampling_method=None, + interpolation=True, + remove_out_of_map_agent=False, + remove_static_agent=False, + teacher_forcing_sdc=False, +): + assert model.training is False, "This function is only for evaluation!" + data_dict = copy.deepcopy(data_dict) + + if allow_newly_added_agent_step is None: + allow_newly_added_agent_step = autoregressive_start_step + assert allow_newly_added_agent_step >= autoregressive_start_step + tokenizer = model.tokenizer + if temperature is None: + temperature = model.config.SAMPLING.TEMPERATURE + if topp is None: + topp = model.config.SAMPLING.TOPP + if sampling_method is None: + sampling_method = model.config.SAMPLING.SAMPLING_METHOD + B, T_input, N = data_dict["decoder/input_action"].shape[:3] + if num_decode_steps is None: + num_decode_steps = 19 + # assert start_action_step + T_input == num_decode_steps # Might not be True in waymo test set. + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"][:, ::tokenizer.num_skipped_steps] + agent_heading = data_dict["decoder/agent_heading"][:, ::tokenizer.num_skipped_steps] + agent_valid_mask = data_dict["decoder/agent_valid_mask"][:, ::tokenizer.num_skipped_steps] + agent_velocity = data_dict["decoder/agent_velocity"][:, ::tokenizer.num_skipped_steps] + B, T_full, N, _ = agent_pos.shape + gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + assert agent_pos.ndim == 4 + gt_input_action = data_dict["decoder/input_action"].clone() + if autoregressive_start_step > 0 or teacher_forcing_sdc: + gt_target_action = data_dict["decoder/target_action"].clone() + gt_target_valid_mask = data_dict["decoder/target_action_valid_mask"].clone() + + # ===== Initialize the state ===== + data_dict, _ = randomize_agent_id(data_dict=data_dict, model=model) + step_data_dict = dict( + input_step=torch.arange(1).to(gt_input_action.device), + input_action=gt_input_action[:, :1].clone(), + input_action_valid_mask=data_dict["decoder/input_action_valid_mask"][:, :1].clone(), + agent_position=data_dict["decoder/modeled_agent_position"][:, :1].clone(), + agent_heading=data_dict["decoder/modeled_agent_heading"][:, :1].clone(), + agent_velocity=data_dict["decoder/modeled_agent_velocity"][:, :1].clone(), # TODO: Remove this? + agent_valid_mask=data_dict["decoder/input_action_valid_mask"][:, :1].clone(), + agent_delta=data_dict["decoder/modeled_agent_delta"][:, :1].clone(), + cache=None, + agent_id=data_dict["decoder/randomized_modeled_agent_id"], + agent_type=data_dict["decoder/agent_type"], + agent_shape=data_dict["decoder/current_agent_shape"], + + decode_step=0, + ) + if model.config.USE_DESTINATION: + step_data_dict["agent_destination"] = data_dict["decoder/dest_map_index"][:, :1].clone() + output_logit_list = [] + output_action_list = [] + input_action_valid_mask_list = [step_data_dict["input_action_valid_mask"]] + pos = [step_data_dict["agent_position"]] + head = [step_data_dict["agent_heading"]] + vel = [step_data_dict["agent_velocity"]] + decode_error_rate = [] + + # ===== Run motion generation ===== + data_dict, _ = encode_scene(data_dict=data_dict, model=model) + for decode_step in range(num_decode_steps): + if decode_step < autoregressive_start_step: + # Overwrite the action by GT action + # teacher_forcing_valid_mask = torch.ones_like(step_data_dict["input_action_valid_mask"]) + teacher_forcing_valid_mask = ( + step_data_dict["input_action_valid_mask"].clone() & gt_target_valid_mask[:, decode_step:decode_step + 1] + ) + teacher_forcing_action = gt_target_action[:, decode_step:decode_step + 1] + + assert gt_target_valid_mask[:, decode_step:decode_step + 1][teacher_forcing_valid_mask].all() + + else: + teacher_forcing_valid_mask = None + teacher_forcing_action = None + + if teacher_forcing_sdc: + assert data_dict["decoder/sdc_index"][0] == 0 + if teacher_forcing_valid_mask is None: + teacher_forcing_valid_mask = torch.zeros_like(step_data_dict["input_action_valid_mask"]) + assert teacher_forcing_valid_mask.shape == (B, 1, N) + teacher_forcing_valid_mask[:, :, 0] = 1 + teacher_forcing_valid_mask = teacher_forcing_valid_mask & step_data_dict["input_action_valid_mask"] + teacher_forcing_valid_mask = teacher_forcing_valid_mask & gt_target_valid_mask[:, + decode_step:decode_step + 1] + step_data_dict["agent_valid_mask"] = torch.where( + teacher_forcing_valid_mask, + step_data_dict["agent_valid_mask"] & gt_target_valid_mask[:, decode_step:decode_step + 1], + step_data_dict["agent_valid_mask"] + ) + step_data_dict["input_action_valid_mask"] = step_data_dict["agent_valid_mask"] + if teacher_forcing_action is None: + teacher_forcing_action = gt_target_action[:, decode_step:decode_step + 1] + + + assert step_data_dict["decode_step"] == decode_step + + next_state_data_dict, decode_one_step_info = decode_one_step( + data_dict=data_dict, + model=model, + sampling_method=sampling_method, + temperature=temperature, + topp=topp, + teacher_forcing_valid_mask=teacher_forcing_valid_mask, + teacher_forcing_action=teacher_forcing_action, + remove_out_of_map_agent=remove_out_of_map_agent, + **step_data_dict, + remove_static_agent=remove_static_agent, + ) + + if decode_step < allow_newly_added_agent_step: + new_agent_valid_mask = agent_valid_mask[:, decode_step + 1:decode_step + + 2] & (~step_data_dict["agent_valid_mask"]) + next_state_data_dict, decode_one_step_info = add_new_agent( + step_data_dict=next_state_data_dict, + step_info=decode_one_step_info, + new_agent_valid_mask=new_agent_valid_mask, + new_agent_pos=agent_pos[:, decode_step + 1:decode_step + 2, ..., :2], + new_agent_heading=agent_heading[:, decode_step + 1:decode_step + 2], + new_agent_velocity=agent_velocity[:, decode_step + 1:decode_step + 2], + new_agent_delta=gt_agent_delta[:, decode_step + 1:decode_step + 2], + new_action=gt_input_action[:, decode_step + 1:decode_step + 2], + ) + + pos.append(next_state_data_dict["agent_position"].clone()) + head.append(next_state_data_dict["agent_heading"].clone()) + vel.append(next_state_data_dict["agent_velocity"].clone()) + if decode_one_step_info["output_token"] is not None: + output_logit_list.append(decode_one_step_info["output_token"].clone()) + output_action_list.append(next_state_data_dict["input_action"].clone()) + input_action_valid_mask_list.append(next_state_data_dict["input_action_valid_mask"].clone()) + step_data_dict = next_state_data_dict + if "error_rate" in decode_one_step_info: + decode_error_rate.append(decode_one_step_info["error_rate"]) + + # ===== Post-process the data ===== + + if output_action_list[0].ndim == 4: + max_seq_len = max([x.shape[-1] for x in output_action_list]) + output_action_list = [ + torch.nn.functional.pad(output_action_list[i], (0, max_seq_len - output_action_list[i].shape[-1]), value=-1) + for i in range(len(output_action_list)) + ] + output_action_list = torch.concatenate(output_action_list, dim=1) + + elif output_action_list[0].ndim == 3: + output_action_list = torch.concatenate(output_action_list, dim=1) + + else: + raise ValueError("Invalid output_action_list shape: {}".format(output_action_list[0].shape)) + + assert output_action_list.shape[:3] == (B, num_decode_steps, N) + assert len(input_action_valid_mask_list) == num_decode_steps + 1 + # Evict the last step's input_action_valid_mask_list as it is not used. + input_action_valid_mask_list = input_action_valid_mask_list[:-1] + input_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + + if output_logit_list: + output_logit_list = torch.concatenate(output_logit_list, dim=1) + traj_log_prob, traj_prob = utils.calculate_trajectory_probabilities_new( + output_logit_list, output_action_list, mask=input_action_valid_mask + ) # (B, N) + data_dict["decoder/output_score"] = traj_log_prob + + else: + data_dict["decoder/output_score"] = torch.zeros((B, N), dtype=torch.float32, device=output_action_list.device) + + pos = torch.cat(pos, dim=1) + head = torch.cat(head, dim=1) + vel = torch.cat(vel, dim=1) + + # ===== Interpolate the output ===== + if interpolation: + data_dict, _ = interpolate_autoregressive_output( + data_dict=data_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + input_valid_mask=input_action_valid_mask, + num_skipped_steps=tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + ) + + # ===== Save the data ===== + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + # data_dict["decoder/debug_ar_pos"] = pos + # data_dict["decoder/debug_ar_head"] = head + # data_dict["decoder/debug_ar_vel"] = vel + + if decode_error_rate: + print("ERROR RATE:", np.mean(decode_error_rate)) + + valid_output_action = output_action_list[input_action_valid_mask] + if valid_output_action.ndim == 1: + assert valid_output_action.max() < tokenizer.num_actions + assert valid_output_action.min() >= 0 + elif valid_output_action.ndim == 2: + assert valid_output_action.max() < tokenizer.num_actions + if valid_output_action.amax(dim=-1).min() < 0: + print("WARNING: Invalid action detected in valid_output_action", valid_output_action.amax(dim=-1)) + + return data_dict + + +def encode_scene(*, data_dict, model): + if "encoder/scenario_token" not in data_dict: + data_dict = model.encode_scene(data_dict) + return data_dict, {} + + +def randomize_agent_id(*, data_dict, model, clip_agent_id=True): + if "decoder/randomized_modeled_agent_id" not in data_dict: + data_dict["decoder/randomized_modeled_agent_id"] = model.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=clip_agent_id + ) + return data_dict, {} + + +def decode_one_step( + *, + data_dict, + model, + input_step, + input_action, + input_action_valid_mask, + agent_position, + agent_heading, + agent_velocity, + agent_valid_mask, + agent_delta, + agent_shape, + agent_type, + agent_id, + sampling_method, + temperature, + topp, + teacher_forcing_valid_mask, + teacher_forcing_action, + agent_destination=None, + cache=None, + agent_position_history=None, + agent_heading_history=None, + agent_valid_mask_history=None, + agent_step_history=None, + remove_out_of_map_agent=False, + remove_static_agent=False, + decode_step=None +): + B = data_dict["decoder/modeled_agent_position"].shape[0] + input_dict = { + # Static encoder features + "encoder/scenario_token": data_dict["encoder/scenario_token"], + "encoder/scenario_heading": data_dict["encoder/scenario_heading"], + "encoder/scenario_position": data_dict["encoder/scenario_position"], + "encoder/scenario_valid_mask": data_dict["encoder/scenario_valid_mask"], + "encoder/map_position": data_dict["encoder/map_position"], + "in_evaluation": torch.ones([B], dtype=torch.bool), + + # Actions + "decoder/input_step": input_step, + "decoder/input_action": input_action, + "decoder/input_action_valid_mask": input_action_valid_mask, + + # Agent features + "decoder/modeled_agent_position": agent_position, + "decoder/modeled_agent_heading": agent_heading, + "decoder/modeled_agent_velocity": agent_velocity, + "decoder/modeled_agent_valid_mask": agent_valid_mask, + "decoder/modeled_agent_delta": agent_delta, + "decoder/current_agent_shape": agent_shape, + "decoder/agent_type": agent_type, + "decoder/randomized_modeled_agent_id": agent_id, + } + + if agent_destination is not None: + assert decode_step is not None + # TODO: This is a workaround to update the destination following GT data. + agent_destination = data_dict["decoder/dest_map_index"][:, decode_step:decode_step + 1] + + input_dict["decoder/dest_map_index"] = agent_destination + + assert (agent_valid_mask == input_action_valid_mask).all() + + if cache is not None: + input_dict.update( + { + "decoder/cache": cache, + "decoder/modeled_agent_position_history": agent_position_history, + "decoder/modeled_agent_heading_history": agent_heading_history, + "decoder/modeled_agent_valid_mask_history": agent_valid_mask_history, + "decoder/modeled_agent_step_history": agent_step_history, + } + ) + assert not (input_action == END_ACTION).any() + + # Decode motion tokens + output_dict = model.decode_motion(input_dict, use_cache=True) + + if model.config.TOKENIZATION.TOKENIZATION_METHOD == "fast": + selected_action = output_dict["decoder/output_token"] + output_token = None + + selected_action = selected_action.masked_fill(selected_action >= model.tokenizer.fast_tokenizer.vocab_size, -1) + + selected_action = torch.where(input_action_valid_mask.unsqueeze(-1), selected_action, -1) + + else: + output_token = output_dict["decoder/output_logit"] + selected_action, sampling_info = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + # Remove invalid actions + # assert selected_action.shape == input_action.shape + # correct_selected_action = torch.where(input_action_valid_mask, selected_action, -1) + selected_action = torch.where(input_action_valid_mask, selected_action, -1) + + if teacher_forcing_valid_mask is not None: + + if model.config.TOKENIZATION.TOKENIZATION_METHOD == "fast": + + assert teacher_forcing_action.shape[:3] == selected_action.shape[:3] + + teacher_forcing_action, selected_action = pad_sequences( + teacher_forcing_action, selected_action, x_value=-1, y_value=-1 + ) + selected_action = torch.where( + teacher_forcing_valid_mask[..., None], teacher_forcing_action, selected_action + ) + + else: + assert teacher_forcing_action.shape == selected_action.shape + selected_action = torch.where(teacher_forcing_valid_mask, teacher_forcing_action, selected_action) + # correct_selected_action = torch.where(teacher_forcing_valid_mask, teacher_forcing_action, correct_selected_action) + output_token[teacher_forcing_valid_mask] = 0 + + tokenizer = model.tokenizer + res = tokenizer.detokenize_step( + current_pos=agent_position, + current_heading=agent_heading, + current_valid_mask=agent_valid_mask, + current_vel=agent_velocity, + action=selected_action, + agent_type=agent_type, + ) + + # debug_err = (tokenizer.detokenize_step( + # current_pos=agent_position, + # current_heading=agent_heading, + # current_valid_mask=agent_valid_mask, + # current_vel=agent_velocity, + # action=selected_action, + # )['pos'] - tokenizer.detokenize_step( + # current_pos=agent_position, + # current_heading=agent_heading, + # current_valid_mask=agent_valid_mask, + # current_vel=agent_velocity, + # action=correct_selected_action, + # )['pos']).norm(dim=-1) + # + # assert (debug_err==0).all() + + B, _, N = input_action.shape[:3] + current_pos = res["pos"].reshape(B, 1, N, 2) + current_heading = res["heading"].reshape(B, 1, N) + current_vel = res["vel"].reshape(B, 1, N, 2) + current_delta = res["delta_pos"].reshape(B, 1, N, 2) + current_model_step = input_step + 1 + current_input_action = selected_action + + current_valid_mask = agent_valid_mask + + next_step_data_dict = dict( + input_step=current_model_step, + input_action=current_input_action, + input_action_valid_mask=current_valid_mask, + agent_position=current_pos, + agent_heading=current_heading, + agent_velocity=current_vel, + agent_valid_mask=current_valid_mask, + agent_delta=current_delta, + agent_id=agent_id, + agent_type=agent_type, + agent_shape=agent_shape, + cache=output_dict["decoder/cache"], + agent_position_history=output_dict["decoder/modeled_agent_position_history"], + agent_heading_history=output_dict["decoder/modeled_agent_heading_history"], + agent_valid_mask_history=output_dict["decoder/modeled_agent_valid_mask_history"], + agent_step_history=output_dict["decoder/modeled_agent_step_history"], + decode_step=decode_step+1 + ) + if agent_destination is not None: + next_step_data_dict["agent_destination"] = agent_destination + info_dict = dict(output_token=output_token) + if "error_rate_full" in res: + info_dict["error_rate"] = res["error_rate_full"].mean() + + next_step_data_dict, info_dict = evict_agents( + data_dict=data_dict, + step_data_dict=next_step_data_dict, + step_info_dict=info_dict, + remove_static_agent=remove_static_agent, + remove_out_of_map_agent=remove_out_of_map_agent + ) + + assert_motion_step_data_dict(step_data_dict=next_step_data_dict, step_info=info_dict) + return next_step_data_dict, info_dict + + +def sample_action(logits, sampling_method, temperature, topp): + # Sample the action + info = {} + if sampling_method == "argmax": + selected_action = logits.argmax(-1) + elif sampling_method == "softmax": + selected_action = torch.distributions.Categorical(logits=logits / temperature).sample() + elif sampling_method == "topp": + selected_action, info = nucleus_sampling(logits=logits / temperature, p=topp) + elif sampling_method == "topk": + candidates = logits.topk(5, dim=-1).indices + selected_action = torch.gather( + candidates, index=torch.randint(0, 5, size=candidates.shape[:-1])[..., None].to(candidates), dim=-1 + ).squeeze(-1) + else: + raise ValueError("Unknown sampling method: {}".format(sampling_method)) + return selected_action, info + + +def nucleus_sampling(logits, p=None, epsilon=1e-8): + p = p or 0.9 + + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits).fill_(-1e9), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits).fill_(-1e9), logits) + + # Convert logits to probabilities + probs = torch.softmax(logits, dim=-1) + + # Sort the probabilities to identify the top-p cutoff + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative probability above the threshold p + cutoff_index = cumulative_probs > p + # Shift the mask to the right to keep the first token above the threshold + cutoff_index[..., 1:] = cutoff_index[..., :-1].clone() + cutoff_index[..., 0] = False + + # Zero out the probabilities for tokens not in the top-p set + sorted_probs.masked_fill_(cutoff_index, 0) + + # Recover the original order of the probabilities + original_probs = torch.zeros_like(probs) + original_probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) + sampled_token_index = torch.distributions.Categorical(probs=original_probs).sample() + return sampled_token_index, {"cutoff_index": cutoff_index} + + +def add_new_agent( + *, step_data_dict, step_info, new_agent_valid_mask, new_agent_pos, new_agent_heading, new_agent_velocity, + new_agent_delta, new_action +): + if new_agent_valid_mask is None or not new_agent_valid_mask.any(): + return step_data_dict, step_info + + B, T, N = new_agent_valid_mask.shape + assert new_agent_pos.shape == (B, T, N, 2) + assert new_agent_heading.shape == (B, T, N) + assert new_agent_velocity.shape == (B, T, N, 2) + assert new_agent_delta.shape == (B, T, N, 2) + + current_pos = step_data_dict["agent_position"] + current_heading = step_data_dict["agent_heading"] + current_vel = step_data_dict["agent_velocity"] + current_valid_mask = step_data_dict["agent_valid_mask"] + current_delta = step_data_dict["agent_delta"] + + mask_2d = new_agent_valid_mask[..., None].expand_as(new_agent_pos) + current_pos = torch.where(mask_2d, new_agent_pos, current_pos) + current_heading = torch.where(new_agent_valid_mask, new_agent_heading, current_heading) + current_vel = torch.where(mask_2d, new_agent_velocity, current_vel) + current_valid_mask = torch.where(new_agent_valid_mask, new_agent_valid_mask, current_valid_mask) + current_delta = torch.where(mask_2d, new_agent_delta, current_delta) + + step_data_dict["agent_position"] = current_pos + step_data_dict["agent_heading"] = current_heading + step_data_dict["agent_velocity"] = current_vel + step_data_dict["agent_valid_mask"] = current_valid_mask + step_data_dict["agent_delta"] = current_delta + + if new_action.ndim == 4: + # Variable length action + new_action, old_action = pad_sequences(new_action, step_data_dict["input_action"], x_value=-1, y_value=-1) + step_data_dict["input_action"] = torch.where(new_agent_valid_mask[..., None], new_action, old_action) + elif new_action.ndim == 3: + step_data_dict["input_action"] = torch.where(new_agent_valid_mask, new_action, step_data_dict["input_action"]) + else: + raise ValueError("Invalid new_action shape: {}".format(new_action.shape)) + step_data_dict["input_action_valid_mask"] = current_valid_mask + + output_token = step_info["output_token"] + if output_token is not None: + output_token = torch.where( + new_agent_valid_mask[..., None].expand_as(output_token), torch.zeros_like(output_token), output_token + ) + step_info["output_token"] = output_token + + assert_motion_step_data_dict(step_data_dict=step_data_dict, step_info=step_info) + + return step_data_dict, step_info + + +def interpolate_autoregressive_output( + *, data_dict, num_skipped_steps, num_decoded_steps, agent_position, agent_heading, agent_velocity, input_valid_mask +): + B, _, N, _ = agent_position.shape + T_generated_chunks = num_decoded_steps + reconstructed_pos = interpolate(agent_position, num_skipped_steps, remove_first_step=False) + assert (reconstructed_pos[:, ::5] == agent_position).all() + reconstructed_heading = interpolate_heading(agent_heading, num_skipped_steps, remove_first_step=False) + reconstructed_vel = interpolate(agent_velocity, num_skipped_steps, remove_first_step=False) + + valid = input_valid_mask.reshape(B, -1, 1, N).expand(-1, -1, num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_valid_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == num_skipped_steps * T_generated_chunks + 1 + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + + return data_dict, {} + + +def evict_agents( + *, + data_dict, + step_data_dict, + step_info_dict, + max_distance=10, + remove_static_agent=False, + remove_out_of_map_agent=False +): + # Get scene token: + # in_evaluation = input_dict["in_evaluation"][0].item() + # scene_token = input_dict["encoder/scenario_token"] + # B, M, _ = input_dict["encoder/map_position"].shape + # action = action.clone() + + should_evict = None + + if remove_out_of_map_agent: + map_position = data_dict["encoder/map_position"][..., :2] + agent_position = step_data_dict["agent_position"] + assert agent_position.ndim == 4 + agent_position = agent_position[:, 0] + + dist = torch.cdist(agent_position, map_position) + min_dist = dist.min(dim=-1).values + + should_evict = min_dist > max_distance + + if remove_static_agent: + agent_speed = step_data_dict["agent_velocity"].norm(dim=-1)[:, 0] + static_agent = agent_speed < 0.5 + if should_evict is None: + should_evict = static_agent + else: + should_evict = torch.logical_or(should_evict, static_agent) + + if should_evict is None or should_evict.sum().item() == 0: + step_info_dict["evicted_agents"] = 0 + step_info_dict["evicted_agent_mask"] = None + return step_data_dict, step_info_dict + + num_evicted = should_evict.sum().item() + + # We should inform the autoregressive process not to generate action in next step. + # However, current's step's action is still valid (because the input_action_valid_mask for this particular agent + # is valid), hence the outer process is still waiting for the new states of the agents. + # Therefore, we shouldn't mask out these information. + new_mask = step_data_dict["input_action_valid_mask"] & (~should_evict) + step_data_dict["input_action_valid_mask"] = new_mask + + # step_data_dict["input_action"] = torch.where(new_mask, step_data_dict["input_action"], -1) + # step_data_dict["agent_position"] = torch.where(new_mask.unsqueeze(-1), agent_position, 0) + # step_data_dict["agent_heading"] = torch.where(new_mask, step_data_dict["agent_heading"], 0) + # step_data_dict["agent_velocity"] = torch.where(new_mask.unsqueeze(-1), step_data_dict["agent_velocity"], 0) + step_data_dict["agent_valid_mask"] = new_mask + # step_data_dict["agent_delta"] = torch.where(new_mask.unsqueeze(-1), step_data_dict["agent_delta"], 0) + # step_info_dict["output_token"] = torch.where(new_mask.unsqueeze(-1), step_info_dict["output_token"], 0) + + step_info_dict["evicted_agents"] = num_evicted + step_info_dict["evicted_agent_mask"] = should_evict + assert_motion_step_data_dict(step_data_dict, step_info_dict) + + return step_data_dict, step_info_dict + + +def assert_motion_step_data_dict(*, step_data_dict, step_info): + assert "input_step" in step_data_dict + assert "input_action" in step_data_dict + assert "input_action_valid_mask" in step_data_dict + assert "agent_position" in step_data_dict + assert "agent_heading" in step_data_dict + assert "agent_velocity" in step_data_dict + assert "agent_valid_mask" in step_data_dict + assert "agent_delta" in step_data_dict + assert "agent_id" in step_data_dict + assert "agent_type" in step_data_dict + assert "agent_shape" in step_data_dict + + m = step_data_dict["input_action_valid_mask"] + assert (step_data_dict["input_action"][~m] == -1).all() + assert (m == step_data_dict["agent_valid_mask"]).all() + if step_info["output_token"] is not None: + assert (step_info["output_token"][~m] == 0).all() + + +def pad_sequences(x, y, x_value=0, y_value=0): + max_seq_len = max(x.shape[-1], y.shape[-1]) + x = torch.nn.functional.pad(x, (0, max_seq_len - x.shape[-1]), value=x_value) + y = torch.nn.functional.pad(y, (0, max_seq_len - y.shape[-1]), value=y_value) + return x, y diff --git a/scenestreamer/infer/scenestreamer_generator.py b/scenestreamer/infer/scenestreamer_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6db0ed0ec554d2cdec46f81ef9bafdb1ee4aa9dd --- /dev/null +++ b/scenestreamer/infer/scenestreamer_generator.py @@ -0,0 +1,2790 @@ +""" +A newer version of generator. This time we will have a class that maintain necessary state for autoregressive rollout. +""" + +import copy +import pathlib + +import numpy as np +import torch +import tqdm +from shapely.geometry import Polygon + +from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour +from scenestreamer.dataset.preprocessor import NUM_TG_MULTI, TG_SKIP_STEP +from scenestreamer.infer import scenestreamer_motion +from scenestreamer.models.scenestreamer_model import get_num_tg +from scenestreamer.tokenization.motion_tokenizers import START_ACTION as MOTION_START_ACTION +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils + + +def evict_agents_function( + *, + data_dict, + step_data_dict, + step_info_dict, + max_distance=10, + remove_static_agent=False, + remove_out_of_map_agent=False +): + # Get scene token: + # in_evaluation = input_dict["in_evaluation"][0].item() + # scene_token = input_dict["encoder/scenario_token"] + # B, M, _ = input_dict["encoder/map_position"].shape + # action = action.clone() + + should_evict = None + + if remove_out_of_map_agent: + map_position = data_dict["encoder/map_position"][..., :2] + agent_position = step_data_dict["agent_position"] + assert agent_position.ndim == 4 + agent_position = agent_position[:, 0] + + dist = torch.cdist(agent_position, map_position) + min_dist = dist.min(dim=-1).values + + should_evict = min_dist > max_distance + + if remove_static_agent: + agent_speed = step_data_dict["agent_velocity"].norm(dim=-1)[:, 0] + static_agent = agent_speed < 0.5 + if should_evict is None: + should_evict = static_agent + else: + should_evict = torch.logical_or(should_evict, static_agent) + + if should_evict is None or should_evict.sum().item() == 0: + step_info_dict["evicted_agents"] = 0 + step_info_dict["evicted_agent_mask"] = None + return step_data_dict, step_info_dict + + num_evicted = should_evict.sum().item() + + # We should inform the autoregressive process not to generate action in next step. + # However, current's step's action is still valid (because the input_action_valid_mask for this particular agent + # is valid), hence the outer process is still waiting for the new states of the agents. + # Therefore, we shouldn't mask out these information. + new_mask = step_data_dict["input_action_valid_mask"] & (~should_evict) + step_data_dict["input_action_valid_mask"] = new_mask + step_data_dict["agent_valid_mask"] = new_mask + + step_info_dict["evicted_agents"] = num_evicted + step_info_dict["evicted_agent_mask"] = should_evict + + return step_data_dict, step_info_dict + + + + + +# A coding trick here to accumulate for multiple TG tokens before calling "prepare_trafficgen_single_token" +class TGTokenBuffer: + def __init__(self): + self.tg_action = [] + self.tg_type = [] + self.tg_agent_id = [] + self.tg_intra_step = [] + self.tg_feat = [] + self.position = [] + self.heading = [] + self.valid_mask = [] + self.width = [] + self.length = [] + self.causal_mask = [] + self.force_mask = [] + self.current_step = [] + self.require_relation = [] + + def add( + self, *, tg_action, tg_type, tg_agent_id, tg_intra_step, tg_feat, + position, heading, valid_mask, width, length, causal_mask, + force_mask, current_step, require_relation + ): + self.tg_action.append(tg_action) + self.tg_type.append(tg_type) + self.tg_agent_id.append(tg_agent_id) + self.tg_intra_step.append(tg_intra_step) + self.tg_feat.append(tg_feat) + self.position.append(position) + self.heading.append(heading) + self.valid_mask.append(valid_mask) + self.width.append(width) + self.length.append(length) + self.causal_mask.append(causal_mask) + self.force_mask.append(force_mask) + self.current_step.append(current_step) + self.require_relation.append(require_relation) + + def append_to_scenestreamer_tokens(self, *, model, scenestreamer_tokens): + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.cat(self.tg_action, dim=1), + tg_type=torch.cat(self.tg_type, dim=1), + tg_agent_id=torch.cat(self.tg_agent_id, dim=1), + tg_intra_step=torch.cat(self.tg_intra_step, dim=1), + tg_feat=torch.cat(self.tg_feat, dim=1), + ) + assert self.current_step[0] == self.current_step[-1] + + num_new_keys = self.causal_mask[-1].shape[-1] + B = self.causal_mask[-1].shape[0] + N = len(self.causal_mask) + new_all_causal_mask = self.causal_mask[0].new_zeros(B, N, num_new_keys) + for i in range(N): + new_all_causal_mask[:, i:i+1, :self.causal_mask[i].shape[2]] = self.causal_mask[i] + + new_all_force_mask = self.force_mask[0].new_zeros(B, N, num_new_keys) + for i in range(N): + new_all_force_mask[:, i:i+1, :self.force_mask[i].shape[2]] = self.force_mask[i] + + scenestreamer_tokens.add( + token=tg_token, + position=torch.cat(self.position, dim=1), + heading=torch.cat(self.heading, dim=1), + valid_mask=torch.cat(self.valid_mask, dim=1), + width=torch.cat(self.width, dim=1), + length=torch.cat(self.length, dim=1), + causal_mask=new_all_causal_mask, + current_step=self.current_step[0], + require_relation=torch.cat(self.require_relation, dim=1), + force_mask=new_all_force_mask + ) + + # import matplotlib.pyplot as plt + # vis = new_all_force_mask[0].cpu().numpy() + # plt.imshow(vis) + # + # import matplotlib.pyplot as plt + # plt.figure() + # vis = (new_all_causal_mask|new_all_force_mask)[0].cpu().numpy() + # plt.imshow(vis) + + + + +class SceneStreamerGenerator: + STATE_START = 0 + STATE_TRAFFICLIGHT_DONE = 1 + STATE_TRAFFICGEN_DONE = 2 + STATE_MOTION_DONE = 3 + STATE_TRAFFICGEN_SKIPPED = 4 + + def __init__(self, model, device): + # self.env = env + self.model = model + self.config = model.config + self.state = None + self.device = device + self.keep_output_token = False + + def reset(self, new_sd=None, new_data_dict=None): + self.raw_data_dict = copy.deepcopy(new_data_dict) + self.state = self.STATE_START + self.current_step = 0 + + model = self.model + data_dict = self.raw_data_dict + # assert teacher_forcing_dest is not None, "Please set teacher_forcing_dest to True or False" + # ===== Some preprocessing ===== + self.topp = model.config.SAMPLING.TOPP + self.temperature = model.config.SAMPLING.TEMPERATURE + self.sampling_method = model.config.SAMPLING.SAMPLING_METHOD + B, T_input, N = data_dict["decoder/input_action"].shape[:3] + num_decode_steps = 19 + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + + # ===== Encode scenes ===== + data_dict, _ = scenestreamer_motion.encode_scene(data_dict=data_dict, model=model) + + # ===== Create a temporary input_dict removing the future information ===== + _, _, L = data_dict["encoder/traffic_light_state"].shape + self.scenestreamer_tokens = None + self.step_info_dict = {} + + # TODO: Are they still correct if we are using TG?? + G = get_num_tg(N) + all_token_casual_mask = model._build_all_tokens_mask( + B=B, T=num_decode_steps, num_tl=L, num_tg=G, num_motion=N + ).to(data_dict["decoder/input_action"].device) + self.all_token_casual_mask = all_token_casual_mask + all_force_mask = model._build_all_force_mask( + B=B, T=num_decode_steps, num_tl=L, num_tg=G, num_motion=N + ).to(data_dict["decoder/input_action"].device) + self.all_force_mask = all_force_mask + + def _tg_generate_agent_agent_state( + self, *, agent_id, agent_type, tg_intra_step, tg_input_action, agent_valid_mask, + teacher_forcing_agent_state, + + selected_map_pos, + selected_map_heading, + ): + assert agent_id.shape == (self.B, 1) + assert agent_type.shape == (self.B, 1) + assert tg_input_action.shape == (self.B, 1) + assert agent_valid_mask.shape == (self.B, 1) + + model = self.model + B = self.B + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + scenestreamer_tokens = self.scenestreamer_tokens + device = self.device + current_step = self.current_step + + tg_token = model.prepare_trafficgen_single_token( + tg_action=tg_input_action, + tg_type=agent_type, + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), tg_intra_step, device=device), + tg_feat=torch.zeros((B, 1, 8), device=device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=selected_map_pos, + heading=selected_map_heading, + valid_mask=agent_valid_mask, + width=torch.full((B, 1), 0.0, device=device), + length=torch.full((B, 1), 0.0, device=device), + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=agent_valid_mask, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + ) + + # Call model if you don't want to teacher forcing: + if teacher_forcing_agent_state: + return None, {} + else: + assert agent_valid_mask.all() + output_dict = scenestreamer_tokens.call_model_with_cache(keep_output_token=self.keep_output_token) + output_token = output_dict["model/all_token"][:, -1:] + # call pred head to get agent feat. + + output_token = model.trafficgen_prenorm(output_token) + + z = output_token.reshape(self.B, -1) + offset_output = model.trafficgen_head.generate(z=z) + + assert offset_output.shape == (B, 9) + + offset_action = offset_output[:, 1:] + assert offset_action.shape == (B, 8), "offset_action shape: {}".format(offset_action.shape) + offset_action = { + "length": offset_action[:, 0].reshape(B, 1), + "width": offset_action[:, 1].reshape(B, 1), + "height": offset_action[:, 2].reshape(B, 1), + "position_x": offset_action[:, 3].reshape(B, 1), + "position_y": offset_action[:, 4].reshape(B, 1), + "heading": offset_action[:, 5].reshape(B, 1), + "velocity_x": offset_action[:, 6].reshape(B, 1), + "velocity_y": offset_action[:, 7].reshape(B, 1), + } + agent_state_output = self.model.trafficgen_tokenizer.detokenize( + data_dict=self.raw_data_dict, action=tg_input_action, agent_type=agent_type, offset_action=offset_action + ) + return agent_state_output, {} + + def _get_map_pos_head(self, index: torch.Tensor): + assert index.numel() == self.B + M = self.raw_data_dict["model/map_token_position"].shape[1] + assert index.min() >= 0, "index: {}, M: {}".format(index, M) + assert index.max() < M, "index: {}, M: {}".format(index, M) + pos = torch.gather( + self.raw_data_dict["model/map_token_position"][..., :2], + index=index.reshape(self.B, 1, 1).expand(self.B, 1, 2), + dim=1 + ) + pos = pos.reshape(self.B, 1, 2) + heading = torch.gather( + self.raw_data_dict["model/map_token_heading"], + index=index.reshape(self.B, 1), + dim=1 + ) + heading = heading.reshape(self.B, 1) + return pos, heading + + def _tg_generate_agent_map_id( + self, *, agent_id, agent_type, tg_intra_step, tg_input_action, agent_valid_mask, teacher_forcing_map_id, + sdc_position, + ): + assert (agent_type == -1).all() + assert agent_id.shape == (self.B, 1) + assert agent_type.shape == (self.B, 1) + assert tg_input_action.shape == (self.B, 1) + assert agent_valid_mask.shape == (self.B, 1) + assert tg_input_action.min() >= self.model.veh_id + assert tg_input_action.max() <= self.model.cyc_id + + model = self.model + B = self.B + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + scenestreamer_tokens = self.scenestreamer_tokens + device = self.device + current_step = self.current_step + + tg_token = model.prepare_trafficgen_single_token( + tg_action=tg_input_action, + tg_type=agent_type, + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), tg_intra_step, device=device), + tg_feat=torch.zeros((B, 1, 8), device=device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=torch.zeros((B, 1, 2), device=device), + heading=torch.zeros((B, 1), device=device), + valid_mask=agent_valid_mask, + width=torch.full((B, 1), 0.0, device=device), + length=torch.full((B, 1), 0.0, device=device), + causal_mask=tg_causal_mask, + current_step=current_step, + # <<< When generating map id, input is agent type. so no relation is required. + require_relation=torch.full((B, 1), False, device=device), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + ) + # Call model if you don't want to teacher forcing: + if teacher_forcing_map_id: + return None, {} + + else: + assert agent_valid_mask.all() + output_dict = scenestreamer_tokens.call_model_with_cache(keep_output_token=self.keep_output_token) + output_token = output_dict["model/all_token"][:, -1:] + # call pred head to get agent feat. + + output_token = model.trafficgen_prenorm(output_token) + map_id_logit = model.trafficgen_head.map_id_head(output_token) + assert map_id_logit.shape[1] == 1 + + map_id_logit_mask = torch.full((B, 1, map_id_logit.shape[-1]), False, device=agent_type.device, + dtype=torch.bool) + + M = self.raw_data_dict["model/map_token_position"].shape[1] + map_id_logit_mask[:, :, :M] = self.raw_data_dict["model/map_token_valid_mask"][:, None] + + if self.config.EVALUATION.TG_SDC_DISTANCE_MASKING: + map_pos = self.raw_data_dict["model/map_token_position"][..., :2] # (B, M, 2) + THRESHOLD = 50.0 + closed_to_sdc_mask = (map_pos - sdc_position).norm(dim=-1) < THRESHOLD + map_id_logit_mask[:, :, :M] = map_id_logit_mask[:, :, :M] & closed_to_sdc_mask[:, None] + + # only_lane = True + # if only_lane: + # map_feature = self.raw_data_dict["encoder/map_feature"] + # map_id_logit_mask[:, :, :M] = (map_feature[:, :, 0, 13] == 1)[:, None] & map_id_logit_mask[:, :, :M] + + map_id_logit[~map_id_logit_mask] = float("-inf") + + map_id, _ = scenestreamer_motion.sample_action(map_id_logit, sampling_method="softmax") + + map_id_pad_mask = map_id == model.trafficgen_sequence_pad_id + map_id[map_id_pad_mask] = model.trafficgen_sequence_pad_id + + return map_id, { + "map_id_logit": map_id_logit, + } + + def _tg_generate_dest( + self, *, agent_pos, agent_heading, current_step, + agent_width, agent_length, agent_id, agent_type, tg_intra_step, + tg_input_action, agent_feature, agent_valid_mask + ): + assert agent_width.shape == (self.B, 1) + assert agent_length.shape == (self.B, 1) + assert agent_pos.shape == (self.B, 1, 2) + assert agent_heading.shape == (self.B, 1) + assert agent_id.shape == (self.B, 1) + assert agent_type.shape == (self.B, 1), "agent_type shape: {}".format(agent_type.shape) + assert tg_input_action.shape == (self.B, 1) + assert agent_feature.shape == (self.B, 1, 8) + assert agent_valid_mask.shape == (self.B, 1) + + model = self.model + B = self.B + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + scenestreamer_tokens = self.scenestreamer_tokens + + tg_token = model.prepare_trafficgen_single_token( + tg_action=tg_input_action, + tg_type=agent_type, + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), tg_intra_step, device=agent_type.device), + tg_feat=agent_feature, + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=agent_pos, + heading=agent_heading, + valid_mask=agent_valid_mask, + width=agent_width, + length=agent_length, + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=agent_valid_mask, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + ) + + # if agent_valid_mask.any(): + # output_dict = scenestreamer_tokens.call_model_with_cache(keep_output_token=self.keep_output_token) + # output_token = output_dict["model/all_token"][:, -1:] + # # call pred head to get agent feat. + # + # output_token = model.trafficgen_prenorm(output_token) + # + # dest_id_logit = model.trafficgen_head.dest_id_head(output_token) + # # tiny masked out here + # M = data_dict["model/map_token_valid_mask"].shape[1] + # assert dest_id_logit.shape[1] == 1 + # + # dest_id_logit_mask = torch.full((B, 1, dest_id_logit.shape[-1]), False, device=agent_type.device, + # dtype=torch.bool) + # dest_id_logit_mask[:, :, :M] = data_dict["model/map_token_valid_mask"][:, None] + # + # dest_pos_full = data_dict["model/map_token_position"][..., :2] # (B, M, 2) + # agent_pos = step_info_dict["agent_position"][:, agent_index][:, None] # (B, 2) + # dest_agent_dist = torch.cdist(dest_pos_full, agent_pos)[..., 0] # (B, M) + # + # speed = step_info_dict["agent_velocity"][:, agent_index].norm(dim=-1) # (B,) + # displacement = speed * 3 + # tolerance = displacement + 20 + # # print("Agent {} speed: {}, displacement: {}, tolerance: {}".format( + # # agent_index, speed[0].item(), displacement[0].item(), tolerance[0].item() + # # )) + # assert dest_agent_dist.ndim == 2 + # assert tolerance.ndim == 1 + # assert dest_id_logit_mask.ndim == 3 + # dest_id_logit_mask[:, :, :M] = dest_id_logit_mask[:, :, :M] & (dest_agent_dist < tolerance[:, None])[:, + # None] + # + # agent_heading = step_info_dict["agent_heading"][:, agent_index] + # + # # Only allow dest in front of the agent. + # rel_pos = (dest_pos_full - agent_pos) + # rel_pos = utils.rotate(x=rel_pos[..., 0], y=rel_pos[..., 1], angle=-agent_heading[:, None].expand(B, M)) + # dest_id_logit_mask[:, :, :M] = dest_id_logit_mask[:, :, :M] & (rel_pos[..., 0] > 0)[:, None] + # + # dest_heading_full = data_dict["model/map_token_heading"] # (B, M) + # dest_agent_heading_dist = dest_heading_full - agent_heading[:, None] # (B, M) + # dest_agent_heading_dist = torch.abs(utils.wrap_to_pi(dest_agent_heading_dist)) + # dest_id_logit_mask[:, :, :M] = dest_id_logit_mask[:, :, :M] & (dest_agent_heading_dist < np.pi / 2)[:, + # None] + # + # dest_id_logit_mask[..., model.trafficgen_sequence_pad_id] = True + # + # only_lane = True + # if only_lane: + # map_feature = data_dict["encoder/map_feature"] + # dest_id_logit_mask[:, :, :M] = (map_feature[:, :, 0, 13] == 1)[:, None] & dest_id_logit_mask[:, :, + # :M] + # + # dest_id_logit[~dest_id_logit_mask] = float("-inf") + # + # # TODO: hardcoded + # # dest_id, _ = scenestreamer_motion.sample_action(dest_id_logit, sampling_method="softmax") + # dest_id, _ = scenestreamer_motion.sample_action(dest_id_logit, sampling_method="topp", topp=0.95) + # + # if teacher_forcing_dest: + # gt_dest = data_dict["decoder/dest_map_index"][:, current_step, agent_index].clone() + # gt_dest[gt_dest == -1] = model.trafficgen_sequence_pad_id + # dest_id = gt_dest.reshape(B, 1) + # + # dest_id_pad_mask = dest_id == model.trafficgen_sequence_pad_id + # + # dest_id[dest_id_pad_mask] = 0 + # + # dest_position = torch.gather( + # data_dict["model/map_token_position"][..., :2], + # index=dest_id.reshape(B, 1, 1).expand(B, 1, 2), + # dim=1 + # ) + # dest_position[dest_id_pad_mask] = step_info_dict["agent_position"][:, agent_index][:, None][ + # dest_id_pad_mask] + # + # dest_heading = torch.gather( + # data_dict["model/map_token_heading"], + # index=dest_id.reshape(B, 1), + # dim=1 + # ) + # dest_heading[dest_id_pad_mask] = step_info_dict["agent_heading"][:, agent_index][:, None][ + # dest_id_pad_mask] + # + # dest_id[dest_id_pad_mask] = model.trafficgen_sequence_pad_id + # + # # TODO: DEBUG + # # dest_dist = (step_info_dict["agent_position"][:, agent_index][0] - dest_position[0, 0]).norm(dim=-1) + # # print("agent {} dest id: {}, dest position: {}, dest heading: {}, dest dist: {}".format( + # # agent_index, dest_id[0].item(), dest_position[0, 0].tolist(), dest_heading[0, 0].item(), dest_dist.item() + # # )) + # else: + # dest_id = torch.full((B, 1), model.trafficgen_sequence_pad_id, device=agent_type.device) + # dest_position = torch.full((B, 1, 2), 0.0, device=agent_type.device) + # dest_heading = torch.full((B, 1), 0.0, device=agent_type.device) + # + # # print("Per agent index{} id{}, dest id: {}".format(agent_index, agent_id[0].item(), dest_id.tolist())) + # dest_id[~agent_valid_mask] = -1 + # + # return { + # "dest_id": dest_id, + # "dest_position": dest_position, + # "dest_heading": dest_heading, + # } + + def _step_generate_trafficgen_no_agent_state(self, *, teacher_forcing_from_gt, teacher_forcing_dest=None, + generate_agent_states=False): + assert self.state == self.STATE_TRAFFICLIGHT_DONE + + if generate_agent_states: + return self._step_generate_trafficgen_with_agent_state( + teacher_forcing_from_gt=teacher_forcing_from_gt, + teacher_forcing_dest=teacher_forcing_dest + ) + + model = self.model + data_dict = self.raw_data_dict + scenestreamer_tokens = self.scenestreamer_tokens + current_step = self.current_step + step_info_dict = self.step_info_dict + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + + if teacher_forcing_from_gt: + step_info_dict["agent_valid_mask"] = data_dict["decoder/input_action_valid_mask"][:, current_step].clone() + step_info_dict["agent_position"] = data_dict["decoder/modeled_agent_position"][:, current_step].clone() + step_info_dict["agent_heading"] = data_dict["decoder/modeled_agent_heading"][:, current_step].clone() + step_info_dict["agent_velocity"] = data_dict["decoder/modeled_agent_velocity"][:, current_step].clone() + step_info_dict["agent_type"] = data_dict["decoder/agent_type"].clone() + step_info_dict["agent_shape"] = data_dict["decoder/current_agent_shape"].clone() + step_info_dict["agent_id"] = data_dict["encoder/modeled_agent_id"].clone() + + B, N, G = scenestreamer_tokens.B, scenestreamer_tokens.N, scenestreamer_tokens.G + + # ===== call trafficgen tokenizer ===== + from scenestreamer.dataset.preprocessor import prepare_trafficgen_data_for_scenestreamer_a_step + # assert B == 1, "B should be 1 but got " + str(B) + device = scenestreamer_tokens.token.device + tg_map_id_list = [] + tg_valid_list = [] + tg_feat_list = [] + tg_target_offset_list = [] + tg_pos_list = [] + tg_head_list = [] + for b in range(B): + tg_map_id, tg_valid, tg_feat, tg_target_offset, tg_pos, tg_head = prepare_trafficgen_data_for_scenestreamer_a_step( + pos=step_info_dict["agent_position"].reshape(B, N, 2)[b].cpu().numpy(), + heading=step_info_dict["agent_heading"].reshape(B, N)[b].cpu().numpy(), + vel=step_info_dict["agent_velocity"].reshape(B, N, 2)[b].cpu().numpy(), + agent_valid_mask=step_info_dict["agent_valid_mask"].reshape(B, N)[b].cpu().numpy(), + agent_type=step_info_dict["agent_type"].reshape(B, N)[b].cpu().numpy(), + current_agent_shape=step_info_dict["agent_shape"].reshape(B, N, 3)[b].cpu().numpy(), + map_pos=data_dict["model/map_token_position"][0].cpu().numpy()[..., :2], + map_heading=data_dict["model/map_token_heading"][0].cpu().numpy(), + map_valid_mask=data_dict["model/map_token_valid_mask"][0].cpu().numpy(), + # start_action_id=model.trafficgen_agent_sos_id, + # end_action_id=model.trafficgen_agent_eos_id, + start_sequence_id=model.trafficgen_sequence_sos_id, + end_sequence_id=model.trafficgen_sequence_eos_id, + dest=None, + dest_pad_id=model.trafficgen_sequence_pad_id, + veh_id=model.veh_id, + ped_id=model.ped_id, + cyc_id=model.cyc_id, + start_agent_id=model.trafficgen_agent_sos_id, + ) + tg_map_id_list.append(tg_map_id) + tg_valid_list.append(tg_valid) + tg_feat_list.append(tg_feat) + tg_target_offset_list.append(tg_target_offset) + tg_pos_list.append(tg_pos) + tg_head_list.append(tg_head) + input_action_for_trafficgen = torch.from_numpy(np.stack(tg_map_id_list, axis=0)).to(device=device) + input_action_for_trafficgen = input_action_for_trafficgen.reshape(B, 1, G) + # input_action_valid_mask_for_trafficgen = torch.from_numpy(np.stack(tg_valid_list, axis=0)).to( + # device=device).reshape(B, 1, G) + agent_feature_for_trafficgen = (torch.from_numpy(np.stack(tg_feat_list, axis=0)).to(device=device) + .reshape(B, 1, G, 8).float()) + trafficgen_position = torch.from_numpy(np.stack(tg_pos_list, axis=0)).to(device=device).reshape(B, 1, G, + 2).float() + trafficgen_heading = torch.from_numpy(np.stack(tg_head_list, axis=0)).to(device=device).reshape(B, 1, G).float() + + # ===== prepare input data for trafficgen ===== + # -1, -1 -1 TYPE -1 -1, ..., -1 + G = scenestreamer_tokens.G + agent_type = step_info_dict["agent_type"] + agent_type_for_trafficgen = torch.full((B, N, NUM_TG_MULTI), -1, device=agent_type.device) + agent_type_for_trafficgen[..., 2:] = agent_type[:, :, None] + agent_type_for_trafficgen = torch.cat( + [ + torch.full((B, 1), -1, device=agent_type.device), + agent_type_for_trafficgen.flatten(1, 2), + torch.full((B, 1), -1, device=agent_type.device), + ], dim=1 + ).reshape(B, 1, G) + + # ===== call model for tg autoregressive ===== + + token_buffer = TGTokenBuffer() + + initial_seq_len = scenestreamer_tokens.seq_len + + # First, input the sequence_sos_id. + intra_step = 0 + token_buffer.add( + tg_action=torch.full((B, 1), model.trafficgen_sequence_sos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=agent_type.device), + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len+intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=torch.full((B, 1), True, device=agent_type.device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + force_mask=all_force_mask[:, initial_seq_len+intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + ) + + agent_destination_list = [] + agent_destination_pos_list = [] + for agent_index in range(N): + agent_id = step_info_dict["agent_id"][:, agent_index:agent_index + 1] + this_agent_valid_mask = step_info_dict["agent_valid_mask"][:, agent_index:agent_index + 1] + + # Step 0, agent start token. + intra_step += 1 + token_buffer.add( + tg_action=torch.full((B, 1), model.trafficgen_agent_sos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + # Step 1: input is the agent type. + intra_step += 1 + token_buffer.add( + tg_action=agent_type[:, agent_index][:, None], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + position=torch.full((B, 1, 2), 0.0, device=agent_type.device), + heading=torch.full((B, 1), 0.0, device=agent_type.device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + + # Step 2: input is the map id. + intra_step += 1 + token_buffer.add( + tg_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1], + position=trafficgen_position[:, 0, intra_step:intra_step + 1], + heading=trafficgen_heading[:, 0, intra_step:intra_step + 1], + valid_mask=this_agent_valid_mask, + # TODO: hardcoded 5, 6 + width=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 6], + length=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 5], + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=this_agent_valid_mask, + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + + # Step 3: input is the agent feat. + intra_step += 1 + token_buffer.add( + tg_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1], + position=trafficgen_position[:, 0, intra_step:intra_step + 1], + heading=trafficgen_heading[:, 0, intra_step:intra_step + 1], + valid_mask=this_agent_valid_mask, + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + require_relation=this_agent_valid_mask, + width=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 6], + length=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 5], + current_step=current_step, + ) + + # quick fix a stupid bug. if we buffer too much it might OOM the GPU............... + if len(token_buffer.current_step) > self.config.TOKEN_BUFFER_CACHE_LENGTH: + if scenestreamer_tokens.able_to_call_model(): + token_buffer.append_to_scenestreamer_tokens( + model=model, + scenestreamer_tokens=scenestreamer_tokens, + ) + scenestreamer_tokens.call_model_with_cache() + token_buffer = TGTokenBuffer() + + # Finally, input the sequence_eos_id. + intra_step += 1 + assert intra_step == G - 1, (intra_step, G, G - 1) + token_buffer.add( + tg_action=torch.full((B, 1), model.trafficgen_sequence_eos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=agent_type.device), + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=torch.full((B, 1), True, device=agent_type.device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + # The only thing need to be updated by non-teacher_forcing TG is the destination: + # step_info_dict["agent_destination"] = torch.stack(agent_destination_list, dim=1) + # step_info_dict["agent_destination_position"] = torch.stack(agent_destination_pos_list, dim=1) + token_buffer.append_to_scenestreamer_tokens( + model=model, + scenestreamer_tokens=scenestreamer_tokens, + ) + self.step_info_dict = step_info_dict + self.state = self.STATE_TRAFFICGEN_DONE + + def _step_generate_trafficgen_densified_agent_state(self, *, teacher_forcing_from_gt, + veh_ratio, ped_ratio, num_new_agents, + teacher_forcing_dest=None, + generate_agent_states=False): + assert self.state == self.STATE_TRAFFICLIGHT_DONE + + if generate_agent_states: + return self._step_generate_trafficgen_with_agent_state( + teacher_forcing_from_gt=teacher_forcing_from_gt, + teacher_forcing_dest=teacher_forcing_dest + ) + + model = self.model + data_dict = self.raw_data_dict + scenestreamer_tokens = self.scenestreamer_tokens + current_step = self.current_step + step_info_dict = self.step_info_dict + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + + if teacher_forcing_from_gt: + + atype = data_dict["decoder/agent_type"].clone() + old_N = atype.shape[1] + num_veh = (atype == model.veh_id).sum(dim=1) + num_ped = (atype == model.ped_id).sum(dim=1) + num_cyc = (atype == model.cyc_id).sum(dim=1) + + # For now just trying to fill vehicle. + new_N = max(num_new_agents, old_N) + + # num_veh_to_add = new_N - old_N + + quota = new_N - old_N - 3 # At least 3 guys left + + if quota < 0: + quota = 0 + + num_veh_to_add = int(veh_ratio * quota) + new_num_veh = num_veh + num_veh_to_add + + num_ped_to_add = int(ped_ratio * quota) + new_num_ped = num_ped + num_ped_to_add + + num_cyc_to_add = new_N - new_num_veh - new_num_ped + new_num_cyc = num_cyc + num_cyc_to_add + + print("num_veh: {}, num_ped: {}, num_cyc: {}".format( + new_num_veh, new_num_ped, new_num_cyc + )) + new_atype = torch.full((data_dict["decoder/agent_type"].shape[0], new_N), -1, device=atype.device) + new_atype[:, :new_num_veh] = model.veh_id + new_atype[:, new_num_veh:new_num_veh + new_num_ped] = model.ped_id + new_atype[:, new_num_veh + new_num_ped:new_N] = model.cyc_id + + # Assume all valid: + new_valid_mask = torch.full((data_dict["decoder/input_action_valid_mask"].shape[0], new_N), True, + device=atype.device) + + def _fill(arr): + if arr.ndim == 2: + new_arr = arr.new_zeros((arr.shape[0], new_N,)) + elif arr.ndim == 3: + new_arr = arr.new_zeros((arr.shape[0], new_N, arr.shape[-1])) + else: + print("arr shape: ", arr.shape) + raise ValueError + new_arr[:, :num_veh] = arr[:, :num_veh] + new_arr[:, new_num_veh:new_num_veh + num_ped] = arr[:, num_veh:num_veh + num_ped] + if num_veh + num_ped < old_N: + l = arr[:, num_veh + num_ped:].shape[1] + if l > 0: + new_arr[:, new_num_veh + new_num_ped:new_num_veh + new_num_ped + l] = arr[:, num_veh + num_ped:] + return new_arr + + step_info_dict["agent_type"] = new_atype + step_info_dict["agent_valid_mask"] = new_valid_mask + + step_info_dict["agent_position"] = _fill( + data_dict["decoder/modeled_agent_position"][:, current_step].clone()) + step_info_dict["agent_heading"] = _fill(data_dict["decoder/modeled_agent_heading"][:, current_step].clone()) + step_info_dict["agent_velocity"] = _fill( + data_dict["decoder/modeled_agent_velocity"][:, current_step].clone()) + step_info_dict["agent_shape"] = _fill(data_dict["decoder/current_agent_shape"].clone()) + step_info_dict["agent_id"] = torch.arange(0, new_N, device=atype.device).expand( + data_dict["decoder/agent_type"].shape[0], -1) + + # TODO FIXME + # TODO FIXME + # TODO FIXME + # TODO FIXME + should_create_new_agent = torch.logical_not(_fill( + data_dict["decoder/input_action_valid_mask"][:, current_step].clone() + )) + + # B = data_dict["decoder/input_action_valid_mask"].shape[0] + # should_create_new_agent = torch.ones((B, new_N)) + # should_create_new_agent[:, 0] = True + + + + + + + + + + scenestreamer_tokens.N = new_N + scenestreamer_tokens.G = get_num_tg(new_N) + + new_input_action = _fill(data_dict["decoder/input_action"][:, current_step].clone()) + new_input_action.fill_(MOTION_START_ACTION) + + step_info_dict["motion_input_action"] = new_input_action + + else: + should_create_new_agent = torch.zeros( + (scenestreamer_tokens.B, scenestreamer_tokens.N), + device=step_info_dict["agent_valid_mask"].device, + dtype=torch.bool + ) + + B, N, G = scenestreamer_tokens.B, scenestreamer_tokens.N, scenestreamer_tokens.G + + # ===== call trafficgen tokenizer ===== + from scenestreamer.dataset.preprocessor import prepare_trafficgen_data_for_scenestreamer_a_step + # assert B == 1, "B should be 1 but got " + str(B) + device = scenestreamer_tokens.token.device + tg_map_id_list = [] + tg_valid_list = [] + tg_feat_list = [] + tg_target_offset_list = [] + tg_pos_list = [] + tg_head_list = [] + for b in range(B): + tg_map_id, tg_valid, tg_feat, tg_target_offset, tg_pos, tg_head = prepare_trafficgen_data_for_scenestreamer_a_step( + pos=step_info_dict["agent_position"].reshape(B, N, 2)[b].cpu().numpy(), + heading=step_info_dict["agent_heading"].reshape(B, N)[b].cpu().numpy(), + vel=step_info_dict["agent_velocity"].reshape(B, N, 2)[b].cpu().numpy(), + agent_valid_mask=step_info_dict["agent_valid_mask"].reshape(B, N)[b].cpu().numpy(), + agent_type=step_info_dict["agent_type"].reshape(B, N)[b].cpu().numpy(), + current_agent_shape=step_info_dict["agent_shape"].reshape(B, N, 3)[b].cpu().numpy(), + map_pos=data_dict["model/map_token_position"][0].cpu().numpy()[..., :2], + map_heading=data_dict["model/map_token_heading"][0].cpu().numpy(), + map_valid_mask=data_dict["model/map_token_valid_mask"][0].cpu().numpy(), + # start_action_id=model.trafficgen_agent_sos_id, + # end_action_id=model.trafficgen_agent_eos_id, + start_sequence_id=model.trafficgen_sequence_sos_id, + end_sequence_id=model.trafficgen_sequence_eos_id, + dest=None, + dest_pad_id=model.trafficgen_sequence_pad_id, + veh_id=model.veh_id, + ped_id=model.ped_id, + cyc_id=model.cyc_id, + start_agent_id=model.trafficgen_agent_sos_id, + ) + tg_map_id_list.append(tg_map_id) + tg_valid_list.append(tg_valid) + tg_feat_list.append(tg_feat) + tg_target_offset_list.append(tg_target_offset) + tg_pos_list.append(tg_pos) + tg_head_list.append(tg_head) + input_action_for_trafficgen = torch.from_numpy(np.stack(tg_map_id_list, axis=0)).to(device=device) + input_action_for_trafficgen = input_action_for_trafficgen.reshape(B, 1, G) + # input_action_valid_mask_for_trafficgen = torch.from_numpy(np.stack(tg_valid_list, axis=0)).to( + # device=device).reshape(B, 1, G) + agent_feature_for_trafficgen = (torch.from_numpy(np.stack(tg_feat_list, axis=0)).to(device=device) + .reshape(B, 1, G, 8).float()) + trafficgen_position = torch.from_numpy(np.stack(tg_pos_list, axis=0)).to(device=device).reshape(B, 1, G, + 2).float() + trafficgen_heading = torch.from_numpy(np.stack(tg_head_list, axis=0)).to(device=device).reshape(B, 1, G).float() + + # ===== prepare input data for trafficgen ===== + # -1, -1 -1 TYPE -1 -1, ..., -1 + G = scenestreamer_tokens.G + agent_type = step_info_dict["agent_type"] + agent_type_for_trafficgen = torch.full((B, N, NUM_TG_MULTI), -1, device=agent_type.device) + agent_type_for_trafficgen[..., 2:] = agent_type[:, :, None] + agent_type_for_trafficgen = torch.cat( + [ + torch.full((B, 1), -1, device=agent_type.device), + agent_type_for_trafficgen.flatten(1, 2), + torch.full((B, 1), -1, device=agent_type.device), + ], dim=1 + ).reshape(B, 1, G) + + # ===== call model for tg autoregressive ===== + + token_buffer = TGTokenBuffer() + + initial_seq_len = scenestreamer_tokens.seq_len + + # First, input the sequence_sos_id. + intra_step = 0 + token_buffer.add( + tg_action=torch.full((B, 1), model.trafficgen_sequence_sos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=agent_type.device), + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=torch.full((B, 1), True, device=agent_type.device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + ) + + agent_destination_list = [] + agent_destination_pos_list = [] + for agent_index in range(N): + + assert B == 1 + should_create_new = should_create_new_agent[0, agent_index].item() + + agent_id = step_info_dict["agent_id"][:, agent_index:agent_index + 1] + this_agent_valid_mask = step_info_dict["agent_valid_mask"][:, agent_index:agent_index + 1] + + # Step 0, agent start token. + intra_step += 1 + token_buffer.add( + tg_action=torch.full((B, 1), model.trafficgen_agent_sos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + if not should_create_new: + + # Step 1: input is the agent type. + intra_step += 1 + token_buffer.add( + tg_action=agent_type[:, agent_index][:, None], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + position=torch.full((B, 1, 2), 0.0, device=agent_type.device), + heading=torch.full((B, 1), 0.0, device=agent_type.device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + # Step 2: input is the map id. + intra_step += 1 + token_buffer.add( + tg_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1], + position=trafficgen_position[:, 0, intra_step:intra_step + 1], + heading=trafficgen_heading[:, 0, intra_step:intra_step + 1], + valid_mask=this_agent_valid_mask, + # TODO: hardcoded 5, 6 + width=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 6], + length=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 5], + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=this_agent_valid_mask, + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + # Step 3: input is the agent feat. + intra_step += 1 + token_buffer.add( + tg_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1], + position=trafficgen_position[:, 0, intra_step:intra_step + 1], + heading=trafficgen_heading[:, 0, intra_step:intra_step + 1], + valid_mask=this_agent_valid_mask, + causal_mask=all_token_casual_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + force_mask=all_force_mask[:, initial_seq_len + intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + require_relation=this_agent_valid_mask, + width=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 6], + length=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 5], + current_step=current_step, + ) + + + + else: + # save current buffer first.. + token_buffer.append_to_scenestreamer_tokens( + model=model, + scenestreamer_tokens=scenestreamer_tokens, + ) + if scenestreamer_tokens.able_to_call_model(): + scenestreamer_tokens.call_model_with_cache() + token_buffer = TGTokenBuffer() + + # Adding new agent! + + if self.config.EVALUATION.TG_REJECT_SAMPLING: + self.scenestreamer_tokens = scenestreamer_tokens + tmp_scenestreamer_tokens = copy.deepcopy(scenestreamer_tokens) + tmp_step_info_dict = copy.deepcopy(self.step_info_dict) + tmp_intra_step = copy.deepcopy(intra_step) + + this_agent_reject_count = 0 + while True: + # Step 1: input is the agent type. + teacher_forcing_this_agent = torch.full((B, 1), False, device=agent_type.device) + intra_step += 1 + assert data_dict["decoder/sdc_index"][0].item() == 0, data_dict["decoder/sdc_index"] + selected_map_id, map_id_info = self._tg_generate_agent_map_id( + agent_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + agent_id=agent_id, + tg_intra_step=intra_step, + agent_valid_mask=this_agent_valid_mask, + teacher_forcing_map_id=teacher_forcing_this_agent, + tg_input_action=agent_type[:, agent_index][:, None], + sdc_position=step_info_dict["agent_position"][:, 0:1], + ) + + # Step 2: input is the map id. + intra_step += 1 + assert selected_map_id is not None + selected_map_id = selected_map_id.reshape(B, 1) + selected_map_pos, selected_map_heading = self._get_map_pos_head(index=selected_map_id) + selected_agent_state, as_info = self._tg_generate_agent_agent_state( + agent_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + agent_id=agent_id, + tg_intra_step=intra_step, + agent_valid_mask=this_agent_valid_mask, + teacher_forcing_agent_state=teacher_forcing_this_agent, + tg_input_action=selected_map_id, + selected_map_pos=selected_map_pos, + selected_map_heading=selected_map_heading, + ) + + # Step 3: input is the agent feat. + intra_step += 1 + as_position = selected_agent_state["position"] + as_heading = selected_agent_state["heading"] + as_feat = torch.zeros((B, 1, 8), device=device) + as_feat[:, :, 0] = selected_agent_state["offset_values"]["position_x"] + as_feat[:, :, 1] = selected_agent_state["offset_values"]["position_y"] + as_feat[:, :, 2] = selected_agent_state["offset_values"]["heading"] + as_feat[:, :, 3] = selected_agent_state["offset_values"]["velocity_x"] # original_relative_vel + as_feat[:, :, 4] = selected_agent_state["offset_values"]["velocity_y"] # original_relative_vel + as_feat[:, :, 5] = selected_agent_state["offset_values"]["length"] + as_feat[:, :, 6] = selected_agent_state["offset_values"]["width"] + as_feat[:, :, 7] = selected_agent_state["offset_values"]["height"] + # Overwrite agent data. + step_info_dict["agent_position"][:, agent_index] = as_position.clone().reshape(B, 2) + step_info_dict["agent_heading"][:, agent_index] = as_heading.clone().reshape(B, ) + step_info_dict["agent_velocity"][:, agent_index] = selected_agent_state[ + "velocity"].clone().reshape(B, 2) + step_info_dict["agent_shape"][:, agent_index] = selected_agent_state["shape"].clone().reshape(B, + 3) + step_info_dict["agent_valid_mask"][:, agent_index] = this_agent_valid_mask.clone().reshape(B) + dest_out = self._tg_generate_dest( + agent_pos=as_position, + agent_heading=as_heading, + current_step=current_step, + agent_width=as_feat[..., 6], + agent_length=as_feat[..., 5], + agent_id=agent_id, + agent_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_intra_step=intra_step, + tg_input_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + agent_feature=as_feat, + agent_valid_mask=this_agent_valid_mask + ) + + # Detect whether collision happens. + if agent_index == 0: + break + + if not self.config.EVALUATION.TG_REJECT_SAMPLING: + # Skip the collision check. + break + + if this_agent_reject_count > 5: + break + + pos = self.step_info_dict["agent_position"][:, :agent_index + 1] + head = self.step_info_dict["agent_heading"][:, :agent_index + 1] + shape = self.step_info_dict["agent_shape"][:, :agent_index + 1] + # assert B == 1 + for b in range(B): + poly = cal_polygon_contour( + x=pos[b, :, 0].cpu().numpy(), + y=pos[b, :, 1].cpu().numpy(), + theta=head[b].cpu().numpy(), + width=shape[b, :, 1].cpu().numpy(), + length=shape[b, :, 0].cpu().numpy() + ) + last_poly = poly[-1] + last_poly = Polygon(last_poly) + coll = False + for i in range(len(poly) - 1): + poly2 = Polygon(poly[i]) + if last_poly.intersects(poly2): + coll = True + coll_b = b + break + if coll: + print("Collision happens at batch {}, repeat the generation.".format(coll_b)) + should_repeat = True + else: + should_repeat = False + break + + self.scenestreamer_tokens = copy.deepcopy(tmp_scenestreamer_tokens) + self.step_info_dict = copy.deepcopy(tmp_step_info_dict) + intra_step = copy.deepcopy(tmp_intra_step) + step_info_dict = self.step_info_dict + scenestreamer_tokens = self.scenestreamer_tokens + + # quick fix a stupid bug. if we buffer too much it might OOM the GPU............... + if len(token_buffer.current_step) > 100: + if scenestreamer_tokens.able_to_call_model(): + token_buffer.append_to_scenestreamer_tokens( + model=model, + scenestreamer_tokens=scenestreamer_tokens, + ) + scenestreamer_tokens.call_model_with_cache() + token_buffer = TGTokenBuffer() + + # Finally, input the sequence_eos_id. + intra_step += 1 + assert intra_step == G - 1, (intra_step, G, G - 1) + token_buffer.add( + tg_action=torch.full((B, 1), model.trafficgen_sequence_eos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=agent_type.device), + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=torch.full((B, 1), True, device=agent_type.device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=all_token_casual_mask[:, initial_seq_len+intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, initial_seq_len+intra_step:initial_seq_len + intra_step + 1, + :initial_seq_len + intra_step + 1] + ) + + # The only thing need to be updated by non-teacher_forcing TG is the destination: + # step_info_dict["agent_destination"] = torch.stack(agent_destination_list, dim=1) + # step_info_dict["agent_destination_position"] = torch.stack(agent_destination_pos_list, dim=1) + token_buffer.append_to_scenestreamer_tokens( + model=model, + scenestreamer_tokens=scenestreamer_tokens, + ) + self.step_info_dict = step_info_dict + self.state = self.STATE_TRAFFICGEN_DONE + + def _step_generate_trafficgen_with_agent_state(self, *, teacher_forcing_from_gt, teacher_forcing_dest=None): + assert self.state == self.STATE_TRAFFICLIGHT_DONE + + model = self.model + data_dict = self.raw_data_dict + scenestreamer_tokens = self.scenestreamer_tokens + current_step = self.current_step + step_info_dict = self.step_info_dict + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + + # assert teacher_forcing_from_gt is False, "teacher_forcing_from_gt should be False for trafficgen with agent state" + if teacher_forcing_from_gt: + step_info_dict["agent_valid_mask"] = data_dict["decoder/input_action_valid_mask"][:, current_step].clone() + step_info_dict["agent_position"] = data_dict["decoder/modeled_agent_position"][:, current_step].clone() + step_info_dict["agent_heading"] = data_dict["decoder/modeled_agent_heading"][:, current_step].clone() + step_info_dict["agent_velocity"] = data_dict["decoder/modeled_agent_velocity"][:, current_step].clone() + step_info_dict["agent_type"] = data_dict["decoder/agent_type"].clone() + step_info_dict["agent_shape"] = data_dict["decoder/current_agent_shape"].clone() + step_info_dict["agent_id"] = data_dict["encoder/modeled_agent_id"].clone() + + B, N, G = scenestreamer_tokens.B, scenestreamer_tokens.N, scenestreamer_tokens.G + + device = scenestreamer_tokens.token.device + + # ===== call trafficgen tokenizer ===== + from scenestreamer.dataset.preprocessor import prepare_trafficgen_data_for_scenestreamer_a_step + tg_map_id_list = [] + tg_feat_list = [] + for b in range(B): + tg_map_id, tg_valid, tg_feat, tg_target_offset, tg_pos, tg_head = prepare_trafficgen_data_for_scenestreamer_a_step( + pos=step_info_dict["agent_position"].reshape(B, N, 2)[b].cpu().numpy(), + heading=step_info_dict["agent_heading"].reshape(B, N)[b].cpu().numpy(), + vel=step_info_dict["agent_velocity"].reshape(B, N, 2)[b].cpu().numpy(), + agent_valid_mask=step_info_dict["agent_valid_mask"].reshape(B, N)[b].cpu().numpy(), + agent_type=step_info_dict["agent_type"].reshape(B, N)[b].cpu().numpy(), + current_agent_shape=step_info_dict["agent_shape"].reshape(B, N, 3)[b].cpu().numpy(), + map_pos=data_dict["model/map_token_position"][0].cpu().numpy()[..., :2], + map_heading=data_dict["model/map_token_heading"][0].cpu().numpy(), + map_valid_mask=data_dict["model/map_token_valid_mask"][0].cpu().numpy(), + start_sequence_id=model.trafficgen_sequence_sos_id, + end_sequence_id=model.trafficgen_sequence_eos_id, + dest=None, + dest_pad_id=model.trafficgen_sequence_pad_id, + veh_id=model.veh_id, + ped_id=model.ped_id, + cyc_id=model.cyc_id, + start_agent_id=model.trafficgen_agent_sos_id, + ) + tg_map_id_list.append(tg_map_id) + tg_feat_list.append(tg_feat) + input_action_for_trafficgen = torch.from_numpy(np.stack(tg_map_id_list, axis=0)).to(device=device) + input_action_for_trafficgen = input_action_for_trafficgen.reshape(B, 1, G) + agent_feature_for_trafficgen = (torch.from_numpy(np.stack(tg_feat_list, axis=0)).to(device=device) + .reshape(B, 1, G, 8).float()) + + # ===== prepare input data for trafficgen ===== + # -1, -1 -1 TYPE -1 -1, ..., -1 + # G = scenestreamer_tokens.G + agent_type = step_info_dict["agent_type"] + agent_type_for_trafficgen = torch.full((B, N, NUM_TG_MULTI), -1, device=agent_type.device) + agent_type_for_trafficgen[..., 2:] = agent_type[:, :, None] + agent_type_for_trafficgen = torch.cat( + [ + torch.full((B, 1), -1, device=agent_type.device), + agent_type_for_trafficgen.flatten(1, 2), + torch.full((B, 1), -1, device=agent_type.device), + ], dim=1 + ).reshape(B, 1, G) + + # ===== call model for tg autoregressive ===== + # First, input the sequence_sos_id. + intra_step = 0 + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.full((B, 1), model.trafficgen_sequence_sos_id, device=device), + tg_type=torch.full((B, 1), -1, device=device), + tg_agent_id=torch.full((B, 1), -1, device=device), + tg_intra_step=torch.full((B, 1), intra_step, device=device), + tg_feat=torch.full((B, 1, 8), 0.0, device=device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=device), + heading=torch.full((B, 1), 0, device=device), + valid_mask=torch.full((B, 1), True, device=device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=device), + length=torch.full((B, 1), 0.0, device=device), + causal_mask=tg_causal_mask, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=device, dtype=torch.bool), + ) + # print("BEFORE STEP {}: agent {}, scenestreamer len {}".format( + # current_step, 0, scenestreamer_tokens.seq_len + # )) + + # agent_destination_list = [] + # agent_destination_pos_list = [] + for agent_index in range(N): + agent_id = step_info_dict["agent_id"][:, agent_index:agent_index + 1] + assert (agent_id >= 0).all(), agent_id + # this_agent_valid_mask = step_info_dict["agent_valid_mask"][:, agent_index:agent_index + 1] + # Teacher forcing the SDC agent!!! + if agent_index == 0: + teacher_forcing_this_agent = True + else: + teacher_forcing_this_agent = False + + # Because we are generating new agents, we manually set them to be all True. + this_agent_valid_mask = torch.full((B, 1), True, device=device, dtype=torch.bool) + assert this_agent_valid_mask.all() + + # Step 0, agent start token. + intra_step += 1 + assert (agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1] == -1).all() + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.full((B, 1), model.trafficgen_agent_sos_id, device=device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=device), + tg_feat=torch.full((B, 1, 8), 0.0, device=device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=device), + heading=torch.full((B, 1), 0, device=device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=device), + length=torch.full((B, 1), 0.0, device=device), + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=torch.full((B, 1), False, device=device, dtype=torch.bool), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + ) + + + if self.config.EVALUATION.TG_REJECT_SAMPLING: + self.scenestreamer_tokens = scenestreamer_tokens + tmp_scenestreamer_tokens = copy.deepcopy(scenestreamer_tokens) + tmp_step_info_dict = copy.deepcopy(self.step_info_dict) + tmp_intra_step = copy.deepcopy(intra_step) + + this_agent_reject_count = 0 + while True: + # Step 1: input is the agent type. + intra_step += 1 + assert data_dict["decoder/sdc_index"][0].item() == 0, data_dict["decoder/sdc_index"] + selected_map_id, map_id_info = self._tg_generate_agent_map_id( + agent_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + agent_id=agent_id, + tg_intra_step=intra_step, + agent_valid_mask=this_agent_valid_mask, + teacher_forcing_map_id=teacher_forcing_this_agent, + tg_input_action=agent_type[:, agent_index][:, None], + sdc_position=step_info_dict["agent_position"][:, 0:1], + ) + + # Step 2: input is the map id. + intra_step += 1 + if teacher_forcing_this_agent: + assert selected_map_id is None + selected_map_id = input_action_for_trafficgen[:, 0, intra_step:intra_step + 1] + else: + assert selected_map_id is not None + selected_map_id = selected_map_id.reshape(B, 1) + selected_map_pos, selected_map_heading = self._get_map_pos_head(index=selected_map_id) + selected_agent_state, as_info = self._tg_generate_agent_agent_state( + agent_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + agent_id=agent_id, + tg_intra_step=intra_step, + agent_valid_mask=this_agent_valid_mask, + teacher_forcing_agent_state=teacher_forcing_this_agent, + tg_input_action=selected_map_id, + selected_map_pos=selected_map_pos, + selected_map_heading=selected_map_heading, + ) + + # Step 3: input is the agent feat. + intra_step += 1 + if teacher_forcing_this_agent: + as_position = step_info_dict["agent_position"][:, agent_index].unsqueeze(1) + as_heading = step_info_dict["agent_heading"][:, agent_index].unsqueeze(1) + as_feat = agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1] + else: + as_position = selected_agent_state["position"] + as_heading = selected_agent_state["heading"] + as_feat = torch.zeros((B, 1, 8), device=device) + as_feat[:, :, 0] = selected_agent_state["offset_values"]["position_x"] + as_feat[:, :, 1] = selected_agent_state["offset_values"]["position_y"] + as_feat[:, :, 2] = selected_agent_state["offset_values"]["heading"] + as_feat[:, :, 3] = selected_agent_state["offset_values"]["velocity_x"] # original_relative_vel + as_feat[:, :, 4] = selected_agent_state["offset_values"]["velocity_y"] # original_relative_vel + as_feat[:, :, 5] = selected_agent_state["offset_values"]["length"] + as_feat[:, :, 6] = selected_agent_state["offset_values"]["width"] + as_feat[:, :, 7] = selected_agent_state["offset_values"]["height"] + # Overwrite agent data. + step_info_dict["agent_position"][:, agent_index] = as_position.clone().reshape(B, 2) + step_info_dict["agent_heading"][:, agent_index] = as_heading.clone().reshape(B, ) + step_info_dict["agent_velocity"][:, agent_index] = selected_agent_state["velocity"].clone().reshape(B, 2) + step_info_dict["agent_shape"][:, agent_index] = selected_agent_state["shape"].clone().reshape(B, 3) + step_info_dict["agent_valid_mask"][:, agent_index] = this_agent_valid_mask.clone().reshape(B) + dest_out = self._tg_generate_dest( + agent_pos=as_position, + agent_heading=as_heading, + current_step=current_step, + agent_width=as_feat[..., 6], + agent_length=as_feat[..., 5], + agent_id=agent_id, + agent_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_intra_step=intra_step, + tg_input_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + agent_feature=as_feat, + agent_valid_mask=this_agent_valid_mask + ) + + this_agent_reject_count += 1 + + # Detect whether collision happens. + if agent_index == 0: + break + + if not self.config.EVALUATION.TG_REJECT_SAMPLING: + # Skip the collision check. + break + + if this_agent_reject_count > 5: + break + + pos = self.step_info_dict["agent_position"][:, :agent_index + 1] + head = self.step_info_dict["agent_heading"][:, :agent_index + 1] + shape = self.step_info_dict["agent_shape"][:, :agent_index + 1] + # assert B == 1 + for b in range(B): + poly = cal_polygon_contour( + x=pos[b, :, 0].cpu().numpy(), + y=pos[b, :, 1].cpu().numpy(), + theta=head[b].cpu().numpy(), + width=shape[b, :, 1].cpu().numpy(), + length=shape[b, :, 0].cpu().numpy() + ) + last_poly = poly[-1] + last_poly = Polygon(last_poly) + coll = False + for i in range(len(poly) - 1): + poly2 = Polygon(poly[i]) + if last_poly.intersects(poly2): + coll = True + coll_b = b + break + if coll: + print("Collision happens at batch {}, repeat the generation.".format(coll_b)) + should_repeat = True + else: + should_repeat = False + break + + self.scenestreamer_tokens = copy.deepcopy(tmp_scenestreamer_tokens) + self.step_info_dict = copy.deepcopy(tmp_step_info_dict) + intra_step = copy.deepcopy(tmp_intra_step) + step_info_dict = self.step_info_dict + scenestreamer_tokens = self.scenestreamer_tokens + + # print("STEP {}, Generating agent {} with teacher forcing: {}. Agent type {}, position: {}, shape: {}".format( + # self.current_step, agent_index, teacher_forcing_this_agent, agent_type[0, agent_index], + # as_position[0].cpu().numpy(), step_info_dict["agent_shape"][0, agent_index].cpu().numpy() + # )) + # print("STEP {}, Generating agent {}. SceneStreamer len {}.".format( + # self.current_step, agent_index, self.scenestreamer_tokens.seq_len + # )) + + # Finally, input the sequence_eos_id. + intra_step += 1 + assert intra_step == G - 1, (intra_step, G, G - 1) + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.full((B, 1), model.trafficgen_sequence_eos_id, device=device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=device), + tg_intra_step=torch.full((B, 1), intra_step, device=device), + tg_feat=torch.full((B, 1, 8), 0.0, device=device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=device), + heading=torch.full((B, 1), 0, device=device), + valid_mask=torch.full((B, 1), True, device=device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=device), + length=torch.full((B, 1), 0.0, device=device), + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=torch.full((B, 1), False, device=device, dtype=torch.bool), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + ) + + # The only thing need to be updated by non-teacher_forcing TG is the destination: + # step_info_dict["agent_destination"] = torch.stack(agent_destination_list, dim=1) + # step_info_dict["agent_destination_position"] = torch.stack(agent_destination_pos_list, dim=1) + + self.step_info_dict = step_info_dict + self.state = self.STATE_TRAFFICGEN_DONE + + def _step_generate_motion(self, *, teacher_forcing, allow_newly_added, teacher_forcing_sdc): + assert self.state in [self.STATE_TRAFFICGEN_DONE, self.STATE_TRAFFICGEN_SKIPPED] + + model = self.model + data_dict = self.raw_data_dict + scenestreamer_tokens = self.scenestreamer_tokens + step_info_dict = self.step_info_dict + current_step = self.current_step + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + sampling_method = self.sampling_method + temperature = self.temperature + topp = self.topp + keep_output_token = self.keep_output_token + + B, N = scenestreamer_tokens.B, scenestreamer_tokens.N + + agent_delta = utils.get_relative_velocity( + vel=step_info_dict["agent_velocity"].reshape(B, 1, N, 2), + heading=step_info_dict["agent_heading"].reshape(B, 1, N) + ) + motion_input_dict = { + "decoder/input_action_valid_mask": step_info_dict["agent_valid_mask"].reshape(B, 1, N), + "decoder/modeled_agent_position": step_info_dict["agent_position"].reshape(B, 1, N, 2), + "decoder/modeled_agent_heading": step_info_dict["agent_heading"].reshape(B, 1, N), + "decoder/modeled_agent_delta": agent_delta, + "decoder/current_agent_shape": step_info_dict["agent_shape"].reshape(B, N, 3), + "decoder/agent_type": step_info_dict["agent_type"].reshape(B, N), + + "encoder/modeled_agent_id": step_info_dict["agent_id"].reshape(B, N), + } + if teacher_forcing: + motion_input_dict["decoder/input_action"] = data_dict["decoder/input_action"][:, + current_step:current_step + 1] + else: + motion_input_dict["decoder/input_action"] = step_info_dict["motion_input_action"].reshape(B, 1, N) + + motion_input_dict = model.prepare_motion_tokens(motion_input_dict) + motion_tokens = motion_input_dict["model/motion_token"] + motion_position = motion_input_dict["model/motion_token_position"] + motion_heading = motion_input_dict["model/motion_token_heading"] + motion_valid_mask = motion_input_dict["model/motion_token_valid_mask"] + motion_width = motion_input_dict["model/motion_token_width"] + motion_length = motion_input_dict["model/motion_token_length"] + B, _, N, _ = motion_tokens.shape + + # ===== causal mask ===== + # causal_mask = model._build_all_tokens_mask_for_motion( + # B=scenestreamer_tokens.B, + # T=current_step + 1, + # num_tl=scenestreamer_tokens.L, + # num_tg=scenestreamer_tokens.G, + # num_motion=scenestreamer_tokens.N + # ) + # causal_mask = causal_mask[:, -1] + + scenestreamer_tokens.add( + token=motion_tokens.flatten(1, 2), + position=motion_position.flatten(1, 2), + heading=motion_heading.flatten(1, 2), + valid_mask=motion_valid_mask.flatten(1, 2), + width=motion_width.flatten(1, 2), + length=motion_length.flatten(1, 2), + causal_mask=all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + N, + :scenestreamer_tokens.seq_len + N], + current_step=current_step, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + N, :scenestreamer_tokens.seq_len + N], + require_relation=motion_valid_mask.flatten(1, 2), + ) + + # print("Step {}: motion position: {}, heading: {}, valid_mask: {}".format( + # current_step, + # motion_position.flatten(1, 2)[0, 0].tolist(), + # motion_heading.flatten(1, 2)[0, 0].tolist(), + # motion_valid_mask.flatten(1, 2)[0, 0].tolist() + # )) + + # debug code: save causal mask to files + # import matplotlib.pyplot as plt + # vis = scenestreamer_tokens.causal_mask[0].cpu().numpy() + # fig = plt.figure() + # ax = fig.add_subplot(111) + # ax.imshow(vis) + # plt.savefig("causal_mask_{}.png".format(current_step)) + + # ===== prepare dynamic relation ===== + output_dict = scenestreamer_tokens.call_model_with_cache(keep_output_token=keep_output_token) + all_token = output_dict["model/all_token"] + motion_token = all_token[:, -scenestreamer_tokens.N:] + if model.motion_prenorm is not None: + motion_token = model.motion_prenorm(motion_token) + output_token = model.motion_head(motion_token) + + # ===== Post-process the data ===== + selected_action, sampling_info = scenestreamer_motion.sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + agent_valid_mask = step_info_dict["agent_valid_mask"] + agent_position = step_info_dict["agent_position"] + agent_heading = step_info_dict["agent_heading"] + agent_velocity = step_info_dict["agent_velocity"] + agent_type = step_info_dict["agent_type"] + + # Remove invalid actions + # assert selected_action.shape == input_action.shape + # correct_selected_action = torch.where(input_action_valid_mask, selected_action, -1) + selected_action = torch.where(agent_valid_mask, selected_action, -1) + + if teacher_forcing_sdc: + teacher_forcing_action = data_dict["decoder/target_action"][:, self.current_step].clone().reshape(B, N) + sdc_index = data_dict["decoder/sdc_index"] + assert (sdc_index == 0).all() + # if not (data_dict["decoder/target_action_valid_mask"][:, self.current_step, 0] == True).all(): + # print(111111) + # "teacher forcing SDC should always be valid in data." + selected_action[:, 0] = teacher_forcing_action[:, 0] + step_info_dict["agent_valid_mask"][:, 0] = data_dict["decoder/target_action_valid_mask"][:, + self.current_step, 0].clone() + agent_position[:, 0] = data_dict["decoder/modeled_agent_position"][:, self.current_step, 0].clone() + agent_heading[:, 0] = data_dict["decoder/modeled_agent_heading"][:, self.current_step, 0].clone() + agent_velocity[:, 0] = data_dict["decoder/modeled_agent_velocity"][:, self.current_step, 0].clone() + + # tokenizer = model.tokenizer + res = model.motion_tokenizer.detokenize_step( + current_pos=agent_position.reshape(B, 1, N, 2), + current_heading=agent_heading.reshape(B, 1, N), + current_valid_mask=agent_valid_mask.reshape(B, 1, N), + current_vel=agent_velocity.reshape(B, 1, N, 2), + action=selected_action.reshape(B, 1, N), + # agent_type=agent_type.reshape(B, 1, N), + ) + + # B, _, N = input_action.shape[:3] + new_agent_position = res["pos"].reshape(B, N, 2) + new_agent_heading = res["heading"].reshape(B, N) + new_agent_velocity = res["vel"].reshape(B, N, 2) + + step_info_dict["agent_position"] = new_agent_position.clone() + step_info_dict["agent_heading"] = new_agent_heading.clone() + step_info_dict["agent_velocity"] = new_agent_velocity.clone() + step_info_dict["motion_input_action"] = selected_action.reshape(B, N).clone() + + if allow_newly_added: + new_agent_valid_mask = ( + data_dict["decoder/input_action_valid_mask"][:, current_step + 1] & ( + ~step_info_dict["agent_valid_mask"]) + ) + + if new_agent_valid_mask.any(): + new_agent_pos = data_dict["decoder/modeled_agent_position"][:, current_step + 1] + new_agent_heading = data_dict["decoder/modeled_agent_heading"][:, current_step + 1] + new_agent_velocity = data_dict["decoder/modeled_agent_velocity"][:, current_step + 1] + new_action = data_dict["decoder/input_action"][:, current_step + 1] + + B, N = new_agent_valid_mask.shape + assert new_agent_pos.shape == (B, N, 2) + assert new_agent_heading.shape == (B, N) + assert new_agent_velocity.shape == (B, N, 2) + + current_pos = step_info_dict["agent_position"] + current_heading = step_info_dict["agent_heading"] + current_vel = step_info_dict["agent_velocity"] + current_valid_mask = step_info_dict["agent_valid_mask"] + + mask_2d = new_agent_valid_mask[..., None].expand_as(new_agent_pos) + current_pos = torch.where(mask_2d, new_agent_pos, current_pos) + current_heading = torch.where(new_agent_valid_mask, new_agent_heading, current_heading) + current_vel = torch.where(mask_2d, new_agent_velocity, current_vel) + current_valid_mask = torch.where(new_agent_valid_mask, new_agent_valid_mask, current_valid_mask) + + step_info_dict["agent_position"] = current_pos.clone() + step_info_dict["agent_heading"] = current_heading.clone() + step_info_dict["agent_velocity"] = current_vel.clone() + step_info_dict["agent_valid_mask"] = current_valid_mask.clone() + step_info_dict["motion_input_action"] = torch.where(new_agent_valid_mask, new_action, + step_info_dict["motion_input_action"]).clone() + + # # TODO: evict agents that moving out of the map (useful in SceneStreamer) + # if evict_agent: + # next_step_data_dict, info_dict = evict_agents_function( + # data_dict=data_dict, + # step_data_dict=next_step_data_dict, + # step_info_dict=info_dict, + # remove_static_agent=remove_static_agent, + # remove_out_of_map_agent=remove_out_of_map_agent + # ) + + tmp_action = step_info_dict["motion_input_action"].clone() + tmp_valid_mask = agent_valid_mask.clone() + tmp_valid_mask[tmp_action == -1] = False + tmp_valid_mask[tmp_action == MOTION_START_ACTION] = False + tmp_action[tmp_action == -1] = 0 + tmp_action[tmp_action == MOTION_START_ACTION] = 0 + log_prob = sampling_info["dist"].log_prob(tmp_action) + step_info_dict["motion_input_action_log_prob"] = (log_prob * tmp_valid_mask).clone() + + self.step_info_dict = step_info_dict + self.state = self.STATE_MOTION_DONE + + def _step_generate_trafficlight(self, teacher_forcing=False): + step_info_dict = self.step_info_dict + data_dict = self.raw_data_dict + current_step = self.current_step + model = self.model + scenestreamer_tokens = self.scenestreamer_tokens + keep_output_token = False + all_token_casual_mask = self.all_token_casual_mask + all_force_mask = self.all_force_mask + + assert self.state in [self.STATE_START, self.STATE_MOTION_DONE], "State should be either start or motion done" + + tl_input_dict = { + # no time dim: + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"][..., :2], + "encoder/traffic_light_heading": data_dict["encoder/traffic_light_heading"], + "encoder/traffic_light_map_id": data_dict["encoder/traffic_light_map_id"], + } + + if teacher_forcing: + tl_input_dict.update({ + "encoder/traffic_light_state": data_dict["encoder/traffic_light_state"][:, + current_step:current_step + 1], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"][:, + current_step:current_step + 1], + }) + else: + tl_input_dict.update({ + "encoder/traffic_light_state": step_info_dict["traffic_light_state"], + "encoder/traffic_light_valid_mask": step_info_dict["traffic_light_valid_mask"], + }) + + B, _, L = tl_input_dict["encoder/traffic_light_state"].shape + + tl_input_dict = model.prepare_traffic_light_tokens(tl_input_dict) + tl_token = tl_input_dict["model/traffic_light_token"] + tl_position = tl_input_dict["model/traffic_light_token_position"] + tl_heading = tl_input_dict["model/traffic_light_token_heading"] + tl_valid_mask = tl_input_dict["model/traffic_light_token_valid_mask"] + assert tl_token.shape == (B, 1, L, model.d_model) + assert tl_position.shape == (B, 1, L, 2) + assert tl_heading.shape == (B, 1, L) + assert tl_valid_mask.shape == (B, 1, L) + traffic_light_width = torch.zeros_like(tl_position[..., 0]) + traffic_light_length = torch.zeros_like(tl_position[..., 0]) + + # ===== causal mask ===== + N = self.N + if model.no_tg: + G = 0 + else: + G = get_num_tg(N) + + # ===== token ===== + if scenestreamer_tokens is None: + tl_causal_mask = all_token_casual_mask[:, :L, :L] + steps = torch.full((B, L), current_step, dtype=torch.long, device=tl_position.device) + scenestreamer_tokens = scenestreamer_motion.SceneStreamerTokens( + token=tl_token.flatten(1, 2), + position=tl_position.flatten(1, 2), + heading=tl_heading.flatten(1, 2), + valid_mask=tl_valid_mask.flatten(1, 2), + width=traffic_light_width.flatten(1, 2), + length=traffic_light_length.flatten(1, 2), + causal_mask=tl_causal_mask, + force_mask=all_force_mask[:, :L, :L], + step=steps, + current_step=current_step, + L=L, + N=N, + G=G, + require_relation=tl_valid_mask.flatten(1, 2), + + model=model, + data_dict=data_dict, + ) + else: + tl_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + L, + :scenestreamer_tokens.seq_len + L] + scenestreamer_tokens.add( + token=tl_token.flatten(1, 2), + position=tl_position.flatten(1, 2), + heading=tl_heading.flatten(1, 2), + valid_mask=tl_valid_mask.flatten(1, 2), + width=traffic_light_width.flatten(1, 2), + length=traffic_light_length.flatten(1, 2), + causal_mask=tl_causal_mask, + current_step=current_step, + require_relation=tl_valid_mask.flatten(1, 2), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + L, + :scenestreamer_tokens.seq_len + L], + ) + + # import matplotlib.pyplot as plt + # vis=all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + L, :scenestreamer_tokens.seq_len + L][0].cpu().numpy() + # plt.imshow(vis) + + # Note that if teacher_forcing is False while there is no traffic light, + # we will have L=1 and there will be error when calling the model in the first step. + # Because at that time num_Q = num_K = 0. + # This won't be a problem if we use teacher_forcing or in any future step > 0. + if teacher_forcing: + step_info_dict["traffic_light_state"] = data_dict["encoder/traffic_light_state"][:, + current_step + 1:current_step + 2].clone() + step_info_dict["traffic_light_valid_mask"] = data_dict["encoder/traffic_light_valid_mask"][:, + current_step + 1:current_step + 2].clone() + + else: + if tl_valid_mask.any(): + output_dict = scenestreamer_tokens.call_model_with_cache(keep_output_token=keep_output_token) + + # ===== Post-process the data ===== + traffic_light_token = output_dict["model/all_token"][:, -L:] + traffic_light_token = model.traffic_light_prenorm(traffic_light_token) + traffic_light_token = model.traffic_light_head(traffic_light_token) + # output_dict["model/traffic_light_logit"] = traffic_light_token + + tl_state, _ = scenestreamer_motion.sample_action(traffic_light_token, sampling_method="softmax") + step_info_dict["traffic_light_state"] = tl_state.reshape(B, 1, L).clone() + step_info_dict["traffic_light_valid_mask"] = tl_valid_mask.reshape(B, 1, L).clone() + + self.step_info_dict = step_info_dict + self.scenestreamer_tokens = scenestreamer_tokens + self.state = self.STATE_TRAFFICLIGHT_DONE + + @property + def B(self): + return self.raw_data_dict["decoder/input_action"].shape[0] + + @property + def N(self): + return self.raw_data_dict["decoder/input_action"].shape[2] + + @property + def L(self): + return self.raw_data_dict["encoder/traffic_light_state"].shape[2] + + @property + def G(self): + return get_num_tg(self.N) + + def generate_scenestreamer_motion(self, *, progress_bar=False, num_decode_steps=19, teacher_forcing_sdc=False): + """ + This is the WOSAC generate where no initial state is generated. + """ + model = self.model + if progress_bar: + pbar = tqdm.trange(num_decode_steps, desc="Decoding Step") + else: + pbar = range(num_decode_steps) + + data_dict = self.raw_data_dict + valid_mask = [data_dict["decoder/input_action_valid_mask"][:, :1].clone()] + pos = [data_dict["decoder/modeled_agent_position"][:, :1].clone()] + head = [data_dict["decoder/modeled_agent_heading"][:, :1].clone()] + vel = [data_dict["decoder/modeled_agent_velocity"][:, :1].clone()] + dest = [] + dest_pos = [] + tl_state = [] + log_prob = [] + action = [] + B, N, G, L = self.B, self.N, self.G, self.L + + # TODO ========================================= + # TODO ========================================= + # TODO ========================================= + # TODO ========================================= + teacher_forcing_dest = True + + no_tg = model.no_tg + for decoding_step in pbar: + self.current_step = decoding_step + if model.no_tg is False: + if decoding_step % TG_SKIP_STEP == 0: + no_tg = False + else: + no_tg = True + # TODO: not hardcoded. + if decoding_step < 2: + teacher_forcing_motion = True + allow_newly_added = True + teacher_forcing_tl = True + else: + teacher_forcing_motion = False + allow_newly_added = False + teacher_forcing_tl = False + if decoding_step <= 2: + teacher_forcing_tg = True + else: + teacher_forcing_tg = False + + # ===== Traffic light ===== + self._step_generate_trafficlight(teacher_forcing=teacher_forcing_tl) + if self.step_info_dict["traffic_light_state"].shape[1] > 0: + tl_state.append(self.step_info_dict["traffic_light_state"].reshape(B, 1, L)) + + # ===== Trafficgen ===== + if no_tg: + if teacher_forcing_tg: + current_step = decoding_step + self.step_info_dict["agent_valid_mask"] = \ + data_dict["decoder/input_action_valid_mask"][:, current_step].clone() + self.step_info_dict["agent_position"] = data_dict["decoder/modeled_agent_position"][:, + current_step].clone() + self.step_info_dict["agent_heading"] = data_dict["decoder/modeled_agent_heading"][:, + current_step].clone() + self.step_info_dict["agent_velocity"] = data_dict["decoder/modeled_agent_velocity"][:, + current_step].clone() + self.step_info_dict["agent_type"] = data_dict["decoder/agent_type"].clone() + self.step_info_dict["agent_shape"] = data_dict["decoder/current_agent_shape"].clone() + self.step_info_dict["agent_id"] = data_dict["encoder/modeled_agent_id"].clone() + self.state = self.STATE_TRAFFICGEN_SKIPPED + else: + self._step_generate_trafficgen_no_agent_state( + teacher_forcing_from_gt=teacher_forcing_tg, teacher_forcing_dest=teacher_forcing_dest + ) + # dest.append(self.step_info_dict["agent_destination"].reshape(B, 1, N)) + # dest_pos.append(self.step_info_dict["agent_destination_position"].reshape(B, 1, N, 2)) + + # ===== Motion ===== + self._step_generate_motion( + teacher_forcing=teacher_forcing_motion, + allow_newly_added=allow_newly_added, + teacher_forcing_sdc=teacher_forcing_sdc + ) + pos.append(self.step_info_dict["agent_position"].reshape(B, 1, N, 2).clone()) + head.append(self.step_info_dict["agent_heading"].reshape(B, 1, N).clone()) + vel.append(self.step_info_dict["agent_velocity"].reshape(B, 1, N, 2).clone()) + valid_mask.append(self.step_info_dict["agent_valid_mask"].reshape(B, 1, N).clone()) + log_prob.append(self.step_info_dict["motion_input_action_log_prob"].reshape(B, 1, N).clone()) + action.append(self.step_info_dict["motion_input_action"].reshape(B, 1, N).clone()) + + assert self.all_token_casual_mask.shape[1] == self.all_token_casual_mask.shape[ + 2] == self.scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + self.all_token_casual_mask.shape, self.scenestreamer_tokens.seq_len + ) + ) + assert self.all_force_mask.shape[1] == self.all_force_mask.shape[2] == self.scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + self.all_force_mask.shape, self.scenestreamer_tokens.seq_len + ) + ) + + pos = torch.cat(pos, dim=1) + head = torch.cat(head, dim=1) + vel = torch.cat(vel, dim=1) + action = torch.cat(action, dim=1) + if dest: + dest = torch.cat(dest, dim=1) + dest_pos = torch.cat(dest_pos, dim=1) + else: + dest = None + dest_pos = None + + # Evict the last step's input_action_valid_mask_list as it is not used. + # valid_mask = valid_mask[:-1] + valid_mask = torch.cat(valid_mask, dim=1) + + tl_state = torch.cat(tl_state, dim=1) + + log_prob = torch.cat(log_prob, dim=1) + + # ===== Interpolate the output ===== + output_dict = {} + + output_dict, _ = scenestreamer_motion.interpolate_autoregressive_output( + data_dict=output_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + agent_destination=dest, + agent_destination_position=dest_pos, + input_valid_mask=valid_mask, + num_skipped_steps=model.motion_tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + teacher_forcing_sdc=teacher_forcing_sdc, + ) + + assert log_prob.shape == (B, 19, N) + scores = (log_prob * valid_mask[:, :-1])[:, 2:].sum(1) + + # from scenestreamer.models import relation + # scenestreamer_tokens = self.scenestreamer_tokens + # knn = self.model.config.SCENESTREAMER_ATTENTION_KNN + # max_distance = self.model.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE + # relation_valid_mask = relation.compute_relation_for_scenestreamer( + # query_pos=scenestreamer_tokens.position[:, :], + # query_heading=scenestreamer_tokens.heading[:, :], + # query_valid_mask=scenestreamer_tokens.valid_mask[:, :], + # query_step=scenestreamer_tokens.step[:, :], + # key_pos=scenestreamer_tokens.position, + # key_heading=scenestreamer_tokens.heading, + # key_valid_mask=scenestreamer_tokens.valid_mask, + # key_step=scenestreamer_tokens.step, + # causal_valid_mask=scenestreamer_tokens.causal_mask[:, :], + # force_attention_mask=scenestreamer_tokens.force_mask[:, :], + # + # knn=knn, + # max_distance=max_distance, + # + # gather=False, + # query_width=None, + # # set query's w/l to 0 so that we get the rel of contour of key w.r.t. center of query + # query_length=None, + # key_width=scenestreamer_tokens.width, + # key_length=scenestreamer_tokens.length, + # non_agent_relation=True, + # + # require_relation=scenestreamer_tokens.require_relation[:, :], + # require_relation_for_key=scenestreamer_tokens.require_relation, + # )[1] + # import matplotlib.pyplot as plt + # vis = relation_valid_mask[0].cpu().numpy() + # plt.imshow(vis) + # + # data_dict = scenestreamer_tokens.data_dict + # map_position = data_dict["model/map_token_position"] + # map_heading = data_dict["model/map_token_heading"] + # map_token_valid_mask = data_dict["model/map_token_valid_mask"] + # relation_valid_mask = relation.compute_relation_for_scenestreamer( + # query_pos=scenestreamer_tokens.position[:, :], + # query_heading=scenestreamer_tokens.heading[:, :], + # query_valid_mask=scenestreamer_tokens.valid_mask[:, :], + # query_step=scenestreamer_tokens.step[:, :], + # + # + # # =========================== + # + # key_pos=map_position, + # key_heading=map_heading, + # key_valid_mask=map_token_valid_mask, + # key_step=torch.zeros_like(map_heading, dtype=torch.int64), + # key_width=None, + # key_length=None, + # causal_valid_mask=None, + # knn=knn, + # max_distance=max_distance, + # gather=False, + # non_agent_relation=True, + # require_relation_for_key=map_token_valid_mask, + # + # require_relation=scenestreamer_tokens.require_relation, + # )[1] + # import matplotlib.pyplot as plt + # vis = relation_valid_mask[0].cpu().numpy() + # plt.imshow(vis) + + output_dict.update({ + + # TODO: Not accumulated across steps? now is the last. + "decoder/current_agent_shape": self.step_info_dict["agent_shape"], + "model/traffic_light_state": tl_state, + + # feed forward + "encoder/map_feature_valid_mask": data_dict["encoder/map_feature_valid_mask"], + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"], + # "decoder/labeled_agent_id" + # "decoder/object_of_interest_id" + + "decoder/output_score": scores, + "model/output_action": action, + }) + if "decoder/sdc_index" in data_dict: + output_dict["decoder/sdc_index"] = data_dict["decoder/sdc_index"] + if "raw/map_feature" in data_dict: + output_dict["raw/map_feature"] = data_dict["raw/map_feature"] + if "vis/map_feature" in data_dict: + output_dict["vis/map_feature"] = data_dict["vis/map_feature"] + if "decoder/object_of_interest_id" in data_dict: + output_dict["decoder/object_of_interest_id"] = data_dict["decoder/object_of_interest_id"] + + # plot_dict = utils.unbatch_data(utils.torch_to_numpy(output_dict)) + # from scenestreamer.gradio_ui.plot import plot_pred + # plot_pred(plot_dict, show=True) + + output_dict["scenestreamer_tokens"] = self.scenestreamer_tokens + return output_dict + + def generate_scenestreamer_motion_with_densified_scenario(self, *, + veh_ratio, + ped_ratio, + num_new_agents, + progress_bar=False, num_decode_steps=19, + teacher_forcing_sdc=False): + """ + This is the WOSAC generate where no initial state is generated. + """ + model = self.model + if progress_bar: + pbar = tqdm.trange(num_decode_steps, desc="Decoding Step") + else: + pbar = range(num_decode_steps) + + data_dict = self.raw_data_dict + + # Should not prepare the data_dict here. + # valid_mask = [data_dict["decoder/input_action_valid_mask"][:, :1].clone()] + # pos = [data_dict["decoder/modeled_agent_position"][:, :1].clone()] + # head = [data_dict["decoder/modeled_agent_heading"][:, :1].clone()] + # vel = [data_dict["decoder/modeled_agent_velocity"][:, :1].clone()] + valid_mask = [] + pos = [] + head = [] + vel = [] + agent_shape = [] + + dest = [] + dest_pos = [] + tl_state = [] + log_prob = [] + action = [] + B, N, G, L = self.B, self.N, self.G, self.L + + # TODO: reset the scenestreamer tokens + N = num_new_agents + G = get_num_tg(N) + all_token_casual_mask = model._build_all_tokens_mask( + B=B, T=num_decode_steps, num_tl=L, num_tg=G, num_motion=N + ).to(data_dict["decoder/input_action"].device) + self.all_token_casual_mask = all_token_casual_mask + all_force_mask = model._build_all_force_mask( + B=B, T=num_decode_steps, num_tl=L, num_tg=G, num_motion=N + ).to(data_dict["decoder/input_action"].device) + self.all_force_mask = all_force_mask + + # TODO ========================================= + # TODO ========================================= + # TODO ========================================= + # TODO ========================================= + teacher_forcing_dest = True + + no_tg = model.no_tg + for decoding_step in pbar: + self.current_step = decoding_step + if model.no_tg is False: + if decoding_step % TG_SKIP_STEP == 0: + no_tg = False + else: + no_tg = True + + # TODO: just disable TF for now. + + # if decoding_step < 2: + # if decoding_step < 1: + # teacher_forcing_motion = True + # allow_newly_added = True + # else: + teacher_forcing_motion = False + allow_newly_added = False + + if decoding_step <= 0: + teacher_forcing_tg = True + teacher_forcing_tl = True + else: + teacher_forcing_tg = False + teacher_forcing_tl = False + + # ===== Traffic light ===== + self._step_generate_trafficlight(teacher_forcing=teacher_forcing_tl) + if self.step_info_dict["traffic_light_state"].shape[1] > 0: + tl_state.append(self.step_info_dict["traffic_light_state"].reshape(B, 1, L)) + + # ===== Trafficgen ===== + if no_tg: + if teacher_forcing_tg: + current_step = decoding_step + self.step_info_dict["agent_valid_mask"] = \ + data_dict["decoder/input_action_valid_mask"][:, current_step].clone() + self.step_info_dict["agent_position"] = data_dict["decoder/modeled_agent_position"][:, + current_step].clone() + self.step_info_dict["agent_heading"] = data_dict["decoder/modeled_agent_heading"][:, + current_step].clone() + self.step_info_dict["agent_velocity"] = data_dict["decoder/modeled_agent_velocity"][:, + current_step].clone() + self.step_info_dict["agent_type"] = data_dict["decoder/agent_type"].clone() + self.step_info_dict["agent_shape"] = data_dict["decoder/current_agent_shape"].clone() + self.step_info_dict["agent_id"] = data_dict["encoder/modeled_agent_id"].clone() + self.state = self.STATE_TRAFFICGEN_SKIPPED + else: + self._step_generate_trafficgen_densified_agent_state( + teacher_forcing_from_gt=teacher_forcing_tg, teacher_forcing_dest=teacher_forcing_dest, + veh_ratio=veh_ratio, + ped_ratio=ped_ratio, + num_new_agents=num_new_agents, + ) + # dest.append(self.step_info_dict["agent_destination"].reshape(B, 1, N)) + # dest_pos.append(self.step_info_dict["agent_destination_position"].reshape(B, 1, N, 2)) + + if self.current_step == 0: + pos.append(self.step_info_dict["agent_position"].reshape(B, 1, N, 2).clone()) + head.append(self.step_info_dict["agent_heading"].reshape(B, 1, N).clone()) + vel.append(self.step_info_dict["agent_velocity"].reshape(B, 1, N, 2).clone()) + valid_mask.append(self.step_info_dict["agent_valid_mask"].reshape(B, 1, N).clone()) + agent_shape.append(self.step_info_dict["agent_shape"].reshape(B, N, 3).clone()) + + # ===== Motion ===== + assert teacher_forcing_sdc is False + self._step_generate_motion( + teacher_forcing=teacher_forcing_motion, + allow_newly_added=allow_newly_added, + teacher_forcing_sdc=teacher_forcing_sdc + ) + pos.append(self.step_info_dict["agent_position"].reshape(B, 1, N, 2).clone()) + head.append(self.step_info_dict["agent_heading"].reshape(B, 1, N).clone()) + vel.append(self.step_info_dict["agent_velocity"].reshape(B, 1, N, 2).clone()) + valid_mask.append(self.step_info_dict["agent_valid_mask"].reshape(B, 1, N).clone()) + log_prob.append(self.step_info_dict["motion_input_action_log_prob"].reshape(B, 1, N).clone()) + action.append(self.step_info_dict["motion_input_action"].reshape(B, 1, N).clone()) + + assert self.all_token_casual_mask.shape[1] == self.all_token_casual_mask.shape[ + 2] == self.scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + self.all_token_casual_mask.shape, self.scenestreamer_tokens.seq_len + ) + ) + assert self.all_force_mask.shape[1] == self.all_force_mask.shape[2] == self.scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + self.all_force_mask.shape, self.scenestreamer_tokens.seq_len + ) + ) + + pos = torch.cat(pos, dim=1) + head = torch.cat(head, dim=1) + vel = torch.cat(vel, dim=1) + action = torch.cat(action, dim=1) + if dest: + dest = torch.cat(dest, dim=1) + dest_pos = torch.cat(dest_pos, dim=1) + else: + dest = None + dest_pos = None + + # Evict the last step's input_action_valid_mask_list as it is not used. + # valid_mask = valid_mask[:-1] + valid_mask = torch.cat(valid_mask, dim=1) + + tl_state = torch.cat(tl_state, dim=1) + + log_prob = torch.cat(log_prob, dim=1) + + # ===== Interpolate the output ===== + output_dict = {} + + output_dict, _ = scenestreamer_motion.interpolate_autoregressive_output( + data_dict=output_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + agent_destination=dest, + agent_destination_position=dest_pos, + input_valid_mask=valid_mask, + num_skipped_steps=model.motion_tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + agent_shape=agent_shape, + teacher_forcing_sdc=teacher_forcing_sdc + ) + + assert log_prob.shape == (B, 19, N) + scores = (log_prob * valid_mask[:, :-1])[:, 2:].sum(1) + + # from scenestreamer.models import relation + # scenestreamer_tokens = self.scenestreamer_tokens + # knn = self.model.config.SCENESTREAMER_ATTENTION_KNN + # max_distance = self.model.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE + # relation_valid_mask = relation.compute_relation_for_scenestreamer( + # query_pos=scenestreamer_tokens.position[:, :], + # query_heading=scenestreamer_tokens.heading[:, :], + # query_valid_mask=scenestreamer_tokens.valid_mask[:, :], + # query_step=scenestreamer_tokens.step[:, :], + # key_pos=scenestreamer_tokens.position, + # key_heading=scenestreamer_tokens.heading, + # key_valid_mask=scenestreamer_tokens.valid_mask, + # key_step=scenestreamer_tokens.step, + # causal_valid_mask=scenestreamer_tokens.causal_mask[:, :], + # force_attention_mask=scenestreamer_tokens.force_mask[:, :], + # + # knn=knn, + # max_distance=max_distance, + # + # gather=False, + # query_width=None, + # # set query's w/l to 0 so that we get the rel of contour of key w.r.t. center of query + # query_length=None, + # key_width=scenestreamer_tokens.width, + # key_length=scenestreamer_tokens.length, + # non_agent_relation=True, + # + # require_relation=scenestreamer_tokens.require_relation[:, :], + # require_relation_for_key=scenestreamer_tokens.require_relation, + # )[1] + # import matplotlib.pyplot as plt + # vis = relation_valid_mask[0].cpu().numpy() + # plt.imshow(vis) + # + # data_dict = scenestreamer_tokens.data_dict + # map_position = data_dict["model/map_token_position"] + # map_heading = data_dict["model/map_token_heading"] + # map_token_valid_mask = data_dict["model/map_token_valid_mask"] + # relation_valid_mask = relation.compute_relation_for_scenestreamer( + # query_pos=scenestreamer_tokens.position[:, :], + # query_heading=scenestreamer_tokens.heading[:, :], + # query_valid_mask=scenestreamer_tokens.valid_mask[:, :], + # query_step=scenestreamer_tokens.step[:, :], + # + # + # # =========================== + # + # key_pos=map_position, + # key_heading=map_heading, + # key_valid_mask=map_token_valid_mask, + # key_step=torch.zeros_like(map_heading, dtype=torch.int64), + # key_width=None, + # key_length=None, + # causal_valid_mask=None, + # knn=knn, + # max_distance=max_distance, + # gather=False, + # non_agent_relation=True, + # require_relation_for_key=map_token_valid_mask, + # + # require_relation=scenestreamer_tokens.require_relation, + # )[1] + # import matplotlib.pyplot as plt + # vis = relation_valid_mask[0].cpu().numpy() + # plt.imshow(vis) + + output_dict.update({ + + # TODO: Not accumulated across steps? now is the last. + "decoder/current_agent_shape": self.step_info_dict["agent_shape"], + "model/traffic_light_state": tl_state, + + # feed forward + "encoder/map_feature_valid_mask": data_dict["encoder/map_feature_valid_mask"], + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"], + # "decoder/labeled_agent_id" + # "decoder/object_of_interest_id" + + "decoder/output_score": scores, + "model/output_action": action, + }) + if "decoder/sdc_index" in data_dict: + output_dict["decoder/sdc_index"] = data_dict["decoder/sdc_index"] + if "raw/map_feature" in data_dict: + output_dict["raw/map_feature"] = data_dict["raw/map_feature"] + if "vis/map_feature" in data_dict: + output_dict["vis/map_feature"] = data_dict["vis/map_feature"] + if "decoder/object_of_interest_id" in data_dict: + output_dict["decoder/object_of_interest_id"] = data_dict["decoder/object_of_interest_id"] + + # plot_dict = utils.unbatch_data(utils.torch_to_numpy(output_dict)) + # from scenestreamer.gradio_ui.plot import plot_pred + # plot_pred(plot_dict, show=True) + + output_dict["scenestreamer_tokens"] = self.scenestreamer_tokens + return output_dict + + def generate_scenestreamer_initial_state_and_motion(self, *, progress_bar=False, num_decode_steps=19, + teacher_forcing_sdc=False): + """ + This is the WOSAC generate where no initial state is generated. + """ + model = self.model + if progress_bar: + pbar = tqdm.trange(num_decode_steps, desc="Decoding Step") + else: + pbar = range(num_decode_steps) + data_dict = self.raw_data_dict + + # Don't do this in this task... + # valid_mask = [data_dict["decoder/input_action_valid_mask"][:, :1].clone()] + # pos = [data_dict["decoder/modeled_agent_position"][:, :1].clone()] + # head = [data_dict["decoder/modeled_agent_heading"][:, :1].clone()] + # vel = [data_dict["decoder/modeled_agent_velocity"][:, :1].clone()] + valid_mask = [] + pos = [] + head = [] + vel = [] + + dest = [] + dest_pos = [] + tl_state = [] + log_prob = [] + action = [] + agent_shape = [] + B, N, G, L = self.B, self.N, self.G, self.L + no_tg = model.no_tg + assert no_tg is False + for decoding_step in pbar: + self.current_step = decoding_step + if decoding_step % TG_SKIP_STEP == 0: + no_tg = False + else: + no_tg = True + if decoding_step < 2: + teacher_forcing_tl = True + else: + teacher_forcing_tl = False + allow_newly_added = False + teacher_forcing_motion = False + + # ===== Traffic light ===== + self._step_generate_trafficlight(teacher_forcing=teacher_forcing_tl) + if self.step_info_dict["traffic_light_state"].shape[1] > 0: + tl_state.append(self.step_info_dict["traffic_light_state"].reshape(B, 1, L)) + + # ===== Trafficgen ===== + if no_tg: + self.state = self.STATE_TRAFFICGEN_SKIPPED + else: + if self.current_step == 0: + self._step_generate_trafficgen_with_agent_state(teacher_forcing_from_gt=True) + assert self.step_info_dict["agent_valid_mask"].shape == (B, N) + assert (self.step_info_dict["agent_valid_mask"]).all() + self.step_info_dict["motion_input_action"] = torch.full( + (B, N), MOTION_START_ACTION, device=self.device + ) + pos.append(self.step_info_dict["agent_position"].reshape(B, 1, N, 2).clone()) + head.append(self.step_info_dict["agent_heading"].reshape(B, 1, N).clone()) + vel.append(self.step_info_dict["agent_velocity"].reshape(B, 1, N, 2).clone()) + valid_mask.append(self.step_info_dict["agent_valid_mask"].reshape(B, 1, N).clone()) + + agent_shape.append(self.step_info_dict["agent_shape"].reshape(B, N, 3).clone()) + else: + self._step_generate_trafficgen_no_agent_state( + teacher_forcing_from_gt=False, + generate_agent_states=False + ) + # dest.append(self.step_info_dict["agent_destination"].reshape(B, 1, N)) + # dest_pos.append(self.step_info_dict["agent_destination_position"].reshape(B, 1, N, 2)) + + # ===== Motion ===== + self._step_generate_motion( + teacher_forcing=teacher_forcing_motion, + allow_newly_added=allow_newly_added, + teacher_forcing_sdc=teacher_forcing_sdc, + ) + # print("MOTION Step {}, scenestreamer len {}".format( + # self.current_step, self.scenestreamer_tokens.seq_len + # )) + pos.append(self.step_info_dict["agent_position"].reshape(B, 1, N, 2).clone()) + head.append(self.step_info_dict["agent_heading"].reshape(B, 1, N).clone()) + vel.append(self.step_info_dict["agent_velocity"].reshape(B, 1, N, 2).clone()) + valid_mask.append(self.step_info_dict["agent_valid_mask"].reshape(B, 1, N).clone()) + log_prob.append(self.step_info_dict["motion_input_action_log_prob"].reshape(B, 1, N).clone()) + action.append(self.step_info_dict["motion_input_action"].reshape(B, 1, N).clone()) + + assert self.all_token_casual_mask.shape[1] == self.all_token_casual_mask.shape[ + 2] == self.scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + self.all_token_casual_mask.shape, self.scenestreamer_tokens.seq_len + ) + ) + assert self.all_force_mask.shape[1] == self.all_force_mask.shape[2] == self.scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + self.all_force_mask.shape, self.scenestreamer_tokens.seq_len + ) + ) + + pos = torch.cat(pos, dim=1) + head = torch.cat(head, dim=1) + vel = torch.cat(vel, dim=1) + action = torch.cat(action, dim=1) + if dest: + dest = torch.cat(dest, dim=1) + dest_pos = torch.cat(dest_pos, dim=1) + else: + dest = None + dest_pos = None + + # Evict the last step's input_action_valid_mask_list as it is not used. + # valid_mask = valid_mask[:-1] + valid_mask = torch.cat(valid_mask, dim=1) + tl_state = torch.cat(tl_state, dim=1) + log_prob = torch.cat(log_prob, dim=1) + + # ===== Interpolate the output ===== + output_dict = {} + output_dict, _ = scenestreamer_motion.interpolate_autoregressive_output( + data_dict=output_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + agent_destination=dest, + agent_destination_position=dest_pos, + input_valid_mask=valid_mask, + num_skipped_steps=model.motion_tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + agent_shape=agent_shape, + teacher_forcing_sdc=teacher_forcing_sdc, + sdc_index=data_dict["decoder/sdc_index"], + ) + + assert log_prob.shape == (B, 19, N) + scores = (log_prob * valid_mask[:, :-1])[:, 2:].sum(1) + + output_dict.update({ + + # TODO: Not accumulated across steps? now is the last. + "decoder/current_agent_shape": self.step_info_dict["agent_shape"], + "model/traffic_light_state": tl_state, + + # feed forward + "encoder/map_feature_valid_mask": data_dict["encoder/map_feature_valid_mask"], + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"], + # "decoder/labeled_agent_id" + # "decoder/object_of_interest_id" + + "decoder/output_score": scores, + "model/output_action": action, + }) + if "decoder/sdc_index" in data_dict: + output_dict["decoder/sdc_index"] = data_dict["decoder/sdc_index"] + if "raw/map_feature" in data_dict: + output_dict["raw/map_feature"] = data_dict["raw/map_feature"] + if "vis/map_feature" in data_dict: + output_dict["vis/map_feature"] = data_dict["vis/map_feature"] + if "decoder/object_of_interest_id" in data_dict: + output_dict["decoder/object_of_interest_id"] = data_dict["decoder/object_of_interest_id"] + + # plot_dict = utils.unbatch_data(utils.torch_to_numpy(output_dict)) + # from scenestreamer.gradio_ui.plot import plot_pred + # plot_pred(plot_dict, show=True) + + output_dict["scenestreamer_tokens"] = self.scenestreamer_tokens + return output_dict + + def generate_scenestreamer_initial_state(self, *, progress_bar=False, num_decode_steps=19): + """ + This is the WOSAC generate where no initial state is generated. + """ + data_dict = self.raw_data_dict + + # Hardcode here, process the data and set "current_step" to 10. + assert data_dict['metadata/current_time_index'].item() == 10 + new_data_dict = {} + for k in data_dict: + if ("encoder/traffic_light_" in k) or (k == "decoder/input_action_valid_mask") or \ + ("decoder/modeled_agent_" in k): + new_data_dict[k] = data_dict[k][:, 10:11] + data_dict = new_data_dict + + tl_state = [] + B, N, G, L = self.B, self.N, self.G, self.L + teacher_forcing_tl = True + assert self.model.no_tg is False, "SceneStreamer should be used with trafficgen" + # ===== Traffic light ===== + self._step_generate_trafficlight(teacher_forcing=teacher_forcing_tl) + if self.step_info_dict["traffic_light_state"].shape[1] > 0: + tl_state.append(self.step_info_dict["traffic_light_state"].reshape(B, 1, L)) + # ===== Trafficgen ===== + self._step_generate_trafficgen_with_agent_state( + teacher_forcing_from_gt=True, teacher_forcing_dest=111111, + ) + # Do some postprocessing + step_data_dict = self.step_info_dict + data_dict.update( + { + "decoder/modeled_agent_position_for_trafficgen": step_data_dict["agent_position"].clone(), + "decoder/modeled_agent_heading_for_trafficgen": step_data_dict["agent_heading"].clone(), + "decoder/modeled_agent_velocity_for_trafficgen": step_data_dict["agent_velocity"].clone(), + "decoder/current_agent_shape_for_trafficgen": step_data_dict["agent_shape"].clone(), + "decoder/agent_type_for_trafficgen": step_data_dict["agent_type"].clone(), + "decoder/input_action_valid_mask_for_trafficgen": step_data_dict["agent_valid_mask"].clone(), + } + ) + assert step_data_dict["agent_valid_mask"].all() + + # from scenestreamer.infer.initial_state import convert_initial_states_as_motion_data + # data_dict = convert_initial_states_as_motion_data(data_dict) + + return data_dict + + +def plot_initial_state(data_dict, save_path, draw_line=False, draw_text=True): + from scenestreamer.gradio_ui.plot import ( + BOUNDARY, + EGO_FONT_SIZE, + MODELED_FONT_SIZE, + NON_EGO_FONT_SIZE, + _plot_map, + _plot_traffic_light, + draw_trajectory, + get_limit, + ) + import seaborn as sns + import matplotlib.pyplot as plt + import PIL + + fig = plt.figure(figsize=(10, 10), dpi=300) + ax = fig.add_subplot(111) + ax.set_aspect(1) + + agent_pos = data_dict["decoder/agent_position"][:, :, :2] # (91, N, 2) + agent_heading = data_dict["decoder/agent_heading"] # (91, N, 2) + agent_velocity = data_dict["decoder/agent_velocity"] # (91, N, 2) + agent_shape = data_dict["decoder/agent_shape"] # (91, N, 2) + agent_mask = data_dict["decoder/agent_valid_mask"] + ego_agent_id = data_dict['decoder/sdc_index'] + + _plot_map(data_dict, ax, dont_draw_lane=True) + + _plot_traffic_light(data_dict, ax) + + T, N, _ = agent_pos.shape + + modeled_agents_indicies = np.concatenate([data_dict["decoder/object_of_interest_id"], np.atleast_1d(ego_agent_id)]) + + # cmap = sns.color_palette("colorblind", n_colors=N) + cmap = sns.color_palette("crest_r", as_cmap=False, n_colors=N) + cmap_cbar = sns.color_palette("crest_r", as_cmap=True, n_colors=N) + + plotted_count = 0 + draw_trajectory( + ax=ax, + pos=agent_pos[:, ego_agent_id], + heading=agent_heading[:, ego_agent_id], + width=agent_shape[:, ego_agent_id, 1], + length=agent_shape[:, ego_agent_id, 0], + mask=agent_mask[:, ego_agent_id], + fill_color=cmap[0], + traj_kwargs=dict(), + contour_kwargs=dict( + edgecolor="k", + linewidth=0.1, + fill=False, + ), + text="{}-SDC".format(str(ego_agent_id)), + fontsize=EGO_FONT_SIZE, + draw_line=draw_line, + draw_text=True, + ) + plotted_count += 1 + + for agent_ind in range(N): + if agent_ind == ego_agent_id: + continue + if agent_ind in modeled_agents_indicies: + text = "{}-OOI".format(str(agent_ind)) + fontsize = MODELED_FONT_SIZE + else: + text = str(agent_ind) + fontsize = NON_EGO_FONT_SIZE + draw_trajectory( + ax=ax, + pos=agent_pos[:, agent_ind], + heading=agent_heading[:, agent_ind], + width=agent_shape[:, agent_ind, 1], + length=agent_shape[:, agent_ind, 0], + mask=agent_mask[:, agent_ind], + fill_color=cmap[plotted_count], + traj_kwargs=dict(), + contour_kwargs=dict( + edgecolor="k", + linewidth=0.1, + fill=False, + ), + text=text, + fontsize=fontsize, + draw_line=draw_line, + draw_text=False + ) + plotted_count += 1 + + if "vis/map_feature" in data_dict: + map_pos = data_dict["vis/map_feature"][:, :, :2][data_dict["encoder/map_feature_valid_mask"]] + else: + map_pos = data_dict["encoder/map_position"][..., :2][data_dict["encoder/map_valid_mask"]] + ret = get_limit(agent_pos=agent_pos[agent_mask], map_pos=map_pos) + + xmin, xmax, ymin, ymax = ret["xmin"], ret["xmax"], ret["ymin"], ret["ymax"] + + ax.set_xlim(xmin - BOUNDARY, xmax + BOUNDARY) + ax.set_ylim(ymin - BOUNDARY, ymax + BOUNDARY) + ax.set_aspect(1) + + # turn on color bar + import matplotlib.cm as cm + import matplotlib.colors as colors + + norm = colors.Normalize(vmin=0, vmax=N - 1) + sm = cm.ScalarMappable(cmap=cmap_cbar, norm=norm) + sm.set_array([]) # dummy array for compatibility + cbar = plt.colorbar(sm, ax=ax) + cbar.ax.invert_yaxis() # ⬅️ Flip the colorbar + # cbar = plt.colorbar(cmap_cbar, ax=ax) + + fig.tight_layout(pad=0.05) + fig.canvas.draw() + + # plt.show() + ret = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) + + fig.savefig(save_path) + + plt.close(fig) diff --git a/scenestreamer/infer/scenestreamer_motion.py b/scenestreamer/infer/scenestreamer_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bbaf984492c57074578d98ea87cc3a81e9829c --- /dev/null +++ b/scenestreamer/infer/scenestreamer_motion.py @@ -0,0 +1,1746 @@ +""" +This module reimplements the autoregressive motion generation process. +""" + +import copy +import dataclasses + +import numpy as np +import torch +import tqdm + +from scenestreamer.dataset.preprocessor import slice_trafficgen_data, NUM_TG_MULTI, TG_SKIP_STEP +from scenestreamer.models import relation +from scenestreamer.models.scenestreamer_model import get_edge_info_for_scenestreamer, get_num_tg +from scenestreamer.tokenization.motion_tokenizers import interpolate, interpolate_heading, START_ACTION as MOTION_START_ACTION +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils + + +@dataclasses.dataclass +class SceneStreamerTokens: + _cache = None + _cache_length = 0 + + token: torch.Tensor + position: torch.Tensor + heading: torch.Tensor + valid_mask: torch.Tensor + width: torch.Tensor + length: torch.Tensor + causal_mask: torch.Tensor + force_mask: torch.Tensor + step: torch.Tensor + require_relation: torch.Tensor + N: int # Number of motion tokens (agents) + G: int # Number of trafficgen tokens + L: int # Number of traffic light tokens + current_step: int + + model: torch.nn.Module = None + data_dict: dict = None + + output_token: torch.Tensor = None + + def add(self, *, token, position, heading, valid_mask, width, length, causal_mask, force_mask, current_step, require_relation): + assert token.ndim == 3, token.shape # B, seq_len, D + assert position.ndim == 3 # B, seq_len, 2 + assert heading.ndim == 2 # B, seq_len + assert valid_mask.ndim == 2 + assert width.ndim == 2 + assert length.ndim == 2 + assert causal_mask.ndim == 3, causal_mask.shape + assert force_mask.ndim == 3, force_mask.shape + assert current_step >= self.current_step, "current step {} < {}".format(current_step, self.current_step) + device = self.token.device + self.token = torch.cat([self.token, token.to(device)], dim=1) + self.position = torch.cat([self.position, position.to(device)], dim=1) + self.heading = torch.cat([self.heading, heading.to(device)], dim=1) + self.valid_mask = torch.cat([self.valid_mask, valid_mask.to(device)], dim=1) + self.width = torch.cat([self.width, width.to(device)], dim=1) + self.length = torch.cat([self.length, length.to(device)], dim=1) + step = torch.full((token.shape[0], token.shape[1]), current_step, dtype=torch.long, device=token.device) + self.step = torch.cat([self.step, step.to(device)], dim=1) + self.current_step = current_step + + num_existing_keys = self.causal_mask.shape[2] + assert self.causal_mask.shape == (self.B, num_existing_keys, num_existing_keys), self.causal_mask.shape + num_new_keys = causal_mask.shape[2] + assert num_new_keys > num_existing_keys, (num_new_keys, num_existing_keys) + new_all_causal_mask = self.causal_mask.new_zeros(self.B, num_new_keys, num_new_keys) + new_all_causal_mask[:, :num_existing_keys, :num_existing_keys] = self.causal_mask + new_all_causal_mask[:, num_existing_keys:, :] = causal_mask.to(device) + self.causal_mask = new_all_causal_mask + assert self.token.shape[-2] == num_new_keys, self.token.shape + + num_existing_keys = self.force_mask.shape[2] + assert self.force_mask.shape == (self.B, num_existing_keys, num_existing_keys), self.force_mask.shape + num_new_keys = force_mask.shape[2] + assert num_new_keys > num_existing_keys, (num_new_keys, num_existing_keys) + new_all_force_mask = self.force_mask.new_zeros(self.B, num_new_keys, num_new_keys) + new_all_force_mask[:, :num_existing_keys, :num_existing_keys] = self.force_mask + new_all_force_mask[:, num_existing_keys:, :] = force_mask.to(device) + self.force_mask = new_all_force_mask + assert self.token.shape[-2] == num_new_keys, self.token.shape + + self.require_relation = torch.cat( + [self.require_relation, require_relation.to(device)], dim=1 + ) + + # print("\t\tadd token length {}, current step {}, valid {}, pos {}".format(token.shape[1], self.current_step, valid_mask[0].tolist(), position[0].tolist())) + + @property + def B(self): + return self.token.shape[0] + + @property + def seq_len(self): + return self.token.shape[1] + + def able_to_call_model(self): + return self.valid_mask[:, self._cache_length:].any().item() + + def call_model_with_cache(self, knn=None, max_distance=None, use_cache=True, keep_output_token=False): + # ===== prepare dynamic relation ===== + data_dict = self.data_dict + map_position = data_dict["model/map_token_position"] + map_heading = data_dict["model/map_token_heading"] + map_token_valid_mask = data_dict["model/map_token_valid_mask"] + if knn is None: + knn = self.model.config.SCENESTREAMER_ATTENTION_KNN + if max_distance is None: + max_distance = self.model.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE + + relation_all_to_all, relation_valid_mask, require_relation_pairwise = relation.compute_relation_for_scenestreamer( + query_pos=self.position[:, self._cache_length:], + query_heading=self.heading[:, self._cache_length:], + query_valid_mask=self.valid_mask[:, self._cache_length:], + query_step=self.step[:, self._cache_length:], + key_pos=self.position, + key_heading=self.heading, + key_valid_mask=self.valid_mask, + key_step=self.step, + causal_valid_mask=self.causal_mask[:, self._cache_length:], + force_attention_mask=self.force_mask[:, self._cache_length:], + + knn=knn, + max_distance=max_distance, + + gather=False, + query_width=None, + # set query's w/l to 0 so that we get the rel of contour of key w.r.t. center of query + query_length=None, + key_width=None, + key_length=None, + non_agent_relation=True, + + require_relation=self.require_relation[:, self._cache_length:], + require_relation_for_key=self.require_relation, + ) + relation_all_to_all = get_edge_info_for_scenestreamer( + q_k_relation=relation_all_to_all, + q_k_valid_mask=relation_valid_mask, + relation_model=self.model.relation_embed_4d, + relation_model_1d=self.model.relation_embed_1d, + require_relation_pairwise=require_relation_pairwise, + ) + a2m_3d = self.model.config.MODEL.ALL_TO_MAP_3D + assert a2m_3d is False + relation_all_to_map, relation_map_valid_mask, require_relation_pairwise_map = relation.compute_relation_for_scenestreamer( + query_pos=self.position[:, self._cache_length:], + query_heading=self.heading[:, self._cache_length:], + query_valid_mask=self.valid_mask[:, self._cache_length:], + query_step=None if a2m_3d else self.step[:, self._cache_length:], + query_width=None, + query_length=None, + key_pos=map_position, + key_heading=map_heading, + key_valid_mask=map_token_valid_mask, + key_step=None if a2m_3d else torch.zeros_like(map_heading, dtype=torch.int64), + key_width=None, + key_length=None, + causal_valid_mask=None, + knn=knn, + max_distance=max_distance, + gather=False, + non_agent_relation=True, + require_relation=self.require_relation[:, self._cache_length:], + require_relation_for_key=map_token_valid_mask, + ) + relation_all_to_map = get_edge_info_for_scenestreamer( + q_k_relation=relation_all_to_map, + q_k_valid_mask=relation_map_valid_mask, + relation_model=self.model.relation_embed_3d if a2m_3d else self.model.relation_embed_4d, + relation_model_1d=self.model.relation_embed_1d, + require_relation_pairwise=require_relation_pairwise_map, + ) + + # if self._cache is None: + # cachesize = None + # else: + # cachesize = self._cache[0][0].shape if self._cache is not None else None + # print("model call at step {}, length {}, already cache size {}, alltoken len {}".format(self.current_step, self.step[:, self._cache_length:].shape, cachesize, self.seq_len)) + + input_dict = { + "model/map_token": data_dict["model/map_token"], + "model/all_token": self.token[:, self._cache_length:], + "model/all_to_map_info": relation_all_to_map, + "model/all_to_all_info": relation_all_to_all, + } + + # ===== Call Model ===== + if use_cache: + output_dict, cache = self.model.decoder(input_dict=input_dict, use_cache=use_cache, cache=self._cache) + new_cache = [] + for layer in range(len(cache)): + new_cache.append(cache[layer] + [(self.B, self.token.shape[1])]) + self._cache = new_cache + self._cache_length = self.seq_len + else: + output_dict = self.model.decoder(input_dict=input_dict, use_cache=use_cache) + self._cache = None + self._cache_length = 0 + + if keep_output_token: + if self.output_token is None: + self.output_token = output_dict["model/all_token"] + else: + self.output_token = torch.cat([self.output_token, output_dict["model/all_token"]], dim=1) + return output_dict + +# def motion_prediction_task( +# *, +# data_dict, +# model, +# autoregressive_start_step=None, +# allow_newly_added_agent_step=None, +# temperature=None, +# topp=None, +# num_decode_steps=None, +# sampling_method=None, +# interpolation=True, +# remove_out_of_map_agent=False, +# remove_static_agent=False, +# teacher_forcing_sdc=False, +# use_cache=True, +# progress_bar=True, +# keep_output_token=False, +# teacher_forcing_dest=None, +# ): +# if num_decode_steps is None: +# num_decode_steps = 19 +# else: +# print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) +# from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator +# g = SceneStreamerGenerator(model=model) +# g.reset(new_data_dict=data_dict) +# output = g.generate_scenestreamer_motion( +# progress_bar=progress_bar, +# num_decode_steps=num_decode_steps +# ) +# return output + +# @torch.no_grad() +def motion_prediction_task( + *, + data_dict, + model, + autoregressive_start_step=None, + allow_newly_added_agent_step=None, + temperature=None, + topp=None, + num_decode_steps=None, + sampling_method=None, + interpolation=True, + remove_out_of_map_agent=False, + remove_static_agent=False, + teacher_forcing_sdc=False, + use_cache=True, + progress_bar=True, + keep_output_token=False, + teacher_forcing_dest=None, +): + assert teacher_forcing_dest is not None, "Please set teacher_forcing_dest to True or False" + # ===== Some preprocessing ===== + if topp is None: + topp = model.config.SAMPLING.TOPP + if temperature is None: + temperature = model.config.SAMPLING.TEMPERATURE + if sampling_method is None: + sampling_method = model.config.SAMPLING.SAMPLING_METHOD + B, T_input, N = data_dict["decoder/input_action"].shape[:3] + # assert model.training is False, "This function is only for evaluation!" + # data_dict = copy.deepcopy(data_dict) + if num_decode_steps is None: + num_decode_steps = 19 + # assert start_action_step + T_input == num_decode_steps # Might not be True in waymo test set. + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Encode scenes ===== + data_dict, _ = encode_scene(data_dict=data_dict, model=model) + + # ===== Create a temporary input_dict removing the future information ===== + _, _, L = data_dict["encoder/traffic_light_state"].shape + + scenestreamer_tokens = None + step_info_dict = {} + + knn = model.config.SCENESTREAMER_ATTENTION_KNN + max_distance = model.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE + + valid_mask = [data_dict["decoder/input_action_valid_mask"][:, :1].clone()] + pos = [data_dict["decoder/modeled_agent_position"][:, :1].clone()] + head = [data_dict["decoder/modeled_agent_heading"][:, :1].clone()] + vel = [data_dict["decoder/modeled_agent_velocity"][:, :1].clone()] + dest = [] + dest_pos = [] + tl_state = [] + log_prob = [] + action = [] + + # for decoding_step in range(num_decode_steps): + if progress_bar: + pbar = tqdm.trange(num_decode_steps, desc="Decoding Step") + else: + pbar = range(num_decode_steps) + + G = get_num_tg(N) + all_token_casual_mask = model._build_all_tokens_mask( + B=B, T=num_decode_steps, num_tl=L, num_tg=G, num_motion=N + ).to(data_dict["decoder/input_action"].device) + + all_force_mask = model._build_all_force_mask( + B=B, T=num_decode_steps, num_tl=L, num_tg=G, num_motion=N + ).to(data_dict["decoder/input_action"].device) + + no_tg = model.no_tg + + for decoding_step in pbar: + + # # TODO: FIXME: generate_all_agents + # if decoding_step == 0: + # generate_all_agents = True + # else: + # generate_all_agents = False + generate_all_agents = False + + if model.no_tg is False: + if decoding_step % TG_SKIP_STEP == 0: + no_tg = False + else: + no_tg = True + + # TODO: not hardcoded. + if decoding_step < 2: + teacher_forcing_motion = True + allow_newly_added = True + else: + teacher_forcing_motion = False + allow_newly_added = False + + if decoding_step <= 2: + teacher_forcing_tg = True + teacher_forcing_tl = True + else: + teacher_forcing_tg = False + teacher_forcing_tl = False + + # ===== Traffic light ===== + # print("Step {}, Calling Traffic Light".format(decoding_step)) + scenestreamer_tokens, step_info_dict = call_model_for_traffic_light( + model=model, data_dict=data_dict, knn=knn, max_distance=max_distance, scenestreamer_tokens=scenestreamer_tokens, + current_step=decoding_step, step_info_dict=step_info_dict, teacher_forcing=teacher_forcing_tl, + use_cache=use_cache, all_token_casual_mask=all_token_casual_mask, all_force_mask=all_force_mask, + keep_output_token=keep_output_token, + ) + if step_info_dict["traffic_light_state"].shape[1] > 0: + tl_state.append(step_info_dict["traffic_light_state"].reshape(B, 1, L)) + + # ===== Trafficgen ===== + if no_tg: + if teacher_forcing_tg: + current_step = decoding_step + step_info_dict["agent_valid_mask"] = data_dict["decoder/input_action_valid_mask"][:, current_step] + step_info_dict["agent_position"] = data_dict["decoder/modeled_agent_position"][:, current_step] + step_info_dict["agent_heading"] = data_dict["decoder/modeled_agent_heading"][:, current_step] + step_info_dict["agent_velocity"] = data_dict["decoder/modeled_agent_velocity"][:, current_step] + step_info_dict["agent_type"] = data_dict["decoder/agent_type"] + step_info_dict["agent_shape"] = data_dict["decoder/current_agent_shape"] + step_info_dict["agent_id"] = data_dict["encoder/modeled_agent_id"] + + else: + + if generate_all_agents: + scenestreamer_tokens, step_info_dict = call_model_for_trafficgen_generate_all_agents( + model=model, data_dict=data_dict, scenestreamer_tokens=scenestreamer_tokens, + current_step=decoding_step, step_info_dict=step_info_dict, + use_cache=use_cache, all_token_casual_mask=all_token_casual_mask, all_force_mask=all_force_mask, + keep_output_token=keep_output_token, + ) + raise ValueError + + else: + # print("Step {}, Calling Trafficgen".format(decoding_step)) + scenestreamer_tokens, step_info_dict = call_model_for_trafficgen( + model=model, data_dict=data_dict, knn=knn, max_distance=max_distance, scenestreamer_tokens=scenestreamer_tokens, + current_step=decoding_step, step_info_dict=step_info_dict, teacher_forcing_from_gt=teacher_forcing_tg, + use_cache=use_cache, all_token_casual_mask=all_token_casual_mask, teacher_forcing_dest=teacher_forcing_dest, + all_force_mask=all_force_mask, keep_output_token=keep_output_token, + ) + + # dest.append(step_info_dict["agent_destination"].reshape(B, 1, N)) + # dest_pos.append(step_info_dict["agent_destination_position"].reshape(B, 1, N, 2)) + + # ===== Motion ===== + # print("Step {}, Calling Motion".format(decoding_step)) + scenestreamer_tokens, step_info_dict = call_model_for_motion( + model=model, data_dict=data_dict, knn=knn, max_distance=max_distance, + scenestreamer_tokens=scenestreamer_tokens, current_step=decoding_step, + sampling_method=sampling_method, temperature=temperature, topp=topp, + step_info_dict=step_info_dict, teacher_forcing=teacher_forcing_motion, + use_cache=use_cache, allow_newly_added=allow_newly_added, + all_token_casual_mask=all_token_casual_mask, all_force_mask=all_force_mask, + keep_output_token=keep_output_token, + ) + pos.append(step_info_dict["agent_position"].reshape(B, 1, N, 2)) + head.append(step_info_dict["agent_heading"].reshape(B, 1, N)) + vel.append(step_info_dict["agent_velocity"].reshape(B, 1, N, 2)) + valid_mask.append(step_info_dict["agent_valid_mask"].reshape(B, 1, N)) + log_prob.append(step_info_dict["motion_input_action_log_prob"].reshape(B, 1, N)) + action.append(step_info_dict["motion_input_action"].reshape(B, 1, N)) + + assert all_token_casual_mask.shape[1] == all_token_casual_mask.shape[2] == scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + all_token_casual_mask.shape, scenestreamer_tokens.seq_len + ) + ) + assert all_force_mask.shape[1] == all_force_mask.shape[2] == scenestreamer_tokens.seq_len, ( + "{} vs {}".format( + all_force_mask.shape, scenestreamer_tokens.seq_len + ) + ) + + pos = torch.cat(pos, dim=1) + head = torch.cat(head, dim=1) + vel = torch.cat(vel, dim=1) + action = torch.cat(action, dim=1) + if dest: + dest = torch.cat(dest, dim=1) + dest_pos = torch.cat(dest_pos, dim=1) + else: + dest = None + dest_pos = None + + # Evict the last step's input_action_valid_mask_list as it is not used. + valid_mask = valid_mask[:-1] + valid_mask = torch.cat(valid_mask, dim=1) + + tl_state = torch.cat(tl_state, dim=1) + + log_prob = torch.cat(log_prob, dim=1) + + # ===== Interpolate the output ===== + output_dict = {} + + output_dict, _ = interpolate_autoregressive_output( + data_dict=output_dict, + agent_heading=head, + agent_position=pos, + agent_velocity=vel, + agent_destination=dest, + agent_destination_position=dest_pos, + input_valid_mask=valid_mask, + num_skipped_steps=model.motion_tokenizer.num_skipped_steps, + num_decoded_steps=num_decode_steps, + teacher_forcing_sdc=teacher_forcing_sdc, + ) + + assert log_prob.shape == (B, 19, N) + scores = (log_prob * valid_mask)[:, 2:].sum(1) + + # relation_valid_mask = relation.compute_relation_for_scenestreamer( + # query_pos=scenestreamer_tokens.position[:, :], + # query_heading=scenestreamer_tokens.heading[:, :], + # query_valid_mask=scenestreamer_tokens.valid_mask[:, :], + # query_step=scenestreamer_tokens.step[:, :], + # key_pos=scenestreamer_tokens.position, + # key_heading=scenestreamer_tokens.heading, + # key_valid_mask=scenestreamer_tokens.valid_mask, + # key_step=scenestreamer_tokens.step, + # causal_valid_mask=scenestreamer_tokens.causal_mask[:, :], + # force_attention_mask=scenestreamer_tokens.force_mask[:, :], + # + # knn=knn, + # max_distance=max_distance, + # + # gather=False, + # query_width=None, + # # set query's w/l to 0 so that we get the rel of contour of key w.r.t. center of query + # query_length=None, + # key_width=scenestreamer_tokens.width, + # key_length=scenestreamer_tokens.length, + # non_agent_relation=True, + # + # require_relation=scenestreamer_tokens.require_relation[:, :], + # require_relation_for_key=scenestreamer_tokens.require_relation, + # )[1] + # import matplotlib.pyplot as plt + # vis = relation_valid_mask[0].cpu().numpy() + # plt.imshow(vis) + # + # data_dict = scenestreamer_tokens.data_dict + # map_position = data_dict["model/map_token_position"] + # map_heading = data_dict["model/map_token_heading"] + # map_token_valid_mask = data_dict["model/map_token_valid_mask"] + # relation_valid_mask = relation.compute_relation_for_scenestreamer( + # query_pos=scenestreamer_tokens.position[:, :], + # query_heading=scenestreamer_tokens.heading[:, :], + # query_valid_mask=scenestreamer_tokens.valid_mask[:, :], + # query_step=scenestreamer_tokens.step[:, :], + # + # + # # =========================== + # + # key_pos=map_position, + # key_heading=map_heading, + # key_valid_mask=map_token_valid_mask, + # key_step=torch.zeros_like(map_heading, dtype=torch.int64), + # key_width=None, + # key_length=None, + # causal_valid_mask=None, + # knn=knn, + # max_distance=max_distance, + # gather=False, + # non_agent_relation=True, + # require_relation_for_key=map_token_valid_mask, + # + # require_relation=scenestreamer_tokens.require_relation, + # )[1] + # import matplotlib.pyplot as plt + # vis = relation_valid_mask[0].cpu().numpy() + # plt.imshow(vis) + + output_dict.update({ + + # TODO: Not accumulated across steps? now is the last. + "decoder/current_agent_shape": step_info_dict["agent_shape"], + "model/traffic_light_state": tl_state, + + # feed forward + "encoder/map_feature_valid_mask": data_dict["encoder/map_feature_valid_mask"], + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"], + # "decoder/labeled_agent_id" + # "decoder/object_of_interest_id" + + "decoder/output_score": scores, + + "model/output_action": action, + }) + if "decoder/sdc_index" in data_dict: + output_dict["decoder/sdc_index"] = data_dict["decoder/sdc_index"] + if "raw/map_feature" in data_dict: + output_dict["raw/map_feature"] = data_dict["raw/map_feature"] + if "vis/map_feature" in data_dict: + output_dict["vis/map_feature"] = data_dict["vis/map_feature"] + if "decoder/object_of_interest_id" in data_dict: + output_dict["decoder/object_of_interest_id"] = data_dict["decoder/object_of_interest_id"] + + + # plot_dict = utils.unbatch_data(utils.torch_to_numpy(output_dict)) + # from scenestreamer.gradio_ui.plot import plot_pred + # plot_pred(plot_dict, show=True) + + output_dict["scenestreamer_tokens"] = scenestreamer_tokens + return output_dict + +def call_model_for_traffic_light( + *, + model, + data_dict, + step_info_dict, + scenestreamer_tokens, + current_step, + knn, + max_distance, + teacher_forcing, + use_cache, + all_token_casual_mask, + all_force_mask, +keep_output_token, +): + tl_input_dict = { + # no time dim: + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"][..., :2], + "encoder/traffic_light_heading": data_dict["encoder/traffic_light_heading"], + "encoder/traffic_light_map_id": data_dict["encoder/traffic_light_map_id"], + } + + if teacher_forcing: + tl_input_dict.update({ + "encoder/traffic_light_state": data_dict["encoder/traffic_light_state"][:, current_step:current_step + 1], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"][:, + current_step:current_step + 1], + }) + else: + tl_input_dict.update({ + "encoder/traffic_light_state": step_info_dict["traffic_light_state"], + "encoder/traffic_light_valid_mask": step_info_dict["traffic_light_valid_mask"], + }) + + B, _, L = tl_input_dict["encoder/traffic_light_state"].shape + + tl_input_dict = model.prepare_traffic_light_tokens(tl_input_dict) + tl_token = tl_input_dict["model/traffic_light_token"] + tl_position = tl_input_dict["model/traffic_light_token_position"] + tl_heading = tl_input_dict["model/traffic_light_token_heading"] + tl_valid_mask = tl_input_dict["model/traffic_light_token_valid_mask"] + assert tl_token.shape == (B, 1, L, model.d_model) + assert tl_position.shape == (B, 1, L, 2) + assert tl_heading.shape == (B, 1, L) + assert tl_valid_mask.shape == (B, 1, L) + traffic_light_width = torch.zeros_like(tl_position[..., 0]) + traffic_light_length = torch.zeros_like(tl_position[..., 0]) + + # ===== causal mask ===== + N = data_dict["decoder/agent_id"].shape[1] + if model.no_tg: + G = 0 + else: + G = get_num_tg(N) + # tl_causal_mask = model._build_all_tokens_mask_for_tl( + # B=B, T=current_step + 1, num_tl=L, num_tg=G, num_motion=N + # ).to(tl_position.device) + + # import matplotlib.pyplot as plt + # vis = tl_causal_mask[0][0].cpu().numpy() + # plt.imshow(vis) + + # tl_causal_mask is in shape (B, T, N, L+G+N). + # the final G+N tokens are the trafficgen and motion tokens, which is in future. + # We need to remove them. + # tl_causal_mask = tl_causal_mask[:, -1, :, :-G - N] + + # ===== token ===== + if scenestreamer_tokens is None: + tl_causal_mask = all_token_casual_mask[:, :L, :L] + steps = torch.full((B, L), current_step, dtype=torch.long, device=tl_position.device) + scenestreamer_tokens = SceneStreamerTokens( + token=tl_token.flatten(1, 2), + position=tl_position.flatten(1, 2), + heading=tl_heading.flatten(1, 2), + valid_mask=tl_valid_mask.flatten(1, 2), + width=traffic_light_width.flatten(1, 2), + length=traffic_light_length.flatten(1, 2), + causal_mask=tl_causal_mask, + force_mask=all_force_mask[:, :L, :L], + step=steps, + current_step=current_step, + L=L, + N=N, + G=G, + require_relation=tl_valid_mask.flatten(1, 2), + + model=model, + data_dict=data_dict, + ) + else: + tl_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len+L, :scenestreamer_tokens.seq_len+L] + scenestreamer_tokens.add( + token=tl_token.flatten(1, 2), + position=tl_position.flatten(1, 2), + heading=tl_heading.flatten(1, 2), + valid_mask=tl_valid_mask.flatten(1, 2), + width=traffic_light_width.flatten(1, 2), + length=traffic_light_length.flatten(1, 2), + causal_mask=tl_causal_mask, + current_step=current_step, + require_relation=tl_valid_mask.flatten(1, 2), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len+L, :scenestreamer_tokens.seq_len+L], + ) + + # import matplotlib.pyplot as plt + # vis=all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + L, :scenestreamer_tokens.seq_len + L][0].cpu().numpy() + # plt.imshow(vis) + + + # Note that if teacher_forcing is False while there is no traffic light, + # we will have L=1 and there will be error when calling the model in the first step. + # Because at that time num_Q = num_K = 0. + # This won't be a problem if we use teacher_forcing or in any future step > 0. + if teacher_forcing: + step_info_dict["traffic_light_state"] = data_dict["encoder/traffic_light_state"][:, + current_step + 1:current_step + 2] + step_info_dict["traffic_light_valid_mask"] = data_dict["encoder/traffic_light_valid_mask"][:, + current_step + 1:current_step + 2] + + else: + if tl_valid_mask.any(): + output_dict = scenestreamer_tokens.call_model_with_cache(use_cache=use_cache, keep_output_token=keep_output_token) + + # ===== Post-process the data ===== + traffic_light_token = output_dict["model/all_token"][:, -L:] + traffic_light_token = model.traffic_light_prenorm(traffic_light_token) + traffic_light_token = model.traffic_light_head(traffic_light_token) + # output_dict["model/traffic_light_logit"] = traffic_light_token + + tl_state, _ = sample_action(traffic_light_token, sampling_method="softmax") # TODO: other sampling methods? + step_info_dict["traffic_light_state"] = tl_state.reshape(B, 1, L) + step_info_dict["traffic_light_valid_mask"] = tl_valid_mask.reshape(B, 1, L) + + return scenestreamer_tokens, step_info_dict + + +def call_model_for_trafficgen( + *, + model, + data_dict, + scenestreamer_tokens: SceneStreamerTokens, + step_info_dict, + current_step, + knn, + max_distance, + teacher_forcing_from_gt, + teacher_forcing_dest, + use_cache, + all_token_casual_mask, + all_force_mask, + keep_output_token, +): + + if teacher_forcing_from_gt: + step_info_dict["agent_valid_mask"] = data_dict["decoder/input_action_valid_mask"][:, current_step] + step_info_dict["agent_position"] = data_dict["decoder/modeled_agent_position"][:, current_step] + step_info_dict["agent_heading"] = data_dict["decoder/modeled_agent_heading"][:, current_step] + step_info_dict["agent_velocity"] = data_dict["decoder/modeled_agent_velocity"][:, current_step] + step_info_dict["agent_type"] = data_dict["decoder/agent_type"] + step_info_dict["agent_shape"] = data_dict["decoder/current_agent_shape"] + step_info_dict["agent_id"] = data_dict["encoder/modeled_agent_id"] + + B, N, G = scenestreamer_tokens.B, scenestreamer_tokens.N, scenestreamer_tokens.G + + # ===== call trafficgen tokenizer ===== + from scenestreamer.dataset.preprocessor import prepare_trafficgen_data_for_scenestreamer_a_step + # assert B == 1, "B should be 1 but got " + str(B) + device = scenestreamer_tokens.token.device + tg_map_id_list = [] + tg_valid_list = [] + tg_feat_list = [] + tg_target_offset_list = [] + tg_pos_list = [] + tg_head_list = [] + for b in range(B): + tg_map_id, tg_valid, tg_feat, tg_target_offset, tg_pos, tg_head = prepare_trafficgen_data_for_scenestreamer_a_step( + pos=step_info_dict["agent_position"].reshape(B, N, 2)[b].cpu().numpy(), + heading=step_info_dict["agent_heading"].reshape(B, N)[b].cpu().numpy(), + vel=step_info_dict["agent_velocity"].reshape(B, N, 2)[b].cpu().numpy(), + agent_valid_mask=step_info_dict["agent_valid_mask"].reshape(B, N)[b].cpu().numpy(), + agent_type=step_info_dict["agent_type"].reshape(B, N)[b].cpu().numpy(), + current_agent_shape=step_info_dict["agent_shape"].reshape(B, N, 3)[b].cpu().numpy(), + map_pos=data_dict["model/map_token_position"][0].cpu().numpy()[..., :2], + map_heading=data_dict["model/map_token_heading"][0].cpu().numpy(), + map_valid_mask=data_dict["model/map_token_valid_mask"][0].cpu().numpy(), + # start_action_id=model.trafficgen_agent_sos_id, + # end_action_id=model.trafficgen_agent_eos_id, + start_sequence_id=model.trafficgen_sequence_sos_id, + end_sequence_id=model.trafficgen_sequence_eos_id, + dest=None, + dest_pad_id=model.trafficgen_sequence_pad_id, + veh_id=model.veh_id, + ped_id=model.ped_id, + cyc_id=model.cyc_id, + start_agent_id=model.trafficgen_agent_sos_id, + ) + tg_map_id_list.append(tg_map_id) + tg_valid_list.append(tg_valid) + tg_feat_list.append(tg_feat) + tg_target_offset_list.append(tg_target_offset) + tg_pos_list.append(tg_pos) + tg_head_list.append(tg_head) + # input_action_for_trafficgen = torch.from_numpy(tg_map_id).to(device=device).reshape(B, 1, G) + # input_action_valid_mask_for_trafficgen = torch.from_numpy(tg_valid).to(device=device).reshape(B, 1, G) + # agent_feature_for_trafficgen = torch.from_numpy(tg_feat).to(device=device).reshape(B, 1, G, 8).float() + # trafficgen_position = torch.from_numpy(tg_pos).to(device=device).reshape(B, 1, G, 2).float() + # trafficgen_heading = torch.from_numpy(tg_head).to(device=device).reshape(B, 1, G).float() + input_action_for_trafficgen = torch.from_numpy(np.stack(tg_map_id_list, axis=0)).to(device=device).reshape(B, 1, G) + input_action_valid_mask_for_trafficgen = torch.from_numpy(np.stack(tg_valid_list, axis=0)).to( + device=device).reshape(B, 1, G) + agent_feature_for_trafficgen = torch.from_numpy(np.stack(tg_feat_list, axis=0)).to(device=device).reshape(B, 1, G, + 8).float() + trafficgen_position = torch.from_numpy(np.stack(tg_pos_list, axis=0)).to(device=device).reshape(B, 1, G, 2).float() + trafficgen_heading = torch.from_numpy(np.stack(tg_head_list, axis=0)).to(device=device).reshape(B, 1, G).float() + + # ===== prepare input data for trafficgen ===== + # -1, -1 -1 TYPE -1 -1, ..., -1 + G = scenestreamer_tokens.G + agent_type = step_info_dict["agent_type"] + agent_type_for_trafficgen = torch.full((B, N, NUM_TG_MULTI), -1, device=agent_type.device) + agent_type_for_trafficgen[..., 2:] = agent_type[:, :, None] + agent_type_for_trafficgen = torch.cat( + [ + torch.full((B, 1), -1, device=agent_type.device), + agent_type_for_trafficgen.flatten(1, 2), + torch.full((B, 1), -1, device=agent_type.device), + ], dim=1 + ).reshape(B, 1, G) + + # ===== build input data for tg autoregressive ===== + # tg_causal_mask = model._build_all_tokens_mask_for_tg( + # B=scenestreamer_tokens.B, + # T=current_step + 1, + # num_tl=scenestreamer_tokens.L, + # num_tg=scenestreamer_tokens.G, + # num_motion=scenestreamer_tokens.N + # ) + # tg_causal_mask = tg_causal_mask[:, -1, :, :-scenestreamer_tokens.N] + + # ===== call model for tg autoregressive ===== + # First, input the sequence_sos_id. + intra_step = 0 + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.full((B, 1), model.trafficgen_sequence_sos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=agent_type.device), + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len+1, :scenestreamer_tokens.seq_len+1] + + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=torch.full((B, 1), True, device=agent_type.device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=tg_causal_mask, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len+1, :scenestreamer_tokens.seq_len+1], + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + ) + + agent_destination_list = [] + agent_destination_pos_list = [] + for agent_index in range(N): + agent_id = step_info_dict["agent_id"][:, agent_index:agent_index + 1] + this_agent_valid_mask = step_info_dict["agent_valid_mask"][:, agent_index:agent_index + 1] + + # Step 0, agent start token. + intra_step += 1 + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.full((B, 1), model.trafficgen_agent_sos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + ) + + # Step 1: input is the agent type. + intra_step += 1 + tg_token = model.prepare_trafficgen_single_token( + # TODO(PZH): Should change in TF. + tg_action=agent_type[:, agent_index][:, None], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=this_agent_valid_mask, + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + ) + + # debug_dict = scenestreamer_tokens.call_model_with_cache() + # debug_token = debug_dict["model/all_token"][:, -1:] + # model.trafficgen_head.dest_id_head(debug_token) + + # Step 2: input is the map id. + intra_step += 1 + tg_token = model.prepare_trafficgen_single_token( + tg_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1], + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=trafficgen_position[:, 0, intra_step:intra_step + 1], + heading=trafficgen_heading[:, 0, intra_step:intra_step + 1], + valid_mask=this_agent_valid_mask, + # TODO: hardcoded 5, 6 + width=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 6], + length=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 5], + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=this_agent_valid_mask, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + ) + + # Step 3: input is the agent feat. + intra_step += 1 + tg_token = model.prepare_trafficgen_single_token( + tg_action=input_action_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=agent_id, + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1], + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=trafficgen_position[:, 0, intra_step:intra_step + 1], + heading=trafficgen_heading[:, 0, intra_step:intra_step + 1], + valid_mask=this_agent_valid_mask, + # TODO: hardcoded 5, 6 + width=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 6], + length=agent_feature_for_trafficgen[:, 0, intra_step:intra_step + 1][..., 5], + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=this_agent_valid_mask, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + ) + + # if this_agent_valid_mask.any(): + # output_dict = scenestreamer_tokens.call_model_with_cache(use_cache=use_cache, keep_output_token=keep_output_token) + # output_token = output_dict["model/all_token"][:, -1:] + # # call pred head to get agent feat. + # + # output_token = model.trafficgen_prenorm(output_token) + # + # dest_id_logit = model.trafficgen_head.dest_id_head(output_token) + # # sample from the logits to get dest_id. TODO: other sampling methods? + # # tiny masked out here + # M = data_dict["model/map_token_valid_mask"].shape[1] + # assert dest_id_logit.shape[1] == 1 + # + # dest_id_logit_mask = torch.full((B, 1, dest_id_logit.shape[-1]), False, device=agent_type.device, dtype=torch.bool) + # dest_id_logit_mask[:, :, :M] = data_dict["model/map_token_valid_mask"][:, None] + # + # dest_pos_full = data_dict["model/map_token_position"][..., :2] # (B, M, 2) + # agent_pos = step_info_dict["agent_position"][:, agent_index][:, None] # (B, 2) + # dest_agent_dist = torch.cdist(dest_pos_full, agent_pos)[..., 0] # (B, M) + # + # speed = step_info_dict["agent_velocity"][:, agent_index].norm(dim=-1) # (B,) + # displacement = speed * 3 + # tolerance = displacement + 20 + # # print("Agent {} speed: {}, displacement: {}, tolerance: {}".format( + # # agent_index, speed[0].item(), displacement[0].item(), tolerance[0].item() + # # )) + # assert dest_agent_dist.ndim == 2 + # assert tolerance.ndim == 1 + # assert dest_id_logit_mask.ndim == 3 + # dest_id_logit_mask[:, :, :M] = dest_id_logit_mask[:, :, :M] & (dest_agent_dist < tolerance[:, None])[:, None] + # + # agent_heading = step_info_dict["agent_heading"][:, agent_index] + # + # # Only allow dest in front of the agent. + # rel_pos = (dest_pos_full - agent_pos) + # rel_pos = utils.rotate(x=rel_pos[..., 0], y=rel_pos[..., 1], angle=-agent_heading[:, None].expand(B, M)) + # dest_id_logit_mask[:, :, :M] = dest_id_logit_mask[:, :, :M] & (rel_pos[..., 0] > 0)[:, None] + # + # dest_heading_full = data_dict["model/map_token_heading"] # (B, M) + # dest_agent_heading_dist = torch.abs(dest_heading_full - agent_heading[:, None]) # (B, M) + # dest_agent_heading_dist = utils.wrap_to_pi(dest_agent_heading_dist) + # dest_id_logit_mask[:, :, :M] = dest_id_logit_mask[:, :, :M] & (dest_agent_heading_dist < np.pi/2)[:, None] + # + # dest_id_logit_mask[..., model.trafficgen_sequence_pad_id] = True + # + # only_lane = True + # if only_lane: + # map_feature = data_dict["encoder/map_feature"] + # dest_id_logit_mask[:, :, :M] = (map_feature[:, :, 0, 13] == 1)[:, None] & dest_id_logit_mask[:, :, :M] + # + # dest_id_logit[~dest_id_logit_mask] = float("-inf") + # # dest_id, _ = sample_action(dest_id_logit, sampling_method="softmax") + # dest_id, _ = sample_action(dest_id_logit, sampling_method="topp", topp=0.9) + # + # if teacher_forcing_dest: + # gt_dest = data_dict["decoder/dest_map_index"][:, current_step, agent_index].clone() + # gt_dest[gt_dest == -1] = model.trafficgen_sequence_pad_id + # dest_id = gt_dest.reshape(B, 1) + # + # dest_id_pad_mask = dest_id == model.trafficgen_sequence_pad_id + # + # dest_id[dest_id_pad_mask] = 0 + # + # dest_position = torch.gather( + # data_dict["model/map_token_position"][..., :2], + # index=dest_id.reshape(B, 1, 1).expand(B, 1, 2), + # dim=1 + # ) + # dest_position[dest_id_pad_mask] = step_info_dict["agent_position"][:, agent_index][:, None][dest_id_pad_mask] + # + # dest_heading = torch.gather( + # data_dict["model/map_token_heading"], + # index=dest_id.reshape(B, 1), + # dim=1 + # ) + # dest_heading[dest_id_pad_mask] = step_info_dict["agent_heading"][:, agent_index][:, None][dest_id_pad_mask] + # + # dest_id[dest_id_pad_mask] = model.trafficgen_sequence_pad_id + # + # # TODO: DEBUG + # # dest_dist = (step_info_dict["agent_position"][:, agent_index][0] - dest_position[0, 0]).norm(dim=-1) + # # print("agent {} dest id: {}, dest position: {}, dest heading: {}, dest dist: {}".format( + # # agent_index, dest_id[0].item(), dest_position[0, 0].tolist(), dest_heading[0, 0].item(), dest_dist.item() + # # )) + # else: + # dest_id = torch.full((B, 1), model.trafficgen_sequence_pad_id, device=agent_type.device) + # dest_position = torch.full((B, 1, 2), 0.0, device=agent_type.device) + # dest_heading = torch.full((B, 1), 0.0, device=agent_type.device) + # + # # print("Per agent index{} id{}, dest id: {}".format(agent_index, agent_id[0].item(), dest_id.tolist())) + # dest_id[~this_agent_valid_mask] = -1 + # + # agent_destination_list.append(dest_id) + # agent_destination_pos_list.append(dest_position) + + # Step 4: prepare dest ID. + # intra_step += 1 + # tg_token = model.prepare_trafficgen_single_token( + # tg_action=dest_id.reshape(B, 1), + # tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + # tg_agent_id=agent_id, + # tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + # tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + # ) + # tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + # :scenestreamer_tokens.seq_len + 1] + # scenestreamer_tokens.add( + # token=tg_token, + # position=dest_position, + # heading=dest_heading, + # valid_mask=this_agent_valid_mask, + # width=torch.full((B, 1), 0.0, device=agent_type.device), + # length=torch.full((B, 1), 0.0, device=agent_type.device), + # causal_mask=tg_causal_mask, + # current_step=current_step, + # require_relation=this_agent_valid_mask, + # force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + # ) + + # Finally, input the sequence_eos_id. + intra_step += 1 + assert intra_step == G - 1, (intra_step, G, G - 1) + tg_token = model.prepare_trafficgen_single_token( + tg_action=torch.full((B, 1), model.trafficgen_sequence_eos_id, device=agent_type.device), + tg_type=agent_type_for_trafficgen[:, 0, intra_step:intra_step + 1], + tg_agent_id=torch.full((B, 1), -1, device=agent_type.device), + tg_intra_step=torch.full((B, 1), intra_step, device=agent_type.device), + tg_feat=torch.full((B, 1, 8), 0.0, device=agent_type.device), + ) + tg_causal_mask = all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, + :scenestreamer_tokens.seq_len + 1] + scenestreamer_tokens.add( + token=tg_token, + position=torch.full((B, 1, 2), 0, device=agent_type.device), + heading=torch.full((B, 1), 0, device=agent_type.device), + valid_mask=torch.full((B, 1), True, device=agent_type.device, dtype=torch.bool), + width=torch.full((B, 1), 0.0, device=agent_type.device), + length=torch.full((B, 1), 0.0, device=agent_type.device), + causal_mask=tg_causal_mask, + current_step=current_step, + require_relation=torch.full((B, 1), False, device=agent_type.device, dtype=torch.bool), + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + 1, :scenestreamer_tokens.seq_len + 1] + ) + + # The only thing need to be updated by non-teacher_forcing TG is the destination: + # step_info_dict["agent_destination"] = torch.stack(agent_destination_list, dim=1) + # step_info_dict["agent_destination_position"] = torch.stack(agent_destination_pos_list, dim=1) + + return scenestreamer_tokens, step_info_dict + + +def call_model_for_trafficgen_generate_all_agents( + *, + model, + data_dict, + scenestreamer_tokens: SceneStreamerTokens, + step_info_dict, + current_step, + use_cache, + all_token_casual_mask, +): + raise ValueError + + + +def call_model_for_motion( + *, + model, + data_dict, + scenestreamer_tokens, + step_info_dict, + current_step, + knn, + max_distance, + sampling_method, + temperature, + topp, + teacher_forcing, + allow_newly_added, + use_cache, + all_token_casual_mask, + all_force_mask, + keep_output_token, +): + B, N = scenestreamer_tokens.B, scenestreamer_tokens.N + + agent_delta = utils.get_relative_velocity( + vel=step_info_dict["agent_velocity"].reshape(B, 1, N, 2), + heading=step_info_dict["agent_heading"].reshape(B, 1, N) + ) + motion_input_dict = { + "decoder/input_action_valid_mask": step_info_dict["agent_valid_mask"].reshape(B, 1, N), + "decoder/modeled_agent_position": step_info_dict["agent_position"].reshape(B, 1, N, 2), + "decoder/modeled_agent_heading": step_info_dict["agent_heading"].reshape(B, 1, N), + "decoder/modeled_agent_delta": agent_delta, + "decoder/current_agent_shape": step_info_dict["agent_shape"].reshape(B, N, 3), + "decoder/agent_type": step_info_dict["agent_type"].reshape(B, N), + + "encoder/modeled_agent_id": step_info_dict["agent_id"].reshape(B, N), + } + if teacher_forcing: + motion_input_dict["decoder/input_action"] = data_dict["decoder/input_action"][:, current_step:current_step + 1] + else: + motion_input_dict["decoder/input_action"] = step_info_dict["motion_input_action"].reshape(B, 1, N) + + motion_input_dict = model.prepare_motion_tokens(motion_input_dict) + motion_tokens = motion_input_dict["model/motion_token"] + motion_position = motion_input_dict["model/motion_token_position"] + motion_heading = motion_input_dict["model/motion_token_heading"] + motion_valid_mask = motion_input_dict["model/motion_token_valid_mask"] + motion_width = motion_input_dict["model/motion_token_width"] + motion_length = motion_input_dict["model/motion_token_length"] + B, _, N, _ = motion_tokens.shape + + # ===== causal mask ===== + # causal_mask = model._build_all_tokens_mask_for_motion( + # B=scenestreamer_tokens.B, + # T=current_step + 1, + # num_tl=scenestreamer_tokens.L, + # num_tg=scenestreamer_tokens.G, + # num_motion=scenestreamer_tokens.N + # ) + # causal_mask = causal_mask[:, -1] + + scenestreamer_tokens.add( + token=motion_tokens.flatten(1, 2), + position=motion_position.flatten(1, 2), + heading=motion_heading.flatten(1, 2), + valid_mask=motion_valid_mask.flatten(1, 2), + width=motion_width.flatten(1, 2), + length=motion_length.flatten(1, 2), + causal_mask=all_token_casual_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + N, :scenestreamer_tokens.seq_len + N], + current_step=current_step, + force_mask=all_force_mask[:, scenestreamer_tokens.seq_len:scenestreamer_tokens.seq_len + N, :scenestreamer_tokens.seq_len + N], + require_relation=motion_valid_mask.flatten(1, 2), + ) + + # print("Step {}: motion position: {}, heading: {}, valid_mask: {}".format( + # current_step, + # motion_position.flatten(1, 2)[0, 0].tolist(), + # motion_heading.flatten(1, 2)[0, 0].tolist(), + # motion_valid_mask.flatten(1, 2)[0, 0].tolist() + # )) + + # debug code: save causal mask to files + # import matplotlib.pyplot as plt + # vis = scenestreamer_tokens.causal_mask[0].cpu().numpy() + # fig = plt.figure() + # ax = fig.add_subplot(111) + # ax.imshow(vis) + # plt.savefig("causal_mask_{}.png".format(current_step)) + + # ===== prepare dynamic relation ===== + output_dict = scenestreamer_tokens.call_model_with_cache(use_cache=use_cache, keep_output_token=keep_output_token) + all_token = output_dict["model/all_token"] + motion_token = all_token[:, -scenestreamer_tokens.N:] + # TODO: dest is not conditioning on anyone. + if model.motion_prenorm is not None: + motion_token = model.motion_prenorm(motion_token) + output_token = model.motion_head(motion_token) + + # ===== Post-process the data ===== + selected_action, sampling_info = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + agent_valid_mask = step_info_dict["agent_valid_mask"] + agent_position = step_info_dict["agent_position"] + agent_heading = step_info_dict["agent_heading"] + agent_velocity = step_info_dict["agent_velocity"] + agent_type = step_info_dict["agent_type"] + + # Remove invalid actions + # assert selected_action.shape == input_action.shape + # correct_selected_action = torch.where(input_action_valid_mask, selected_action, -1) + selected_action = torch.where(agent_valid_mask, selected_action, -1) + + # TODO: Teacher forcing a subset of agents here + # if teacher_forcing_valid_mask is not None: + # assert teacher_forcing_action.shape == selected_action.shape + # selected_action = torch.where(teacher_forcing_valid_mask, teacher_forcing_action, selected_action) + # # correct_selected_action = torch.where(teacher_forcing_valid_mask, teacher_forcing_action, correct_selected_action) + # output_token[teacher_forcing_valid_mask] = 0 + + # tokenizer = model.tokenizer + res = model.motion_tokenizer.detokenize_step( + current_pos=agent_position.reshape(B, 1, N, 2), + current_heading=agent_heading.reshape(B, 1, N), + current_valid_mask=agent_valid_mask.reshape(B, 1, N), + current_vel=agent_velocity.reshape(B, 1, N, 2), + action=selected_action.reshape(B, 1, N), + # agent_type=agent_type.reshape(B, 1, N), + ) + + # B, _, N = input_action.shape[:3] + new_agent_position = res["pos"].reshape(B, N, 2) + new_agent_heading = res["heading"].reshape(B, N) + new_agent_velocity = res["vel"].reshape(B, N, 2) + + step_info_dict["agent_position"] = new_agent_position + step_info_dict["agent_heading"] = new_agent_heading + step_info_dict["agent_velocity"] = new_agent_velocity + step_info_dict["motion_input_action"] = selected_action.reshape(B, N) + + if allow_newly_added: + new_agent_valid_mask = ( + data_dict["decoder/input_action_valid_mask"][:, current_step + 1] & (~step_info_dict["agent_valid_mask"]) + ) + + if new_agent_valid_mask.any(): + new_agent_pos = data_dict["decoder/modeled_agent_position"][:, current_step + 1] + new_agent_heading = data_dict["decoder/modeled_agent_heading"][:, current_step + 1] + new_agent_velocity = data_dict["decoder/modeled_agent_velocity"][:, current_step + 1] + new_action = data_dict["decoder/input_action"][:, current_step + 1] + + B, N = new_agent_valid_mask.shape + assert new_agent_pos.shape == (B, N, 2) + assert new_agent_heading.shape == (B, N) + assert new_agent_velocity.shape == (B, N, 2) + + current_pos = step_info_dict["agent_position"] + current_heading = step_info_dict["agent_heading"] + current_vel = step_info_dict["agent_velocity"] + current_valid_mask = step_info_dict["agent_valid_mask"] + + mask_2d = new_agent_valid_mask[..., None].expand_as(new_agent_pos) + current_pos = torch.where(mask_2d, new_agent_pos, current_pos) + current_heading = torch.where(new_agent_valid_mask, new_agent_heading, current_heading) + current_vel = torch.where(mask_2d, new_agent_velocity, current_vel) + current_valid_mask = torch.where(new_agent_valid_mask, new_agent_valid_mask, current_valid_mask) + + step_info_dict["agent_position"] = current_pos + step_info_dict["agent_heading"] = current_heading + step_info_dict["agent_velocity"] = current_vel + step_info_dict["agent_valid_mask"] = current_valid_mask + step_info_dict["motion_input_action"] = torch.where(new_agent_valid_mask, new_action, step_info_dict["motion_input_action"]) + + # TODO: evict agents that moving out of the map (useful in SceneStreamer) + # next_step_data_dict, info_dict = evict_agents( + # data_dict=data_dict, + # step_data_dict=next_step_data_dict, + # step_info_dict=info_dict, + # remove_static_agent=remove_static_agent, + # remove_out_of_map_agent=remove_out_of_map_agent + # ) + + tmp_action = step_info_dict["motion_input_action"].clone() + tmp_valid_mask = agent_valid_mask.clone() + tmp_valid_mask[tmp_action == -1] = False + tmp_valid_mask[tmp_action == MOTION_START_ACTION] = False + tmp_action[tmp_action == -1] = 0 + tmp_action[tmp_action == MOTION_START_ACTION] = 0 + log_prob = sampling_info["dist"].log_prob(tmp_action) + step_info_dict["motion_input_action_log_prob"] = log_prob * tmp_valid_mask + + # print("Step {}, sdc position: {}".format(current_step, step_info_dict["agent_position"][0, 0].tolist())) + + return scenestreamer_tokens, step_info_dict + + +def encode_scene(*, data_dict, model): + if "model/map_token" not in data_dict: + data_dict = model.prepare_map_tokens(data_dict) + return data_dict, {} + + +def sample_action(logits, sampling_method, temperature=1.0, topp=None): + # Sample the action + info = {} + if sampling_method == "argmax": + selected_action = logits.argmax(-1) + elif sampling_method == "softmax": + dist = torch.distributions.Categorical(logits=logits / temperature) + selected_action = dist.sample() + info["dist"] = dist + elif sampling_method == "topp": + selected_action, info = nucleus_sampling(logits=logits / temperature, p=topp) + elif sampling_method == "topk": + candidates = logits.topk(5, dim=-1).indices + selected_action = torch.gather( + candidates, index=torch.randint(0, 5, size=candidates.shape[:-1])[..., None].to(candidates), dim=-1 + ).squeeze(-1) + else: + raise ValueError("Unknown sampling method: {}".format(sampling_method)) + return selected_action, info + + +def nucleus_sampling(logits, p=None, epsilon=1e-8): + p = p or 0.9 + + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits).fill_(-1e9), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits).fill_(-1e9), logits) + + # Convert logits to probabilities + probs = torch.softmax(logits, dim=-1) + + # Sort the probabilities to identify the top-p cutoff + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative probability above the threshold p + cutoff_index = cumulative_probs > p + # Shift the mask to the right to keep the first token above the threshold + cutoff_index[..., 1:] = cutoff_index[..., :-1].clone() + cutoff_index[..., 0] = False + + # Zero out the probabilities for tokens not in the top-p set + sorted_probs.masked_fill_(cutoff_index, 0) + + # Recover the original order of the probabilities + original_probs = torch.zeros_like(probs) + original_probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) + dist = torch.distributions.Categorical(probs=original_probs) + sampled_token_index = dist.sample() + return sampled_token_index, {"cutoff_index": cutoff_index, "dist": dist} + + +def add_new_agent( + *, step_data_dict, step_info, new_agent_valid_mask, new_agent_pos, new_agent_heading, new_agent_velocity, + new_agent_delta, new_action +): + if new_agent_valid_mask is None or not new_agent_valid_mask.any(): + return step_data_dict, step_info + + B, T, N = new_agent_valid_mask.shape + assert new_agent_pos.shape == (B, T, N, 2) + assert new_agent_heading.shape == (B, T, N) + assert new_agent_velocity.shape == (B, T, N, 2) + assert new_agent_delta.shape == (B, T, N, 2) + + current_pos = step_data_dict["agent_position"] + current_heading = step_data_dict["agent_heading"] + current_vel = step_data_dict["agent_velocity"] + current_valid_mask = step_data_dict["agent_valid_mask"] + current_delta = step_data_dict["agent_delta"] + + mask_2d = new_agent_valid_mask[..., None].expand_as(new_agent_pos) + current_pos = torch.where(mask_2d, new_agent_pos, current_pos) + current_heading = torch.where(new_agent_valid_mask, new_agent_heading, current_heading) + current_vel = torch.where(mask_2d, new_agent_velocity, current_vel) + current_valid_mask = torch.where(new_agent_valid_mask, new_agent_valid_mask, current_valid_mask) + current_delta = torch.where(mask_2d, new_agent_delta, current_delta) + + step_data_dict["agent_position"] = current_pos + step_data_dict["agent_heading"] = current_heading + step_data_dict["agent_velocity"] = current_vel + step_data_dict["agent_valid_mask"] = current_valid_mask + step_data_dict["agent_delta"] = current_delta + + if new_action.ndim == 4: + # Variable length action + new_action, old_action = pad_sequences(new_action, step_data_dict["input_action"], x_value=-1, y_value=-1) + step_data_dict["input_action"] = torch.where(new_agent_valid_mask[..., None], new_action, old_action) + elif new_action.ndim == 3: + step_data_dict["input_action"] = torch.where(new_agent_valid_mask, new_action, step_data_dict["input_action"]) + else: + raise ValueError("Invalid new_action shape: {}".format(new_action.shape)) + step_data_dict["input_action_valid_mask"] = current_valid_mask + + output_token = step_info["output_token"] + if output_token is not None: + output_token = torch.where( + new_agent_valid_mask[..., None].expand_as(output_token), torch.zeros_like(output_token), output_token + ) + step_info["output_token"] = output_token + + assert_motion_step_data_dict(step_data_dict=step_data_dict, step_info=step_info) + + return step_data_dict, step_info + + +def interpolate_autoregressive_output( + *, data_dict, num_skipped_steps, num_decoded_steps, agent_position, agent_heading, agent_velocity, + input_valid_mask, agent_destination, agent_destination_position, teacher_forcing_sdc, agent_shape=None, + sdc_index=None +): + B, _, N, _ = agent_position.shape + T_generated_chunks = num_decoded_steps + reconstructed_pos = interpolate(agent_position, num_skipped_steps, remove_first_step=False) + assert (reconstructed_pos[:, ::5] == agent_position).all() + reconstructed_heading = interpolate_heading(agent_heading, num_skipped_steps, remove_first_step=False) + reconstructed_vel = interpolate(agent_velocity, num_skipped_steps, remove_first_step=False) + + assert input_valid_mask.shape[1] == T_generated_chunks + 1 + + valid = input_valid_mask[:, :-1].reshape(B, -1, 1, N).expand(-1, -1, num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_valid_mask[:, -1:]], dim=1) + + if teacher_forcing_sdc: + sdc_index = sdc_index[0].item() + assert sdc_index == 0 + valid[:, 91:, sdc_index] = False + + reconstructed_valid_mask = valid + + if agent_destination is not None: + step = TG_SKIP_STEP * 5 + agent_destination = agent_destination[:, :, None].repeat((1, 1, step, 1)).flatten(1, 2)[:, :96] + agent_destination_position = agent_destination_position[:, :, None, :].repeat((1, 1, step, 1, 1)).flatten(1, 2)[:, :96] + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == num_skipped_steps * T_generated_chunks + 1 + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + data_dict["decoder/reconstructed_agent_destination"] = agent_destination + data_dict["decoder/reconstructed_agent_destination_position"] = agent_destination_position + if agent_shape is not None: + data_dict["decoder/reconstructed_shape"] = \ + torch.stack(agent_shape, dim=1).expand(-1, reconstructed_vel.shape[1], -1, -1) + + return data_dict, {} + + +def evict_agents( + *, + data_dict, + step_data_dict, + step_info_dict, + max_distance=10, + remove_static_agent=False, + remove_out_of_map_agent=False +): + # Get scene token: + # in_evaluation = input_dict["in_evaluation"][0].item() + # scene_token = input_dict["encoder/scenario_token"] + # B, M, _ = input_dict["encoder/map_position"].shape + # action = action.clone() + + should_evict = None + + if remove_out_of_map_agent: + map_position = data_dict["encoder/map_position"][..., :2] + agent_position = step_data_dict["agent_position"] + assert agent_position.ndim == 4 + agent_position = agent_position[:, 0] + + dist = torch.cdist(agent_position, map_position) + min_dist = dist.min(dim=-1).values + + should_evict = min_dist > max_distance + + if remove_static_agent: + agent_speed = step_data_dict["agent_velocity"].norm(dim=-1)[:, 0] + static_agent = agent_speed < 0.5 + if should_evict is None: + should_evict = static_agent + else: + should_evict = torch.logical_or(should_evict, static_agent) + + if should_evict is None or should_evict.sum().item() == 0: + step_info_dict["evicted_agents"] = 0 + step_info_dict["evicted_agent_mask"] = None + return step_data_dict, step_info_dict + + num_evicted = should_evict.sum().item() + + # We should inform the autoregressive process not to generate action in next step. + # However, current's step's action is still valid (because the input_action_valid_mask for this particular agent + # is valid), hence the outer process is still waiting for the new states of the agents. + # Therefore, we shouldn't mask out these information. + new_mask = step_data_dict["input_action_valid_mask"] & (~should_evict) + step_data_dict["input_action_valid_mask"] = new_mask + + # step_data_dict["input_action"] = torch.where(new_mask, step_data_dict["input_action"], -1) + # step_data_dict["agent_position"] = torch.where(new_mask.unsqueeze(-1), agent_position, 0) + # step_data_dict["agent_heading"] = torch.where(new_mask, step_data_dict["agent_heading"], 0) + # step_data_dict["agent_velocity"] = torch.where(new_mask.unsqueeze(-1), step_data_dict["agent_velocity"], 0) + step_data_dict["agent_valid_mask"] = new_mask + # step_data_dict["agent_delta"] = torch.where(new_mask.unsqueeze(-1), step_data_dict["agent_delta"], 0) + # step_info_dict["output_token"] = torch.where(new_mask.unsqueeze(-1), step_info_dict["output_token"], 0) + + step_info_dict["evicted_agents"] = num_evicted + step_info_dict["evicted_agent_mask"] = should_evict + assert_motion_step_data_dict(step_data_dict, step_info_dict) + + return step_data_dict, step_info_dict + + +def assert_motion_step_data_dict(*, step_data_dict, step_info): + assert "input_step" in step_data_dict + assert "input_action" in step_data_dict + assert "input_action_valid_mask" in step_data_dict + assert "agent_position" in step_data_dict + assert "agent_heading" in step_data_dict + assert "agent_velocity" in step_data_dict + assert "agent_valid_mask" in step_data_dict + assert "agent_delta" in step_data_dict + assert "agent_id" in step_data_dict + assert "agent_type" in step_data_dict + assert "agent_shape" in step_data_dict + + m = step_data_dict["input_action_valid_mask"] + assert (step_data_dict["input_action"][~m] == -1).all() + assert (m == step_data_dict["agent_valid_mask"]).all() + if step_info["output_token"] is not None: + assert (step_info["output_token"][~m] == 0).all() + + +def pad_sequences(x, y, x_value=0, y_value=0): + max_seq_len = max(x.shape[-1], y.shape[-1]) + x = torch.nn.functional.pad(x, (0, max_seq_len - x.shape[-1]), value=x_value) + y = torch.nn.functional.pad(y, (0, max_seq_len - y.shape[-1]), value=y_value) + return x, y + +def test_moving_dist(data_dict): + def _get_first_last_pos(pos, valid_mask): + T, N = valid_mask.shape + ind = np.arange(T).reshape(-1, 1).repeat(N, axis=1) # T, N + ind[~valid_mask] = 0 + ind = ind.max(axis=0) + last = np.take_along_axis(pos, indices=ind.reshape(1, N, 1), axis=0) + last = np.squeeze(last, axis=0) + + # Find the index of the first True (or 1) along axis 0 (time) for each agent + # First, create a mask of where any True exists per column + has_valid = valid_mask.any(axis=0) + + # Use argmax along time axis: this returns first occurrence of maximum (i.e. True) + first_idx = valid_mask.argmax(axis=0) + + # Set result to -1 where there was no valid entry + first_idx[~has_valid] = -1 + + first = np.take_along_axis(pos, indices=first_idx.reshape(1, N, 1), axis=0) + first = np.squeeze(first, axis=0) + return first, last + agent_valid_mask = data_dict["decoder/agent_valid_mask"] + agent_position = data_dict["decoder/agent_position"] + first_pos, last_pos = _get_first_last_pos(agent_position, agent_valid_mask) + moving_dist = np.linalg.norm((last_pos-first_pos)[:, :2], axis=-1) + return moving_dist + + +def animate_scenestreamer( + save_path, data_dict, fps=10, dpi=300, draw_traffic=True +): + from scenestreamer.gradio_ui.plot import FFMpegWriter, _plot_map, _plot_traffic_light + import seaborn as sns + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle, FancyArrowPatch + from matplotlib import transforms + + agent_pos=data_dict["decoder/reconstructed_position"] + agent_mask=data_dict["decoder/reconstructed_valid_mask"] + agent_heading=data_dict["decoder/reconstructed_heading"] + agent_shape=data_dict["decoder/current_agent_shape"] + + agent_dest_pos = data_dict["decoder/reconstructed_agent_destination_position"] + agent_dest = data_dict["decoder/reconstructed_agent_destination"] + + # all_agent_pos = data_dict["decoder/agent_position"][:91, :, :2] + # all_agent_heading = data_dict["decoder/agent_heading"] + # all_agent_shape = data_dict["decoder/agent_shape"][10] + if "decoder/labeled_agent_id" in data_dict: + ooi = data_dict["decoder/labeled_agent_id"] + else: + ooi = [] + + if 'decoder/sdc_index' in data_dict: + ego_agent_id = int(data_dict['decoder/sdc_index']) + else: + ego_agent_id = 0 + + assert agent_pos.ndim == 3 + T = agent_pos.shape[0] # Number of timesteps + N = agent_pos.shape[1] # Number of agents + + cmap = sns.color_palette("colorblind", n_colors=N) # Color for each agent + + all_agent_positions = agent_pos[:, :, ...].reshape(-1, 2) + xmin, ymin = all_agent_positions.min(axis=0) + xmax, ymax = all_agent_positions.max(axis=0) + xlim, ylim = (xmin - 10, xmax + 10), (ymin - 10, ymax + 10) # Adjust `BOUNDARY` as needed + + writer = FFMpegWriter(fps=fps, codec='libx264', extra_args=['-preset', 'ultrafast', '-crf', '23', '-threads', '4']) + fig, ax = plt.subplots(figsize=(10, 10), dpi=dpi) + ax.set_aspect(1) + ax.set_xlim(xlim) + ax.set_ylim(ylim) + + _plot_map(data_dict, ax, dont_draw_lane=True) + _plot_traffic_light(data_dict, ax) + + agent_patches = [] + agent_texts = [] + agent_arrows = [] + agent_stars = [] + + for agent_ind in range(N): + if not draw_traffic and agent_ind not in ooi: + agent_patches.append(None) + agent_texts.append(None) + agent_arrows.append(None) + agent_stars.append(None) + continue + face_color = cmap[0] if agent_ind == ego_agent_id else cmap[agent_ind] + label = "{}-SDC".format(ego_agent_id) if agent_ind == ego_agent_id else \ + "{}-OOI".format(agent_ind) if agent_ind in ooi else str(agent_ind) + + # Create a rectangular patch for each agent with black edge + length = agent_shape[agent_ind, 0] + width = agent_shape[agent_ind, 1] + + rect = Rectangle( + (-length / 2, -width / 2), # Center it at origin for now + width=length, + height=width, + facecolor=face_color, + edgecolor='black', + linewidth=0.6, + zorder=10 + ) + + agent_patches.append(rect) + ax.add_patch(rect) + + text = ax.text(0, 0, label, color=face_color, fontsize=11, ha='center', va='center', zorder=15) + agent_texts.append(text) + + # Arrow from agent to destination + arrow = FancyArrowPatch((0, 0), (0, 0), + facecolor=face_color, + edgecolor='black', + arrowstyle='->', + mutation_scale=10, + linewidth=0.8, + zorder=5) + arrow.set_visible(False) # Initially set to invisible + agent_arrows.append(arrow) + ax.add_patch(arrow) + + star = ax.scatter(0, 0, s=30, c='green', marker='*', zorder=12, visible=False) + agent_stars.append(star) + + with writer.saving(fig, save_path, dpi=dpi): + for t in range(T): + pos = agent_pos[t] # update agent positions and labels for each frame + heading = agent_heading[t] + + for agent_ind, (rect, text) in enumerate(zip(agent_patches, agent_texts)): + x, y = pos[agent_ind] + arrow = agent_arrows[agent_ind] + star = agent_stars[agent_ind] + if rect is None or text is None or arrow is None or star is None: + continue + + if not agent_mask[t, agent_ind]: + rect.set_visible(False) + text.set_visible(False) + arrow.set_visible(False) + continue + + # Show rect and label + rect.set_visible(True) + text.set_visible(True) + arrow.set_visible(True) + + length = agent_shape[agent_ind, 0] + width = agent_shape[agent_ind, 1] + + # Reset base rectangle (centered at origin) + rect.set_xy((-agent_shape[agent_ind, 0] / 2, -agent_shape[agent_ind, 1] / 2)) + + theta_deg = np.degrees(heading[agent_ind]) + # Create a rotation around the agent center + trans = ( + transforms.Affine2D() + .rotate_deg_around(0, 0, theta_deg) + .translate(x, y) + + ax.transData + ) + rect.set_transform(trans) + + + rect.set_edgecolor('black') + rect.set_linewidth(0.8) + + text.set_position((x, y)) + text.set_text(text.get_text()) # forces the text to render + + # Update arrow + if agent_dest_pos is not None: + dest_x, dest_y = agent_dest_pos[t, agent_ind] + + dest = agent_dest[t, agent_ind] + # FIXME: hardcoded + if dest == 3002: + arrow.set_positions((x, y), (dest_x, dest_y)) + arrow.set_color("red") + arrow.set_visible(True) + elif dest == -1: + arrow.set_visible(False) + else: + arrow.set_positions((x, y), (dest_x, dest_y)) + arrow.set_color("black") + arrow.set_visible(True) + + else: + arrow.set_visible(False) + + writer.grab_frame() diff --git a/scenestreamer/mcts.py b/scenestreamer/mcts.py new file mode 100644 index 0000000000000000000000000000000000000000..9abc7f11d8789ac82d2b1cb7e56bf102df0cd8a0 --- /dev/null +++ b/scenestreamer/mcts.py @@ -0,0 +1,772 @@ +import hydra +import copy + +import hydra +import numpy as np +import omegaconf +import torch + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.gradio_ui.plot import plot_pred +from scenestreamer.models.motionlm import sample_action, calculate_trajectory_probabilities +from scenestreamer.tokenization.motion_tokenizers import END_ACTION, START_ACTION +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import utils + + +def _reconstruct_delta_pos_from_abs_vel(vel, heading, dt): + vel = utils.rotate(vel[..., 0], vel[..., 1], angle=-heading) + pos = vel * dt + return pos + + +def mcts_search( + model, + data_dict, + config, + start_steps, + num_search_steps, + num_search_width, + bin_centers, +): + """ + This function runs model forward search for a number of steps. + Then use backpropagation to evaluate the trajectories. + """ + backward_run_full_length = False + per_agent_argmax = False + backward_inference_horizon = 10 + + # Do some tricks here to remove useless data. + for pattern in [ + "eval/", + "encoder/current_", + "encoder/future_", + ]: + data_dict = {k: v for k, v in data_dict.items() if not k.startswith(pattern)} + data_dict = {k: v for k, v in data_dict.items() if not k.startswith("encoder/agent_")} + new_data_dict = {} + for pattern in ["decoder/agent_id", "decoder/agent_type", "decoder/cache", "decoder/current_", + "decoder/modeled_agent_", "decoder/input_", "decoder/target_", "encoder/", "in_evaluation", + "batch_idx", "in_backward_prediction", "decoder/randomized_modeled_agent_id"]: + new_data_dict.update({k: v for k, v in data_dict.items() if k.startswith(pattern)}) + data_dict = new_data_dict + + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + original_B, original_T, original_N, _ = data_dict["decoder/modeled_agent_position"].shape + data_dict = { + k: ( + utils.expand_for_modes(data_dict[k], num_modes=num_search_width) + if k not in ["decoder/cache", "decoder/input_step", "decoder/modeled_agent_step_history"] else data_dict[k] + ) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k == "in_evaluation" + or k == "in_backward_prediction" or k == "batch_idx" + ) + } + bin_centers = utils.expand_for_modes(bin_centers, num_modes=num_search_width) + if "decoder/cache" in data_dict: + + def _expand_cache(tensor, shape): + # expanding cache is not a easy task. + # our tensor is only used in A2T attention, where the cache (K, V) is in shape: + # (BN, T, D) + # In the first dim the order is: b1n1, b1n2, b2n1, b2n2, ... + # After expanding, say expanded to W, the shape should be: + # (BWN, T, D) + # In the first dim the order should be: b1w1n1, b1w1n2, b1w2n1, b1w2n2, ... + # However, if we simply repeat dim, will make the shape be: + # (BNW, T, D), but the order is wrong: + # b1n1w1, b1n1w2, b1n2w1, b1n2w2, ... + # So we need to do some reshaping. + tensor = tensor.reshape(original_B, -1, *tensor.shape[1:]) + tensor = utils.expand_for_modes(tensor, num_modes=num_search_width) + tensor = tensor.reshape(shape[0] * num_search_width, shape[1], -1) + return tensor + + def _new_cache(c): + return [ + _expand_cache(c[0], c[2]), + _expand_cache(c[1], c[2]), + (c[2][0] * num_search_width, c[2][1]), + ] + + data_dict["decoder/cache"] = [_new_cache(v) for v in data_dict["decoder/cache"]] + + # Another trick is to re-randomize the modeled agent id to further improve diversity. + # data_dict["decoder/randomized_modeled_agent_id"] = model.motion_decoder.randomize_modeled_agent_id( + # data_dict["decoder/agent_id"], clip_agent_id=True + # ) + # The above is wrong. We can't do this because the cache is for original randomized agent id. + + sampling_method = config.SAMPLING.SAMPLING_METHOD + temperature = config.SAMPLING.TEMPERATURE + + topp = config.SAMPLING.TOPP + tokenizer = get_tokenizer(config=config) + assert "encoder/scenario_token" in data_dict + + current_pos = data_dict["decoder/modeled_agent_position"].clone() + current_heading = data_dict["decoder/modeled_agent_heading"].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"].clone() + current_valid_mask = data_dict["decoder/modeled_agent_valid_mask"].clone() + current_delta = data_dict["decoder/modeled_agent_delta"].clone() + current_model_step = data_dict["decoder/input_step"].clone() + assert (current_model_step == start_steps).all() + current_input_action = data_dict["decoder/input_action"].clone() + agent_shape = data_dict["decoder/current_agent_shape"].clone() + agent_type = data_dict["decoder/agent_type"].clone() + B, T, N, _ = current_pos.shape + + # ===== Run forward prediction to get forward trajectory ===== + if "decoder/modeled_agent_position_history" in data_dict: + pos = [data_dict["decoder/modeled_agent_position_history"]] + head = [data_dict["decoder/modeled_agent_heading_history"]] + vel = [data_dict["decoder/modeled_agent_velocity_history"]] + # delta = data_dict["decoder/modeled_agent_delta_history"].clone() + else: + pos = [current_pos.clone()] + head = [current_heading.clone()] + vel = [current_vel.clone()] + # delta = [current_delta.clone()] + output_logit_list = [] + output_action_list = [current_input_action.clone()] + input_action_valid_mask_list = [] + for decode_step in range(num_search_steps): + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + decode_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + assert not (current_input_action == END_ACTION).any() + assert (data_dict["in_backward_prediction"] == False).all() + with torch.no_grad(): + data_dict = model.decode_motion(data_dict, use_cache=True) + output_token = data_dict["decoder/output_logit"] + selected_action = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=data_dict["decoder/current_agent_shape"], + bin_centers=bin_centers, + dt=tokenizer.dt, + ) + recon_next_pos, recon_next_heading, recon_next_vel = res["pos"], res["heading"], res["vel"] + + # TODO: delta_pos computing is updated. + raise ValueError() + relative_delta_pos = recon_next_pos.reshape(B, 1, N, 2) - current_pos + relative_delta_pos = utils.rotate( + relative_delta_pos[..., 0], relative_delta_pos[..., 1], angle=-recon_next_heading.reshape(B, 1, N) + ) + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + current_input_action = selected_action + pos.append(current_pos.clone()) + head.append(current_heading.clone()) + vel.append(current_vel.clone()) + # delta.append(current_delta.clone()) + output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + data_dict.pop("decoder/cache") + output_action_list = torch.concatenate(output_action_list, dim=1) + output_logit_list = torch.concatenate(output_logit_list, dim=1) + pos = torch.cat(pos, dim=1) + head = torch.cat(head, dim=1) + vel = torch.cat(vel, dim=1) + + # # ===== Backward-tokenize the forward trajectory ===== + current_pos = pos[:, -1:] + current_heading = head[:, -1:] + current_vel = vel[:, -1:] + backward_first_action = torch.full_like(current_input_action, -1) + backward_first_action[current_valid_mask] = END_ACTION + backward_actions = [backward_first_action.reshape(B, N)] + init_delta = _reconstruct_delta_pos_from_abs_vel(current_vel, current_heading + np.pi, dt=tokenizer.dt) + backward_pos = [current_pos.clone()] + backward_head = [current_heading.clone()] + backward_vel = [current_vel.clone()] + backward_delta = [init_delta.clone()] + total_forward_steps = pos.shape[1] - 1 # minus one because the first step is already in the pos + + if backward_run_full_length: + backward_tokenize_steps = min(total_forward_steps, backward_inference_horizon + num_search_steps) + else: + backward_tokenize_steps = num_search_steps + for backward_step in range(backward_tokenize_steps): + # backward_step = 0, ..., D-1 + forward_next_step = total_forward_steps - backward_step - 1 + # forward_next_step = D-1, ..., 0 + res = tokenizer._tokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_vel=current_vel, + current_valid_mask=current_valid_mask, + next_pos=pos[:, forward_next_step:forward_next_step + 1], + next_heading=head[:, forward_next_step:forward_next_step + 1], + next_valid_mask=current_valid_mask, + next_velocity=vel[:, forward_next_step:forward_next_step + 1], + bin_centers=bin_centers, + add_noise=False, + topk=0, + agent_shape=agent_shape, + agent_type=agent_type, + dt=-tokenizer.dt, + ) + backward_actions.append(res["action"]) + current_pos = res["pos"] + current_heading = res["heading"] + current_vel = res["vel"] + backward_pos.append(current_pos) + backward_head.append(current_heading) + backward_vel.append(current_vel) + backward_delta.append(res["delta_pos"]) + backward_pos = torch.cat(backward_pos, dim=1) + backward_head = torch.cat(backward_head, dim=1) + backward_vel = torch.cat(backward_vel, dim=1) + backward_delta = torch.cat(backward_delta, dim=1) + + # ===== Run backward prediction with teacher forcing ===== + backward_input_action = torch.stack(backward_actions, dim=1) + backward_input_valid_mask = current_valid_mask.expand(-1, backward_tokenize_steps + 1, -1) + backward_input_dict = { + "decoder/modeled_agent_position": backward_pos, + "decoder/modeled_agent_heading": backward_head, + "decoder/modeled_agent_velocity": backward_vel, + # "decoder/modeled_agent_valid_mask": current_valid_mask, + "decoder/modeled_agent_delta": backward_delta, + "decoder/input_step": torch.arange(backward_tokenize_steps + 1).to(current_pos.device), + "decoder/input_action": backward_input_action, + "decoder/input_action_valid_mask": backward_input_valid_mask, + "encoder/scenario_token": data_dict["encoder/scenario_token"], + "encoder/scenario_valid_mask": data_dict["encoder/scenario_valid_mask"], + "encoder/scenario_position": data_dict["encoder/scenario_position"], + "encoder/scenario_heading": data_dict["encoder/scenario_heading"], + "in_backward_prediction": torch.ones_like(data_dict["in_backward_prediction"]), + "in_evaluation": torch.zeros_like(data_dict["in_evaluation"]), + "decoder/agent_id": data_dict["decoder/agent_id"], + "decoder/agent_type": data_dict["decoder/agent_type"], + "decoder/current_agent_shape": data_dict["decoder/current_agent_shape"], + "batch_idx": data_dict["batch_idx"], + # "decoder/randomized_modeled_agent_id": data_dict["decoder/randomized_modeled_agent_id"], + } + del data_dict + with torch.no_grad(): + backward_output_dict = model.decode_motion(backward_input_dict, use_cache=False) + + # ===== Calculate the scores ===== + backward_logit = backward_output_dict["decoder/output_logit"][:, :-1] + dist = torch.distributions.Categorical(logits=backward_logit / temperature) + backward_target_action = backward_input_action[:, 1:].clone() + backward_input_action_mask = backward_input_valid_mask.clone()[:, :-1] + backward_input_action_mask = backward_input_action_mask & (backward_target_action != + END_ACTION) & (backward_target_action != START_ACTION) + del backward_input_dict + del backward_output_dict + + # === Use log_prob as the score === + backward_target_action[~backward_input_action_mask] = 0 + backward_log_prob = dist.log_prob(backward_target_action) + backward_log_prob[~backward_input_action_mask] = 0 + assert backward_log_prob.ndim == 3 + + backward_entropy = dist.entropy() + backward_entropy[~backward_input_action_mask] = 0 + + # === Forward log_prob === + forward_dist = torch.distributions.Categorical(logits=output_logit_list / temperature) + forward_input_action_mask = current_valid_mask.clone().expand(-1, num_search_steps, -1) + forward_input_action_mask = forward_input_action_mask & (output_action_list[:, 1:] != + END_ACTION) & (output_action_list[:, 1:] != START_ACTION) + forward_log_prob = forward_dist.log_prob(output_action_list[:, 1:]) + forward_log_prob[~forward_input_action_mask] = 0 + + forward_entropy = forward_dist.entropy() + forward_entropy[~forward_input_action_mask] = 0 + + # === Combine forward and backward log_prob === + # backward_scores = forward_log_prob.sum(1) + backward_log_prob.sum(1) # Sum over time + + # # === Normalized scores === + # forward_mean = forward_log_prob.mean(dim=1) + # forward_variance = forward_log_prob.var(dim=1) + # backward_mean = backward_log_prob.mean(dim=1) + # backward_variance = backward_log_prob.var(dim=1) + # backward_scores = (forward_mean + backward_mean) / (1 + forward_variance + backward_variance) + + # === Entropy-regularized scores === + backward_scores = (forward_log_prob * forward_entropy).mean(1) + (backward_log_prob * backward_entropy).mean(1) + # Sum of probs: + # backward_scores = (forward_log_prob).mean(1) + (backward_log_prob).mean(1) + # backward_scores = (forward_log_prob).mean(1) * 0 + + # # # ===== Use GT data to evaluate the trajectories ===== + # agent_pos = data_dict["decoder/agent_position"][..., :2][:, ::5] + # if start_steps + num_search_steps >= agent_pos.shape[1]: + # gt_final_pos = agent_pos[:, -1] + # final_pos = pos[:, - start_steps + agent_pos.shape[1]] + # gt_mask = data_dict["decoder/agent_valid_mask"][:, ::5][:, -1] + # else: + # gt_final_pos = agent_pos[:, start_steps + num_search_steps] + # final_pos = pos[:, -1] + # gt_mask = data_dict["decoder/agent_valid_mask"][:, ::5][:, start_steps + num_search_steps] + # error = torch.norm(gt_final_pos - final_pos, dim=-1) + # gt_mask = gt_mask & current_valid_mask.squeeze(1) + # error = error * gt_mask + # backward_scores = -error.sum(-1) + # backward_scores = backward_scores.reshape(-1, num_search_width) + + # # ===== Another Option, Run backward prediction to get backward ADE ===== + # backward_input_action = torch.stack(backward_actions, dim=1) + # backward_input_valid_mask = current_valid_mask.expand(-1, num_search_steps+1, -1) + # backward_input_dict = { + # "decoder/modeled_agent_position": backward_pos, + # "decoder/modeled_agent_heading": backward_head, + # "decoder/modeled_agent_velocity": backward_vel, + # "decoder/modeled_agent_valid_mask": current_valid_mask, + # "decoder/modeled_agent_delta": backward_delta, + # "decoder/input_step": torch.arange(num_search_steps+1).to(current_pos.device), + # "decoder/input_action": backward_input_action, + # "decoder/input_action_valid_mask": backward_input_valid_mask, + # "encoder/scenario_token": data_dict["encoder/scenario_token"], + # "encoder/scenario_valid_mask": data_dict["encoder/scenario_valid_mask"], + # "encoder/scenario_position": data_dict["encoder/scenario_position"], + # "encoder/scenario_heading": data_dict["encoder/scenario_heading"], + # "in_backward_prediction": torch.ones_like(data_dict["in_backward_prediction"]), + # "in_evaluation": torch.zeros_like(data_dict["in_evaluation"]), + # "decoder/agent_id": data_dict["decoder/agent_id"], + # "decoder/agent_type": data_dict["decoder/agent_type"], + # "decoder/current_agent_shape": data_dict["decoder/current_agent_shape"], + # "decoder/randomized_modeled_agent_id": data_dict["decoder/randomized_modeled_agent_id"], + # } + # backward_current_pos = backward_pos[:, :1] + # backward_current_heading = backward_head[:, :1] + # backward_current_vel = backward_vel[:, :1] + # backward_current_delta = backward_delta[:, :1] + # backward_input_action = backward_input_action[:, :1] + # backward_pos = [current_pos.clone()] + # backward_head = [current_heading.clone()] + # backward_vel = [current_vel.clone()] + # backward_delta = [current_delta.clone()] + # backward_output_logit_list = [] + # backward_output_action_list = [current_input_action.clone()] + # backward_input_action_valid_mask_list = [] + # for decode_step in range(num_search_steps): + # # Overwrite all necessary data: + # backward_input_dict["decoder/modeled_agent_position"] = backward_current_pos + # backward_input_dict["decoder/modeled_agent_heading"] = backward_current_heading + # backward_input_dict["decoder/modeled_agent_velocity"] = backward_current_vel + # backward_input_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + # backward_input_dict["decoder/modeled_agent_delta"] = backward_current_delta + # backward_input_dict["decoder/input_step"] = torch.full_like(current_model_step, decode_step) + # backward_input_dict["decoder/input_action"] = backward_input_action + # backward_input_dict["decoder/input_action_valid_mask"] = current_valid_mask + # assert not (current_input_action == START_ACTION).any() + # assert (backward_input_dict["in_backward_prediction"] == True).all() + # backward_input_dict = model.decode_motion(backward_input_dict, use_cache=True) + # output_token = backward_input_dict["decoder/output_logit"] + # selected_action = sample_action( + # logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + # ) + # res = tokenizer.detokenize_step( + # current_pos=current_pos, + # current_heading=current_heading, + # current_valid_mask=current_valid_mask, + # current_vel=current_vel, + # action=selected_action, + # agent_shape=backward_input_dict["decoder/current_agent_shape"], + # bin_centers=bin_centers, + # dt=tokenizer.dt, + # ) + # recon_next_pos, recon_next_heading, recon_next_vel = res["pos"], res["heading"], res["vel"] + # relative_delta_pos = recon_next_pos.reshape(B, 1, N, 2) - current_pos + # relative_delta_pos = utils.rotate( + # relative_delta_pos[..., 0], relative_delta_pos[..., 1], angle=-recon_next_heading.reshape(B, 1, N) + # ) + # backward_current_pos = recon_next_pos.reshape(B, 1, N, 2) + # backward_current_heading = recon_next_heading.reshape(B, 1, N) + # backward_current_vel = recon_next_vel.reshape(B, 1, N, 2) + # backward_current_delta = relative_delta_pos.reshape(B, 1, N, 2) + # #current_model_step.fill_(decode_step + 1 - start_steps) + # backward_input_action = selected_action + # backward_pos.append(backward_current_pos.clone()) + # backward_head.append(backward_current_heading.clone()) + # backward_vel.append(backward_current_vel.clone()) + # backward_delta.append(backward_current_delta.clone()) + # backward_output_logit_list.append(output_token.clone()) + # backward_output_action_list.append(backward_input_action.clone()) + # backward_output_action_list = torch.concatenate(backward_output_action_list, dim=1) + # backward_output_logit_list = torch.concatenate(backward_output_logit_list, dim=1) + # backward_pos = torch.cat(backward_pos, dim=1) + # backward_head = torch.cat(backward_head, dim=1) + # backward_vel = torch.cat(backward_vel, dim=1) + # backward_delta = torch.cat(backward_delta, dim=1) + # + # backward_final_pos = backward_pos[:, -1] + # gt_current_pos = data_dict["decoder/modeled_agent_position"].clone().squeeze(1) + # + # # TODO: Can compute contour error here! + # error = torch.norm(gt_current_pos - backward_final_pos, dim=-1) + # error = error * current_valid_mask.squeeze(1) + # scenario_error = error.sum(-1) / current_valid_mask.squeeze(1).sum(-1) + # scenario_error = scenario_error.reshape(-1, num_search_width) + # backward_scores = -scenario_error + + if per_agent_argmax: + backward_scores = backward_scores.reshape(-1, num_search_width, N) + + else: + score_mask = current_valid_mask.squeeze(1) + backward_scores = backward_scores.sum(1) / score_mask.sum(-1) # Avg over agent + assert backward_scores.ndim == 1 + backward_scores = backward_scores.reshape(-1, num_search_width) + + # ===== Get the best trajectory ===== + best_idx = torch.argmax(backward_scores, dim=1) + # print("[MCTS Step: {}] Agent 0 Scores: ".format(start_steps), backward_scores[..., 0].cpu().numpy()) + # print("[MCTS Step: {}] Agent 1 Scores: ".format(start_steps), backward_scores[..., 1].cpu().numpy()) + # print("[MCTS Step: {}] Best idx: {}, scores: {}".format(start_steps, best_idx.cpu().numpy(), best_scores.values.cpu().numpy())) + + output_action_tmp = output_action_list[:, 1:] + output_action = output_action_tmp[:, 0].reshape(original_B, num_search_width, N) + + if per_agent_argmax: + selected_action = torch.gather( + output_action, dim=1, index=best_idx[:, None, :].expand(original_B, 1, N) + ).squeeze(1) + + else: + selected_action = torch.gather( + output_action, dim=1, index=best_idx[:, None, None].expand(original_B, 1, N) + ).squeeze(1) + + return selected_action.reshape(original_B, 1, N), {"scores": backward_scores} + + +def autoregressive_rollout_with_mcts( + model, + data_dict, + config, + # num_decode_steps, + num_decode_steps=None, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None, + **kwargs +): + + raw_data = data_dict + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + tokenizer = get_tokenizer(config=config) + + if temperature is None: + temperature = config.SAMPLING.TEMPERATURE + if topp is None: + topp = config.SAMPLING.TOPP + + B, T_input, N = data_dict["decoder/input_action"].shape + + if config.GPT_STYLE: + start_action_step = 0 + assert T_input == 19 + else: + start_action_step = 2 + assert T_input == 17 + autoregressive_start_step = 2 + + if num_decode_steps is None: + num_decode_steps = 19 + assert start_action_step + T_input == num_decode_steps + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"] # .clone() + agent_heading = data_dict["decoder/agent_heading"] # .clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] # .clone() + agent_velocity = data_dict["decoder/agent_velocity"] # .clone() + agent_shape = data_dict["decoder/current_agent_shape"] # .clone() + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::tokenizer.num_skipped_steps] + agent_heading = agent_heading[:, ::tokenizer.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::tokenizer.num_skipped_steps] + agent_velocity = agent_velocity[:, ::tokenizer.num_skipped_steps] + gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + current_pos = data_dict["decoder/modeled_agent_position"][:, :1].clone() + current_heading = data_dict["decoder/modeled_agent_heading"][:, :1].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"][:, :1].clone() + current_valid_mask = data_dict["decoder/input_action_valid_mask"][:, :1].clone() + current_delta = data_dict["decoder/modeled_agent_delta"][:, :1].clone() + current_model_step = torch.arange(1).to(current_pos.device) # it's 0 + gt_input_action = data_dict["decoder/input_action"].clone() + gt_target_action = data_dict["decoder/target_action"].clone() + current_input_action = gt_input_action[:, :1].clone() + + output_logit_list = [] + output_action_list = [] + input_action_valid_mask_list = [] + assert use_cache + + pos = [] + head = [] + vel = [] + + # Select correct bins: + agent_type = data_dict["decoder/agent_type"] + bin_centers = tokenizer.get_bin_centers(agent_type) + + data_dict = model.encode_scene(data_dict) + data_dict["decoder/randomized_modeled_agent_id"] = model.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=True + ) + for decode_step in range(num_decode_steps): + if decode_step == autoregressive_start_step: + assert (current_valid_mask == agent_valid_mask[:, autoregressive_start_step:autoregressive_start_step + + 1]).all() + assert (current_valid_mask == data_dict["decoder/current_agent_valid_mask"][:, None]).all() + + # ===== Fill a lot of stuff ===== + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + input_action_valid_mask_list.append(current_valid_mask.clone()) + assert not (current_input_action == END_ACTION).any() + + selected_action, mcts_info = mcts_search( + model, + data_dict, + config, + start_steps=decode_step, + num_search_steps=4, + num_search_width=4, + bin_centers=bin_centers, + ) + + # Note: Call model after MCTS search so the cache will not be used in MCTS. + data_dict = model.decode_motion(data_dict, use_cache=use_cache) + + if "decoder/modeled_agent_position_history" in data_dict: + assert data_dict["decoder/modeled_agent_position_history"].shape[1] == decode_step + 1 - start_action_step + output_token = data_dict["decoder/output_logit"] + assert output_token.shape[:3] == (B, 1, N) + # selected_action = sample_action( + # logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + # ) + + if decode_step < autoregressive_start_step: + # Overwrite the action by GT action + selected_action = gt_target_action[:, decode_step:decode_step + 1] + + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=data_dict["decoder/current_agent_shape"], + bin_centers=bin_centers, + dt=tokenizer.dt, + ) + recon_next_pos, recon_next_heading, recon_next_vel = res["pos"], res["heading"], res["vel"] + + # TODO: delta_pos computing is updated. + raise ValueError + relative_delta_pos = recon_next_pos.reshape(B, 1, N, 2) - current_pos + relative_delta_pos = utils.rotate( + relative_delta_pos[..., 0], relative_delta_pos[..., 1], angle=-recon_next_heading.reshape(B, 1, N) + ) + + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + current_model_step.fill_(decode_step + 1 - start_action_step) + current_input_action = selected_action + + # Overwrite the data FOR NEXT STEP by the GT data: + if decode_step < autoregressive_start_step: + newly_added = agent_valid_mask[:, decode_step + 1:decode_step + 2] & (~current_valid_mask) + if newly_added.any(): + current_pos[newly_added] = agent_pos[:, decode_step + 1:decode_step + 2, ..., :2][newly_added] + current_heading[newly_added] = agent_heading[:, decode_step + 1:decode_step + 2][newly_added] + current_vel[newly_added] = agent_velocity[:, decode_step + 1:decode_step + 2][newly_added] + current_valid_mask[newly_added] = agent_valid_mask[:, decode_step + 1:decode_step + 2][newly_added] + current_delta[newly_added] = gt_agent_delta[:, decode_step + 1:decode_step + 2][newly_added] + + # Overwrite the input action by GT action + current_input_action = gt_input_action[:, decode_step + 1:decode_step + 2] + output_token = torch.zeros_like(output_token) + + pos.append(current_pos.clone()) + head.append(current_heading.clone()) + vel.append(current_vel.clone()) + output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps - start_action_step, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + + # FIXME + # FIXME + # FIXME What is the score? + data_dict["decoder/output_score"] = calculate_trajectory_probabilities( + output_logit_list, output_action_list, mask=current_valid_mask + ) # (B, N) + + input_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + + data_dict["decoder/debug_ar_pos"] = torch.cat(pos, dim=1) + data_dict["decoder/debug_ar_head"] = torch.cat(head, dim=1) + data_dict["decoder/debug_ar_vel"] = torch.cat(vel, dim=1) + + valid_output_action = output_action_list[input_action_valid_mask] + assert valid_output_action.max() < tokenizer.num_actions + assert valid_output_action.min() >= 0 + + # ===== Debug! rewrite output action by GT ===== + # tokenizer = get_tokenizer(config=self.config) + # input_dict["decoder/output_action"] = input_dict["decoder/target_action"].clone() + # fill_zero = ((input_dict["decoder/output_action"] == -1) & input_dict["decoder/input_action_valid_mask"]) + # input_dict["decoder/output_action"][fill_zero] = tokenizer.default_action + + return data_dict + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1026_gpt.yaml") +def debug_run_model(config): + omegaconf.OmegaConf.set_struct(config, False) + config.PREPROCESSING.keep_all_data = True + config.DATA.SD_PASSTHROUGH = False + omegaconf.OmegaConf.set_struct(config, True) + + # Load model + from scenestreamer.utils import utils + # path = config.pretrain + # path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/1028_tinygpt_wBACKWARD_onHybrid_2024-10-28_2232/checkpoints/last.ckpt" + # model = utils.load_from_checkpoint( + # checkpoint_path=path, cls=MotionLMLightning, config=None, + # ) + # import torch + # model = model.to("cuda") + + model = utils.get_model(config, device="cuda") + device = model.device + + test_dataset = SceneStreamerDataset(config, "training") + ddd = iter(test_dataset) + + backward_prediction = False + + search_width = 4 + + while True: + try: + raw_data_dict = data_dict = next(ddd) + + # Create a new ADV in the data so backward prediction will help us generate it. + # data_dict = create_new_adv(data_dict) + + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config) + + # Force to run backward prediction first to make sure the data is tokenized correctly. + tok_data_dict, _ = tokenizer.tokenize_numpy_array( + data_dict, + backward_prediction=backward_prediction, + ) + data_dict.update(tok_data_dict) + + input_data_dict = utils.numpy_to_torch(data_dict, device=device) + # Extend the batch dim: + input_data_dict = { + k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v + for k, v in input_data_dict.items() + } + input_data_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + + if backward_prediction: + input_data_dict["in_backward_prediction"] = torch.tensor([1], dtype=bool).to(device) + else: + input_data_dict["in_backward_prediction"] = torch.tensor([0], dtype=bool).to(device) + + with torch.no_grad(): + output_dict = autoregressive_rollout_with_mcts( + model=model.model, + data_dict=input_data_dict, + config=config, + num_decode_steps=None, + sampling_method=config.SAMPLING.SAMPLING_METHOD, + temperature=config.SAMPLING.TEMPERATURE, + ) + + output_dict = tokenizer.detokenize( + output_dict, + + # detokenizing_gt=True, + detokenizing_gt=False, + backward_prediction=backward_prediction, + ) + + # Get the first batch + output_dict = {k: v[:1] if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} + + output_dict = { + k: (v.squeeze(0).cpu().numpy() if isinstance(v, torch.Tensor) else v) + for k, v in output_dict.items() + } + + raw_data_dict.update(output_dict) + # plot_pred(raw_data) + plot_pred(raw_data_dict, show=True) + + except StopIteration: + break + print("End") + + +if __name__ == '__main__': + # debug() + # debug_backward_prediction() + debug_run_model() diff --git a/scenestreamer/models/__deprecated__initializer.py b/scenestreamer/models/__deprecated__initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2fc213f9375e30df8a0722b0815bd6a5a7b7bc --- /dev/null +++ b/scenestreamer/models/__deprecated__initializer.py @@ -0,0 +1,962 @@ +""" + +""" +from collections import defaultdict +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scenestreamer.models.layers.initializer_predictor import InitializerPredictor +from torch import Tensor + +from scenestreamer.dataset import constants +from scenestreamer.models.layers import polyline_encoder, common_layers +from scenestreamer.models.layers import position_encoding_utils +from scenestreamer.models.layers.multi_head_attention import MultiheadAttention +# from scenestreamer.models.layers.multi_head_attention_local import MultiheadAttentionLocal +from scenestreamer.models.ops.knn import knn_utils +from scenestreamer.utils import utils + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + use_local_attn=False, + ): + super().__init__() + self.use_local_attn = use_local_attn + + if self.use_local_attn: + raise ValueError() + self.self_attn = MultiheadAttentionLocal(d_model, nhead, dropout=dropout) + else: + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward layers + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = common_layers.get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + index_pair=None, + query_batch_cnt=None, + key_batch_cnt=None, + index_pair_batch=None + ): + q = k = self.with_pos_embed(src, pos) + + src2 = self.self_attn( + q, + k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + index_pair=index_pair, + query_batch_cnt=query_batch_cnt, + key_batch_cnt=key_batch_cnt, + index_pair_batch=index_pair_batch + )[0] + + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class SceneEncoderFuser(nn.Module): + """ + A stack of transformer layers to fuse multi-modal information of a scenario. + Note that the embedding layer for each modality is not included. + + Input: The embedding of different modality. + Output: A set of tokens of the scene. + """ + def __init__(self, config): + super().__init__() + self.model_cfg = config + self.d_model = d_model = self.model_cfg.D_MODEL + self.num_decoder_layers = self.model_cfg.NUM_ATTN_LAYERS + self.num_modes = self.model_cfg.NUM_MOTION_MODES + self.num_of_neighbors = self.model_cfg.NUM_NEIGHBORS + + nhead = self.model_cfg.NUM_ATTN_HEAD + dropout = self.model_cfg.get('DROPOUT_OF_ATTN', 0.1) + + # build transformer encoder layers + self.use_local_attn = True + self_attn_layers = [] + for _ in range(self.num_decoder_layers): + self_attn_layers.append( + TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=d_model * 4, + dropout=dropout, + normalize_before=False, + use_local_attn=self.use_local_attn + ) + ) + + self.self_attn_layers = nn.ModuleList(self_attn_layers) + + def apply_local_attn( + self, *, map_token, agent_token, map_position, agent_position, map_valid_mask, agent_valid_mask, + num_of_neighbors, light_token, light_position, light_valid_mask + ): + + B, M, D = map_token.shape + + if agent_token is not None: + _, N, _ = agent_token.shape + + x = torch.cat([map_token, agent_token], axis=1) + x_mask = torch.cat([map_valid_mask, agent_valid_mask], axis=1) + x_pos = torch.cat([map_position, agent_position], axis=1)[..., :2] + + else: + N = 0 + + x = map_token + x_mask = map_valid_mask + x_pos = map_position[..., :2] + + if light_token is not None: + _, L, _ = light_token.shape + x = torch.cat([x, light_token], axis=1) + x_mask = torch.cat([x_mask, light_valid_mask], axis=1) + x_pos = torch.cat([x_pos, light_position[..., :2]], axis=1) + else: + L = 0 + + assert torch.all(x_mask.sum(dim=-1) > 0) + # batch_size, N, d_model = x.shape + _, num_tokens, _ = x.shape + + x_stack_full = x.view(-1, D) # (batch_size * N, d_model) + x_mask_stack = x_mask.view(-1) + + x_pos_stack_full = x_pos.view(-1, 2) + batch_idxs_full = torch.arange(B).type_as(x)[:, None].repeat(1, num_tokens).view(-1).int() # (batch_size * N) + + # filter invalid elements + x_stack = x_stack_full[x_mask_stack] + x_pos_stack = x_pos_stack_full[x_mask_stack] + + # It is in shape (BS * N,). It record the batch index of each selected "map feat". + batch_idxs = batch_idxs_full[x_mask_stack] + + # knn + batch_offsets = utils.get_batch_offsets(batch_idxs=batch_idxs, bs=B, device=x.device) # (batch_size + 1) + + # in shape (bs,) + # how many map_feat's index==i + batch_cnt = batch_offsets[1:] - batch_offsets[:-1] + + x_pos_stack_3d = F.pad(x_pos_stack, (0, 1)) + index_pair = knn_utils.knn_batch_mlogk( + x_pos_stack_3d, x_pos_stack_3d, batch_idxs, batch_offsets, num_of_neighbors + ) # (num_valid_elems, K) + + # positional encoding + pos_embedding = \ + position_encoding_utils.gen_sineembed_for_position(x_pos_stack[None, :, 0:2], hidden_dim=D)[0] + + output = x_stack + for k in range(len(self.self_attn_layers)): + output = self.self_attn_layers[k]( + src=output, + pos=pos_embedding, + index_pair=index_pair, + query_batch_cnt=batch_cnt, + key_batch_cnt=batch_cnt, + index_pair_batch=batch_idxs, + ) + + ret_full_feature = utils.unwrap(output, x_mask) + + output_map_token = ret_full_feature[:, :M] + + if agent_token is not None: + output_agent_token = ret_full_feature[:, M:M + N] + assert output_agent_token.shape[1] == N + + else: + output_agent_token = None + + if light_token is not None: + output_light_token = ret_full_feature[:, M + N:] + assert output_light_token.shape[1] == L + else: + output_light_token = None + + return output_map_token, output_agent_token, output_light_token + + def forward( + self, *, map_token, agent_token, map_position, agent_position, map_valid_mask, agent_valid_mask, light_token, + light_position, light_valid_mask + ): + if self.use_local_attn: + out = self.apply_local_attn( + map_token=map_token, + agent_token=agent_token, + map_position=map_position, + agent_position=agent_position, + map_valid_mask=map_valid_mask, + agent_valid_mask=agent_valid_mask, + num_of_neighbors=self.num_of_neighbors, + light_token=light_token, + light_position=light_position, + light_valid_mask=light_valid_mask + ) + else: + raise ValueError() + # global_token_feature = self.apply_global_attn( + # x=global_token_feature, x_mask=global_token_mask, x_pos=global_token_pos + # ) + return out + + +def get_distributions_for_evaluation(data_dict, model_output, selected_map, actor_type): + B, M, _ = data_dict["encoder/map_position"].shape + + actor_type_index = (actor_type - 1).reshape(B, 1, 1).expand(B, M, 1) + + selected_map_pos = torch.gather( + data_dict["encoder/map_position"], # [B, M, 2] + index=selected_map.reshape(B, 1, 1).expand(B, 1, 3), # [B, N] + dim=1 + ).squeeze(1) # [B, 3] + selected_map_heading = torch.gather( + data_dict["encoder/map_heading"], dim=1, index=selected_map.reshape(B, 1) + ).squeeze(1) # [B] + # Recenter the map position here. This is a little troublesome. + + # ===== Size ===== + selected_map = selected_map.reshape(B, 1, 1, 1) + actor_type_index = actor_type_index.reshape(B, M, 1, 1, 1) + nearest_size_logit = torch.gather( + model_output["fake_size"], dim=1, index=actor_type_index.expand(B, M, 1, *model_output["fake_size"].shape[3:]) + ).squeeze(2) + nearest_size_logit = torch.gather( + nearest_size_logit, dim=1, index=selected_map.expand(B, 1, *model_output["fake_size"].shape[3:]) + ).squeeze(1) + size_dist = utils.get_distribution(nearest_size_logit) + + nearest_position_logit = torch.gather( + model_output["fake_position"], + dim=1, + index=actor_type_index.expand(B, M, 1, *model_output["fake_position"].shape[3:]) + ).squeeze(2) + nearest_position_logit = torch.gather( + nearest_position_logit, dim=1, index=selected_map.expand(B, 1, *model_output["fake_position"].shape[3:]) + ).squeeze(1) + pos_dist = utils.get_distribution(nearest_position_logit) + + nearest_heading_logit = torch.gather( + model_output["fake_heading"], + dim=1, + index=actor_type_index.expand(B, M, 1, *model_output["fake_heading"].shape[3:]) + ).squeeze(2) + nearest_heading_logit = torch.gather( + nearest_heading_logit, dim=1, index=selected_map.expand(B, 1, *model_output["fake_heading"].shape[3:]) + ).squeeze(1) + head_dist = utils.get_distribution(nearest_heading_logit) + + nearest_velocity_logit = torch.gather( + model_output["fake_velocity"], + dim=1, + index=actor_type_index.expand(B, M, 1, *model_output["fake_velocity"].shape[3:]) + ).squeeze(2) + nearest_velocity_logit = torch.gather( + nearest_velocity_logit, dim=1, index=selected_map.expand(B, 1, *model_output["fake_velocity"].shape[3:]) + ).squeeze(1) + vel_dist = utils.get_distribution(nearest_velocity_logit) + + return selected_map_pos, selected_map_heading, pos_dist, vel_dist, head_dist, size_dist + + +def get_distributions_for_training(data_dict, model_output, selected_map, actor_type): + B, M, _ = data_dict["encoder/map_position"].shape + + _, _, N, _ = data_dict["encoder/agent_feature"].shape + + selected_map_pos = torch.gather( + data_dict["encoder/map_position"], # [B, M, 2] + index=selected_map.reshape(B, N, 1).expand(B, N, 3), # [B, N] + dim=1 + ) # [B, N, 3] + selected_map_heading = torch.gather( + data_dict["encoder/map_heading"], dim=1, index=selected_map.reshape(B, N) + ) # [B, N] + # Recenter the map position here. This is a little troublesome. + + # ===== Size ===== + selected_map = selected_map.reshape(B, N, 1, 1, 1) + + actor_type = actor_type.clone() + actor_type[(actor_type < 1) | (actor_type > 3)] = 1 + actor_type = actor_type - 1 + + actor_type_index = actor_type.reshape(B, N, 1, 1, 1) + nearest_size_logit = torch.gather( + model_output["fake_size"], dim=1, index=selected_map.expand(B, N, *model_output["fake_size"].shape[2:]) + ) + nearest_size_logit = torch.gather( + nearest_size_logit, dim=2, index=actor_type_index.expand(B, N, 1, *model_output["fake_size"].shape[3:]) + ).squeeze(2) + size_dist = utils.get_distribution(nearest_size_logit) + + nearest_position_logit = torch.gather( + model_output["fake_position"], dim=1, index=selected_map.expand(B, N, *model_output["fake_position"].shape[2:]) + ) + nearest_position_logit = torch.gather( + nearest_position_logit, dim=2, index=actor_type_index.expand(B, N, 1, *model_output["fake_position"].shape[3:]) + ).squeeze(2) + pos_dist = utils.get_distribution(nearest_position_logit) + + nearest_heading_logit = torch.gather( + model_output["fake_heading"], dim=1, index=selected_map.expand(B, N, *model_output["fake_heading"].shape[2:]) + ).squeeze(1) + nearest_heading_logit = torch.gather( + nearest_heading_logit, dim=2, index=actor_type_index.expand(B, N, 1, *model_output["fake_heading"].shape[3:]) + ).squeeze(2) + head_dist = utils.get_distribution(nearest_heading_logit) + + nearest_velocity_logit = torch.gather( + model_output["fake_velocity"], dim=1, index=selected_map.expand(B, N, *model_output["fake_velocity"].shape[2:]) + ) + nearest_velocity_logit = torch.gather( + nearest_velocity_logit, dim=2, index=actor_type_index.expand(B, N, 1, *model_output["fake_velocity"].shape[3:]) + ).squeeze(2) + vel_dist = utils.get_distribution(nearest_velocity_logit) + + return selected_map_pos, selected_map_heading, pos_dist, vel_dist, head_dist, size_dist + + +@torch.no_grad() +def sample_from_distributions( + pos_dist, vel_dist, head_dist, size_dist, deterministic_state, selected_map_pos, selected_map_heading +): + if deterministic_state: + sampled_size = size_dist.mean # * constants.SIZE_RANGE + else: + sampled_size = size_dist.sample() # * constants.SIZE_RANGE + sampled_size = sampled_size.clamp(0.1) + + if deterministic_state: + sampled_pos = pos_dist.mean # * constants.LOCAL_POSITION_XY_RANGE + else: + sampled_pos = pos_dist.sample() # * constants.LOCAL_POSITION_XY_RANGE + # sampled_pos = sampled_pos.clamp(-LOCAL_POSITION_XY_RANGE, constants.LOCAL_POSITION_XY_RANGE) + sampled_pos = torch.cat([sampled_pos, sampled_size[:, 2:] / 2], dim=-1) # Add Z axis + sampled_pos = utils.relative_to_absolute(sampled_pos, selected_map_heading) + sampled_pos = sampled_pos + selected_map_pos + + if deterministic_state: + sampled_head = head_dist.mean # * constants.HEADING_RANGE + else: + sampled_head = head_dist.sample() # * constants.HEADING_RANGE + # sampled_head = sampled_head.clamp(-np.pi/2, np.pi/2) + sampled_head = utils.wrap_to_pi(sampled_head + selected_map_heading) + + if deterministic_state: + sampled_vel = vel_dist.mean # * constants.VELOCITY_XY_RANGE + else: + sampled_vel = vel_dist.sample() # * constants.VELOCITY_XY_RANGE + sampled_vel = utils.relative_to_absolute(sampled_vel, selected_map_heading) + + return sampled_pos, sampled_vel, sampled_head, sampled_size + + +def if_intersection(new, actor_position, actor_feature, actor_valid_mask): + if actor_position is None: + return False + pos = actor_position[:, 0] # [B, N, 3] + size = actor_feature[:, 0, :, 6:9] # * constants.SIZE_RANGE + max_size = size.max(-1)[0] # [B, N] + dist = torch.cdist(new["sampled_pos"].unsqueeze(1), pos).squeeze(1) # [B, N] + ret = (dist < max_size).any(-1) # [B, ] + return ret + + +@torch.no_grad() +def sample_new_actor( + data_dict, + model_output, + sampling_method, + actor_type, + temperature=1.0, + topk=10, + topp=0.9, + deterministic_state=False, + use_nature_probability=False +): + # [B, N, num_modes] + B, M, _, num_modes, _ = model_output["fake_position"].shape + + # actor_type is in [B,] in range {0, 1, 2, 3, 4} + # We can use it to select map feature. + actor_type = actor_type.clone() + actor_type[(actor_type < 1) | (actor_type > 3)] = 1 + actor_type_index = (actor_type - 1).reshape(B, 1, 1).expand(B, M, 1) + + if use_nature_probability: + map_prob = F.sigmoid(model_output["fake_map_feat_score"]).reshape(B, M, constants.NUM_TYPES) # [B, M, 3] + + pos_log_prob = model_output["fake_position_dist"].log_prob(model_output["fake_position_dist"].mean) # [B, M, 3] + vel_log_prob = model_output["fake_velocity_dist"].log_prob(model_output["fake_velocity_dist"].mean) # [B, M, 3] + head_log_prob = model_output["fake_heading_dist"].log_prob(model_output["fake_heading_dist"].mean) # [B, M, 3] + size_log_prob = model_output["fake_size_dist"].log_prob(model_output["fake_size_dist"].mean) # [B, M, 3] + + prob = map_prob * torch.exp(pos_log_prob + vel_log_prob + head_log_prob + size_log_prob) + + # mask out invalid map feature + prob[~model_output["encoder/map_valid_mask"]] = 0 + + if temperature is not None: + prob = prob**(1.0 / temperature) + score = prob + else: + score = model_output["fake_map_feat_score"].reshape(B, M, constants.NUM_TYPES) + score[~model_output["encoder/map_valid_mask"]] = float("-inf") + score = score / temperature # [B, M, 3] + + score = torch.gather(input=score, index=actor_type_index, dim=2).reshape(B, M) + + if sampling_method == "softmax": + if use_nature_probability: + selected_map = torch.distributions.Categorical(probs=score.clamp(-100, 100)).sample() + else: + selected_map = torch.distributions.Categorical(logits=score.clamp(-100, 100)).sample() + + elif sampling_method == "topk": + indices_to_remove = score < score.topk(topk, dim=-1)[0][..., -1, None] + if use_nature_probability: + score[indices_to_remove] = 0.0 + selected_map = torch.distributions.Categorical(probs=score.clamp(-100, 100)).sample() + else: + score[indices_to_remove] = float("-inf") + selected_map = torch.distributions.Categorical(logits=score.clamp(-100, 100)).sample() + + elif sampling_method == "topp": + sorted_logits, sorted_indices = torch.sort(score, descending=True) + + # Compute cumulative probabilities + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Create a mask for all tokens whose cumulative probability exceeds the threshold + sorted_indices_to_remove = cumulative_probs > topp + + # Since we want to keep at least one token, shift the mask to the right + # This way the first token (with the highest probability) will always be kept + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # Scatter the sorted tensor to match the original indices + indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) + + # Set all logits that are masked out to a very large negative number, + # so that they become zero after applying softmax + if use_nature_probability: + score[indices_to_remove] = 0.0 + selected_map = torch.distributions.Categorical(probs=score.clamp(-100, 100)).sample() + else: + score[indices_to_remove] = float("-inf") + selected_map = torch.distributions.Categorical(logits=score.clamp(-100, 100)).sample() + + elif sampling_method == "argmax": + selected_map = score.argmax(-1) + + else: + raise ValueError(f"Unknown {sampling_method}") + assert selected_map.shape == (B, ) + + selected_map_pos, selected_map_heading, pos_dist, vel_dist, head_dist, size_dist = get_distributions_for_evaluation( + data_dict, model_output, selected_map, actor_type + ) + + new_actor_feature = score.new_zeros([B, constants.AGENT_STATE_DIM]) + + sampled_pos, sampled_vel, sampled_head, sampled_size = sample_from_distributions( + size_dist=size_dist, + pos_dist=pos_dist, + head_dist=head_dist, + vel_dist=vel_dist, + deterministic_state=deterministic_state, + selected_map_pos=selected_map_pos, + selected_map_heading=selected_map_heading + ) + + new_actor_feature[:, 6:9] = sampled_size # / constants.SIZE_RANGE + new_actor_feature[:, :3] = sampled_pos # / constants.POSITION_XY_RANGE + new_actor_feature[:, 3] = sampled_head # / constants.HEADING_RANGE + new_actor_feature[:, 9] = torch.sin(sampled_head) + new_actor_feature[:, 10] = torch.cos(sampled_head) + new_actor_feature[:, 4:6] = sampled_vel # / constants.VELOCITY_XY_RANGE + new_actor_feature[:, 11] = sampled_vel.norm(dim=-1) # / constants.VELOCITY_XY_RANGE + + # actor type + new_actor_feature[actor_type == 1, 12] = 1 + new_actor_feature[actor_type == 2, 13] = 1 + new_actor_feature[actor_type == 3, 14] = 1 + + new_actor_feature[:, 15] = 1 + + return new_actor_feature, { + "sampled_pos": sampled_pos, + "sampled_vel": sampled_vel, + "sampled_head": sampled_head, + "sampled_size": sampled_size + } + + +class SceneStreamerModel(nn.Module): + def __init__(self, config): + super().__init__() + self.model_cfg = config + self.d_model = self.model_cfg.D_MODEL + self.num_decoder_layers = self.model_cfg.NUM_ATTN_LAYERS + self.num_modes = self.model_cfg.NUM_MOTION_MODES + + self.map_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( + in_channels=constants.MAP_FEATURE_STATE_DIM, + hidden_dim=64, + num_layers=2, + num_pre_layers=1, + out_channels=self.d_model + ) + self.actor_mlps = common_layers.build_mlps( + c_in=constants.AGENT_STATE_DIM, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + self.light_mlps = common_layers.build_mlps( + c_in=constants.TRAFFIC_LIGHT_STATE_DIM, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + + self.start_token_predictor = InitializerPredictor( + d_model=self.d_model, + num_modes=self.model_cfg.NUM_MOTION_MODES, + # max_map_feat=self.config.MAX_MAP_FEATURES + ) + + self.encoder = SceneEncoderFuser(self.model_cfg) + + def get_loss(self, data_dict, gt_dict, forward_ret_dict, tb_pre_tag=''): + actor_position = data_dict["encoder/agent_position"][:, 0] # [B, N, 2] + raw_actor_valid_mask = data_dict["encoder/agent_valid_mask"][:, 0] + + map_mask = forward_ret_dict["map_mask_for_initializer"] + + map_position = data_dict["encoder/map_position"] + map_dist = torch.cdist(actor_position, map_position, compute_mode="donot_use_mm_for_euclid_dist") # [B, N, M] + B, N, M = map_dist.shape + + map_mask = map_mask.unsqueeze(1).expand(B, N, M) + map_dist[~map_mask] = +100000 + selected_map_distance, selected_map = map_dist.min(-1) + selected_map = selected_map.reshape(B, N, 1, 1) + + # Remove objects that are too far from any existing map feature. + actor_is_closed = selected_map_distance <= 5 + actor_valid_mask = torch.logical_and(raw_actor_valid_mask, actor_is_closed) + + actor_type = data_dict["decoder/actor_type"] + + selected_map_pos, selected_map_heading, pos_dist, vel_dist, head_dist, size_dist = \ + get_distributions_for_training(data_dict, forward_ret_dict, selected_map, actor_type) + + local_actor_position = actor_position - selected_map_pos + local_actor_position = utils.absolute_to_relative(local_actor_position, selected_map_heading) + # local_actor_position = local_actor_position / constants.LOCAL_POSITION_XY_RANGE + # local_actor_position = local_actor_position + init_pos_loss = -pos_dist.log_prob(local_actor_position[..., :2]) # THE TARGET IS NORMALIZED + init_pos_loss = (torch.sum(init_pos_loss * actor_valid_mask, dim=1) / actor_valid_mask.sum(-1).clamp(1)).mean() + + local_actor_velocity = data_dict["encoder/agent_feature"][:, 0, :, 4:6] + # local_actor_velocity = data_dict["encoder/agent_feature"][:, 0, :, 4:6] * constants.VELOCITY_XY_RANGE + + local_actor_velocity = utils.absolute_to_relative(local_actor_velocity, selected_map_heading) + + # local_actor_velocity = local_actor_velocity / constants.VELOCITY_XY_RANGE + local_actor_velocity = local_actor_velocity + + init_vel_loss = -vel_dist.log_prob(local_actor_velocity) # THE TARGET IS NORMALIZED + init_vel_loss = (torch.sum(init_vel_loss * actor_valid_mask, dim=1) / actor_valid_mask.sum(-1).clamp(1)).mean() + + # head_tar = data_dict["encoder/agent_feature"][:, 0, :, 3] * constants.HEADING_RANGE + head_tar = data_dict["encoder/agent_feature"][:, 0, :, 3] + + # Pred Head + Map Head = GT Head -> Pred Head = GT Head - Map Head + head_tar = utils.wrap_to_pi(head_tar - selected_map_heading) + + # head_tar = head_tar / constants.HEADING_RANGE + head_tar = head_tar + + init_head_loss = -head_dist.log_prob(head_tar) # THE TARGET IS NORMALIZED + init_head_loss = (torch.sum(init_head_loss * actor_valid_mask, dim=1) / + actor_valid_mask.sum(-1).clamp(1)).mean() + + size_tar = (data_dict["encoder/agent_feature"][:, 0, :, 6:9]) + init_size_loss = -size_dist.log_prob(size_tar) # THE TARGET IS NORMALIZED + init_size_loss = (torch.sum(init_size_loss * actor_valid_mask, dim=1) / + actor_valid_mask.sum(-1).clamp(1)).mean() + + # ===== Score Loss ===== + pred_score = forward_ret_dict["fake_map_feat_score"].squeeze(-1) # [B, M] + + # map_feat_score_loss = F.binary_cross_entropy_with_logits(input=pred_score, target=is_lane_mask.float(), reduction="none") + + map_feat_score_loss_total = [] + for type_int in range(constants.NUM_TYPES): + + gt_score = pred_score.new_zeros([B, M]) + for i in range(B): + m = actor_valid_mask[i] + m = torch.logical_and(m, actor_type[i] == type_int + 1) + gt_score[i].index_fill_(dim=0, index=selected_map[i].reshape(-1)[m], value=1) + + map_feat_score_loss = F.binary_cross_entropy_with_logits( + input=pred_score[..., type_int], target=gt_score, reduction='none' + ) + original_map_mask = forward_ret_dict["encoder/map_valid_mask"] + map_feat_score_loss = ( + torch.sum(map_feat_score_loss.nan_to_num() * original_map_mask, dim=1) / + original_map_mask.sum(-1).clamp(1) + ).mean() + map_feat_score_loss_total.append(map_feat_score_loss) + + # ===== Actor Type Loss ===== + # type_mask = forward_ret_dict["compress_actor_valid_mask"].unsqueeze(-1).expand(B, compress_T, N, num_modes) + nearest_type_logit = torch.gather( + forward_ret_dict["fake_actor_type"], + dim=1, + index=selected_map.squeeze(-1).expand(B, N, *forward_ret_dict["fake_actor_type"].shape[2:]) + ) + gt_actor_type = gt_dict["decoder/actor_type"].reshape(B, N) + type_loss = F.cross_entropy( + input=nearest_type_logit.reshape(-1, 5), target=gt_actor_type.reshape(-1), reduction="none" + ) + type_loss = type_loss.reshape_as(gt_actor_type) + type_loss = (torch.sum(type_loss * actor_valid_mask, dim=1) / actor_valid_mask.sum(-1).clamp(1)).mean() + + # ===== Total Loss ===== + total_loss = init_pos_loss + init_vel_loss + init_head_loss + init_size_loss + sum( + map_feat_score_loss_total + ) + type_loss + tb_dict = dict( + total_loss=total_loss.item(), + # init_actor_loss=total_loss.item(), + init_pos_loss=init_pos_loss.item(), + init_vel_loss=init_vel_loss.item(), + init_head_loss=init_head_loss.item(), + init_size_loss=init_size_loss.item(), + map_feat_score_loss=sum(map_feat_score_loss_total).item(), + **{"map_feat_score_loss_type{}".format(i): v.item() + for i, v in enumerate(map_feat_score_loss_total)}, + init_actor_type_loss=type_loss.item(), + actor_valid_mask_raw=raw_actor_valid_mask.sum(-1).float().mean().item(), + actor_valid_mask_to_model=forward_ret_dict["first_step_actor_valid_mask"].sum(-1).float().mean().item(), + actor_valid_mask_to_loss=actor_valid_mask.sum(-1).float().mean().item(), + ) + tb_dict[f'{tb_pre_tag}loss'] = total_loss.item() + return total_loss, tb_dict, tb_dict + + def forward(self, input_dict): + actor_feature = input_dict["encoder/agent_feature"] + actor_valid_mask = input_dict["encoder/agent_valid_mask"] + actor_position = input_dict["encoder/agent_position"] + # B, T, N, D_actor = actor_feature.shape + # assert actor_feature.shape[:3] == actor_position.shape[:3] + + in_evaluation = input_dict.get("in_evaluation", False) + + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_token_valid_mask = map_valid_mask.sum(axis=-1) != 0 + map_position = input_dict["encoder/map_position"] + B = map_position.shape[0] + + if actor_feature is not None: + first_step_actor_valid_mask = actor_valid_mask[:, 0] + first_step_actor_feature = actor_feature[:, 0] + first_step_actor_position = actor_position[:, 0] + _, N, _ = first_step_actor_feature.shape + + if not in_evaluation: # TODO: Use a config to control this + selected_num = torch.minimum( + (first_step_actor_valid_mask.sum(-1) * torch.rand(B, device=actor_valid_mask.device)).int(), + first_step_actor_valid_mask.sum(-1) + ).clamp(0) # [B, ] + keep_mask = map_valid_mask.new_zeros((B, N)) + for i in range(B): + st_valids = first_step_actor_valid_mask[i].nonzero()[:, 0] + st_ind = torch.randperm(len(st_valids))[:selected_num[i]] + st_selected_ind = st_valids[st_ind] + keep_mask[i, st_selected_ind] = 1 + + before = first_step_actor_valid_mask.sum() + first_step_actor_valid_mask = torch.logical_and(first_step_actor_valid_mask, keep_mask) + after = first_step_actor_valid_mask.sum() + + # TODO: In future, we can try to add some "empty token" to tell layers where there is a car but masked out. + + agent_enc = self.actor_mlps(first_step_actor_feature[first_step_actor_valid_mask]) + # agent_enc += self.actor_type_pe(input_dict["decoder/actor_type"][first_step_actor_valid_mask]) + agent_enc = utils.unwrap(agent_enc, first_step_actor_valid_mask) + + else: + agent_enc = None + first_step_actor_valid_mask = None + first_step_actor_position = None + + map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + # map_type = input_dict["map_feature_type"] + # map_type[map_type == -1] = 0 + # map_token += self.map_feature_type_pe(map_type) + + traffic_light_feature = input_dict["encoder/traffic_light_feature"][:, 0] + traffic_light_position = input_dict["encoder/traffic_light_position"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"][:, 0] + _, L, D_light = traffic_light_feature.shape + # [B, T, L] -> [B, compress_T, L] + # compress_light_mask = traffic_light_valid_mask[:, :compress_T * self.compress_step].reshape( + # B, compress_T, self.compress_step, L).any(dim=2).clone() + + # [B, T, num light, token dim] + if L != 0: + light_token = self.light_mlps(traffic_light_feature[traffic_light_valid_mask]) + # light_token += PE # TODO: Can add PE for traffic light type. + light_token = utils.unwrap(light_token, traffic_light_valid_mask) + else: + light_token = None + + output_map_token, output_actor_token, output_light_token = self.encoder( + map_token=map_token, + actor_token=agent_enc, + map_position=map_position, + actor_position=first_step_actor_position, + map_valid_mask=map_token_valid_mask, + actor_valid_mask=first_step_actor_valid_mask, + light_token=light_token if L > 0 else None, + light_position=traffic_light_position if L > 0 else None, + light_valid_mask=traffic_light_valid_mask if L > 0 else None + ) + feature = output_map_token + + # Note: The map feature here includes all types. But when computing prediction, we will only select the lane + # map feature. + pred_pos, pred_vel, pred_head, pred_size, map_feat_score, actor_type, pos_dist, vel_dist, head_dist, size_dist = \ + self.start_token_predictor(feature, map_token_valid_mask) + + original_map_mask = input_dict["encoder/map_feature_valid_mask"].sum(axis=-1) != 0 + map_mask = original_map_mask # TODO: We can filter out some useless map feature avoid them be the anchor. + + return { + "fake_position": pred_pos, + "fake_velocity": pred_vel, + "fake_heading": pred_head, + "fake_size": pred_size, + "fake_map_feat_score": map_feat_score, + "fake_actor_type": actor_type, + "fake_position_dist": pos_dist, + "fake_velocity_dist": vel_dist, + "fake_heading_dist": head_dist, + "fake_size_dist": size_dist, + "map_mask_for_initializer": map_mask, + "encoder/map_valid_mask": original_map_mask, + "first_step_actor_valid_mask": first_step_actor_valid_mask, + } + + def autoregressive_generate( + self, + data_dict, + num_v, + temperature=1.0, + sampling_method="softmax", + topk=10, + topp=0.9, + deterministic_state=False, + use_nature_probability=False, + record_intermediate_model_output=False, + condition_on_sdc=False + ): + input_dict, gt_dict = data_dict + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + B = map_valid_mask.shape[0] + + if condition_on_sdc: + sdc_index = gt_dict["sdc_index"] # [B, ] + sdc_index = sdc_index.reshape(B, 1, 1) + + actor_feature = torch.gather( + input_dict["encoder/agent_feature"][:, 0], # [B, N, D] + index=sdc_index.expand(B, 1, input_dict["encoder/agent_feature"].shape[-1]), # [B, 1, D] + dim=1 + ).unsqueeze(1) # -> [B, 1, 1, D] + actor_feature_list = [actor_feature.reshape(B, input_dict["encoder/agent_feature"].shape[-1])] + + actor_position = torch.gather( + input_dict["encoder/agent_position"][:, 0], # [B, N, 3] + index=sdc_index.expand(B, 1, input_dict["encoder/agent_position"].shape[-1]), # [B, 1, 3] + dim=1 + ).unsqueeze(1) # -> [B, 1, 1, 3] + actor_position_list = [actor_position.reshape(B, 3)] + + actor_valid_mask = torch.gather( + input_dict["encoder/agent_valid_mask"][:, 0], # [B, N] + index=sdc_index.reshape(B, 1), # [B, 1] + dim=1 + ).unsqueeze(1) # [B, 1, 1] + actor_valid_mask_list = [actor_valid_mask.reshape(B, )] + + assert num_v > 1, num_v + num_v = num_v - 1 + + else: + actor_feature = None + actor_position = None + actor_valid_mask = None + actor_valid_mask_list = [] + actor_feature_list = [] + actor_position_list = [] + + intermediate_model_output = [] + + return_sampled_dict = defaultdict(list) + + # TODO: In future we should support user-specified different actor type + gt_actor_type = input_dict["decoder/actor_type"].clone() + if condition_on_sdc: + # The actor type of known actors + actor_type = torch.gather( + gt_actor_type, # [B, N, 3] + index=sdc_index.reshape(B, 1), # [B, 1, 3] + dim=1 + ) # [B, 1] + actor_type_list = [actor_type.reshape(B, )] + + new_gt_actor_type = [] + for i in range(B): + N = len(gt_actor_type[i]) + ind = gt_actor_type.new_ones(N, dtype=bool) + ind[sdc_index[i]] = 0 + new_gt_actor_type.append(gt_actor_type[i][ind]) + new_gt_actor_type = torch.stack(new_gt_actor_type, dim=0) + gt_actor_type = new_gt_actor_type + + else: + actor_type = None + actor_type_list = [] + + new_gt_actor_type = torch.zeros_like(gt_actor_type) + for i in range(B): + valid_actor_type = gt_actor_type[i][gt_actor_type[i] != -1] + new_gt_actor_type[i, :len(valid_actor_type)] = valid_actor_type + gt_actor_type = new_gt_actor_type + + actor_valid_sum = input_dict["encoder/agent_valid_mask"][:, 0].sum(-1) # [B, ] + intersection_count = 0 + + for i in range(num_v): + input_dict_tmp = { + "encoder/map_feature": input_dict["encoder/map_feature"], + "encoder/map_feature_valid_mask": input_dict["encoder/map_feature_valid_mask"], + "encoder/map_position": input_dict["encoder/map_position"], + "encoder/map_heading": input_dict["encoder/map_heading"], + # "map_feature_type": input_dict["map_feature_type"], + "encoder/traffic_light_feature": input_dict["encoder/traffic_light_feature"], + "encoder/traffic_light_position": input_dict["encoder/traffic_light_position"], + "encoder/traffic_light_valid_mask": input_dict["encoder/traffic_light_valid_mask"], + + # These data should be updated + "encoder/agent_feature": actor_feature, + "encoder/agent_valid_mask": actor_valid_mask, + "encoder/agent_position": actor_position, + "decoder/actor_type": actor_type, + "in_evaluation": True, + } + out = self(input_dict_tmp) + + if record_intermediate_model_output: + intermediate_model_output.append(out) + + # Do the sampling and create new token here + trial = 0 + while trial < 5: + new_token, sampled_dict = sample_new_actor( + data_dict=input_dict_tmp, + model_output=out, + sampling_method=sampling_method, + actor_type=gt_actor_type[:, i], + temperature=temperature, + topk=topk, + topp=topp, + deterministic_state=deterministic_state, + use_nature_probability=use_nature_probability + ) + intersect = if_intersection(sampled_dict, actor_position, actor_feature, actor_valid_mask) + + if actor_position is not None and intersect.any(): + # print(f"Trial {trial} Find {intersect.sum()} intersection") + intersection_count += intersect.sum() + trial += 1 + else: + break + + for k, v in sampled_dict.items(): + return_sampled_dict[k].append(v) + + # Fill up information here + actor_feature_list.append(new_token) # [B, D] + actor_feature = torch.stack(actor_feature_list, dim=1).unsqueeze(1) # -> B, 1, N, 16 + actor_position_list.append(sampled_dict["sampled_pos"]) + actor_position = torch.stack(actor_position_list, dim=1).unsqueeze(1) # -> B, 1, N, 3 + + actor_type_list.append(gt_actor_type[:, i]) + actor_type = torch.stack(actor_type_list, dim=1) + + # actor_valid_mask = map_valid_mask.new_ones(actor_feature.shape[:3]) if actor_feature is not None else None + actor_valid_mask_list.append(i < actor_valid_sum) # [B, ] + actor_valid_mask = torch.stack(actor_valid_mask_list, dim=1).unsqueeze(1) # -> B, 1, N + + return_sampled_dict = {k: torch.stack(v, dim=1) for k, v in return_sampled_dict.items()} + return_sampled_dict.update( + { + "encoder/agent_feature": actor_feature, + "encoder/agent_valid_mask": actor_valid_mask, + "encoder/agent_position": actor_position, + "intersection_count": intersection_count / B, + "total_count": actor_valid_sum.sum().item() / B + } + ) + + if record_intermediate_model_output: + return_sampled_dict["intermediate_model_output"] = intermediate_model_output + return return_sampled_dict diff --git a/scenestreamer/models/__deprecated__initializer_pl.py b/scenestreamer/models/__deprecated__initializer_pl.py new file mode 100644 index 0000000000000000000000000000000000000000..a89dc09b0fefb2cee69fa120c1067ac3b87b7f4d --- /dev/null +++ b/scenestreamer/models/__deprecated__initializer_pl.py @@ -0,0 +1,243 @@ +import lightning.pytorch as pl +import torch +from scenestreamer.models.initializer import SceneStreamerModel +from torch.optim.lr_scheduler import LambdaLR, LinearLR, CosineAnnealingWarmRestarts + +from scenestreamer.dataset import constants +from scenestreamer.eval import metrics + + +class SceneStreamerInitializer(pl.LightningModule): + def __init__(self, cfg): + if "SEED" in cfg: + pl.seed_everything(cfg.SEED) + print("Everything is seeded to: ", cfg.SEED) + + super().__init__() + self.cfg = cfg + self.model_cfg = self.cfg.MODEL + self.sampling_method = self.cfg.INITIALIZER.SAMPLING_METHOD + + self.decoder = SceneStreamerModel(config=self.model_cfg) + + self.save_hyperparameters() + + self.validation_outputs = [] + self.validation_ground_truth = [] + + mmd_metrics = {} + for word in ["", "_vehicle", "_pedestrian", "_cyclist"]: + mmd_metrics.update( + { + f"mmd_pos{word}": metrics.MMD(kernel_mul=1.0, kernel_num=1), + f"mmd_size{word}": metrics.MMD(kernel_mul=1.0, kernel_num=1), + f"mmd_head{word}": metrics.MMD(kernel_mul=1.0, kernel_num=1), + f"mmd_vel{word}": metrics.MMD(kernel_mul=1.0, kernel_num=1), + } + ) + for k, v in mmd_metrics.items(): + self.register_module(k, v) + self.mmd_metrics_keys = list(mmd_metrics.keys()) + + def autoregressive_generate(self, *args, **kwargs): + return self.decoder.autoregressive_generate(*args, **kwargs) + + def forward(self, batch_dict): + forward_ret_dict = self.decoder(batch_dict) + return forward_ret_dict + + def get_loss(self, data_dict, gt_dict, forward_ret_dict): + loss, tb_dict, disp_dict = self.decoder.get_loss(data_dict, gt_dict, forward_ret_dict) + return loss, tb_dict, disp_dict + + def training_step(self, batch, batch_idx): + data_dict, gt_dict = batch + forward_ret_dict = self(data_dict) + loss, tb_dict, disp_dict = self.get_loss(data_dict, gt_dict, forward_ret_dict) + self.log_dict( + {f"train/{k}": float(v) + for k, v in tb_dict.items()}, + batch_size=data_dict["encoder/agent_feature"].shape[0], + ) + self.log('monitoring_step', float(self.global_step)) + return loss + + def validation_step(self, batch, batch_idx, condition_on_sdc=None): + data_dict, gt_dict = batch + + if condition_on_sdc is None: + condition_on_sdc = self.cfg.INITIALIZER.CONDITION_ON_SDC + + num_v = data_dict["encoder/agent_valid_mask"][:, 0].sum(-1).max() + ret = self.autoregressive_generate( + data_dict=batch, + num_v=num_v, + deterministic_state=True, + temperature=self.cfg.INITIALIZER.TEMPERATURE, + sampling_method=self.cfg.INITIALIZER.SAMPLING_METHOD, + topk=self.cfg.INITIALIZER.TOPK, + topp=self.cfg.INITIALIZER.TOPP, + use_nature_probability=self.cfg.INITIALIZER.USE_NATURE_PROBABILITY, + condition_on_sdc=condition_on_sdc + ) + self.log("eval/intersection_count", ret["intersection_count"], on_step=True) + self.log("eval/intersection_total_count", ret["total_count"], on_step=True) + self.log("eval/intersection_rate", ret["intersection_count"] / ret["total_count"], on_step=True) + + B = ret["sampled_pos"].shape[0] + for i in range(B): + + pos_target = data_dict["encoder/agent_position"][i, 0, ..., :2] + vel_target = data_dict["encoder/agent_feature"][i, 0, ..., 4:6] #* constants.VELOCITY_XY_RANGE + head_target = data_dict["encoder/agent_feature"][i, 0, ..., 3:4] #* constants.HEADING_RANGE + size_target = data_dict["encoder/agent_feature"][i, 0, ..., 6:8] # * constants.SIZE_RANGE + valid_mask = data_dict["encoder/agent_valid_mask"][i, 0] + if condition_on_sdc: + valid_mask[gt_dict["sdc_index"][i]] = False # Mask out SDC in target + + pos_target = pos_target[valid_mask] + vel_target = vel_target[valid_mask] + head_target = head_target[valid_mask] + size_target = size_target[valid_mask] + + num_target = len(pos_target) + + actor_type = data_dict["decoder/actor_type"][i][valid_mask] + for suffix, mask in { + "": pos_target.new_ones(num_target, dtype=bool), + "_vehicle": actor_type == 1, + "_pedestrian": actor_type == 2, + "_cyclist": actor_type == 3, + }.items(): + if not mask.any(): + continue + self.get_submodule(f"mmd_pos{suffix}").update( + source=ret["sampled_pos"][i, :num_target, :2][mask], target=pos_target[mask] + ) + self.get_submodule(f"mmd_vel{suffix}").update( + source=ret["sampled_vel"][i, :num_target][mask], target=vel_target[mask] + ) + self.get_submodule(f"mmd_head{suffix}").update( + source=ret["sampled_head"][i, :num_target].unsqueeze(-1)[mask], target=head_target[mask] + ) + self.get_submodule(f"mmd_size{suffix}").update( + source=ret["sampled_size"][i, :num_target, :2][mask], target=size_target[mask] + ) + + def on_validation_epoch_end(self): + for k in self.mmd_metrics_keys: + self.log(f'eval/{k}', self.get_submodule(k)) + + def on_validation_start(self): + torch.cuda.empty_cache() + + def on_validation_end(self): + torch.cuda.empty_cache() + + def configure_optimizers(self): + opt_cfg = self.cfg.OPTIMIZATION + + if opt_cfg.OPTIMIZER == 'Adam': + optimizer = torch.optim.Adam( + [each[1] for each in self.named_parameters()], + lr=opt_cfg.LR, + weight_decay=opt_cfg.get('WEIGHT_DECAY', 0) + ) + elif opt_cfg.OPTIMIZER == 'AdamW': + optimizer = torch.optim.AdamW( + self.parameters(), lr=opt_cfg.LR, weight_decay=opt_cfg.get('WEIGHT_DECAY', 0), betas=(0.9, 0.95) + ) + else: + assert False + + if opt_cfg.get('SCHEDULER', None) == 'cosine': + scheduler = CosineAnnealingWarmRestarts( + optimizer, + T_0=2, + T_mult=1, + eta_min=max(1e-2 * opt_cfg.LR, 1e-6), + last_epoch=-1, + ) + elif opt_cfg.get('SCHEDULER', None) == 'lambdaLR': + + def lr_lbmd(cur_epoch): + cur_decay = 1 + for decay_step in opt_cfg.get('DECAY_STEP_LIST', [5, 10, 15, 20]): + if cur_epoch >= decay_step: + cur_decay = cur_decay * opt_cfg.LR_DECAY + return max(cur_decay, opt_cfg.LR_CLIP / opt_cfg.LR) + + scheduler = LambdaLR(optimizer, lr_lbmd) + + elif opt_cfg.get('SCHEDULER', None) == 'linearLR': + scheduler = LinearLR( + optimizer, + start_factor=1.0, + end_factor=opt_cfg.LR_CLIP / opt_cfg.LR, + total_iters=opt_cfg.NUM_EPOCHS, + ) + else: + scheduler = None + + return { + "optimizer": optimizer, + + # PZH NOTE: The scheduler step will be added 1 after each epoch. + "lr_scheduler": scheduler + } + + def generate_from_scratch( + self, input_tuple, num_v, num_p, num_c, temperature, angle_limit_in_deg, compress_step=None + ): + data_dict, gt_dict = input_tuple + + self.eval() + with torch.no_grad(): + ret, new_feat, actor_type_list = self.decoder.get_map_and_start_token( + data_dict, + num_v=num_v, + num_p=num_p, + num_c=num_c, + temperature=temperature, + angle_limit_in_deg=angle_limit_in_deg, + ) + + sampled = {k: torch.cat([v[k] for v in ret], dim=0).squeeze(1) for k in ret[0].keys()} + + # B, T, N, _ = data_dict["encoder/agent_feature"].shape + + N, _ = sampled["sampled_pos"].shape + B = 1 + + fake_pos = data_dict["encoder/map_position"].new_zeros([B, 91, N, 2]) + + fake_pos[:, 4] = sampled["sampled_pos"] + + actor_type = torch.tensor(actor_type_list, device=self.device).reshape(B, N) + + # actor_feature = torch.stack(new_feat, dim=2) + + actor_feature = data_dict["encoder/map_feature"].new_zeros([B, 91, N, constants.AGENT_STATE_DIM]) + actor_feature[:, :5] = torch.stack(new_feat, dim=2) + + actor_feature = data_dict["encoder/map_feature"].new_zeros([B, 91, N, constants.AGENT_STATE_DIM]) + actor_feature[:, :5] = torch.stack(new_feat, dim=2) + + fake_data = { + "encoder/map_feature": data_dict["encoder/map_feature"], + "encoder/map_position": data_dict["encoder/map_position"], + "encoder/map_feature_valid_mask": data_dict["encoder/map_feature_valid_mask"], + "encoder/traffic_light_feature": data_dict["encoder/traffic_light_feature"], + "encoder/traffic_light_position": data_dict["encoder/traffic_light_position"], + "encoder/traffic_light_valid_mask": data_dict["encoder/traffic_light_valid_mask"], + "encoder/agent_feature": actor_feature, + "encoder/agent_position": fake_pos, + "encoder/agent_valid_mask": data_dict["encoder/map_feature_valid_mask"].new_ones([B, 91, N]), + "decoder/actor_type": actor_type + } + + gt_dict["current_time_index"].fill_(5) + + ar_ret = self.autoregressive_generate((fake_data, gt_dict), compress_step=compress_step) + + return ar_ret, fake_data, sampled diff --git a/scenestreamer/models/__deprecated__motion.py b/scenestreamer/models/__deprecated__motion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a6554cf5876321ba435b4ed1c533c0eda3c9ed --- /dev/null +++ b/scenestreamer/models/__deprecated__motion.py @@ -0,0 +1,1585 @@ +import copy + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scenestreamer.dataset.constants import AGENT_STATE_DIM, MAP_FEATURE_STATE_DIM, TRAFFIC_LIGHT_STATE_DIM, \ + TRAFFIC_LIGHT_PREDICT_DIM, VELOCITY_XY_RANGE, HEADING_RANGE, POSITION_XY_RANGE +from scenestreamer.models.layers import common_layers, polyline_encoder, our_decoder_layer, position_encoding_utils +from scenestreamer.models.ops.search_knn_indices import search_k_nearest_object_indices, \ + search_k_nearest_map_feature_indicies, \ + search_k_nearest_map_feature_indicies_for_map +from scenestreamer.utils import wrap_to_pi, rotate + +NUM_TYPES = 3 + + +class ActorPredictor(nn.Module): + """ + input: output token of a given actor from transformer (might from each layer or from the last layer) + output: predicted state of each actor + """ + def __init__(self, d_model, num_modes, step_per_token, small, use_gaussian=False): + super().__init__() + + self.d_model = d_model + self.use_gaussian = use_gaussian + self.num_modes = num_modes + self.step_per_token = step_per_token + self.head_input_dim = self.d_model * (1 if small else 4) + + self.extra = not small + + if small: + self.actor_mlps_decompress = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model, self.d_model], + ret_before_act=False, + ) + self.position_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[ + 6 * self.step_per_token * self.num_modes if use_gaussian else 3 * self.step_per_token * + self.num_modes * NUM_TYPES + ], + ret_before_act=True, + ) + self.score_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[self.num_modes * NUM_TYPES], + ret_before_act=True, + ) + else: + self.actor_mlps_decompress = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model * 4, self.d_model * 4], + ret_before_act=False, + ) + self.position_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[ + 6 * self.num_modes * self.step_per_token if use_gaussian else 3 * self.num_modes * + self.step_per_token * NUM_TYPES + ], + ret_before_act=True, + ) + self.score_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[self.num_modes * NUM_TYPES], + ret_before_act=True, + ) + if self.extra: + self.velocity_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[ + 4 * self.num_modes * self.step_per_token if use_gaussian else 2 * self.num_modes * + self.step_per_token * NUM_TYPES + ], + ret_before_act=True, + ) + self.heading_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[ + 2 * self.num_modes * self.step_per_token if use_gaussian else 1 * self.num_modes * + self.step_per_token * NUM_TYPES + ], + ret_before_act=True, + ) + + # TODO: This is a little weird, remove this. + self.actor_type_head = common_layers.build_mlps( + c_in=self.d_model * (1 if small else 4), + mlp_channels=[5 * self.num_modes], + ret_before_act=True, + ) + + def forward(self, actor_tokens, actor_valid_mask, step_per_token, actor_type): + B, compress_T, N, token_dim = actor_tokens.shape + actor_prediction_feat = self.actor_mlps_decompress(actor_tokens[actor_valid_mask]) + + num_modes = self.num_modes + + actor_type = actor_type.clone() + actor_type[(actor_type < 1) | (actor_type > 3)] = 3 + actor_type = actor_type - 1 + + # Get predicted position in shape: [B, compress_T, N, num_modes, step_per_token, 2/4] + pred_pos = unwrap(self.position_head(actor_prediction_feat), actor_valid_mask) + pred_pos = pred_pos.reshape( + B, compress_T, N, NUM_TYPES, num_modes, step_per_token, (6 if self.use_gaussian else 3) + ) + pred_pos = torch.gather( + pred_pos, + index=actor_type.reshape(B, 1, N, 1, 1, 1, 1).expand( + B, compress_T, N, 1, num_modes, step_per_token, (6 if self.use_gaussian else 3) + ), + dim=3 + ).squeeze(3) + + # Get predicted trajectory score in shape: [B, compress_T, N, num_modes] + score_prediction_logit = self.score_head(actor_prediction_feat) + score_prediction_logit = score_prediction_logit.reshape(score_prediction_logit.shape[0], NUM_TYPES, -1) + pred_score = score_prediction_logit.new_zeros(B, compress_T, N, NUM_TYPES, num_modes) + pred_score.fill_(float("-inf")) + pred_score[actor_valid_mask] = score_prediction_logit + pred_score = torch.gather( + pred_score, index=actor_type.reshape(B, 1, N, 1, 1).expand(B, compress_T, N, 1, num_modes), dim=3 + ).squeeze(3) + # pred_score = F.log_softmax(pred_score, dim=-1) # PZH 0531: Strange, why we have logsoftmax here? + # Note that there are some nan in pred_score! + + if not self.extra: + return pred_pos, pred_score + + # Get predicted heading in shape: [B, compress_T, N, num_modes, step_per_token, 1/2] + pred_heading = unwrap(self.heading_head(actor_prediction_feat), actor_valid_mask) + pred_heading = pred_heading.reshape( + B, compress_T, N, NUM_TYPES, num_modes, step_per_token, (2 if self.use_gaussian else 1) + ) + pred_heading = torch.gather( + pred_heading, + index=actor_type.reshape(B, 1, N, 1, 1, 1, 1).expand( + B, compress_T, N, 1, num_modes, step_per_token, (2 if self.use_gaussian else 1) + ), + dim=3 + ).squeeze(3) + + # Get predicted velocity in shape: [B, compress_T, N, num_modes, step_per_token, 2/4] + pred_velocity = unwrap(self.velocity_head(actor_prediction_feat), actor_valid_mask) + pred_velocity = pred_velocity.reshape( + B, compress_T, N, NUM_TYPES, num_modes, step_per_token, (4 if self.use_gaussian else 2) + ) + pred_velocity = torch.gather( + pred_velocity, + index=actor_type.reshape(B, 1, N, 1, 1, 1, 1).expand( + B, compress_T, N, 1, num_modes, step_per_token, (4 if self.use_gaussian else 2) + ), + dim=3 + ).squeeze(3) + + # Get predicted vehicle type in shape: [B, compress_T, N, step_per_token, 5] + pred_actor_type = unwrap(self.actor_type_head(actor_prediction_feat), actor_valid_mask) + pred_actor_type = pred_actor_type.reshape(B, compress_T, N, num_modes, 5) + + return pred_pos, pred_velocity, pred_heading, pred_actor_type, pred_score + + +class QueryPE(nn.Module): + def __init__(self, d_model): + super().__init__() + self.map_pe = nn.Embedding(2000, d_model) + self.actor_pe = nn.Embedding(500, d_model) + self.light_pe = nn.Embedding(500, d_model) + self.time_pe = nn.Embedding(500, d_model) + self.d_model = d_model + + max_seq_len = 2000 + # Initialize the positional encoding matrix + pos_enc = torch.zeros(max_seq_len, d_model) + + # Compute the positional encodings + for pos in range(max_seq_len): + for i in range(0, d_model, 2): + pos_enc[pos, i] = np.sin(pos / (10000**((2 * i) / d_model))) + if i + 1 < d_model: + pos_enc[pos, i + 1] = np.cos(pos / (10000**((2 * (i + 1)) / d_model))) + + # pos_enc = pos_enc.unsqueeze(0) # Add a batch dimension + self.register_buffer("pos_enc", pos_enc) + + def forward(self, map_token, actor_token, light_token): + B, T, N, _ = actor_token.shape + _, _, L, _ = light_token.shape + + # Apply position embeddings for map features + map_pos = torch.arange(map_token.size(1), device=map_token.device) + map_pos_emb = self.map_pe(map_pos) + map_pos_emb += self.pos_enc[:map_token.size(1)] + map_pos_emb = map_pos_emb.unsqueeze(0) + + map_token_emb = map_token + map_pos_emb + + # Apply position and time embeddings for actors and traffic lights + time_pos = torch.arange(T, device=actor_token.device) + time_pos_emb = self.time_pe(time_pos) + time_pos_emb += self.pos_enc[:T] + time_pos_emb = time_pos_emb.reshape(1, T, 1, self.d_model) + + # Actors (cars) + actor_pos = torch.arange(N, device=actor_token.device) + actor_pos_emb = self.actor_pe(actor_pos) + actor_pos_emb += self.pos_enc[:N] + actor_pos_emb = actor_pos_emb.reshape(1, 1, N, self.d_model) + actor_token_emb = actor_token + actor_pos_emb + time_pos_emb + + # Traffic lights + if L > 0: + light_pos = torch.arange(L, device=light_token.device) + light_pos_emb = self.light_pe(light_pos) + light_pos_emb += self.pos_enc[:L] + light_pos_emb = light_pos_emb.reshape(1, 1, L, self.d_model) + light_token_emb = light_token + light_pos_emb + time_pos_emb + else: + light_token_emb = light_token + + return map_token_emb, actor_token_emb, light_token_emb + + +class MotionLM(nn.Module): + def __init__(self, config): + super().__init__() + + # Set up config + self.model_cfg = config + self.d_model = self.model_cfg.D_MODEL + self.num_decoder_layers = self.model_cfg.NUM_ATTN_LAYERS + self.compress_step = self.model_cfg.INPUT_STEP_PER_TOKEN + self.step_per_token = self.model_cfg.PREDICT_STEP_PER_TOKEN + # self.num_modes = self.config.NUM_MOTION_MODES + # use_gaussian = self.config.USE_GAUSSIAN + # hidden_size = self.d_model + # self.discrete = self.config.DISCRETE + + # ========== A general encoder of everything: map features & obj features ========== + + # Allow three types of starting token, so that user can control which actor type to create. + # self.start_token = nn.Embedding(5, self.d_model) + # self.start_token_pe = nn.Embedding(500, self.d_model) + # self.start_token_pe_for_empty = nn.Embedding(1, self.d_model) + + self.map_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( + in_channels=MAP_FEATURE_STATE_DIM, hidden_dim=64, num_layers=5, num_pre_layers=3, out_channels=self.d_model + ) + self.actor_mlps = common_layers.build_mlps( + c_in=AGENT_STATE_DIM, + mlp_channels=[self.d_model] * 2, + ret_before_act=True, + ) + self.light_mlps = common_layers.build_mlps( + c_in=TRAFFIC_LIGHT_STATE_DIM, mlp_channels=[self.d_model] * 2, ret_before_act=True, without_norm=True + ) + self.actor_mlps_compress = common_layers.build_mlps( + c_in=self.d_model * self.compress_step, + mlp_channels=[self.d_model * 4, self.d_model * 4, self.d_model], + ret_before_act=True, + ) + self.light_mlps_compress = common_layers.build_mlps( + c_in=self.d_model * self.compress_step, + mlp_channels=[self.d_model * 2, self.d_model * 2, self.d_model], + ret_before_act=True, + without_norm=True + ) + self.decoder_tokenizer = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model, self.d_model], + ret_before_act=True, + ) + self.decoder_layers = self.build_transformer_decoder( + d_model=self.d_model, # 256 + nhead=self.model_cfg.NUM_ATTN_HEAD, + dropout=self.model_cfg.DROPOUT_OF_ATTN, + num_decoder_layers=self.model_cfg.NUM_ATTN_LAYERS, + use_local_attn=True + ) + + self.pe = QueryPE(d_model=self.d_model) + + self.actor_predictor = ActorPredictor( + d_model=self.d_model, + num_modes=self.num_modes, + step_per_token=self.step_per_token, + use_gaussian=use_gaussian, + small=False + ) + + self.traffic_light_predictor = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model, self.d_model * 2, TRAFFIC_LIGHT_PREDICT_DIM * self.step_per_token], + ret_before_act=True, + without_norm=True + ) + + def build_transformer_decoder(self, d_model, nhead, dropout=0.1, num_decoder_layers=1, use_local_attn=False): + decoder_layer_1 = our_decoder_layer.TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + relative_pe=self.model_cfg.RELATIVE_PE, + dim_feedforward=d_model * 4, + # dim_feedforward=d_model, + dropout=dropout, + activation="gelu", + normalize_before=False, + keep_query_pos=False, # Not using Query Position (but use for the first decoder layer) + rm_self_attn_decoder=False, + use_local_attn=use_local_attn, + is_first=True + ) + decoder_layer = our_decoder_layer.TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + relative_pe=self.model_cfg.RELATIVE_PE, + dim_feedforward=d_model * 4, + # dim_feedforward=d_model, + dropout=dropout, + activation="gelu", + normalize_before=False, + keep_query_pos=False, # Not using Query Position (but use for the first decoder layer) + rm_self_attn_decoder=False, + use_local_attn=use_local_attn, + ) + decoder_layers = nn.ModuleList( + [decoder_layer_1] + [copy.deepcopy(decoder_layer) for _ in range(num_decoder_layers - 1)] + ) + return decoder_layers + + def apply_transformer_decoder( + self, + *, + all_token, + all_token_valid_mask, + all_position, + all_token_valid_mask_without_map, + # all_token_valid_mask_without_start_token, + all_position_without_map, + actor_token, + actor_position, + actor_valid_mask, + stacked_actor_valid_mask, + # start_token_pe, + # st_drop_mask, + # start_token_valid_mask, + traffic_light_token, + traffic_light_position, + traffic_light_valid_mask, + map_token, + map_position, + map_valid_mask, + anchor_heading, + anchor_velocity, + anchor_position, + query_cache=None, + in_evaluation=False, + ): + """ + Notations: + + B - Batch size + T - Total time steps of the trajectory, does not consider which are history and which are for prediction. + M - Number of map features + N - Number of actors + L - number of traffic lights + """ + B, num_total_tokens, d_model = all_token.shape + _, compress_T, N, _ = actor_token.shape + _, _, L, _ = traffic_light_token.shape + _, M, _ = map_token.shape + + T = compress_T + + H = self.model_cfg.HISTORY_TOKENS + + num_of_neighbors_traffic_light = 10 + num_of_neighbors_actor = 20 + num_of_neighbors_map = 128 + num_of_neighbors_map_for_map = 128 + + # New output size: [num valid in all_position_without_map, K] + # Fall in range [0, N] + neighbor_index_actor = search_k_nearest_object_indices( + ego_position_full=all_position_without_map, # [B, T, N+L, 2] + ego_valid_mask=all_token_valid_mask_without_map, + neighbor_position_full=actor_position, # [B, T, N, 2] + neighbor_valid_mask=actor_valid_mask, + num_neighbors=num_of_neighbors_actor + ) + + # New output size: [B, T, N+L, K] + # Fall in range [0, L] + if L == 0: + neighbor_index_traffic_light = neighbor_index_actor.new_zeros([B, T, N + L, num_of_neighbors_traffic_light]) + neighbor_index_traffic_light.fill_(-1) + else: + neighbor_index_traffic_light = search_k_nearest_object_indices( + ego_position_full=all_position_without_map, # [B, T, N+L, 2] + ego_valid_mask=all_token_valid_mask_without_map, + neighbor_position_full=traffic_light_position, # [B, T, L, 2] + neighbor_valid_mask=traffic_light_valid_mask, + num_neighbors=num_of_neighbors_traffic_light + ) + neighbor_index_traffic_light[neighbor_index_traffic_light != -1] += N + + # Fall in range [0, num valid map feat] + neighbor_index_map = search_k_nearest_map_feature_indicies( + ego_position_full=all_position_without_map, # [B, T, N+L, 2] + ego_valid_mask=all_token_valid_mask_without_map, + neighbor_position_full=map_position, # [B, M, 2] + neighbor_valid_mask=map_valid_mask, # [B, M] + num_neighbors=num_of_neighbors_map + ) + + # Fall in range [0, num valid map feat] + map_neighbor = search_k_nearest_map_feature_indicies_for_map( + ego_position_full=map_position, # [B, M, 2] + ego_valid_mask=map_valid_mask, + num_neighbors=num_of_neighbors_map_for_map + ) + + # PZH NOTE: At this moment, we already collected 4 set of neighbor indices. + # Each of them are fallen into different domain, which is: the number of valid "neighbors" at given time step. + # Now, we need to convert the domain from "per-step neighbors" to the "whole input sequence". + + # Number of valid map feats in each batch + num_valid_maps = map_valid_mask.sum(-1, keepdims=True).repeat(1, T) # [B, T] + num_valid_actors = F.pad(actor_valid_mask.sum(-1).cumsum(-1), (1, 0)) # [B, T+1] + num_valid_lights = F.pad(traffic_light_valid_mask.sum(-1).cumsum(-1), (1, 0)) # [B, T+1] + + # ST: Offset the number of valid start token + # num_valid_start_tokens = start_token_valid_mask.sum(-1, keepdims=True).repeat(1, T) + + # We can build a mapping in shape [B, T, N+L] mapping every box-index to flatten-index. + mapping_to_flatten = num_valid_maps.new_zeros([B, T, N + L]) + mapping_to_flatten += (num_valid_maps + num_valid_actors[:, :T] + num_valid_lights[:, :T])[..., None] + # Till now, the row (b, t) is filled with the "number of tokens" before (b, t). + # But for the token in the row [b, t, :], we still need to offset them one by one. + + # Now add the "in row offset": + mapping_to_flatten[..., :N] += F.pad(actor_valid_mask.cumsum(-1), (1, 0))[..., :N] + + mapping_to_flatten[..., :N][~actor_valid_mask] = -1 + mapping_to_flatten[..., + N:] += actor_valid_mask.sum(-1)[..., + None] + F.pad(traffic_light_valid_mask.cumsum(-1), + (1, 0))[..., :L] + mapping_to_flatten[..., N:][~traffic_light_valid_mask] = -1 + + neighbor_index_object = torch.cat([neighbor_index_actor, neighbor_index_traffic_light], dim=-1) + + max_token_to_attend = H * neighbor_index_object.shape[-1] + + neighbor_index_per_obj = neighbor_index_object.new_empty([B, T, N + L, max_token_to_attend]) + neighbor_index_per_obj.fill_(-1) + history_ind = torch.arange(H) + for t in range(T): + selected_t_dim = history_ind[:t + 1] + max(0, t - H) + selected_length = min(t + 1, H) + + # Shape: [B, selected_length, N+L, N+L] + key = mapping_to_flatten[:, selected_t_dim].reshape(B, selected_length, 1, N + L).repeat(1, 1, N + L, 1) + + query = neighbor_index_object[:, t].reshape(B, 1, N + L, -1).repeat(1, selected_length, 1, 1).long() + + query_mask = query == -1 + + query[query_mask] = 0 + + flatten_index = torch.gather(key, index=query, dim=3) + + flatten_index[query_mask] = -1 + + # [B, selected_length, N, K] -> [B, N, K * selected_length] + flatten_index = flatten_index.permute(0, 2, 3, 1).flatten(2, 3).reshape(B, N + L, -1) + + neighbor_index_per_obj[:, t, :, :flatten_index.shape[-1]] = flatten_index + + assert (neighbor_index_map.max(-1)[0] < map_valid_mask.sum(-1)[:, None, None]).all() + assert neighbor_index_per_obj[neighbor_index_per_obj != -1].min() >= map_valid_mask.sum(-1).min() + assert (map_neighbor.max(-1)[0].max(-1)[0] < map_valid_mask.sum(-1)).all() + + # -> [B, T, N, K_map + K_obj] + neighbor_index_per_obj = torch.cat([neighbor_index_map, neighbor_index_per_obj], dim=-1) + + max_token_to_attend = max(max_token_to_attend, neighbor_index_per_obj.shape[-1]) + max_token_to_attend = max(max_token_to_attend, N + num_of_neighbors_map) + + neighbor_index_per_obj = neighbor_index_per_obj.reshape(B, T * (N + L), -1) + if neighbor_index_per_obj.shape[-1] < max_token_to_attend: + neighbor_index_per_obj = F.pad( + neighbor_index_per_obj, (0, max_token_to_attend - neighbor_index_per_obj.shape[-1]), + mode="constant", + value=-1 + ) + + if map_neighbor.shape[-1] < max_token_to_attend: + # now neighbor_index_per_step is in shape [B, T*(N+L), sum of K] + map_neighbor = F.pad( + map_neighbor, (0, max_token_to_attend - map_neighbor.shape[-1]), mode="constant", value=-1 + ) + + # valid_map_counts = map_valid_mask.sum(-1).cpu().numpy() + # neighbor_index_start_token = np.zeros([B, N, max_token_to_attend], dtype=int) - 1 + + # causal_mask = np.triu(np.full((N, N), -1), k=1) + # causal_indices = np.tril(np.arange(N)) + # causal_mask += causal_indices + # causal_mask = np.repeat(causal_mask[np.newaxis, :, :], B, axis=0) + + # start_token_valid_mask_np = start_token_valid_mask.cpu().numpy() + # start_token_valid_mask_cumsum_np = F.pad(start_token_valid_mask.cumsum(-1), (1, 0)).cpu().numpy()[:, :-1] + # start_token_valid_mask_cumsum_np[~start_token_valid_mask_np] = -1 + # start_token_valid_mask_cumsum_np = np.repeat(start_token_valid_mask_cumsum_np[:, np.newaxis, :], N, axis=1) + # start_token_valid_mask_cumsum_np[~start_token_valid_mask_np] = -1 + # for i in range(B): + # m = start_token_valid_mask_cumsum_np[i] + # m[m != -1] += valid_map_counts[i] + # neighbor_index_start_token[i, :, :N] = m + # if valid_map_counts[i] > max_token_to_attend - N: + # + # tar = map_valid_mask[i].nonzero().cpu().numpy().reshape(-1) + # for j in range(N): + # neighbor_index_start_token[i, j, N:] = np.random.choice(tar, size=(max_token_to_attend - N), replace=False) + # + # else: + # neighbor_index_start_token[i, :, N:N+valid_map_counts[i]] = map_valid_mask[i].nonzero().cpu().numpy()[None, :, 0] + # + # neighbor_index_start_token = torch.from_numpy(neighbor_index_start_token).to(map_neighbor) + + # rand_map_neighbor = torch.rand((B, N, max_token_to_attend), device=map_valid_mask.device) + # max_map_feat = map_valid_mask.sum(-1)[..., None, None].expand(B, N, max_token_to_attend) + # rand_map_neighbor = (rand_map_neighbor * max_map_feat).floor().int() + # rand_map_neighbor = torch.minimum(rand_map_neighbor, max_map_feat - 1).clamp(0).int() + # assert (rand_map_neighbor.max(-1)[0].max(-1)[0] < map_valid_mask.sum(-1)).all() + + all_neighbor_index_full = torch.cat([map_neighbor, neighbor_index_per_obj], dim=1) # [B, M+N+T*(N+L), sum of K] + + assert all_neighbor_index_full.max() < M + N + T * (N + L) + assert (all_neighbor_index_full.max(dim=-1)[0].max(dim=-1)[0] < all_token_valid_mask.sum(-1)).all() + + batch_index = torch.arange(0, B, device=all_token.device, dtype=torch.int) # [B,] + batch_index = batch_index.reshape(B, 1, 1) # [B, 1, 1] + batch_index = batch_index.repeat(1, T, N + L) # [B, T, max_ego_objects] + batch_index = batch_index.reshape(B, -1) # [B, T*(N+L)] + + batch_index_map = torch.arange(0, B, device=all_token.device, dtype=torch.int).reshape(B, 1) + batch_index_map = batch_index_map.repeat(1, M) # [B, M] + + # batch_index_start_token = torch.arange(0, B, device=all_token.device, dtype=torch.int).reshape(B, 1) + # batch_index_start_token = batch_index_start_token.repeat(1, N) # [B, M] + + batch_index = torch.cat([batch_index_map, batch_index], dim=1) # [B, M+T*(N+L)] + + _, num_keys, _ = all_token.shape + query_sine_embed = position_encoding_utils.gen_sineembed_for_position(all_position[..., :2], hidden_dim=d_model) + + # query_sine_embed[:, M: M + N][st_drop_mask] = self.start_token_pe_for_empty( + # batch_index.new_zeros(query_sine_embed[:, M: M + N][st_drop_mask].shape[:1]) + # ) + + # query_sine_embed = torch.cat([ + # query_sine_embed[:, :M], + # # start_token_pe, + # query_sine_embed[:, M:] + # ], dim=1) + + kv_pos_embed_stack = query_sine_embed[all_token_valid_mask] # [num valid tokens, DIM] + + key_batch_cnt = all_token_valid_mask.sum(-1).int() + + kv_pos_raw = all_position[..., :2][all_token_valid_mask] + + if query_cache: + last_token_num = query_cache["last_token_num"] # [B, num total tokens] + + query_cache_list = [] + if in_evaluation: + # pre-allocate space + for i in range(self.num_decoder_layers): + v = torch.zeros_like(all_token) + if query_cache: + v[:, :last_token_num] = query_cache[f"query_cache_{i}"] + query_cache_list.append(v) + + prediction_list = [] + + # all_position_with_start_token = torch.cat([ + # all_position[:, :M], + # all_position.new_zeros([B, N, 2]), + # all_position[:, M:] + # ], dim=1) + + if query_cache: # Need to have full shape query_feature since we need to slice it later. + assert in_evaluation + diff_all_token_valid_mask = all_token_valid_mask[:, last_token_num:] + query_feature = all_token + all_neighbor_index = all_neighbor_index_full[:, last_token_num:][diff_all_token_valid_mask] + index_pair_batch = batch_index[:, last_token_num:][diff_all_token_valid_mask] + query_pos = all_position[:, last_token_num:][diff_all_token_valid_mask] + query_sine_embed_stack = query_sine_embed[:, last_token_num:][diff_all_token_valid_mask] + + else: + query_feature = all_token[all_token_valid_mask] + # they are share the same first dim size = num valid objects (across time steps and batch) + all_neighbor_index = all_neighbor_index_full[all_token_valid_mask] + # (num valid) -> the batch index of each valid object + index_pair_batch = batch_index[all_token_valid_mask] + assert len(all_neighbor_index) == len(index_pair_batch) + assert (all_neighbor_index_full.max(-1)[0].max(-1)[0] < key_batch_cnt).all() + query_sine_embed_stack = query_sine_embed[all_token_valid_mask] + query_pos = all_position[all_token_valid_mask] + + for layer_idx in range(self.num_decoder_layers): + + if query_cache: + kv_feature_stack = query_feature[all_token_valid_mask] + query_feature = query_feature[:, last_token_num:][diff_all_token_valid_mask] + else: + kv_feature_stack = query_feature + + query_feature = self.decoder_layers[layer_idx]( + tgt=query_feature, + # tgt_valid_mask=diff_all_token_valid_mask, + query_pos=query_pos, + query_sine_embed=query_sine_embed_stack, + memory=kv_feature_stack, + memory_pos_emb=kv_pos_embed_stack, + memory_pos=kv_pos_raw, + is_first=(layer_idx == 0), + key_batch_cnt=key_batch_cnt, + index_pair=all_neighbor_index, + index_pair_batch=index_pair_batch, + ) + assert query_feature.ndim == 2 + + # If using query_cache from previous forward pass, we now need to fill the new query into existing queries. + # This should be done even if we need to add future embedding into actors' tokens. + # Everything in the query cache is not added with future embedding. + if query_cache: + query_cache_list[layer_idx][:, last_token_num:][diff_all_token_valid_mask] = query_feature.clone() + query_feature = query_cache_list[layer_idx] + else: + if in_evaluation: + query_cache_list[layer_idx][all_token_valid_mask] = query_feature.clone() + + if self.model_cfg.LOSS_EACH_LAYER: + if query_cache: + all_output_tokens = query_feature + else: + all_output_tokens = unwrap(query_feature, all_token_valid_mask) + object_output_tokens = all_output_tokens[:, M + N:] + object_output_tokens = object_output_tokens.reshape(B, compress_T, N + L, self.d_model) + actor_output_tokens = object_output_tokens[:, :, :N] + ret = self.get_prediction_for_actor( + anchor_heading=anchor_heading, + anchor_velocity=anchor_velocity, + anchor_position=anchor_position, + actor_output_tokens=actor_output_tokens, + actor_valid_mask=actor_valid_mask, + stacked_actor_valid_mask=stacked_actor_valid_mask, + in_evaluation=in_evaluation if (layer_idx == self.num_decoder_layers - 1) else False, + layer_index=layer_idx, + actor_type=actor_type + ) + prediction_list.append(ret) + + if layer_idx < self.num_decoder_layers - 1: + embedding = self.get_internal_future_embedding(ret, actor_valid_mask, layer_index=layer_idx) + actor_output_tokens += embedding + + all_output_tokens = torch.cat( + [ + all_output_tokens[:, :M + N], + torch.cat([actor_output_tokens, object_output_tokens[:, :, N:]], dim=2).flatten(1, 2) + ], + dim=1 + ) + if query_cache: + query_feature = all_output_tokens + else: + query_feature = all_output_tokens[all_token_valid_mask] + + if query_feature.ndim == 3: + ret = query_feature + else: + ret = unwrap(query_feature, all_token_valid_mask) + + ret_dict = {"last_token_num": all_token_valid_mask.shape[1]} + for i in range(len(query_cache_list)): + ret_dict[f"query_cache_{i}"] = query_cache_list[i] + return ret, ret_dict, prediction_list + + def build_our_predict_head(self, hidden_size, num_modes, actor_state_dim, light_state_dim): + actor_predict_heads = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[hidden_size, hidden_size, actor_state_dim * num_modes], + ret_before_act=True, + ) + light_predict_head = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[hidden_size, hidden_size, light_state_dim], + ret_before_act=True, + ) + return actor_predict_heads, light_predict_head + + def get_position_loss(self, data_dict, model_output, step, start_time, future_end): + B, T, N, num_modes, _ = model_output["sampled_position"].shape + predicted_traj = model_output["sampled_position"][:, start_time:future_end] + gt_traj = data_dict["encoder/agent_position"][:, start_time:future_end].unsqueeze(3).expand( + B, future_end - start_time, N, num_modes, 3 + ) + assert predicted_traj.shape == gt_traj.shape + + original_mask = data_dict["encoder/agent_valid_mask"][:, start_time:future_end] + mask = original_mask.unsqueeze(3).expand(B, future_end - start_time, N, num_modes) + assert mask.shape == predicted_traj.shape[:4] + + assert gt_traj.shape == predicted_traj.shape + diff = (gt_traj - predicted_traj).norm(dim=-1) # [B, T, N, num_modes] + + mode_diff = (diff.sum(1) / mask.sum(1).clamp(1)) # [B, N, num_modes] + + gt_selected = (diff * mask).sum(1).argmin(-1) + + selected = model_output["log_probability"][:, start_time:future_end].sum(1).argmax(-1) # [B, N] + + best_diff = torch.gather(mode_diff, index=selected.unsqueeze(-1), dim=2).squeeze(-1) # [B, N] + + acc = (gt_selected == selected)[original_mask.any(1)].float().mean().item() + + return { + f"eval/best_diff_{step}": best_diff.mean().item(), + f"eval/avg_diff_{step}": diff[mask].mean().item(), + f"eval/score_acc_{step}": acc + } + + def get_actor_loss( + self, + *, + stacked_actor_feat, + stacked_actor_valid_mask, + gt_position, + gt_dict, + forward_ret_dict, + anchor_velocity, + anchor_position, + anchor_heading, + layer_index=None + ): + + stacked_actor_feat = stacked_actor_feat + gt_position = gt_position + + B, compress_T, N, num_modes, step_per_token, _ = forward_ret_dict["position_logit"].shape + + if num_modes > 1: + nearest_predicted_dict, per_step_distance = generate_predicted_trajectory_for_training( + gt_position, forward_ret_dict, use_ade=self.model_cfg.USE_ADE + ) + + # ========== Step 2: Compute loss using the selected trajectory ========== + pos_target = (stacked_actor_feat[:, 1:, ..., :3] * POSITION_XY_RANGE - anchor_position[:, :-1]) + if self.model_cfg.RELATIVE_POSITION_HEADING: + pos_target = rotate( + pos_target[..., 0], pos_target[..., 1], -anchor_heading[:, :-1].squeeze(-1), z=pos_target[..., 2] + ) + stacked_actor_valid_mask = torch.logical_and(stacked_actor_valid_mask[:, 1:], stacked_actor_valid_mask[:, :-1]) + + before_count = stacked_actor_valid_mask.sum() + if self.model_cfg.USE_SLOW_MASK: + # rule out those samples where the actor is static. + # current_actor_valid_mask = torch.logical_and(current_actor_valid_mask, (pos_target != 0).any(-1)) + offset_norm = pos_target.norm(dim=-1) + slow_mask = offset_norm > 0.01 + stacked_actor_valid_mask = torch.logical_and(stacked_actor_valid_mask, slow_mask) + after_count = stacked_actor_valid_mask.sum() + assert stacked_actor_valid_mask.sum() > 0, stacked_actor_valid_mask.sum() + + # ===== Position Loss ===== + if self.model_cfg.USE_NEAREST_LOSS and num_modes > 1: + position_logit = nearest_predicted_dict["nearest_position_logit"] + else: + position_logit = forward_ret_dict["position_logit"] + assert position_logit.shape[:5] == ( + B, compress_T, N, 1 if self.model_cfg.USE_NEAREST_LOSS else num_modes, step_per_token + ) + + if self.model_cfg.USE_CUMSUM: + position_logit = position_logit.cumsum(4) + + # Fast actor loss: + position_logit1 = position_logit[:, :-1][stacked_actor_valid_mask] + pos_target1 = pos_target[stacked_actor_valid_mask] + assert position_logit1.shape == pos_target1.shape + if self.model_cfg.USE_HUBER_LOSS: + pos_loss = F.huber_loss(input=position_logit1, target=pos_target1) + else: + pos_loss = F.mse_loss(input=position_logit1, target=pos_target1) + + # Static actor loss + # position_logit2 = position_logit[:, :-1][torch.logical_and(stacked_actor_valid_mask, ~slow_mask)] + # pos_target2 = pos_target[torch.logical_and(stacked_actor_valid_mask, ~slow_mask)] + # assert position_logit2.shape == pos_target2.shape + # pos_loss2 = F.huber_loss(input=position_logit2, target=pos_target2) + # + # pos_loss = pos_loss1 + pos_loss2 / 5 + + if self.model_cfg.SCORE_FORM == "class" and num_modes > 1: + score_input = forward_ret_dict["score_logit"][forward_ret_dict["compress_actor_valid_mask"]] + score_target = nearest_predicted_dict["nearest_index_target"][forward_ret_dict["compress_actor_valid_mask"]] + assert score_input.shape[:-1] == score_target.shape + score_loss = F.cross_entropy(input=score_input, target=score_target) + + elif self.model_cfg.SCORE_FORM == "reward" and num_modes > 1: + raise ValueError() + # with torch.no_grad(): + # gt_score = 1 / per_step_distance.clamp(1e-3, 1000).detach() + # gt_score = gt_score[forward_ret_dict["compress_actor_valid_mask"]] + # + # score_input = forward_ret_dict["score_logit"][forward_ret_dict["compress_actor_valid_mask"]] + # + # assert score_input.shape == gt_score.shape + # score_loss = F.huber_loss( + # input=score_input, + # target=gt_score + # ) + + elif num_modes == 1: + score_loss = 0.0 + + else: + raise ValueError() + + if self.model_cfg.LOSS_EACH_LAYER and layer_index < self.num_decoder_layers - 1: + # return since we don't regress other values. + return dict( + position_loss=pos_loss, + velocity_loss=0.0, + heading_loss=0.0, + actor_type_loss=0.0, + score_loss=score_loss, + slow_mask_remove=(before_count - after_count) / before_count, + ) + + if self.model_cfg.USE_NEAREST_LOSS and num_modes > 1: + velocity_logit = nearest_predicted_dict["nearest_velocity_logit"] + else: + velocity_logit = forward_ret_dict["velocity_logit"] + assert velocity_logit.ndim == 6 + if self.model_cfg.USE_CUMSUM: + velocity_logit = velocity_logit.cumsum(4) + + if self.model_cfg.RELATIVE_VELOCITY: + # Relative velocity + vel_target = stacked_actor_feat[:, 1:, ..., 4:6] * VELOCITY_XY_RANGE - anchor_velocity[:, :-1] + + if self.model_cfg.RELATIVE_POSITION_HEADING: + vel_target = rotate(vel_target[..., 0], vel_target[..., 1], -anchor_heading[:, :-1].squeeze(-1)) + + else: + vel_target = stacked_actor_feat[:, 1:, ..., 4:6] + + velocity_logit = velocity_logit[:, :-1][stacked_actor_valid_mask] + vel_target = vel_target[stacked_actor_valid_mask] + + assert velocity_logit.shape == vel_target.shape + if self.model_cfg.USE_HUBER_LOSS: + vel_loss = F.huber_loss(input=velocity_logit, target=vel_target) + else: + vel_loss = F.mse_loss(input=velocity_logit, target=vel_target) + + if self.model_cfg.RELATIVE_HEADING: + # Absolute heading + heading_target = wrap_to_pi(stacked_actor_feat[:, 1:, ..., 3:4] * HEADING_RANGE - anchor_heading[:, :-1]) + else: + heading_target = wrap_to_pi(stacked_actor_feat[:, 1:, ..., 3:4] * HEADING_RANGE) + + if self.model_cfg.USE_NEAREST_LOSS and num_modes > 1: + heading_logit = nearest_predicted_dict["nearest_heading_logit"][:, :-1] + else: + heading_logit = forward_ret_dict["heading_logit"][:, :-1] + assert heading_logit.shape == heading_target.shape + assert heading_logit.ndim == 6 + if self.model_cfg.USE_CUMSUM: + heading_logit = heading_logit.cumsum(4) + + heading_logit = heading_logit[stacked_actor_valid_mask] + heading_target = heading_target[stacked_actor_valid_mask] + + assert heading_logit.shape == heading_target.shape, (heading_logit.shape, heading_target.shape) + if self.model_cfg.USE_HUBER_LOSS: + head_loss = F.huber_loss(input=wrap_to_pi(heading_logit - heading_target), target=heading_target * 0) + else: + head_loss = F.mse_loss(input=wrap_to_pi(heading_logit - heading_target), target=heading_target * 0) + + type_mask = forward_ret_dict["compress_actor_valid_mask"].unsqueeze(-1).expand(B, compress_T, N, num_modes) + + actor_type_logit = forward_ret_dict["actor_type_logit"] # [B, compress_T, N, num_modes, 5] + actor_type_logit = actor_type_logit[type_mask] + + gt_actor_type = gt_dict["decoder/actor_type"].reshape(B, 1, N, 1).expand(B, compress_T, N, num_modes) + gt_actor_type = gt_actor_type[type_mask] + + type_loss = F.cross_entropy(input=actor_type_logit, target=gt_actor_type) + + return dict( + position_loss=pos_loss, + velocity_loss=vel_loss, + heading_loss=head_loss, + actor_type_loss=type_loss, + score_loss=score_loss, + slow_mask_remove=(before_count - after_count) / before_count, + ) + + def get_loss(self, data_dict, gt_dict, forward_ret_dict, tb_pre_tag=''): + anchor_position = forward_ret_dict["anchor_position"] + anchor_heading = forward_ret_dict["anchor_heading"] + anchor_velocity = forward_ret_dict["anchor_velocity"] + _, _, L, _ = data_dict["encoder/traffic_light_feature"].shape + B, compress_T, N, num_modes, step_per_token, _ = forward_ret_dict["sampled_position"].shape + + weight_pos = self.model_cfg.LOSS_WEIGHTS.get('position', 1.0) + weight_vel = self.model_cfg.LOSS_WEIGHTS.get('velocity', 0.2) + weight_heading = self.model_cfg.LOSS_WEIGHTS.get('heading', 0.5) + weight_type = self.model_cfg.LOSS_WEIGHTS.get('actor_type', 0.5) + weight_light = self.model_cfg.LOSS_WEIGHTS.get('traffic_light_state', 0.5) + weight_score = self.model_cfg.LOSS_WEIGHTS.get('score', 1.0) + weight_token = self.model_cfg.LOSS_WEIGHTS.get('token', 1.0) + weight_start = self.model_cfg.LOSS_WEIGHTS.get('start', 1.0) + + total_loss = 0.0 + layer_loss_list = [] + + stacked_actor_feat = roll_and_stack( + data_dict["encoder/agent_feature"], + step_per_token=self.step_per_token, + num_modes=1 if self.model_cfg.USE_NEAREST_LOSS else num_modes, + compress_T=compress_T, + compress_step=self.compress_step + ) + stacked_actor_valid_mask = forward_ret_dict["stacked_actor_valid_mask"] + + # [B, T, N, num_modes, 2] + gt_position = roll_and_stack( + data_dict["encoder/agent_position"], + compress_T=compress_T, + step_per_token=step_per_token, + num_modes=num_modes, + compress_step=self.compress_step + ) + + # ========== Loss for start token predictor ========== + init_actor_loss = torch.zeros(1) + init_actor_loss_dict = {} + if self.model_cfg.ENABLE_START_TOKEN: + raise ValueError() + # init_actor_loss_dict = self.get_init_actor_loss( + # data_dict=data_dict, + # stacked_actor_valid_mask=stacked_actor_valid_mask[:, 0, :, 0], + # forward_ret_dict=forward_ret_dict, + # ) + # init_actor_loss = ( + # init_actor_loss_dict["init_pos_loss"] + + # init_actor_loss_dict["init_vel_loss"] + + # init_actor_loss_dict["init_head_loss"] + + # init_actor_loss_dict["init_size_loss"] + + # # init_actor_loss_dict["init_score_loss"] + + # init_actor_loss_dict["init_map_feat_score_loss"] + # ) + # if self.config.ENABLE_START_TOKEN_ACTOR_LOSS: + # init_actor_loss += ( + # weight_pos * init_actor_loss_dict["init_position_loss"] + + # weight_vel * init_actor_loss_dict["init_velocity_loss"] + + # weight_heading * init_actor_loss_dict["init_heading_loss"] + # ) + # total_loss += weight_start * init_actor_loss + + # ========== Loss for actor predictor ========== + if self.model_cfg.USE_NEAREST_LOSS: + anchor_velocity = anchor_velocity[:, :, :, :1] + anchor_position = anchor_position[:, :, :, :1] + anchor_heading = anchor_heading[:, :, :, :1] + stacked_actor_valid_mask = stacked_actor_valid_mask[:, :, :, :1] + if self.model_cfg.LOSS_EACH_LAYER: + for layer_index in range(self.num_decoder_layers): + actor_loss = self.get_actor_loss( + stacked_actor_feat=stacked_actor_feat, + stacked_actor_valid_mask=stacked_actor_valid_mask, + gt_position=gt_position, + gt_dict=gt_dict, + forward_ret_dict=forward_ret_dict[f"prediction_list_{layer_index}"], + anchor_heading=anchor_heading, + anchor_velocity=anchor_velocity, + anchor_position=anchor_position, + layer_index=layer_index + ) + layer_loss = ( + weight_pos * actor_loss["position_loss"] + weight_vel * actor_loss["velocity_loss"] + + weight_heading * actor_loss["heading_loss"] + weight_type * actor_loss["actor_type_loss"] + + weight_score * actor_loss["score_loss"] + ) + total_loss += layer_loss + layer_loss_list.append(layer_loss) + else: + actor_loss = self.get_actor_loss( + stacked_actor_feat=stacked_actor_feat, + stacked_actor_valid_mask=stacked_actor_valid_mask, + gt_dict=gt_dict, + gt_position=gt_position, + forward_ret_dict=forward_ret_dict, + anchor_heading=anchor_heading, + anchor_velocity=anchor_velocity, + anchor_position=anchor_position, + ) + total_loss += ( + weight_pos * actor_loss["position_loss"] + weight_vel * actor_loss["velocity_loss"] + + weight_heading * actor_loss["heading_loss"] + weight_type * actor_loss["actor_type_loss"] + + weight_score * actor_loss["score_loss"] + ) + + # ========== Loss for traffic light predictor ========== + # TODO: Add traffic light loss back. + # if L > 0: + # light_gt = roll_and_stack( + # gt_dict["traffic_light_state"].unsqueeze(-1), + # step_per_token=self.step_per_token, + # num_modes=1, + # compress_T=compress_T, + # compress_step=self.compress_step + # ) + # light_mask = roll_and_stack_for_mask( + # data_dict["encoder/traffic_light_valid_mask"], + # step_per_token=self.step_per_token, + # num_modes=1, + # compress_T=compress_T, + # compress_step=self.compress_step + # ) + # light_gt = light_gt.squeeze(3) # Squeeze the "modes" dim since we don't output multi-mode traffic lights. + # light_mask = light_mask.squeeze(3) + # compress_traffic_light_mask = forward_ret_dict["compress_traffic_light_mask"] + # compress_traffic_light_mask = compress_traffic_light_mask.unsqueeze(-1).expand( + # B, compress_T, L, step_per_token + # ) + # light_mask = torch.logical_and(light_mask, compress_traffic_light_mask) + # light_gt = light_gt.squeeze(-1) + # light_input = forward_ret_dict["traffic_light_state_logit"] + # light_input = light_input[light_mask] + # light_gt = light_gt[light_mask] + # light_loss = F.cross_entropy( + # input=light_input, + # target=light_gt + # ) + # else: + light_loss = 0.0 + + # ========== Loss for token-evolution ========== + token_loss = 0.0 + if self.model_cfg.TOKEN_EVOLUTION: + token_mask = forward_ret_dict["input_token_valid_mask"][:, 1:] + out_t = forward_ret_dict["output_token"][:, :-1] + next_in_t = forward_ret_dict["input_token"].detach()[:, 1:] + token_loss = F.mse_loss(input=out_t[token_mask], target=next_in_t[token_mask]) + + total_loss += weight_token * token_loss + weight_light * light_loss + + tb_dict = {k: v.mean().item() if torch.is_tensor(v) else float(v) for k, v in actor_loss.items()} + tb_dict.update( + total_loss=total_loss.item(), + token_loss=token_loss.item() if self.model_cfg.TOKEN_EVOLUTION else float("nan"), + traffic_light_loss=light_loss.item() if isinstance(light_loss, torch.Tensor) else float("nan"), + init_actor_loss=init_actor_loss.item(), + **init_actor_loss_dict + ) + if layer_loss_list: + for i in range(len(layer_loss_list)): + tb_dict["layer{}_loss".format(i)] = layer_loss_list[i].item() + tb_dict[f'{tb_pre_tag}loss'] = total_loss.item() + + return total_loss, tb_dict, tb_dict + + def forward(self, input_dict): + in_evaluation = input_dict.get("in_evaluation", False) + + actor_feature = input_dict["encoder/agent_feature"] + actor_valid_mask = input_dict["encoder/agent_valid_mask"] + actor_position = input_dict["encoder/agent_position"] + B, T, N, D_actor = actor_feature.shape + assert actor_feature.shape[:3] == actor_position.shape[:3] + + assert (T - 1) % self.compress_step == 0 + compress_T = (T - 1) // self.compress_step + + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_position = input_dict["encoder/map_position"] + _, M, num_vector, D_vector = map_feature.shape + + traffic_light_feature = input_dict["encoder/traffic_light_feature"] + traffic_light_position = input_dict["encoder/traffic_light_position"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + _, _, L, D_light = traffic_light_feature.shape + + # ========== Tokenize all objects (actor & map feat) ========== + # [B, M, token dim] + if "map_token" in input_dict: + map_token = input_dict["map_token"] + else: + map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + + # [B, M] + map_token_valid_mask = map_valid_mask.sum(axis=-1) != 0 + + actor_token = unwrap(self.actor_mlps(actor_feature[actor_valid_mask]), actor_valid_mask) + + # ===== Code fragment that stack tokens from [B, T, N, D] to [B, T/5, N, 5*D] ===== + token_dim = actor_token.shape[-1] + # [B, T, N, token dim] -> [B, N, T, token dim] + actor_token = actor_token.permute(0, 2, 1, 3) + # -> [B, N, compress_T, token dim*compress_step] + actor_token = actor_token[:, :, :compress_T * + self.compress_step].reshape(B, N, compress_T, self.compress_step * token_dim) + # -> [B, compress_T, N, token_dim*compress_step] + actor_token = actor_token.permute(0, 2, 1, 3) + + decompress_actor_valid_mask = actor_valid_mask.clone() # [:, :compress_T * self.compress_step] + + # [B, T, N+L] -> [B, N+L, T] -> [B, N+L, compress_T * compress_step] + compress_actor_valid_mask = actor_valid_mask[:, :compress_T * + self.compress_step].reshape(B, compress_T, self.compress_step, + N).any(dim=2).clone() + + # -> [B, compress_T, N, token dim] + actor_token = unwrap( + self.actor_mlps_compress(actor_token[compress_actor_valid_mask]), compress_actor_valid_mask + ) + + displacement, last_pos = get_displacement( + input_dict["encoder/agent_position"][:, :10, ..., :2], input_dict["encoder/agent_valid_mask"][:, :10] + ) + # start_token = actor_token[:, 0] + # start_token_valid_mask = compress_actor_valid_mask[:, 0].clone() + # start_token_valid_mask = torch.logical_and(start_token_valid_mask, displacement.squeeze(1) > 0.01) + # start_token_pe = self.start_token_pe(torch.arange(N, device=actor_token.device).unsqueeze(0).expand(B, N)) + + # prepare the starting token. in shape [B, N, d_model] + # actor_type_valid_mask = input_dict["decoder/actor_type"] != -1 + # empty_start_token = unwrap( + # self.start_token(input_dict["decoder/actor_type"][actor_type_valid_mask]), + # actor_type_valid_mask + # ) + # assert start_token.shape == empty_start_token.shape + + # selected_num = torch.minimum( + # (start_token_valid_mask.sum(-1) * torch.rand(B, device=start_token_valid_mask.device)).int(), + # start_token_valid_mask.sum(-1) + # ).clamp(1) # [B, ] + # + # st_drop_mask = actor_type_valid_mask.new_zeros((B, N)) + # for i in range(B): + # st_valids = start_token_valid_mask[i].nonzero()[:, 0] + # st_ind = torch.randperm(len(st_valids))[:selected_num[i]] + # st_selected_ind = st_valids[st_ind] + # st_drop_mask[i, st_selected_ind] = 1 + # # assert (st_drop_mask.sum(-1) == selected_num).all(), (st_drop_mask.sum(-1), selected_num) + # start_token[st_drop_mask] = empty_start_token[st_drop_mask] + # start_token += start_token_pe + + compress_actor_position = find_last_valid_in_compress_step( + actor_position, actor_valid_mask, compress_step=self.compress_step + ) + + # start_token_position = compress_actor_position[:, 0].clone() + # start_token_position[st_drop_mask] = 0.0 + + # st_training_mask = torch.logical_and(st_drop_mask, start_token_valid_mask) + # st_training_mask = start_token_valid_mask + + # [B, T, L] -> [B, compress_T, L] + compress_light_mask = traffic_light_valid_mask[:, :compress_T * self.compress_step].reshape( + B, compress_T, self.compress_step, L + ).any(dim=2).clone() + + # [B, T, num light, token dim] + if L != 0: + light_token = unwrap( + self.light_mlps(traffic_light_feature[traffic_light_valid_mask]), traffic_light_valid_mask + ) + + # [B, T, L, token dim] -> [B, L, T, token dim] + light_token = light_token.permute(0, 2, 1, 3) + # [B, N, T, token dim] -> [B, L, compress_T, token dim*compress_step] + light_token = light_token[:, :, :compress_T * + self.compress_step].reshape(B, L, compress_T, self.compress_step * token_dim) + # [B, N, compress_T, token dim*compress_step] -> [B, compress_T,LN, token dim*compress_step] + light_token = light_token.permute(0, 2, 1, 3) + + light_token = unwrap(self.light_mlps_compress(light_token[compress_light_mask]), compress_light_mask) + + else: + light_token = traffic_light_feature.new_zeros([B, compress_T, L, self.d_model]) + + map_token, actor_token, light_token = self.pe(map_token, actor_token, light_token) + + cat_token = torch.concatenate([actor_token, light_token], dim=2) # [B, T, N+L, token dim] + + cat_mask = torch.concatenate([compress_actor_valid_mask, compress_light_mask], dim=2) # [B, T, N+L] + all_token_valid_mask = torch.concatenate( + [ + map_token_valid_mask, # [B, M, token dim] + # start_token_valid_mask, # [B, N, token dim] + cat_mask.reshape(B, compress_T * (L + N)) # [B, T*(N+L), token dim] + ], + dim=1 + ) + all_token_valid_mask_without_map = cat_mask + # all_token_valid_mask_without_start_token = torch.concatenate([ + # map_token_valid_mask, # [B, M, token dim] + # cat_mask.reshape(B, compress_T * (L + N)) # [B, T*(N+L), token dim] + # ], dim=1) + + # all_token is in shape [B, M+N+T*(N+L), token dim] + all_token = torch.concatenate( + [ + map_token, # [B, M, token dim] + # start_token, + cat_token.reshape(B, compress_T * (L + N), -1), # [B, T*(N+L), token dim] + ], + dim=1 + ) + + all_token = unwrap(self.decoder_tokenizer(all_token[all_token_valid_mask]), all_token_valid_mask) + + compress_light_position = find_last_valid_in_compress_step( + traffic_light_position.unsqueeze(1).repeat(1, T, 1, 1), + traffic_light_valid_mask, + compress_step=self.compress_step + ) + + compress_actor_velocity = find_last_valid_in_compress_step( + input_dict["encoder/agent_feature"][..., 4:6] * VELOCITY_XY_RANGE, + actor_valid_mask, + compress_step=self.compress_step + ) + + compress_actor_heading = find_last_valid_in_compress_step( + wrap_to_pi(input_dict["encoder/agent_feature"][..., 3:4] * HEADING_RANGE), + actor_valid_mask, + compress_step=self.compress_step + ) + + cat_position = torch.concatenate( + [compress_actor_position[..., :2], compress_light_position[..., :2]], dim=2 + ) # [B, T, N+L] + all_position = torch.concatenate( + [ + map_position[..., :2], # [B, M, token dim] + # start_token_position[..., :2], + cat_position.reshape(B, compress_T * (L + N), 2), # [B, T*(N+L), token dim] + ], + dim=1 + ) + all_position_without_map = cat_position + + stacked_actor_valid_mask = roll_and_stack_for_mask( + decompress_actor_valid_mask, + compress_T=compress_T, + step_per_token=self.step_per_token, + num_modes=self.num_modes, + compress_step=self.compress_step, + set_unknown_to_true=False + ) + # Also mask out the prediction of the + input_token_valid = compress_actor_valid_mask.reshape(B, compress_T, N, 1, + 1).expand(*stacked_actor_valid_mask.shape) + stacked_actor_valid_mask = torch.logical_and(stacked_actor_valid_mask, input_token_valid) + + output_tokens, query_cache, pred_list = self.apply_transformer_decoder( + all_token=all_token, + all_token_valid_mask=all_token_valid_mask, + all_position=all_position, + all_token_valid_mask_without_map=all_token_valid_mask_without_map, + # all_token_valid_mask_without_start_token=all_token_valid_mask_without_start_token, + all_position_without_map=all_position_without_map, + actor_token=actor_token, + actor_position=compress_actor_position, + actor_valid_mask=compress_actor_valid_mask, + stacked_actor_valid_mask=stacked_actor_valid_mask, + # start_token_pe=start_token_pe, + # st_drop_mask=st_drop_mask, + # start_token_valid_mask=start_token_valid_mask, + traffic_light_token=light_token, + traffic_light_position=compress_light_position, + traffic_light_valid_mask=compress_light_mask, + map_token=map_token, + map_position=map_position, + map_valid_mask=map_token_valid_mask, + query_cache=input_dict["query_cache"] if "query_cache" in input_dict else None, + in_evaluation=in_evaluation, + anchor_heading=compress_actor_heading, + anchor_velocity=compress_actor_velocity, + anchor_position=compress_actor_position + ) + # output_tokens = all_token + + assert output_tokens.shape == (B, compress_T * (N + L) + M, self.d_model) + + # output_start_tokens = output_tokens[:, M: M + N] + object_output_tokens = output_tokens[:, M:] + object_output_tokens = object_output_tokens.reshape(B, compress_T, N + L, self.d_model) + + actor_output_tokens = object_output_tokens[:, :, :N] # (B, compress_T, N) + traffic_light_output_tokens = object_output_tokens[:, :, N:] # (B, compress_T, L) + + # ==================== Translate output tokens to prediction ==================== + if self.model_cfg.LOSS_EACH_LAYER: + ret = pred_list[-1] + if not in_evaluation: + for layer_index in range(len(pred_list)): + ret[f"prediction_list_{layer_index}"] = pred_list[layer_index] + + else: + ret = self.get_prediction_for_actor( + anchor_position=compress_actor_position, + anchor_velocity=compress_actor_velocity, + anchor_heading=compress_actor_heading, + actor_output_tokens=actor_output_tokens, + actor_valid_mask=compress_actor_valid_mask, + in_evaluation=in_evaluation, + stacked_actor_valid_mask=stacked_actor_valid_mask, + actor_type=input_dict["decoder/actor_type"] + ) + + # ret.update(self.get_prediction_for_start_token( + # output_start_tokens=output_start_tokens, + # start_token_valid_mask=start_token_valid_mask, + # )) + + ret.update( + self.get_prediction_for_traffic_light( + traffic_light_output_tokens=traffic_light_output_tokens, + traffic_light_valid_mask=compress_light_mask, + in_evaluation=in_evaluation, + ) + ) + + # ret["st_training_mask"] = st_training_mask + ret["encoder/map_valid_mask"] = map_token_valid_mask + ret["compress_actor_valid_mask"] = compress_actor_valid_mask + ret["compress_light_mask"] = compress_light_mask + if self.model_cfg.TOKEN_EVOLUTION: + ret["input_token"] = cat_token # [B, compress_T, N+L, d_model] + ret["input_token_valid_mask"] = all_token_valid_mask_without_map # [B, comT, N+L] + ret["output_token"] = object_output_tokens # [B, compress_T, N+L, d_model] + + if in_evaluation: + ret["map_token"] = map_token + query_cache["last_T"] = T + ret["query_cache"] = query_cache + + return ret + + def get_prediction_for_actor( + self, + *, + anchor_position, + anchor_velocity, + anchor_heading, + actor_output_tokens, + actor_valid_mask, + stacked_actor_valid_mask, + in_evaluation, + actor_type, + layer_index=None, + force_loss_at_final_layer=False + ): + B, compress_T, N, _ = actor_output_tokens.shape + step_per_token = self.step_per_token + num_modes = self.num_modes + + pred_heading = pred_velocity = pred_actor_type = None + sampled_heading = sampled_actor_type = sampled_velocity = None + if self.model_cfg.LOSS_EACH_LAYER and (not force_loss_at_final_layer): + assert layer_index is not None + if layer_index == self.num_decoder_layers - 1: # Last layer: + pred_pos, pred_velocity, pred_heading, pred_actor_type, pred_score = self.actor_predictor[layer_index]( + actor_output_tokens, actor_valid_mask, step_per_token + ) + else: + # Only return 1 mode if in internal layers + pred_pos, pred_score = self.actor_predictor[layer_index]( + actor_output_tokens, actor_valid_mask, step_per_token + ) + + elif self.model_cfg.LOSS_EACH_LAYER and force_loss_at_final_layer: + raise NotImplementedError() + + else: + assert layer_index is None + pred_pos, pred_velocity, pred_heading, pred_actor_type, pred_score = self.actor_predictor( + actor_output_tokens, actor_valid_mask, step_per_token, actor_type + ) + + ret = {"score_logit": pred_score} + if in_evaluation: + ret.update( + { + # "score": F.softmax(pred_score, dim=-1), + "anchor_position": anchor_position, + "anchor_velocity": anchor_velocity, + "anchor_heading": anchor_heading, + } + ) + + # Transform from [B, compress_T, N, 2] to [B, compress_T, N, num_modes, step_per_token, 2] + anchor_position = anchor_position.reshape(B, compress_T, N, 1, 1, + 3).repeat(1, 1, 1, num_modes, step_per_token, 1) + anchor_velocity = anchor_velocity.reshape(B, compress_T, N, 1, 1, + 2).repeat(1, 1, 1, num_modes, step_per_token, 1) + anchor_heading = anchor_heading.reshape(B, compress_T, N, 1, 1, 1).repeat(1, 1, 1, num_modes, step_per_token, 1) + + sampled_position = pred_pos.clone() + assert sampled_position.ndim == 6 + if self.model_cfg.USE_CUMSUM: + sampled_position = sampled_position.cumsum(4) + assert anchor_position.shape == sampled_position.shape + if self.model_cfg.RELATIVE_POSITION_HEADING: + sampled_position = rotate( + sampled_position[..., 0], + sampled_position[..., 1], + angle=anchor_heading.squeeze(-1), + z=sampled_position[..., 2] + ) + sampled_position += anchor_position + else: + sampled_position += anchor_position + + if pred_heading is not None: + sampled_heading = pred_heading.clone() + assert sampled_heading.ndim == 6 + if self.model_cfg.USE_CUMSUM: + sampled_heading = sampled_heading.cumsum(4) + if self.model_cfg.RELATIVE_HEADING: + sampled_heading = sampled_heading + anchor_heading + else: + pass + sampled_heading = wrap_to_pi(sampled_heading) + + if pred_velocity is not None: + sampled_velocity = pred_velocity.clone() + assert sampled_velocity.ndim == 6 + if self.model_cfg.USE_CUMSUM: + sampled_velocity = sampled_velocity.cumsum(4) + if self.model_cfg.RELATIVE_VELOCITY: + if self.model_cfg.RELATIVE_POSITION_HEADING: + sampled_velocity = rotate( + sampled_velocity[..., 0], sampled_velocity[..., 1], angle=anchor_heading.squeeze(-1) + ) + assert anchor_velocity.shape == sampled_velocity.shape + sampled_velocity = sampled_velocity + anchor_velocity + else: + sampled_velocity = sampled_velocity * VELOCITY_XY_RANGE + + if pred_actor_type is not None: + actor_type_dist = torch.distributions.Categorical(logits=pred_actor_type) + sampled_actor_type = actor_type_dist.sample() + + # sampled_position[~stacked_actor_valid_mask222] = 0 + # sampled_velocity[~stacked_actor_valid_mask222] = 0 + # sampled_heading[~stacked_actor_valid_mask222] = 0 + + if in_evaluation: + ret.update( + { + # "sampled_position_before_offset": sampled_position_before_offset, + "sampled_position": sampled_position, + "sampled_heading": sampled_heading, + "sampled_velocity": sampled_velocity, + "sampled_actor_type": sampled_actor_type, + "position_logit": pred_pos, + } + ) + else: + ret.update( + { + "sampled_position": sampled_position, + "position_logit": pred_pos, + "heading_logit": pred_heading, + "velocity_logit": pred_velocity, + "actor_type_logit": pred_actor_type, + } + ) + if layer_index == self.num_decoder_layers - 1: + ret.update( + { + "sampled_heading": sampled_heading, + "sampled_velocity": sampled_velocity, + "sampled_actor_type": sampled_actor_type, + } + ) + + if not in_evaluation: + if self.model_cfg.LOSS_EACH_LAYER and layer_index < self.num_decoder_layers - 1: + ret.update( + { + "stacked_actor_valid_mask": stacked_actor_valid_mask, + "compress_actor_valid_mask": actor_valid_mask, # [B, compress_T, N] + } + ) + else: + ret.update( + { + "anchor_position": anchor_position, + "anchor_velocity": anchor_velocity, + "anchor_heading": anchor_heading, + "stacked_actor_valid_mask": stacked_actor_valid_mask, + "compress_actor_valid_mask": actor_valid_mask, # [B, compress_T, N] + } + ) + return ret + + def get_prediction_for_traffic_light(self, *, traffic_light_output_tokens, traffic_light_valid_mask, in_evaluation): + step_per_token = self.step_per_token + B, compress_T, L = traffic_light_valid_mask.shape + + # Get predicted vehicle type in shape: [B, compress_T, L, step_per_token, TRAFFIC_LIGHT_PREDICT_DIM] + pred_traffic_light = traffic_light_output_tokens.new_zeros( + B, compress_T, L, step_per_token, TRAFFIC_LIGHT_PREDICT_DIM + ) + sampled_state = torch.zeros( + [B, compress_T, L, step_per_token], dtype=torch.long, device=pred_traffic_light.device + ) + # -> [B, compressed_T, L, compress_step, token_dim] + if L > 0: + traffic_light_prediction = self.traffic_light_predictor( + traffic_light_output_tokens[traffic_light_valid_mask] + ) + traffic_light_prediction = traffic_light_prediction.reshape(-1, step_per_token, TRAFFIC_LIGHT_PREDICT_DIM) + if traffic_light_prediction.shape[0] > 0: + state_dist = torch.distributions.Categorical(logits=traffic_light_prediction) + sampled_state_valid = state_dist.sample() + sampled_state[traffic_light_valid_mask] = sampled_state_valid + pred_traffic_light[traffic_light_valid_mask] = traffic_light_prediction + + assert pred_traffic_light.shape[-1] == TRAFFIC_LIGHT_PREDICT_DIM + B, compress_T, L, step_per_token, _ = pred_traffic_light.shape + + if in_evaluation: + return { + "sampled_traffic_light_state": sampled_state, + } + else: + return { + "traffic_light_state_logit": pred_traffic_light, + "sampled_traffic_light_state": sampled_state, + "compress_traffic_light_mask": traffic_light_valid_mask + } + + # def get_internal_future_embedding(self, pred_dict, actor_valid_mask, layer_index): + # """ + # This function process predict futures to the embedding. + # The embedding will be used to added to the tokens for next self-attention layer. + # """ + # raise ValueError() + # B, T, N, num_modes, step_per_token, _ = pred_dict["position_logit"].shape + # + # # [B, T, N, num_modes, step_per_token, 6] + # # with torch.no_grad(): + # # x = pred_dict["position_logit"][actor_valid_mask].flatten(1, -1) + # x = torch.cat([ + # pred_dict["position_logit"][actor_valid_mask], + # pred_dict["score_logit"][actor_valid_mask].reshape(-1, num_modes, 1, 1).repeat(1, 1, step_per_token, 1) + # ], dim=-1).flatten(1, -1) + # + # # -> [B, T, N, d_model] + # x = unwrap(self.future_mlp[layer_index](x), actor_valid_mask) + # + # return x diff --git a/scenestreamer/models/__deprecated_motionpl.py b/scenestreamer/models/__deprecated_motionpl.py new file mode 100644 index 0000000000000000000000000000000000000000..dec1e0a0cf7f1ab168cb5f5a0d7f61ad713a9f15 --- /dev/null +++ b/scenestreamer/models/__deprecated_motionpl.py @@ -0,0 +1,686 @@ +import os +import pickle +import shutil +import time + +import lightning.pytorch as pl +import torch +from scenestreamer.models.motion import MotionLM +from torch.optim.lr_scheduler import LambdaLR, LinearLR, CosineAnnealingWarmRestarts + +from scenestreamer.dataset.preprocessor import sample_from_distributions_and_merge +from scenestreamer.utils.utils import rotate + +# TODO: Add waymo eval +# from scenestreamer.eval.waymo_eval import waymo_evaluation + + +def get_derivative(array2d, dt=0.1): + diff = (array2d[1:] - array2d[:-1]) / dt + return np.concatenate([[diff[0]], diff], axis=0) + + +def find_last_valid(array, mask): + assert mask.ndim + 1 == array.ndim + assert mask.shape == array.shape[:-1] + assert array.ndim == 4 + B, T, N, D = array.shape + indices = mask * torch.arange(T, device=mask.device).reshape(1, T, 1).expand(*mask.shape) + indices = indices.argmax(1, keepdims=True).unsqueeze(-1).expand(B, 1, N, D) + ret = torch.gather(array, index=indices, dim=1) # [B, 1, N, D] + ret[~mask.any(1, keepdims=True)] = 0 + return ret + + +@torch.no_grad() +def sample_from_distributions_and_merge( + step, copy_data_dict, model_output, batch_size, compress_step, sampling_method, max_known_step +): + """ + This function merges the layers's output (the sampled position/velocity/...) into the input dict, + for preparing next forward pass of the layers in inference. + """ + + # model_output["sampled_traffic_light_state"] + effective_bs, _, L, _ = copy_data_dict["encoder/traffic_light_feature"].shape + _, T, N, _ = copy_data_dict["encoder/agent_feature"].shape + B = batch_size + + num_modes = effective_bs // batch_size + + pred_start = max(step - compress_step, 0) + pred_end = min(step, T - compress_step) + pred_length = min(compress_step, T - step) + + future_start = step + future_end = min(step + compress_step, T) + + # [B, N, num_modes] + score = model_output["score_logit"][:, -1].clone() # This is already in logit! Not probability! + + # Only take the last token's first compress_step predictions -> [B, N, num_modes, pred_length (T), 3] + sampled_position = model_output["sampled_position"][:, -1, :, :, :pred_length].clone() + # [B, N, num_modes, pred_length, 2] -> [B, pred_length, N, num_modes, 3] + sampled_position = sampled_position.permute(0, 3, 1, 2, 4) + + sampled_heading = model_output["sampled_heading"][:, -1, :, :, :pred_length].clone() + sampled_heading = sampled_heading.permute(0, 3, 1, 2, 4) + + sampled_velocity = model_output["sampled_velocity"][:, -1, :, :, :pred_length].clone() + sampled_velocity = sampled_velocity.permute(0, 3, 1, 2, 4) + + assert copy_data_dict["encoder/agent_feature"].shape[1] >= future_end + + # Build new agent feature + new_agent_feature = copy_data_dict["encoder/agent_feature"][:, future_start:future_end].clone() + + # print("We are filling [{}, {}).".format(future_start, future_end)) + + # pred_agent_pos = sampled_position[:, pred_start:pred_end] + # pred_score = score[:, pred_start // compress_step:pred_start // compress_step + 1] + + if sampling_method == "native": + # non-sampled based: + # [B*num_modes, compress_T, N, num_modes, 2] -> [B, num_modes, compress_T, N, num_modes, 2] + pred_agent_pos = sampled_position.reshape(B, num_modes, *sampled_position.shape[1:]) + + # [B, num_modes, compress_T, N, num_modes, 2] -> [B, num_modes * num_modes, compress_T, N, 2] + pred_agent_pos = pred_agent_pos.permute(0, 1, 4, 2, 3, 5).flatten(1, 2) + + # [0+0, 6+1, 12+2, 18+3, 24+4, 30+5] + ind = torch.arange(num_modes) * num_modes + torch.arange(num_modes) + + # [B, num_modes * num_modes, compress_T, N, 2] -> [B*num_modes, compress_T, N, 2] + pred_agent_pos = pred_agent_pos[:, ind].flatten(0, 1) + + raise ValueError("Not finished yet") + + else: + # pred_agent_pos is in shape [B*num_modes, compress_step (T), N, num_modes, 2] + + comp_mask = model_output["compress_agent_valid_mask"][:, -1, :] + score = score[comp_mask] + dist = torch.distributions.Categorical(logits=score) + if sampling_method == "argmax": + pred_score_ind = score.argmax(-1) + elif sampling_method == "softmax": + pred_score_ind = dist.sample() + else: + raise ValueError() + + log_probability = dist.log_prob(pred_score_ind) + + pred_score_ind = utils.unwrap(pred_score_ind.unsqueeze(-1), comp_mask) + + # This is extremely important!! Need to fill "-inf" to log_probability! + log_probability = utils.unwrap( + log_probability.unsqueeze(-1), comp_mask, fill=float("-inf") + ).squeeze(-1) # [B, N] + + # assert pred_score_ind.shape[1] == 1, pred_score_ind.shape + pred_score_ind = pred_score_ind.reshape(effective_bs, 1, N, 1, 1) + + pred_score_ind_pos = pred_score_ind.expand(effective_bs, sampled_position.shape[1], N, 1, 3) + pred_agent_pos = torch.gather(sampled_position, index=pred_score_ind_pos, dim=3) + assert pred_agent_pos.shape == (effective_bs, pred_agent_pos.shape[1], N, 1, 3) + pred_agent_pos = pred_agent_pos.squeeze(3) + + # pred_agent_pos_feat = pred_agent_pos.clone() / constants.POSITION_XY_RANGE + pred_agent_pos_feat = pred_agent_pos.clone() + assert new_agent_feature[..., :3].shape == pred_agent_pos_feat.shape + new_agent_feature[..., :3] = pred_agent_pos_feat + # Let the layers predict Z axis! Below is use old way to get Z. + # new_agent_feature[..., 2:3] = find_last_valid( + # copy_data_dict["encoder/agent_feature"][:, :max_known_step, :, 2:3], + # copy_data_dict["encoder/agent_valid_mask"][:, :max_known_step], + # ) + + # Repeat above process for heading + pred_score_ind_heading = pred_score_ind.expand(effective_bs, sampled_heading.shape[1], N, 1, 1) + pred_agent_heading = torch.gather(sampled_heading, index=pred_score_ind_heading, dim=3) + assert pred_agent_heading.ndim == 5 + pred_agent_heading = pred_agent_heading.squeeze(-1).squeeze(-1) + + # non-sampled based: + # pred_agent_heading = pred_agent_heading.reshape(B, num_modes, *pred_agent_heading.shape[1:]) + # pred_agent_heading = pred_agent_heading.permute(0, 1, 4, 2, 3, 5).flatten(1, 2) + # pred_agent_heading = pred_agent_heading[:, ind].flatten(0, 1).squeeze(-1) + + pred_agent_heading_feat = pred_agent_heading.clone() + pred_agent_heading_feat = utils.wrap_to_pi(pred_agent_heading_feat) + pred_agent_heading_feat /= constants.HEADING_RANGE + + # print("HEADING MIN: ", pred_agent_heading_feat.min()) + + assert new_agent_feature[..., 3].shape == pred_agent_heading_feat.shape + new_agent_feature[..., 3] = pred_agent_heading_feat + new_agent_feature[..., 9] = torch.sin(pred_agent_heading) + new_agent_feature[..., 10] = torch.cos(pred_agent_heading) + + # Repeat above process for velocity + # pred_agent_vel = sampled_velocity[:, pred_start:pred_end] + pred_score_ind_vel = pred_score_ind.expand(effective_bs, sampled_velocity.shape[1], N, 1, 2) + pred_agent_vel = torch.gather(sampled_velocity, index=pred_score_ind_vel, dim=3) + assert pred_agent_vel.ndim == 5 + pred_agent_vel = pred_agent_vel.squeeze(3) + + pred_agent_vel_feat = pred_agent_vel.clone() / constants.VELOCITY_XY_RANGE + assert new_agent_feature[..., 4:6].shape == pred_agent_vel_feat.shape + new_agent_feature[..., 4:6] = pred_agent_vel_feat + + # length width height + new_agent_feature[..., [6, 7, 8]] = find_last_valid( + copy_data_dict["encoder/agent_feature"][:, :max_known_step, :, [6, 7, 8]], + copy_data_dict["encoder/agent_valid_mask"][:, :max_known_step], + ) + + # Input data already scaled. + + # speed + speed = pred_agent_vel.norm(dim=-1) / constants.VELOCITY_XY_RANGE + assert new_agent_feature[..., 11].shape == speed.shape + new_agent_feature[..., 11] = speed + + # agent type + new_agent_feature[..., 12:15] = find_last_valid( + copy_data_dict["encoder/agent_feature"][:, :max_known_step, :, 12:15], + copy_data_dict["encoder/agent_valid_mask"][:, :max_known_step], + ) + + # valid + new_agent_feature[..., 15] = 1 + + assert pred_agent_pos.shape == copy_data_dict["encoder/agent_position"][:, future_start:future_end].shape + assert new_agent_feature.shape == copy_data_dict["encoder/agent_feature"][:, future_start:future_end].shape + copy_data_dict["encoder/agent_position"][:, future_start:future_end] = pred_agent_pos.clone() + copy_data_dict["encoder/agent_feature"][:, future_start:future_end] = new_agent_feature.clone() + copy_data_dict["encoder/agent_valid_mask"][:, future_start:future_end] = 1 + + # Build new traffic light feature + # new_traffic_light_feature = data_dict["encoder/traffic_light_feature"].new_zeros((effective_bs, 1, L, TRAFFIC_LIGHT_STATE_DIM)) + new_traffic_light_feature = copy_data_dict["encoder/traffic_light_feature"][:, future_start:future_end].clone() + new_T = new_traffic_light_feature.shape[1] + assert new_T > 0 + if L > 0: + # Fill "stop_point" + new_traffic_light_feature[..., :3] = find_last_valid( + copy_data_dict["encoder/traffic_light_feature"][:, :max_known_step, :, :3], + copy_data_dict["encoder/traffic_light_valid_mask"][:, :max_known_step], + ) + # [B, T, L] + pred_light_state = model_output["sampled_traffic_light_state"][:, -1, :, :pred_length].permute(0, 2, 1) + st = torch.nn.functional.one_hot(pred_light_state, num_classes=9) + new_traffic_light_feature[:, :st.shape[1], ..., 3:] = st + + assert new_traffic_light_feature.shape == copy_data_dict["encoder/traffic_light_feature"][:, future_start:future_end + ].shape + copy_data_dict["encoder/traffic_light_feature"][:, future_start:future_end] = new_traffic_light_feature.clone() + copy_data_dict["encoder/traffic_light_valid_mask"][:, future_start:future_end] = 1 + + return copy_data_dict, pred_agent_pos, future_start, future_end, log_probability + + +# TODO: This might be helpful +# Handle unsupervised learning by using an IterableDataset where the dataset itself is constantly updated during training +# https://lightning.ai/docs/pytorch/latest/notebooks/lightning_examples/reinforce-learning-DQN.html?highlight=target + + +# TODO: Could move this to a util file? +def get_displacement(array, mask): + assert mask.ndim + 1 == array.ndim + assert mask.shape == array.shape[:-1] + assert array.ndim == 4 + B, T, N, D = array.shape + assert D == 2 or D == 3 + indices = mask * torch.arange(T, device=mask.device).reshape(1, T, 1).expand(*mask.shape) + last_indices = indices.argmax(dim=1, keepdims=True).unsqueeze(-1).expand(B, 1, N, D) + last_pos = torch.gather(array, index=last_indices, dim=1) + + first_indices = indices.argmin(dim=1, keepdims=True).unsqueeze(-1).expand(B, 1, N, D) + first_pos = torch.gather(array, index=first_indices, dim=1) + + return (last_pos - first_pos).norm(dim=-1), last_pos + + +class MotionLMPL(pl.LightningModule): + def __init__(self, cfg): + if "SEED" in cfg: + pl.seed_everything(cfg.SEED) + print("Everything is seeded to: ", cfg.SEED) + + super().__init__() + self.cfg = cfg + self.model_cfg = self.cfg.MODEL + self.sampling_method = self.cfg.EVALUATION.SAMPLING_METHOD + + self.motion_decoder = MotionLM(config=self.model_cfg) + + self.save_hyperparameters() + + self.validation_outputs = [] + self.validation_ground_truth = [] + + def forward(self, batch_dict): + forward_ret_dict = self.motion_decoder(batch_dict) + return forward_ret_dict + + def get_loss(self, data_dict, gt_dict, forward_ret_dict): + loss, tb_dict, disp_dict = self.motion_decoder.get_loss(data_dict, gt_dict, forward_ret_dict) + return loss, tb_dict, disp_dict + + def training_step(self, batch, batch_idx): + data_dict, gt_dict = batch + forward_ret_dict = self(data_dict) + loss, tb_dict, disp_dict = self.get_loss(data_dict, gt_dict, forward_ret_dict) + self.log_dict( + {f"train/{k}": float(v) + for k, v in tb_dict.items()}, + batch_size=data_dict["encoder/agent_feature"].shape[0], + # on_epoch=True, + # prog_bar=True, + ) + self.log('monitoring_step', float(self.global_step)) + return loss + + def on_validation_start(self): + torch.cuda.empty_cache() + + @torch.no_grad() + def autoregressive_generate( + self, batch, compress_step=None, total_step=None, return_step_model_out=False, sampling_method=None + ): + data_dict, gt_dict = batch + + start_time = gt_dict['current_time_index'].unique().item() + + B, T, N, D_actor = data_dict['actor_feature'].shape + _, _, L, D_light = data_dict['traffic_light_feature'].shape + + if total_step is None: + rollout_T = T + else: + assert total_step <= T + rollout_T = total_step + + feat = data_dict["encoder/agent_feature"] + + num_modes = self.motion_decoder.num_modes + + num_repeat = num_modes + + # num_repeat = 8 + + def _repeat_for_modes(v): + d = v.ndim + v = v.unsqueeze(1) + v = v.repeat(1, num_repeat, *((1, ) * (d - 1))) + v = v.flatten(0, 1) + return v + + copy_data_dict = { + "encoder/agent_feature": feat.new_zeros([B * num_repeat, T, N, D_actor]), + "encoder/agent_position": feat.new_zeros([B * num_repeat, T, N, 3]), + "encoder/agent_valid_mask": feat.new_zeros([B * num_repeat, T, N], dtype=bool), + "decoder/actor_type": _repeat_for_modes(data_dict["decoder/actor_type"]), + "encoder/map_feature": _repeat_for_modes(data_dict["encoder/map_feature"]), + "encoder/map_feature_valid_mask": _repeat_for_modes(data_dict["encoder/map_feature_valid_mask"]), + "encoder/map_position": _repeat_for_modes(data_dict["encoder/map_position"]), + "encoder/traffic_light_feature": feat.new_zeros([B * num_repeat, T, L, D_light]), + "encoder/traffic_light_position": _repeat_for_modes(data_dict["encoder/traffic_light_position"]), + "encoder/traffic_light_valid_mask": feat.new_zeros([B * num_repeat, T, L], dtype=bool) + } + + for k in [ + "encoder/agent_feature", + "encoder/agent_valid_mask", + "encoder/agent_position", + "encoder/traffic_light_valid_mask", + "encoder/traffic_light_feature", + ]: + v = data_dict[k] + v = v.reshape(B, 1, *(v.shape[1:])) + v = v.repeat(1, num_repeat, *(1, ) * len(v.shape[2:])) + v = v.flatten(0, 1) + assert v.shape[1] == T, (k, v.shape) + copy_data_dict[k][:, :start_time + 1] = v[:, :start_time + 1] + + model_output_collection = { + "sampled_position": copy_data_dict["encoder/map_position"].new_zeros([B, T, N, num_repeat, 3]), + "log_probability": copy_data_dict["encoder/map_position"].new_zeros([B, T, N, num_repeat]) + } + model_output_collection["log_probability"].fill_(float("-inf")) + + if compress_step is None: + compress_step = self.motion_decoder.compress_step + + model_output = {} + + if return_step_model_out: + step_model_out = [] + + for step in range(start_time, rollout_T, compress_step): + + # input_dict is a snapshot of data_dict + input_dict = { + "encoder/map_feature": copy_data_dict["encoder/map_feature"].clone(), + "encoder/map_feature_valid_mask": copy_data_dict["encoder/map_feature_valid_mask"].clone(), + "encoder/map_position": copy_data_dict["encoder/map_position"].clone(), + "encoder/traffic_light_position": copy_data_dict["encoder/traffic_light_position"].clone(), + "encoder/traffic_light_valid_mask": copy_data_dict["encoder/traffic_light_valid_mask"][:, :step + + 1].clone(), + "encoder/agent_feature": copy_data_dict["encoder/agent_feature"][:, :step + 1].clone(), + "encoder/agent_valid_mask": copy_data_dict["encoder/agent_valid_mask"][:, :step + 1].clone(), + "encoder/agent_position": copy_data_dict["encoder/agent_position"][:, :step + 1].clone(), + "encoder/traffic_light_feature": copy_data_dict["encoder/traffic_light_feature"][:, :step + 1].clone(), + "in_evaluation": True, + "output_compress_step": compress_step, + "decoder/actor_type": copy_data_dict["decoder/actor_type"] + } + + for k in ["map_token", "query_cache"]: + if k in model_output: + input_dict[k] = model_output[k] + + model_output = self(input_dict) + + if return_step_model_out: + step_model_out.append(model_output) + + # Fuse the predicted state into data dict directly. + copy_data_dict, pred_actor_pos, future_start, future_end, log_probability = sample_from_distributions_and_merge( + step=step, + copy_data_dict=copy_data_dict, + model_output=model_output, + batch_size=B, + compress_step=compress_step, + sampling_method=sampling_method or self.sampling_method, + max_known_step=start_time + 1 + ) + + if step == start_time: + # overwrite the predicted position to initial position, if the actor is not moved in the initial + # interval + + # [B, 1, N], [B, 1, N, 2] + displacement, last_pos = get_displacement( + input_dict["encoder/agent_position"][:, :start_time], + input_dict["encoder/agent_valid_mask"][:, :start_time] + ) + last_pos = last_pos.reshape(B, num_repeat, 1, N, 3) + displacement = displacement.reshape(B, num_repeat, 1, N) + + last_pos_use = last_pos.repeat(1, 1, pred_actor_pos.shape[1], 1, 1) + last_pos_use = last_pos_use.permute(0, 2, 3, 1, 4) + + static_actor_mask = displacement < 0.001 + static_actor_mask_pred = static_actor_mask.reshape(B, num_repeat, 1, + N).repeat(1, 1, pred_actor_pos.shape[1], 1) + static_actor_mask_pred = static_actor_mask_pred.permute(0, 2, 3, 1) # [B, step, N, num_modes] + + # [B*num_modes, T, N, 3] -> [B, num_modes, T, N, 3] + pred_actor_pos_reform = pred_actor_pos.reshape(B, num_repeat, *pred_actor_pos.shape[1:]) + # -> [B, T, N, num_modes, 3] + pred_actor_pos_reform = pred_actor_pos_reform.permute(0, 2, 3, 1, 4) + + pred_actor_pos_reform[static_actor_mask_pred] = last_pos_use[static_actor_mask_pred] + + model_output_collection["sampled_position"][:, future_start:future_end] = pred_actor_pos_reform + + # Since we will sum up all scores for each mode, + # We can fill divide the score by 5 (compress_step) before filling them into matrix, + # By doing so we avoid sum the same the log probability 5 times when computing the trajectory-level scores. + log_probability_reform = log_probability.reshape(B, num_repeat, 1, N) + log_probability_reform = log_probability_reform.permute(0, 2, 3, 1) # -> [B, 1, N, num_modes] + model_output_collection["log_probability"][:, future_start:future_end] = \ + log_probability_reform / (future_end - future_start) + + if step in [10, 15, 20, 40, 60, 90]: + # Compute the first batch output's loss. It should be similar to training loss. + self.log_dict( + self.motion_decoder.get_position_loss( + data_dict, model_output_collection, step, start_time, future_end + ), + # prog_bar=True + ) + + if return_step_model_out: + return copy_data_dict, model_output_collection, step_model_out + + return copy_data_dict, model_output_collection + + def validation_step( + self, batch, batch_idx, output_compress_step=None, total_step=None, return_dict=False, sampling_method=None + ): + + # TODO: Add this back + return + + data_dict, gt_dict = batch + + if output_compress_step is None: + output_compress_step = self.config.EVALUATION.OUTPUT_COMPRESS_STEP + + copy_data_dict, model_output_collection = self.autoregressive_generate( + batch, compress_step=output_compress_step, total_step=total_step, sampling_method=sampling_method + ) + final_pred_dicts = generate_predicted_trajectory_for_eval(data_dict, gt_dict, model_output_collection) + final_pred_dicts.update(gt_dict) + self.validation_outputs.append(final_pred_dicts) + if return_dict: + final_pred_dicts["model_output_collection"] = model_output_collection + final_pred_dicts["data_dict"] = data_dict + return final_pred_dicts + + def on_validation_epoch_end(self): + + # TODO: Add this + return + + st = time.time() + + # https://lightning.ai/docs/pytorch/latest/accelerators/accelerator_prepare.html?highlight=hardware + torch.cuda.empty_cache() + + # PZH NOTE: Hack to implement our own all_gather across ranks. + self.trainer.strategy.barrier() + + # if "Wandb" in str(self.trainer.logger): + # tmpdir = os.path.join(self.trainer.logger.version, "validation_tmpdir") + # os.makedirs(self.trainer.logger.version, exist_ok=True) + # os.makedirs(tmpdir, exist_ok=True) + # else: + tmpdir = os.path.join(self.trainer.log_dir, "validation_tmpdir") + # os.makedirs(self.trainer.log_dir, exist_ok=True) + os.makedirs(tmpdir, exist_ok=True) + + self.validation_outputs = [ + { + k: v.detach().cpu().float().numpy() if isinstance(v, torch.Tensor) else v + for k, v in final_pred_dicts.items() + } for final_pred_dicts in self.validation_outputs + ] + + with open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(self.global_rank)), 'wb') as f: + pickle.dump(self.validation_outputs, f) + self.validation_outputs.clear() + self.trainer.strategy.barrier() + + self.log("monitoring_step", float(self.global_step)) + + if self.trainer.is_global_zero: + validation_list = [] + for i in range(self.trainer.world_size): + file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i)) + + for sleep in range(10): + if not os.path.isfile(file): + time.sleep(1) + print(f"Can't find file: {file}. Sleep {sleep} seconds.") + with open(file, "rb") as f: + val_outputs = pickle.load(f) + validation_list.extend(val_outputs) + + if self.config.DELETE_EVAL_RESULT: + shutil.rmtree(tmpdir) + + # print("==== log eval dir: ", eval_output_dir) + + # with open(os.path.join(eval_output_dir, 'validation_result.pkl'), 'wb') as f: + # pickle.dump(validation_list, f) + + # print(f"===== Start evaluation: {time.time() - st:.3f}") + # scenario_id_map = self.trainer.val_dataloaders.dataset.scenario_id_map + result_str, result_dict = waymo_evaluation(validation_list) + result_dict = {f"eval/{k}": float(v) for k, v in result_dict.items()} + self.log_dict(result_dict, rank_zero_only=True) + + for k in ['eval/minADE', 'eval/minFDE', 'eval/MissRate', 'eval/mAP']: + self.log(name=k.split("/")[1], value=result_dict[k], rank_zero_only=True) + + self.print(result_str) + print(f"===== Finish evaluation: {time.time() - st:.3f}") + + torch.cuda.empty_cache() + + def configure_optimizers(self): + opt_cfg = self.cfg.OPTIMIZATION + + if opt_cfg.OPTIMIZER == 'Adam': + optimizer = torch.optim.Adam( + [each[1] for each in self.named_parameters()], + lr=opt_cfg.LR, + weight_decay=opt_cfg.get('WEIGHT_DECAY', 0) + ) + elif opt_cfg.OPTIMIZER == 'AdamW': + optimizer = torch.optim.AdamW( + self.parameters(), lr=opt_cfg.LR, weight_decay=opt_cfg.get('WEIGHT_DECAY', 0), betas=(0.9, 0.95) + ) + else: + assert False + + if opt_cfg.get('SCHEDULER', None) == 'cosine': + scheduler = CosineAnnealingWarmRestarts( + optimizer, + T_0=2, + T_mult=1, + eta_min=max(1e-2 * opt_cfg.LR, 1e-6), + last_epoch=-1, + ) + elif opt_cfg.get('SCHEDULER', None) == 'lambdaLR': + + def lr_lbmd(cur_epoch): + cur_decay = 1 + for decay_step in opt_cfg.get('DECAY_STEP_LIST', [5, 10, 15, 20]): + if cur_epoch >= decay_step: + cur_decay = cur_decay * opt_cfg.LR_DECAY + return max(cur_decay, opt_cfg.LR_CLIP / opt_cfg.LR) + + scheduler = LambdaLR(optimizer, lr_lbmd) + + elif opt_cfg.get('SCHEDULER', None) == 'linearLR': + scheduler = LinearLR( + optimizer, + start_factor=1.0, + end_factor=opt_cfg.LR_CLIP / opt_cfg.LR, + total_iters=opt_cfg.NUM_EPOCHS, + ) + else: + scheduler = None + + return { + "optimizer": optimizer, + + # PZH NOTE: The scheduler step will be added 1 after each epoch. + "lr_scheduler": scheduler + } + + +def generate_predicted_trajectory_for_eval(data_dict, gt_dict, forward_ret_dict): + """ + This function will extract the predicted state for all objects. + We can then compare those state with the ground truths for formal evaluation in the motion forecasting task. + Note that this function is only called in validation_step and for evaluation only. + We don't use the result here to compute loss. + """ + actor_valid_mask = data_dict["encoder/agent_valid_mask"] # [B, T, N] + predicted_position = forward_ret_dict["sampled_position"] # [B*num_modes, T, N, num_modes, 2] + + # ======================================================================= + # For debug use only, overwrite the predicted result by GT to check evaluation pipeline. + # pred_pos = data_dict["encoder/agent_position"][data_dict["encoder/agent_valid_mask"]] + # pred_pos = pred_pos.reshape(-1, 1, 3).repeat(1, 1, 1) # Extend the num_modes dim. + # ======================================================================= + + B, T, N = actor_valid_mask.shape + _, _, _, num_modes, _ = predicted_position.shape + + track_index_to_predict = data_dict["track_index_to_predict"].long() # [B, num interested actors] + _, max_actor_to_predict = track_index_to_predict.shape + + map_center = gt_dict["map_center"][..., :2] # [B, 2] + map_heading = gt_dict["encoder/map_heading"] # [B, ] + + map_center = map_center.reshape(B, 1, 1, 1, 2).repeat(1, T, N, num_modes, 1) + + map_heading = map_heading.reshape(B, 1, 1, 1).repeat(1, T, N, num_modes) + + predicted_position = rotate(x=predicted_position[..., 0], y=predicted_position[..., 1], angle=map_heading) + + predicted_position = predicted_position + map_center + + predicted_position[~data_dict["encoder/agent_valid_mask"][..., None, None].expand(B, T, N, num_modes, 2)] = 0 + + # [B, T, N, num_modes, 2] -> [B, N, num_modes, T, 2] + predicted_position = predicted_position.permute(0, 2, 3, 1, 4) + + # [B, N, num_modes, T, 2] -> [B * N, num_modes, T, 2] + predicted_position = predicted_position.flatten(0, 1) + + # change the index range from [0, N] to [0, B*N] + valid_track_index = torch.where(track_index_to_predict != -1) + track_index_to_predict += torch.arange(B).to(track_index_to_predict).reshape(-1, 1) * N + track_index_flatten = track_index_to_predict[valid_track_index] + + assert track_index_flatten.max().item() < predicted_position.shape[0], ( + track_index_flatten, predicted_position.shape + ) + + # [sum valid actor to predict, num_modes, T, 2] + predicted_traj = predicted_position[track_index_flatten] + + # Assume we only have one current_time_index = 10. + current_time_index = gt_dict["current_time_index"].unique().item() + + predicted_traj = predicted_traj[:, :, current_time_index + 1:, :2] + + if "score" in forward_ret_dict: + score = forward_ret_dict["score"] + else: + score = forward_ret_dict["log_probability"] + + # [B, T, N, num_modes] -> [B, N, num_modes, T] + score = score.permute(0, 2, 3, 1) + # [B, N, num_modes, T] -> [B * N, num_modes, T] + score = score.flatten(0, 1) + # [sum valid actor to predict, num_modes, T] + score = score[track_index_flatten] + score = score[:, :, current_time_index + 1:] + + # ===== DEBUG CODE ===== + # score[:, :1] = 10000 + # score[:, 1:] = 0 + # score.fill_(1) + # ===== DEBUG CODE ===== + + # ======================================================================= + # debug_mask in [num valid, 80] + # debug_mask = actor_valid_mask.permute(0, 2, 1, 3).flatten(0, 1)[track_index_flatten][:, 11:][..., 0] + # pred_origin = predicted_traj[:, 0] + # gt_origin = gt_dict["center_gt_trajs_src"][:, 11:, :2] + # assert (pred_origin[debug_mask] - gt_origin[debug_mask]).abs().max().item() < 0.1 + # ======================================================================= + assert predicted_traj.shape[-1] == 2 + return { + "pred_trajs": predicted_traj, + "pred_scores": score.sum(-1) # sum of log probability + } diff --git a/scenestreamer/models/__init__.py b/scenestreamer/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/models/diffusion.py b/scenestreamer/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5dd293c1569fdf9b8315f0e9659b0828b6fae7 --- /dev/null +++ b/scenestreamer/models/diffusion.py @@ -0,0 +1,541 @@ +from typing import Any, Mapping, Tuple, Union +from typing import Callable +from typing import Dict, Optional, Sequence + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from einops import rearrange +from jax import Array +from jax.typing import ArrayLike + +# adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py +# from octo.model.components.base import TokenGroup +# from octo.utils.typing import Dtype, PRNGKey, Shape, Union + +# from octo.model.components.base import TokenGroup +# from octo.model.components.diffusion import cosine_beta_schedule, create_diffusion_model +# from octo.model.components.tokenizers import BinTokenizer +# from octo.model.components.transformer import MAPHead +# from octo.utils.typing import PRNGKey + +PRNGKey = jax.random.KeyArray +PyTree = Union[jax.typing.ArrayLike, Mapping[str, "PyTree"]] +Config = Union[Any, Mapping[str, "Config"]] +Params = Mapping[str, PyTree] +Data = Mapping[str, PyTree] +Shape = Sequence[int] +Dtype = jax.typing.DTypeLike + +default_init = nn.initializers.xavier_uniform + + +@flax.struct.dataclass +class TokenGroup: + """A group of tokens that have semantic meaning together (e.g. the tokens for a single observation) + + Attributes: + tokens: jax.Array of shape (..., n_tokens, token_dim) + mask: jax.Array of shape (..., n_tokens) indicating which tokens are valid (1) vs padding (0) + """ + + tokens: jax.typing.ArrayLike + mask: jax.typing.ArrayLike + + @classmethod + def create(cls, tokens: jax.typing.ArrayLike, mask: jax.typing.ArrayLike = None, **kwargs): + if mask is None: + mask = jnp.ones(tokens.shape[:-1]) + assert mask.ndim == tokens.ndim - 1 + return cls(tokens, mask, **kwargs) + + @classmethod + def concatenate(cls, group_list: Sequence["TokenGroup"], axis=-2): + data = jnp.concatenate([t.tokens for t in group_list], axis=axis) + mask = jnp.concatenate([t.mask for t in group_list], axis=axis + 1) + return cls(data, mask) + + +def masked_mean(x, mask): + mask = jnp.broadcast_to(mask, x.shape) + return jnp.mean(x * mask) / jnp.clip(jnp.mean(mask), a_min=1e-5, a_max=None) + + +def continuous_loss( + pred_value: ArrayLike, + ground_truth_value: ArrayLike, + mask: ArrayLike, + loss_type: str = "mse", +) -> Array: + """ + Args: + pred_value: shape (batch_dims...) + ground_truth_value: continuous values w/ shape (batch_dims...) + mask: broadcastable to ground_truth + """ + if loss_type == "mse": + loss = jnp.square(pred_value - ground_truth_value) + elif loss_type == "l1": + loss = jnp.abs(pred_value - ground_truth_value) + else: + raise ValueError(f"Invalid loss type: {loss_type}") + + loss = masked_mean(loss, mask) + + mse = jnp.square(pred_value - ground_truth_value) + mse = masked_mean(mse, mask) + return loss, { + "loss": loss, + "mse": mse, + } + + +def chunk_actions(actions: ArrayLike, pred_horizon: int) -> Array: + """Chunk actions for predicting actions `pred_horizon` steps into the future. + + The resulting actions have shape (batch, actions.shape[-2] - (pred_horizon - 1), pred_horizon, action_dim) + + For example: chunk_actions([a_1, a_2, a_3, a_4, a_5], 3) -> + [ + [a_1, a_2, a_3], + [a_2, a_3, a_4], + [a_3, a_4, a_5], + ] + + """ + assert ( + actions.ndim == 3 + ), f"Expected actions to have shape (batch, window_size, action_dim), but got shape {actions.shape}" + window_size = actions.shape[1] + assert (window_size >= pred_horizon), f"pred_horizon {pred_horizon} too large for window size {window_size}" + chunk_window_size = window_size - (pred_horizon - 1) + + curr_step = jnp.arange(chunk_window_size) + action_offset = jnp.arange(pred_horizon) + chunk_indices = curr_step[:, None] + action_offset[None, :] + return actions[:, chunk_indices] + + +def _check_action_window_size(actions, window_size, pred_horizon): + assert ( + actions.shape[1] >= window_size + pred_horizon - 1 + ), f""" + To predict actions for window_size {window_size} and future prediction horizon {pred_horizon}, + the ground-truth actions must have at least {window_size + pred_horizon - 1} timesteps, but got shape {actions.shape}. + + Did you make sure to set "future_action_window_size" correctly in the data config? + """ + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + t = jnp.linspace(0, timesteps, steps) / timesteps + alphas_cumprod = jnp.cos((t + s) / (1 + s) * jnp.pi * 0.5)**2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return jnp.clip(betas, 0, 0.999) + + +class ScoreActor(nn.Module): + time_preprocess: nn.Module + cond_encoder: nn.Module + reverse_network: nn.Module + + def __call__(self, obs_enc, actions, time, train=False): + t_ff = self.time_preprocess(time) + cond_enc = self.cond_encoder(t_ff, train=train) + reverse_input = jnp.concatenate([cond_enc, obs_enc, actions], axis=-1) + eps_pred = self.reverse_network(reverse_input, train=train) + return eps_pred + + +class FourierFeatures(nn.Module): + output_size: int + learnable: bool = True + + @nn.compact + def __call__(self, x: jax.Array): + if self.learnable: + w = self.param( + "kernel", + nn.initializers.normal(0.2), + (self.output_size // 2, x.shape[-1]), + jnp.float32, + ) + f = 2 * jnp.pi * x @ w.T + else: + half_dim = self.output_size // 2 + f = jnp.log(10000) / (half_dim - 1) + f = jnp.exp(jnp.arange(half_dim) * -f) + f = x * f + return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) + + +class MLP(nn.Module): + hidden_dims: Sequence[int] + activation: Callable = nn.swish + activate_final: bool = False + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + + @nn.compact + def __call__(self, x: jax.Array, train: bool = False) -> jax.Array: + for i, size in enumerate(self.hidden_dims): + x = nn.Dense(size, kernel_init=default_init())(x) + + if i + 1 < len(self.hidden_dims) or self.activate_final: + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = self.activation(x) + return x + + +class MLPResNetBlock(nn.Module): + features: int + act: Callable + dropout_rate: float = None + use_layer_norm: bool = False + + @nn.compact + def __call__(self, x, train: bool = False): + residual = x + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.Dense(self.features * 4)(x) + x = self.act(x) + x = nn.Dense(self.features)(x) + + if residual.shape != x.shape: + residual = nn.Dense(self.features)(residual) + + return residual + x + + +class MLPResNet(nn.Module): + num_blocks: int + out_dim: int + dropout_rate: float = None + use_layer_norm: bool = False + hidden_dim: int = 256 + activation: Callable = nn.swish + + @nn.compact + def __call__(self, x: jax.typing.ArrayLike, train: bool = False) -> jax.Array: + x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x) + for _ in range(self.num_blocks): + x = MLPResNetBlock( + self.hidden_dim, + act=self.activation, + use_layer_norm=self.use_layer_norm, + dropout_rate=self.dropout_rate, + )(x, train=train) + + x = self.activation(x) + x = nn.Dense(self.out_dim, kernel_init=default_init())(x) + return x + + +def create_diffusion_model( + out_dim: int, + time_dim: int, + num_blocks: int, + dropout_rate: float, + hidden_dim: int, + use_layer_norm: bool, +): + return ScoreActor( + FourierFeatures(time_dim, learnable=True), + MLP((2 * time_dim, time_dim)), + MLPResNet( + num_blocks, + out_dim, + dropout_rate=dropout_rate, + hidden_dim=hidden_dim, + use_layer_norm=use_layer_norm, + ), + ) + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + + mlp_dim: int + dtype: Dtype = jnp.float32 + out_dim: Optional[int] = None + dropout_rate: float = 0.1 + kernel_init: Callable[[PRNGKey, Shape, Dtype], jax.Array] = nn.initializers.xavier_uniform() + bias_init: Callable[[PRNGKey, Shape, Dtype], jax.Array] = nn.initializers.normal(stddev=1e-6) + + @nn.compact + def __call__(self, inputs, *, deterministic): + """Applies Transformer MlpBlock module.""" + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + features=self.mlp_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )(inputs) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + output = nn.Dense( + features=actual_out_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )(x) + output = nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) + return output + + +class MAPHead(nn.Module): + """Multihead Attention Pooling. + + From https://github.com/google-research/big_vision/blob/main/big_vision/models/vit.py + """ + + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 8 + num_readouts: int = 1 + + @nn.compact + def __call__(self, x: Union[jax.Array, TokenGroup], train=True): + if isinstance(x, TokenGroup): + x, mask = x.tokens, x.mask + else: + mask = None + + *batch_dims, l, d = x.shape + x = x.reshape(-1, l, d) + batch_size = x.shape[0] + + probe = self.param( + "probe", + nn.initializers.xavier_uniform(), + (1, self.num_readouts, d), + x.dtype, + ) + probe = jnp.tile(probe, [batch_size, 1, 1]) + + if mask is not None: + mask = mask.reshape(-1, l) + mask = jnp.broadcast_to(mask[:, None, None, :], (batch_size, 1, self.num_readouts, l)) + + out = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform() + )(probe, x, mask=mask) + + # TODO: dropout on head? + y = nn.LayerNorm()(out) + + out = out + MlpBlock(mlp_dim=nn.merge_param("mlp_dim", self.mlp_dim, 4 * d))(y, deterministic=not train) + out = out.reshape(*batch_dims, self.num_readouts, d) + return out + + +class DiffusionActionHead(nn.Module): + """Predicts actions uses a diffusion process. + + Only a single pass through the transformer is done to obtain an action embedding at each timestep. The + action is then predicted using a diffusion process conditioned on this embedding. The diffusion model + architecture is an MLP with residual connections (see `octo.model.components.diffusion`). + + You may create an embedding by either mean-pooling across tokens (use_map=False) or using multi-head + attention pooling (use_map=True). It is recommended to use MAP when decoding from the observation token + stream. + """ + + readout_key: str + use_map: bool = False + pred_horizon: int = 1 + action_dim: int = 7 + max_action: float = 5.0 + loss_type: str = "mse" + + # diffusion-specific config with sane defaults + time_dim: int = 32 + num_blocks: int = 3 + dropout_rate: float = 0.1 + hidden_dim: int = 256 + use_layer_norm: bool = True + diffusion_steps: int = 20 + + def setup(self): + if self.use_map: + self.map_head = MAPHead() + + # create the diffusion model (score network) + self.diffusion_model = create_diffusion_model( + self.action_dim * self.pred_horizon, + time_dim=self.time_dim, + num_blocks=self.num_blocks, + dropout_rate=self.dropout_rate, + hidden_dim=self.hidden_dim, + use_layer_norm=self.use_layer_norm, + ) + + # create beta schedule + self.betas = jnp.array(cosine_beta_schedule(self.diffusion_steps)) + self.alphas = 1 - self.betas + self.alpha_hats = jnp.array([jnp.prod(self.alphas[:i + 1]) for i in range(self.diffusion_steps)]) + + def __call__( + self, + transformer_outputs: Dict[str, TokenGroup], + time: Optional[ArrayLike] = None, + noisy_actions: Optional[ArrayLike] = None, + train: bool = True, + ) -> jax.Array: + """Performs a single forward pass through the diffusion model.""" + token_group = transformer_outputs[self.readout_key] + assert token_group.tokens.ndim == 4, ( + f"Expected token_group.tokens to have shape (batch_size, window_size, num_tokens, embedding_size), " + f"but got shape {token_group.tokens.shape}" + ) + if self.use_map: # Multi-head attention pooling + embeddings = self.map_head(token_group, train=train)[:, :, 0] + else: # mean pooling + embeddings = token_group.tokens.mean(axis=-2) + # Now, embeddings is (batch_size, window_size, embedding_size) + + # time and noisy_actions are None during initialization, so we replace them with a dummy array + if (time is None or noisy_actions is None) and not self.is_initializing(): + raise ValueError("Must provide time and noisy_actions when calling diffusion action head") + elif self.is_initializing(): + time = jnp.zeros((*embeddings.shape[:2], 1), dtype=jnp.float32) + noisy_actions = jnp.zeros( + (*embeddings.shape[:2], self.action_dim * self.pred_horizon), + dtype=jnp.float32, + ) + + pred_eps = self.diffusion_model(embeddings, noisy_actions, time, train=train) + return pred_eps + + def loss( + self, + transformer_outputs: Dict[str, TokenGroup], + actions: ArrayLike, + pad_mask: ArrayLike, + train: bool = True, + ) -> Tuple[Array, Dict[str, Array]]: + """Computes the loss for the diffusion objective. + + Args: + transformer_ouputs: must contain self.readout_key with shape (batch_size, window_size, num_tokens, + embedding_size) + actions: shape (batch_size, >= window_size + pred_horizon - 1, action_dim) + pad_mask: boolean array (batch, window_size) which is True if the timestep is not a padding timestep + + Returns: + loss: float + metrics: dict + """ + batch_size, window_size = pad_mask.shape + _check_action_window_size(actions, window_size, self.pred_horizon) + actions_chunked = chunk_actions(actions, self.pred_horizon) + actions_chunked = actions_chunked[:, :window_size] + # fold action_dim and pred_horizon into one dimension + actions_flat = rearrange(actions_chunked, "b w p a -> b w (p a)") + actions_flat = jnp.clip(actions_flat, -self.max_action, self.max_action) + + # piggy-back on the dropout rng chain for diffusion rng + rng = self.make_rng("dropout") + time_key, noise_key = jax.random.split(rng) + time = jax.random.randint(time_key, (batch_size, window_size, 1), 0, self.diffusion_steps) + noise = jax.random.normal(noise_key, actions_flat.shape) + + alpha_hat = self.alpha_hats[time] + alpha_1 = jnp.sqrt(alpha_hat) + alpha_2 = jnp.sqrt(1 - alpha_hat) + noisy_actions = alpha_1 * actions_flat + alpha_2 * noise + + pred_eps = self(transformer_outputs, train=train, time=time, noisy_actions=noisy_actions) + + loss, metrics = continuous_loss(pred_eps, noise, pad_mask[:, :, None], loss_type=self.loss_type) + # Sum over action dimension instead of averaging + loss = loss * self.action_dim + metrics["loss"] = metrics["loss"] * self.action_dim + metrics["mse"] = metrics["mse"] * self.action_dim + return loss, metrics + + def predict_action( + self, + transformer_outputs: Dict[str, TokenGroup], + rng: PRNGKey, + train: bool = True, + *args, + sample_shape: tuple = (), + **kwargs, + ) -> jax.Array: + """Convenience methods for predicting actions for the final timestep in the window.""" + module, variables = self.unbind() + + def scan_fn(carry, time): + current_x, rng = carry + input_time = jnp.broadcast_to(time, (*current_x.shape[:-1], 1)) + + eps_pred = module.apply(variables, transformer_outputs, input_time, current_x, train=train) + + alpha_1 = 1 / jnp.sqrt(self.alphas[time]) + alpha_2 = (1 - self.alphas[time]) / (jnp.sqrt(1 - self.alpha_hats[time])) + current_x = alpha_1 * (current_x - alpha_2 * eps_pred) + + rng, key = jax.random.split(rng) + z = jax.random.normal(key, shape=current_x.shape) + current_x = current_x + (time > 0) * (jnp.sqrt(self.betas[time]) * z) + + current_x = jnp.clip(current_x, -self.max_action, self.max_action) + + return (current_x, rng), () + + def sample_actions(rng): + rng, key = jax.random.split(rng) + batch_size, window_size = transformer_outputs[self.readout_key].tokens.shape[:2] + + (actions_flat, _), () = jax.lax.scan( + scan_fn, + ( + jax.random.normal( + key, + (batch_size, window_size, self.pred_horizon * self.action_dim), + ), + rng, + ), + jnp.arange(self.diffusion_steps - 1, -1, -1), + ) + + actions = rearrange( + actions_flat, + "b w (p a) -> b w p a", + p=self.pred_horizon, + a=self.action_dim, + ) + # only get the last timestep in the window + return actions[:, -1] + + n_samples = int(np.prod(sample_shape)) + actions = jax.vmap(sample_actions)(jax.random.split(rng, n_samples)) + actions = actions.reshape(sample_shape + actions.shape[1:]) + return actions + + +if __name__ == '__main__': + mod = DiffusionActionHead( + pred_horizon=5, + action_dim=2, + readout_key="readout_action", + ) + + input_to_model = TokenGroup(tokens=None, mask=None) + + print(111) diff --git a/scenestreamer/models/gen_model.py b/scenestreamer/models/gen_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2663fc1530fb0922b3d0955dea62b29edec274cd --- /dev/null +++ b/scenestreamer/models/gen_model.py @@ -0,0 +1,832 @@ +import logging + +import torch.nn as nn + +from scenestreamer.dataset import constants +from scenestreamer.models.layers import polyline_encoder, common_layers, position_encoding_utils +from scenestreamer.models.layers.decoder_layer import TransformerDecoder, TransformerDecoderLayer +# from scenestreamer.models.motion_decoder import MotionDecoder +from scenestreamer.models.layers.encoder_layer import TransformerEncoderLayer # as NativeTransformerEncoderLayer +from scenestreamer.models.motionlm import MotionLM, nucleus_sampling +from scenestreamer.models.ops.collapse_time import collapse_time +# from torch.nn.modules.transformer import TransformerEncoderLayer as NativeTransformerEncoderLayer +from scenestreamer.models.scene_encoder import compute_relation +from scenestreamer.tokenization.gen_tokenizers import GenTokenizer, Tokens, SceneStreamerTokenizer +from scenestreamer.utils import calculate_trajectory_probabilities + +logger = logging.getLogger(__file__) + + +def create_causal_mask(causal_mask_offset, num_heads=None): + """ Create the causal mask for a flattened token sequence. Tokens will not attend to future ids. Tokens for the + agents in the same step can attend to each other. + + row: a query + col: a key + + So for mask[100] it should see more keys than mask[0]. + + Note that all +1 positions will be filled -inf. + """ + B, L = causal_mask_offset.shape + + causal_mask_offset.masked_fill_(causal_mask_offset == -1, L) + + i = causal_mask_offset.unsqueeze(2) # Shape (B, N, 1) + j = causal_mask_offset.unsqueeze(1) # Shape (B, 1, N) + causal_mask = (i >= j) #.int() + causal_mask = ~causal_mask + if num_heads is not None: + B, L, _ = causal_mask.shape + causal_mask = causal_mask.unsqueeze(1).expand(B, num_heads, L, L).reshape(B * num_heads, L, L) + return causal_mask + + +class Tokenizer(nn.Module): + def __init__(self, num_actions, d_model): + super(Tokenizer, self).__init__() + self.tokens = nn.Embedding(num_actions, d_model) # An extra useless dummy token is used for invalid input. + self.num_actions = num_actions + + def forward(self, actions, allow_invalid=False): + if allow_invalid: + actions = actions.clone() + actions[actions < 0] = self.num_actions - 1 + return self.tokens(actions) + + +class SceneEncoderSceneStreamer(nn.Module): + def __init__(self, config): + super().__init__() + + # TODO: Pass this from config or datasource + SCENE_INPUT_TIME_STEPS = 11 + self.total_time_steps = SCENE_INPUT_TIME_STEPS + self.config = config + self.d_model = self.config.MODEL.D_MODEL + self.num_layers = self.config.MODEL.NUM_ATTN_LAYERS + + self.map_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( + in_channels=constants.MAP_FEATURE_STATE_DIM, + hidden_dim=64, + num_layers=2, + num_pre_layers=1, + out_channels=self.d_model + ) + self.agent_mlps = common_layers.build_mlps( + c_in=constants.AGENT_STATE_DIM * SCENE_INPUT_TIME_STEPS, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + self.light_mlps = common_layers.build_mlps( + c_in=constants.TRAFFIC_LIGHT_STATE_DIM * SCENE_INPUT_TIME_STEPS, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + + dropout = self.config.MODEL.DROPOUT_OF_ATTN + self_attn_layers = [] + for _ in range(self.num_layers): + self_attn_layers.append( + TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.config.MODEL.NUM_ATTN_HEAD, + dim_feedforward=self.d_model * 4, + dropout=dropout, + batch_first=True, + pre_projection=self.config.MODEL.get('PRE_PROJECTION', False), + relative_pe=self.config.MODEL.get('RELATIVE_PE', False), + ) + ) + + self.self_attn_layers = nn.ModuleList(self_attn_layers) + # self.agent_pe = nn.Embedding(self.config.PREPROCESSING.MAX_AGENTS, self.d_model) + + self.out = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model], + ret_before_act=True, + ) + + self.relative_pe = self.config.MODEL.get('RELATIVE_PE', False) + + def forward(self, input_dict): + + # ===== Get shape ===== + B, T, N, D_agent = input_dict["encoder/agent_feature"].shape + _, M, num_vector, D_vector = input_dict["encoder/map_feature"].shape + _, _, L, D_light = input_dict["encoder/traffic_light_feature"].shape + + # ===== Embed agent feature ===== + agent_feature = input_dict["encoder/agent_feature"] + agent_valid_mask = input_dict["encoder/agent_valid_mask"] + agent_position = input_dict["encoder/agent_position"] + agent_heading = input_dict["encoder/agent_heading"] + # agent_id = input_dict["encoder/agent_id"] + assert agent_feature.shape[:3] == agent_position.shape[:3] == agent_valid_mask.shape[:3] + agent_feature = ( + agent_feature[:, :self.total_time_steps] * agent_valid_mask[:, :self.total_time_steps, ..., None] + ) + agent_feature = collapse_time(agent_feature) + agent_token = self.agent_mlps(agent_feature) # (B, N, D) + + # Add: + # agent_pe = self.agent_pe(agent_id) # (B, N, D) + # agent_token += agent_pe + agent_pe = input_dict["encoder/agent_pe"] + agent_token += agent_pe + + assert agent_token.shape == (B, N, self.d_model) + + # ===== Embed map feature ===== + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_position = input_dict["encoder/map_position"] + map_heading = input_dict["encoder/map_heading"] + map_token_valid_mask = input_dict["encoder/map_valid_mask"] + map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + + # Add: + map_pe = input_dict["encoder/map_pe"] + map_token += map_pe + + assert map_token.shape == (B, M, self.d_model) + + # ===== Embed traffic light ===== + traffic_light_feature = input_dict["encoder/traffic_light_feature"] + traffic_light_position = input_dict["encoder/traffic_light_position"] + traffic_light_heading = input_dict["encoder/traffic_light_heading"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + if L != 0: + traffic_light_feature = ( + traffic_light_feature[:, :self.total_time_steps] * + traffic_light_valid_mask[:, :self.total_time_steps, ..., None] + ) + traffic_light_feature = collapse_time(traffic_light_feature) + traffic_light_token = self.light_mlps(traffic_light_feature) + else: + traffic_light_token = traffic_light_feature.new_zeros([B, L, self.d_model]) + assert traffic_light_token.shape == (B, L, self.d_model) + + # ===== Call transformer layers ===== + x = torch.concatenate([map_token, agent_token, traffic_light_token], dim=1) + x_pos = torch.concatenate( + [map_position, agent_position[:, self.total_time_steps], traffic_light_position], dim=1 + ) + + x_mask = torch.concatenate( + [ + map_token_valid_mask, agent_valid_mask[:, self.total_time_steps], + traffic_light_valid_mask[:, self.total_time_steps] + ], + dim=1 + ) + assert torch.all(x_mask.sum(dim=-1) > 0) + + if self.relative_pe: + x_heading = torch.concatenate( + [map_heading, agent_heading[:, self.total_time_steps], traffic_light_heading], dim=1 + ) + relation, rel_mask, indices = compute_relation( + pos=x_pos, + heading=x_heading, + mask=x_mask, + hidden_dim=self.d_model, + knn=self.config.MODEL.get('KNN', 128) + ) + pos_embedding = None + else: + relation = None + pos_embedding = position_encoding_utils.gen_sineembed_for_position(x_pos[..., 0:2], hidden_dim=self.d_model) + + for k in range(len(self.self_attn_layers)): + # inp = self._add_pe(x, pos_embedding) + x = self.self_attn_layers[k]( + tgt=x, + pos=pos_embedding, + tgt_key_padding_mask=~x_mask, + relation=relation, + relation_mask=rel_mask, + relation_indices=indices, + ) + + # x = torch.cat([x, pos_embedding], dim=-1) + x = self.out(x.reshape(-1, x.shape[-1])).reshape(list(x.shape[:-1]) + [self.d_model]) + + if pos_embedding is not None: + x = x + pos_embedding + + input_dict["encoder/scenario_token"] = x + if self.relative_pe: + input_dict["encoder/scenario_position"] = x_pos + input_dict["encoder/scenario_heading"] = x_heading + input_dict["encoder/scenario_valid_mask"] = x_mask + + # Add: + # input_dict["encoder/modeled_agent_pe"] = self.agent_pe(input_dict["encoder/modeled_agent_id"]) + input_dict["encoder/map_pe"] = map_pe + + return input_dict + + +class MotionDecoder(nn.Module): + def __init__(self, config, num_actions): + super().__init__() + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + + # TODO: Pass through config. + # self.num_actions = 169 + # num_pred_steps = 16 + 1 # TODO: FIXME: How to change this to support scenestreamer???? + + pre_projection = self.config.MODEL.get('PRE_PROJECTION', False) + + self.relative_pe = self.config.MODEL.get('RELATIVE_PE_DECODER', False) + + dropout = self.config.MODEL.get('DROPOUT_OF_ATTN', 0.1) + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + self.decoder = TransformerDecoder( + decoder_layer=TransformerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dim_feedforward=d_model * 4, + dropout=dropout, + activation="relu", + pre_projection=pre_projection + ), + num_layers=num_decoder_layers, + relative_pe=self.relative_pe, + d_model=d_model, + self_attention_knn=self.config.MODEL['SELF_ATTN_KNN'], + cross_attention_knn=self.config.MODEL['CROSS_ATTN_KNN'] + ) + self.prediction_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, d_model, num_actions], ret_before_act=True + ) + + num_pred_steps = 100 # TODO: Is it enough? What should we do for this? + self.step_pe = nn.Embedding(num_pred_steps, d_model) + + intra_step_tokens = 1000 + self.intra_step_pe = nn.Embedding(intra_step_tokens, d_model) + + def forward(self, input_dict, use_cache=False): + # === Process scene embedding === + scene_token = input_dict["encoder/scenario_token"] + scenario_valid_mask = input_dict["encoder/scenario_valid_mask"] + modeled_agent_pe = input_dict["encoder/modeled_agent_pe"] + scene_padding_mask = ~scenario_valid_mask + + # === Process action embedding === + input_token = input_dict["decoder/input_token"] + + step_pe = self.step_pe(input_dict["decoder/input_step"]) + input_token += step_pe + + intra_step_pe = self.intra_step_pe(input_dict["decoder/input_intra_step"]) + input_token += intra_step_pe + + input_token_valid_mask = input_dict["decoder/input_token_valid_mask"] + + # assert action_token.shape == (B, T_skipped, N, self.d_model) + # assert modeled_agent_pe.shape == (B, N, self.d_model), modeled_agent_pe.shape + # action_token += modeled_agent_pe[:, None] + + casual_mask = create_causal_mask(input_dict["decoder/causal_mask_offset"], num_heads=self.num_heads) + + action_padding_mask = ~input_token_valid_mask # (T_skipped, N) + # Flatten action token from (B, T_skipped, N, D) to (B, T_skipped*N, D) + # action_token = action_token.flatten(1, 2) + # Flatten action token from (B, T_skipped, N) to (B, T_skipped*N) + # action_padding_mask = action_padding_mask.flatten(1, 2) + + # Cache from last rollout + past_key_value = None + if "decoder/cache" in input_dict: + past_key_value = input_dict["decoder/cache"] + + # === Call models === + decoded_tokens = self.decoder( + tgt=input_token.swapaxes(0, 1), + tgt_mask=casual_mask, # swapaxes(0, 1), + tgt_key_padding_mask=action_padding_mask, + tgt_is_causal=True, + memory=scene_token.swapaxes(0, 1), + memory_mask=None, # The casual mask for memory + memory_key_padding_mask=scene_padding_mask, + memory_is_causal=False, + past_key_value=past_key_value, + use_cache=use_cache + ) + + if use_cache: + decoded_tokens, past_key_value = decoded_tokens + input_dict["decoder/cache"] = past_key_value + + decoded_tokens = decoded_tokens.swapaxes(0, 1) + logits = self.prediction_head(decoded_tokens) # TODO: We can do a masking here to reduce the computation. + # logits = logits.reshape(B, T_skipped, N, self.num_actions) + + input_dict["decoder/output_logit"] = logits + + return input_dict + + +class GenModel(MotionLM): + def __init__(self, config): + super().__init__(config) + self.config = config + self.d_model = self.config.MODEL.D_MODEL + + self.scene_encoder = SceneEncoderSceneStreamer(config=self.config) + + num_actions = SceneStreamerTokenizer.get_num_actions(config) + + self.tokenizer = Tokenizer(num_actions=num_actions, d_model=self.d_model) + self.motion_decoder = MotionDecoder(config=self.config, num_actions=num_actions) + + def encode_scene(self, input_dict): + B, M, _, _ = input_dict["encoder/map_feature"].shape + + map_id = torch.arange(M).to(input_dict["encoder/map_feature"].device).reshape(1, M).repeat(B, 1) + map_id.masked_fill_(input_dict["encoder/map_valid_mask"], 0) + map_id = SceneStreamerTokenizer.get_map_id(map_id, self.config) + map_pe = self.tokenizer(map_id) + input_dict["encoder/map_pe"] = map_pe + + agent_id = SceneStreamerTokenizer.get_agent_id(input_dict["encoder/agent_id"], self.config, allow_invalid=False) + agent_pe = self.tokenizer(agent_id) + input_dict["encoder/agent_pe"] = agent_pe + + modeled_agent_id = SceneStreamerTokenizer.get_agent_id( + input_dict["encoder/modeled_agent_id"], self.config, allow_invalid=False + ) + modeled_agent_pe = self.tokenizer(modeled_agent_id) + input_dict["encoder/modeled_agent_pe"] = modeled_agent_pe + + return self.scene_encoder(input_dict) + + def decode_motion(self, data_dict, use_cache=False, in_evaluation=False): + data_dict["decoder/input_token"] = self.tokenizer(data_dict["decoder/input_token_id"], allow_invalid=True) + data_dict = self.motion_decoder(data_dict, use_cache=use_cache) + + is_agent_tokens = GenTokenizer.is_agent_tokens(data_dict["decoder/input_token_id"], self.config) + motion_token_valid_mask = torch.logical_and(data_dict["decoder/input_token_valid_mask"], is_agent_tokens) + + B, L, D = data_dict["decoder/output_logit"].shape + _, T_plus_1, N = data_dict["decoder/input_action"].shape + + motion_logits = data_dict["decoder/output_logit"].new_zeros(B, T_plus_1, N, D) + + # all valid actions in (B, 17, N). + # motion_logits[:, :-1][data_dict["decoder/input_action_valid_mask"][:, :-1]] = + + if in_evaluation: + motion_logits[data_dict["decoder/input_action_valid_mask"]] = \ + data_dict["decoder/output_logit"][motion_token_valid_mask] + # motion_logits[:, :-1][data_dict["decoder/input_action_valid_mask"][:, :-1]] = \ + # data_dict["decoder/output_logit"][motion_token_valid_mask] + else: + motion_logits[:, :-1][data_dict["decoder/input_action_valid_mask"][:, :-1]] = \ + data_dict["decoder/output_logit"][motion_token_valid_mask] + + # TODO: FIXME: Do we want to implement the "masking out invalid actions" here? + # Should we implement the "masking out invalid actions" in loss? + data_dict["decoder/output_logit"] = motion_logits + + return data_dict + + def autoregressive_rollout( + self, + input_dict, + num_decode_steps, + num_prev_steps=1, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None + ): + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + # B, T_input, N = input_dict["decoder/input_action"].shape + assert num_decode_steps >= 1 + # assert input_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + # assert T_input >= num_prev_steps + + # Record "current" valid mask of input actions, we'll repeat it for each decoding step. + # input_action_valid_mask = torch.clone( + # input_dict["decoder/input_action_valid_mask"][:, num_prev_steps - 1:num_prev_steps] + # ) + + # Discard future actions / mask + # input_dict["decoder/input_action"] = input_dict["decoder/input_action"][:, :num_prev_steps] + # input_dict["decoder/input_action_valid_mask"] = \ + # input_dict["decoder/input_action_valid_mask"][:, :num_prev_steps] + + B, _, N = input_dict["decoder/input_action"].shape + + # Get scene embedding + input_dict = self.encode_scene(input_dict) + output_logit_list = [] + output_logit_masked_list = [] + output_action_list = [] + + device = input_dict["decoder/input_action"].device + + # Record "current" valid mask of input actions, we'll repeat it for each decoding step. + # input_action_valid_mask = torch.clone( + # input_dict["decoder/input_action_valid_mask"][:, num_prev_steps - 1:num_prev_steps] + # ) + # # Discard future actions / mask + # input_dict["decoder/input_action"] = input_dict["decoder/input_action"][:, :num_prev_steps] + # input_dict["decoder/input_action_valid_mask"] = \ + # input_dict["decoder/input_action_valid_mask"][:, :num_prev_steps] + + # === prepare those reusable tokens that will be appended at the end of the sequence at each step === + # [STEP_START, UPDATE_START, (AGENT_ID * N), ] + pre_action_tokens = Tokens.create( + ids=input_dict["decoder/input_token_id"].clone(), + mask=input_dict["decoder/input_token_valid_mask"].clone(), + causal_mask_offset=input_dict["decoder/causal_mask_offset"].clone(), + length=input_dict["decoder/input_token_id"].shape[1] + ) + + # [UPDATE_END, STEP_END, ] + update_end_tokens = Tokens.concatenate( + [ + GenTokenizer.get_update_end_tokens(), + GenTokenizer.get_step_end_tokens(), + ] + ).to_tensor( + batch_size=B, device=device + ) + + # [UPDATE_END, STEP_END, STEP_START, UPDATE_START, (AGENT_ID * N),] + post_action_tokens = Tokens.concatenate([update_end_tokens, pre_action_tokens]) + + pre_intra_steps = input_dict["decoder/input_token_id"].shape[1] # 0, 1, ..., 129 (130 steps) + + # intra step for new tokens + intra_steps = torch.cat( + [ + torch.arange(pre_intra_steps, pre_intra_steps + N + update_end_tokens.length), + torch.arange(pre_intra_steps), + ] + ).to(device).reshape(1, -1).expand(B, -1) + + action_id_min, action_id_max = GenTokenizer.get_action_id_range(self.config) + # You can't select the noop action. So we force: + action_id_max = action_id_max - 1 + + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + if not use_cache: + raise ValueError() + input_dict["decoder/input_step"] = input_step[:decode_step + 1] + + # Decode motion ids + input_dict = self.decode_motion(input_dict, use_cache=use_cache, in_evaluation=True) + + output_token = input_dict["decoder/output_logit"] + + if use_cache: + assert output_token.shape[:3] == (B, 1, N) + else: + assert output_token.shape[:3] == (B, decode_step + 1, N) + output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + + output_logit_list.append(output_token.clone()) + + # mask out invalid actions + output_token[..., :action_id_min].fill_(-1e9) + output_token[..., action_id_max:].fill_(-1e9) + output_logit_masked_list.append(output_token) + + # Sample the action + if sampling_method == "argmax": + selected_action = output_token.argmax(-1) + elif sampling_method == "softmax": + selected_action = torch.distributions.Categorical(logits=output_token / temperature).sample() + elif sampling_method == "topp": + selected_action = nucleus_sampling(logits=output_token / temperature, p=topp) + else: + raise ValueError("Unknown sampling method: {}".format(sampling_method)) + + assert selected_action.max() < action_id_max + assert selected_action.min() >= action_id_min + + output_action_list.append(selected_action) + + action_tokens = Tokens.create( + ids=selected_action.reshape(B, N), + mask=input_dict["decoder/input_action_valid_mask"].reshape(B, N), + causal_mask_offset=selected_action.new_ones(B, N).fill_(N).int(), + length=N + ) + input_tokens = Tokens.concatenate([action_tokens, post_action_tokens]) + + if use_cache: + # Discard the previous ids whose key/value are cached. + input_dict["decoder/input_token_id"] = input_tokens.ids + input_dict["decoder/input_token_valid_mask"] = input_tokens.mask + input_dict["decoder/input_step"] = torch.ones_like(input_tokens.ids).fill_(decode_step + 1) + input_dict["decoder/input_intra_step"] = intra_steps + input_dict["decoder/causal_mask_offset"] = input_tokens.causal_mask_offset + + else: + raise ValueError() + input_dict["decoder/input_token_id"] = torch.cat( + [input_dict["decoder/input_token_id"], new_tokens], dim=1 + ) + input_dict["decoder/input_action_valid_mask"] = torch.cat( + [input_dict["decoder/input_action_valid_mask"], step_valid_mask], dim=1 + ) + + assert input_dict["decoder/input_action"].shape == input_dict["decoder/input_action_valid_mask"].shape + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + output_logit_masked_list = torch.concatenate(output_logit_masked_list, dim=1) + input_dict["decoder/output_logit"] = output_logit_list + + # Need to translate back to normal action range + input_dict["decoder/output_action"] = output_action_list - action_id_min + assert input_dict["decoder/output_action"].min() >= 0 + assert input_dict["decoder/output_action"].max() < NUM_ACTIONS + 1 # There is also a noop action. + + # TODO: Study which one is better + # input_dict["decoder/output_score"] = calculate_trajectory_probabilities( + # output_logit_list, output_action_list, mask=input_dict["decoder/input_action_valid_mask"] + # ) # (B, N) + input_dict["decoder/output_score"] = calculate_trajectory_probabilities( + output_logit_masked_list, output_action_list, mask=input_dict["decoder/input_action_valid_mask"] + ) # (B, N) + + return input_dict + + +class SceneStreamerModel(GenModel): + def decode_motion(self, data_dict, use_cache=False, in_evaluation=False): + """ + Do not do any postprocessing, just through away the logits. + """ + data_dict["decoder/input_token"] = self.tokenizer(data_dict["decoder/input_token_id"], allow_invalid=True) + data_dict = self.motion_decoder(data_dict, use_cache=use_cache) + return data_dict + + def autoregressive_rollout( + self, + input_dict, + num_decode_steps, + num_prev_steps=1, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None + ): + """ + GenModel: + step 0 input: STEP_START, UPDATE_START, (AGENT_ID * N) + predict 0: (ACTION_ID * N) + step 1 input: (ACTION_ID * N), UPDATE_END, STEP_END, STEP_START, UPDATE_START, (AGENT_ID * N) + ... + + SceneStreamerModel: + step 0 input: STEP_START, ADD_START, (some tokens for add), ADD_END, UPDATE_START, (AGENT_ID * N) + predict 0: (ACTION_ID * N) + step 1 input: (ACTION_ID * N), UPDATE_END, STEP_END, STEP_START, UPDATE_START, (AGENT_ID * N) + (that is, we just pretent the model will never remove or add new agents) + """ + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + assert num_decode_steps >= 1 + + B, _, N = input_dict["decoder/input_action"].shape + + # Get scene embedding + input_dict = self.encode_scene(input_dict) + output_logit_list = [] + output_logit_masked_list = [] + output_action_list = [] + + device = input_dict["decoder/input_action"].device + + # === prepare those reusable tokens that will be appended at the end of the sequence at each step === + + B = input_dict["decoder/input_token_valid_mask"].shape[0] + + # (B,) + seq_start_indices = input_dict["decoder/input_token_valid_mask"].new_zeros(B).long() + seq_end_indices = input_dict["decoder/input_token_valid_mask"].sum(-1) + # (B,) + num_motions = input_dict["eval/should_predict_motion"].sum(-1) + + from scenestreamer.utils.autoregressive_rollout import ARRollout + + map_ids = [input_dict["encoder/map_valid_mask"][i].nonzero()[:, 0] for i in range(B)] + agent_ids = [ + input_dict["encoder/agent_id"][i][input_dict["decoder/input_action_valid_mask"][i, 0]] for i in range(B) + ] + rollout = ARRollout( + init_tokens=input_dict["decoder/input_token_id"], + init_valid_mask=input_dict["decoder/input_token_valid_mask"], + causal_mask_offset=input_dict["decoder/causal_mask_offset"], + config=self.config, + map_ids=map_ids, + agent_ids=input_dict["encoder/agent_id"], + ) + + # + # pre_action_tokens = Tokens.create( + # ids=input_dict["decoder/input_token_id"].clone(), + # mask=input_dict["decoder/input_token_valid_mask"].clone(), + # causal_mask_offset=input_dict["decoder/causal_mask_offset"].clone(), + # length=input_dict["decoder/input_token_id"].shape[1] + # ) + # + # # [UPDATE_END, STEP_END, ] + # update_end_tokens = Tokens.concatenate( + # [ + # GenTokenizer.get_update_end_tokens(), + # GenTokenizer.get_step_end_tokens(), + # ] + # ).to_tensor(batch_size=B, device=device) + # + # # [UPDATE_END, STEP_END, STEP_START, UPDATE_START, (AGENT_ID * N),] + # post_action_tokens = Tokens.concatenate([update_end_tokens, pre_action_tokens]) + # + # + # pre_intra_steps = input_dict["decoder/input_token_id"].shape[1] # 0, 1, ..., 129 (130 steps) + # + # # intra step for new tokens + # intra_steps = torch.cat([ + # torch.arange(pre_intra_steps, pre_intra_steps + N + update_end_tokens.length), + # torch.arange(pre_intra_steps), + # ]).to(device).reshape(1, -1).expand(B, -1) + # + action_id_min, action_id_max = SceneStreamerTokenizer.get_action_id_range(self.config) + # # You can't select the noop action. So we force: + # action_id_max = action_id_max - 1 + # + # assert len(np.unique(input_dict["scenario_id"])) == 1 + + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + if not use_cache: + raise ValueError() + input_dict["decoder/input_step"] = input_step[:decode_step + 1] + + # input_tokens = rollout.get_tokens() + + # Decode motion ids + input_dict = self.decode_motion(input_dict, use_cache=use_cache, in_evaluation=True) + + output_token = input_dict["decoder/output_logit"] + + # No matter what is it, just treat the output tokens as they are motion tokens. + # output_token[..., :action_id_min].fill_(-1e9) + # output_token[..., action_id_max:].fill_(-1e9) + # output_logit_masked_list.append(output_token) # TODO: Deal with scores + # + # # Sample the action + # if sampling_method == "argmax": + # selected_action = output_token.argmax(-1) + # elif sampling_method == "softmax": + # selected_action = torch.distributions.Categorical(logits=output_token / temperature).sample() + # elif sampling_method == "topp": + # selected_action = nucleus_sampling(logits=output_token / temperature, p=topp) + # else: + # raise ValueError("Unknown sampling method: {}".format(sampling_method)) + # + # assert selected_action.max() < action_id_max + # assert selected_action.min() >= action_id_min + + # Have a for loop here.... + # new_tokens = [] + # for b in range(B): + # motion_start = seq_end_indices[b] - num_motions[b] + # motion_end = seq_end_indices[b] + # a = selected_action[b,motion_start: motion_end] + # new_tokens.append(a) + + rollout.update(output_token) + input_tokens = rollout.get_tokens() + # print(1111) + + # output_action_list.append(selected_action) + # + # action_tokens = Tokens.create( + # ids=selected_action.reshape(B, N), + # mask=input_dict["decoder/input_action_valid_mask"].reshape(B, N), + # causal_mask_offset=selected_action.new_ones(B, N).fill_(N).int(), + # length=N + # ) + # input_tokens = Tokens.concatenate([action_tokens, post_action_tokens]) + + if use_cache: + # Discard the previous ids whose key/value are cached. + input_dict["decoder/input_token_id"] = input_tokens.ids + input_dict["decoder/input_token_valid_mask"] = input_tokens.mask + input_dict["decoder/input_step"] = torch.ones_like(input_tokens.ids).fill_(decode_step + 1) + input_dict["decoder/input_intra_step"] = intra_steps + input_dict["decoder/causal_mask_offset"] = input_tokens.causal_mask_offset + + else: + raise ValueError() + input_dict["decoder/input_token_id"] = torch.cat( + [input_dict["decoder/input_token_id"], new_tokens], dim=1 + ) + input_dict["decoder/input_action_valid_mask"] = torch.cat( + [input_dict["decoder/input_action_valid_mask"], step_valid_mask], dim=1 + ) + + assert input_dict["decoder/input_action"].shape == input_dict["decoder/input_action_valid_mask"].shape + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + output_logit_masked_list = torch.concatenate(output_logit_masked_list, dim=1) + input_dict["decoder/output_logit"] = output_logit_list + + # Need to translate back to normal action range + input_dict["decoder/output_action"] = output_action_list - action_id_min + assert input_dict["decoder/output_action"].min() >= 0 + assert input_dict["decoder/output_action"].max() < NUM_ACTIONS + 1 # There is also a noop action. + + # TODO: Study which one is better + # input_dict["decoder/output_score"] = calculate_trajectory_probabilities( + # output_logit_list, output_action_list, mask=input_dict["decoder/input_action_valid_mask"] + # ) # (B, N) + input_dict["decoder/output_score"] = calculate_trajectory_probabilities( + output_logit_masked_list, output_action_list, mask=input_dict["decoder/input_action_valid_mask"] + ) # (B, N) + + return input_dict + + +if __name__ == '__main__': + + import torch + from tqdm import tqdm + + from scenestreamer.dataset.datamodule import SceneStreamerDataModule + from scenestreamer.utils import debug_tools + + cfg_file = "cfgs/motion_debug_2_local_train.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + + config.MODEL.update(dict( + D_MODEL=512, + NUM_ATTN_LAYERS=6, + NUM_ATTN_HEAD=8, + NUM_DECODER_LAYERS=6, + )) + + config.MODEL.NAME = "gen" + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=2, + train_num_workers=0, + val_batch_size=2, + val_num_workers=0, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + model = GenModel(config) + model.eval() + # model.cuda() + + for data_dict in tqdm(dataloader): + # GenTokenizer.get_token_names(data_dict["decoder/input_token_id"], config) + + # data_dict = model(data_dict) + + model.autoregressive_rollout(data_dict, num_decode_steps=16, sampling_method="topp") + + # gt = data_dict["decoder/target_action"] + # gt_mask = data_dict["decoder/target_action_valid_mask"] + + # gt = data_dict["decoder/input_token_id"][is_motion_token] + # logits = data_dict["decoder/output_logit"][is_motion_token] + + # we need to reconstruct the "output_logit" in shape (B, T, N, D) so we can match decoder/target_action + + print(1) diff --git a/scenestreamer/models/gpt_scene_encoder.py b/scenestreamer/models/gpt_scene_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd5a18deb0acf616bff2e973f97116e0c28930c --- /dev/null +++ b/scenestreamer/models/gpt_scene_encoder.py @@ -0,0 +1,308 @@ +import torch +import torch.nn as nn + +from scenestreamer.dataset import constants +from scenestreamer.models import relation +from scenestreamer.models.layers import polyline_encoder, common_layers, fourier_embedding +from scenestreamer.models.layers.gpt_encoder_layer import SelfAttTransformerEncoder, SelfAttTransformerEncoderLayer +from scenestreamer.models.motion_decoder_gpt import get_edge_info, get_edge_info_new +from scenestreamer.models.ops.collapse_time import collapse_time +from scenestreamer.models.scene_encoder import find_last_valid, mode_agent_id +from scenestreamer.utils import utils + + +class SceneEncoderGPT(nn.Module): + def __init__(self, config): + super().__init__() + # TODO: Pass this from config or datasource + SCENE_INPUT_TIME_STEPS = 11 + self.history_steps = SCENE_INPUT_TIME_STEPS + self.config = config + self.d_model = self.config.MODEL.D_MODEL + self.num_layers = self.config.MODEL.NUM_ATTN_LAYERS + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + + dropout = self.config.MODEL.DROPOUT + + is_v7 = self.config.MODEL.IS_V7 + self.is_v7 = is_v7 + + self.map_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( + in_channels=constants.MAP_FEATURE_STATE_DIM, + hidden_dim=64, + num_layers=2, + num_pre_layers=1, + out_channels=self.d_model, + is_v7=is_v7 + ) + + if self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE: + # The input is all zeros, so we can just use a single layer MLP. + self.light_mlps = common_layers.build_mlps( + c_in=constants.TRAFFIC_LIGHT_STATE_DIM, mlp_channels=[self.d_model], ret_before_act=True, is_v7=is_v7 + ) + else: + self.light_mlps = common_layers.build_mlps( + c_in=constants.TRAFFIC_LIGHT_STATE_DIM * SCENE_INPUT_TIME_STEPS, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + is_v7=is_v7 + ) + + simple_relation_factor = self.config.SIMPLE_RELATION_FACTOR + simple_relation = self.config.SIMPLE_RELATION + if self.config.SIMPLE_RELATION: + relation_d_model = self.d_model // simple_relation_factor + self.relation_embed = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + + if self.config.MODEL.ADD_RELATION_TO_V: + self.relation_embed_v = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + + else: + self.relation_embed = fourier_embedding.FourierEmbedding( + input_dim=4, hidden_dim=self.d_model, num_freq_bands=64, is_v7=is_v7 + ) + assert self.config.MODEL.NAME in ['gpt'] + self.encoder = SelfAttTransformerEncoder( + decoder_layer=SelfAttTransformerEncoderLayer( + d_model=self.d_model, + nhead=self.num_heads, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + dropout=dropout, + is_v7=is_v7, + update_relation=self.config.UPDATE_RELATION, + add_relation_to_v=self.config.MODEL.ADD_RELATION_TO_V, + remove_rel_norm=self.config.REMOVE_REL_NORM, + ), + num_layers=self.num_layers, + ) + self.out = common_layers.build_mlps( + c_in=self.d_model, mlp_channels=[self.d_model], ret_before_act=True, is_v7=is_v7 + ) + self.out_prenorm = nn.LayerNorm(self.d_model) + + self.use_agent_history = not self.config.REMOVE_AGENT_FROM_SCENE_ENCODER + if self.use_agent_history: + self.agent_pe = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_AGENTS, d_model=self.d_model, add_one_more_action=False + ) + self.agent_mlps = common_layers.build_mlps( + c_in=constants.AGENT_STATE_DIM * SCENE_INPUT_TIME_STEPS, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + + def encode_agent_history(self, input_dict): + B, T, N, D_agent = input_dict["encoder/agent_feature"].shape + in_evaluation = input_dict["in_evaluation"][0].item() + + # ===== Embed agent feature ===== + agent_feature = input_dict["encoder/agent_feature"] + agent_valid_mask = input_dict["encoder/agent_valid_mask"] + agent_position = input_dict["encoder/agent_position"] + agent_heading = input_dict["encoder/agent_heading"] + agent_id = input_dict["encoder/agent_id"].clone() + assert agent_feature.shape[:3] == agent_position.shape[:3] == agent_valid_mask.shape[:3] + agent_feature = (agent_feature[:, :self.history_steps] * agent_valid_mask[:, :self.history_steps, ..., None]) + agent_feature = collapse_time(agent_feature) + agent_token = self.agent_mlps(agent_feature) # (B, N, D) + if in_evaluation: + # Exempt filtering for maximum number of agents, so agent_id might be out of bound. + agent_id = mode_agent_id(agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True) + # Exempt filtering for maximum number of agents, so agent_id might be out of bound. + modeled_agent_id = mode_agent_id( + input_dict["encoder/modeled_agent_id"].clone(), + self.config.PREPROCESSING.MAX_AGENTS, + fill_negative_1=True + ) + else: + modeled_agent_id = input_dict["encoder/modeled_agent_id"].clone() + + if self.config.MODEL.RANDOMIZE_AGENT_ID: + weights = torch.ones(self.config.PREPROCESSING.MAX_AGENTS).expand(B, -1) + if N > self.config.PREPROCESSING.MAX_AGENTS: + new_encoder_agent_id = torch.full_like(agent_id, -1) + num_samples = self.config.PREPROCESSING.MAX_AGENTS + new_encoder_agent_id[:, :num_samples] = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(agent_id) + assert (agent_id[:, num_samples:] == -1).all() + else: + num_samples = N + new_encoder_agent_id = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(agent_id) + new_encoder_agent_id[agent_id == -1] = -1 + input_dict["encoder/randomized_agent_id"] = new_encoder_agent_id + agent_id = new_encoder_agent_id + + modeled_agent_mask = modeled_agent_id == -1 + modeled_agent_id[modeled_agent_mask] = N - 1 # Quick workaround + new_modeled_agent_id = torch.gather(new_encoder_agent_id, dim=1, index=modeled_agent_id) + new_modeled_agent_id[modeled_agent_mask] = -1 + input_dict["encoder/randomized_modeled_agent_id"] = new_modeled_agent_id + modeled_agent_id = new_modeled_agent_id + else: + raise ValueError("Please turn on MODEL.RANDOMIZE_AGENT_ID=True") + + agent_pe = self.agent_pe(agent_id) # (B, N, D) + agent_token += agent_pe + assert agent_token.shape == (B, N, self.d_model) + + agent_pos = find_last_valid(agent_position[:, :self.history_steps], agent_valid_mask[:, :self.history_steps])[:, + 0] + agent_mask = agent_valid_mask[:, :self.history_steps].any(dim=1) + agent_heading = find_last_valid( + agent_heading[:, :self.history_steps, ..., None], agent_valid_mask[:, :self.history_steps] + )[:, 0, :, 0] + + input_dict["encoder/modeled_agent_pe"] = self.agent_pe(modeled_agent_id) + + return agent_token, agent_pos, agent_mask, agent_heading + + def forward(self, input_dict): + # ===== Get shape ===== + B, M, num_vector, D_vector = input_dict["encoder/map_feature"].shape + L, D_light = input_dict["encoder/traffic_light_feature"].shape[-2:] + + # ===== Embed map feature ===== + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_position = input_dict["encoder/map_position"] + map_heading = input_dict["encoder/map_heading"] + map_token_valid_mask = input_dict["encoder/map_valid_mask"] + map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + assert map_token.shape == (B, M, self.d_model) + + # ===== Embed traffic light ===== + traffic_light_feature = input_dict["encoder/traffic_light_feature"] + traffic_light_position = input_dict["encoder/traffic_light_position"] + traffic_light_heading = input_dict["encoder/traffic_light_heading"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + if L != 0: + if self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE: + traffic_light_feature = traffic_light_feature * traffic_light_valid_mask[..., None] + traffic_light_token = self.light_mlps(traffic_light_feature) + else: + traffic_light_feature = ( + traffic_light_feature[:, :self.history_steps] * + traffic_light_valid_mask[:, :self.history_steps, ..., None] + ) + traffic_light_feature = collapse_time(traffic_light_feature) + traffic_light_token = self.light_mlps(traffic_light_feature) + else: + traffic_light_token = traffic_light_feature.new_zeros([B, L, self.d_model]) + assert traffic_light_token.shape == (B, L, self.d_model), (traffic_light_token.shape, B, L, self.d_model) + if self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE: + assert traffic_light_valid_mask.ndim == 2 + tlmask = traffic_light_valid_mask + else: + tlmask = traffic_light_valid_mask[:, :self.history_steps].any(dim=1) + + x = [map_token, traffic_light_token] + x_pos = [map_position, traffic_light_position] + x_heading = [map_heading, traffic_light_heading] + x_mask = [map_token_valid_mask, tlmask] + if self.use_agent_history: + agent_token, agent_pos, agent_mask, agent_heading = self.encode_agent_history(input_dict=input_dict) + x.append(agent_token) + x_pos.append(agent_pos) + x_mask.append(agent_mask) + x_heading.append(agent_heading) + + # ===== Call transformer layers ===== + x = torch.concatenate(x, dim=1) + x_pos = torch.concatenate(x_pos, dim=1) + x_heading = torch.concatenate(x_heading, dim=1) + x_mask = torch.concatenate(x_mask, dim=1) + + # There something wrong in waymo test set: + # https://github.com/waymo-research/waymo-open-dataset/issues/772 + # And the line below might cause issue if we don't skip scenario before entering here. + assert torch.all(x_mask.sum(dim=-1) > 0) + + if self.config.SIMPLE_RELATION: + relation_func = relation.compute_relation_simple_relation + else: + relation_func = relation.compute_relation + rel_feat, rel_mask, indices = relation_func( + query_pos=x_pos, + query_heading=x_heading, + query_valid_mask=x_mask, + key_pos=x_pos, + key_heading=x_heading, + key_valid_mask=x_mask, + # hidden_dim=self.d_model, + causal_valid_mask=None, + knn=self.config.MODEL.KNN, + max_distance=self.config.MODEL.S2S_DISTANCE, + gather=False, + # return_pe=False, + non_agent_relation=True, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION, + ) + + scene_info = get_edge_info_new( + q_k_valid_mask=rel_mask, + q_k_relation=rel_feat, + relation_model=self.relation_embed, + relation_model_v=self.relation_embed_v if self.config.MODEL.ADD_RELATION_TO_V else None, + ) + # scene_info = get_edge_info(attn_valid_mask=rel_mask, rel_pe_cross=rel_pe, rel_pe_cross_v=rel_pe_v) + + # print("rel_mask.shape", rel_mask.shape, rel_mask.sum(-1).float().mean(), rel_mask.float().mean()) + + # + # from torch.nn.attention.flex_attention import ( + # _DEFAULT_SPARSE_BLOCK_SIZE, + # create_block_mask, + # create_mask, + # flex_attention, + # ) + # from triton.testing import do_bench + + # torch.set_default_device("cuda") + # torch.manual_seed(0) + # + # torch._dynamo.config.cache_size_limit = 1000 + # + # Compile the flex_attention function + # flex_attention = torch.compile(flex_attention, dynamic=False) + + # Define `score_mod` without precomputing softmax_bias + + # Q = x.reshape(B, x.shape[1], 8, -1).swapaxes(1, 2) + # Relation = rel_pe.reshape(B, rel_pe.shape[1], rel_pe.shape[2], 8, -1) + # def score_mod(score, b, h, q_idx, kv_idx): + # bias = Q[b, h, q_idx] @ Relation[b, q_idx, kv_idx, h] + # return score + bias + # + # + # + # flex_attention( + # query=Q, + # key=x.reshape(B, x.shape[1], 8, -1).swapaxes(1, 2), + # value=x.reshape(B, x.shape[1], 8, -1).swapaxes(1, 2), + # score_mod=score_mod, + # ) + + x = self.encoder( + scene_tokens=x, + scene_info=scene_info, + edge_features=scene_info["edge_features"], + edge_features_v=scene_info["edge_features_v"] + ) + x = self.out_prenorm(x[x_mask]) + x = self.out(x) # .reshape(list(x.shape[:-1]) + [self.d_model]) + x = utils.unwrap(x, x_mask) + input_dict["encoder/scenario_token"] = x + input_dict["encoder/map_token"] = x[:, :M] + input_dict["encoder/scenario_position"] = x_pos + input_dict["encoder/scenario_heading"] = x_heading + input_dict["encoder/scenario_valid_mask"] = x_mask + return input_dict diff --git a/scenestreamer/models/language_motionlm.py b/scenestreamer/models/language_motionlm.py new file mode 100644 index 0000000000000000000000000000000000000000..753acaf5dc835a598418bb3ac5f8ab09e06703e9 --- /dev/null +++ b/scenestreamer/models/language_motionlm.py @@ -0,0 +1,267 @@ +import logging + +import numpy as np +import torch +import torch.nn as nn + +from scenestreamer.models.motion_decoder import MotionDecoder +from scenestreamer.models.scene_encoder import SceneEncoder +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import calculate_trajectory_probabilities + +logger = logging.getLogger(__file__) + + +def nucleus_sampling(logits, p=None, epsilon=1e-8): + p = p or 0.9 + + # logits = logits.clamp(-20, 20) + + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits).fill_(-1e9), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits).fill_(-1e9), logits) + + # Adding a small epsilon to logits to avoid log(0) + # logits = logits + epsilon + + # Convert logits to probabilities + probs = torch.softmax(logits, dim=-1) + + # Sort the probabilities to identify the top-p cutoff + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative probability above the threshold p + cutoff_index = cumulative_probs > p + # Shift the mask to the right to keep the first token above the threshold + cutoff_index[..., 1:] = cutoff_index[..., :-1].clone() + cutoff_index[..., 0] = False + + # Zero out the probabilities for tokens not in the top-p set + sorted_probs.masked_fill_(cutoff_index, 0) + + # Recover the original order of the probabilities + original_probs = torch.zeros_like(probs) + original_probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) + + # original_probs += epsilon + + # Sample from the adjusted probability distribution + # try: + sampled_token_index = torch.distributions.Categorical(probs=original_probs).sample() + # except ValueError: + # import ipdb; ipdb.set_trace() + # print(1111111) + + return sampled_token_index + + +class LanguageMotionLM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.scene_encoder = SceneEncoder(config=self.config) + self.motion_decoder = MotionDecoder(config=self.config) + + from transformers import BertModel, BertTokenizer + + self.bert_model = BertModel.from_pretrained("bert-base-uncased") + self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + if not config.FINE_TUNE_BERT: # if not finetuning BERT encoder for our prompt + for param in self.bert_model.parameters(): + param.requires_grad = False + + self.bert_projection = nn.Linear(self.bert_model.config.hidden_size, 512) + + def encode_prompt(self, prompts): + # Tokenize the batch of prompts + encoded_input = self.bert_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True) + + # Ensure the input tensors are on the same device as the model + device = next(self.bert_model.parameters()).device + encoded_input = {key: value.to(device) for key, value in encoded_input.items()} + + # Encode the batch of prompts using BERT + with torch.no_grad(): + bert_output = self.bert_model(**encoded_input) + + # Use the output of the [CLS] token for each prompt in the batch + cls_embedding = bert_output.last_hidden_state[:, 0, :] + + # Project BERT embedding to the desired dimensionality + prompt_embedding = self.bert_projection(cls_embedding) + + return prompt_embedding + + def forward(self, input_dict): + print("in scene encoder, input dict:", input_dict.key()) + + # if self.config.LANGUAGE_CONDITION and 'decoder/text_label' in input_dict: + # prompt = input_dict['decoder/text_label'] + # print("text_label:", text_label) + # prompt_embedding = self.encode_prompt(prompt) + # print("prompt embedding:", prompt_embedding) + # input_dict['decoder/prompt_embedding'] = prompt_embedding + + # else: + # print("NOOOOO") + # print(input_dict.keys()) + # print(self.config.LANGUAGE_CONDITION) + + input_dict = self.encode_scene(input_dict) + input_dict = self.decode_motion(input_dict) + return input_dict + + def encode_scene(self, input_dict): + return self.scene_encoder(input_dict) + + def decode_motion(self, input_dict, use_cache=False): + input_dict = self.motion_decoder(input_dict, use_cache=use_cache) + return input_dict + + def autoregressive_rollout( + self, + input_dict, + num_decode_steps, + num_prev_steps=1, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None + ): + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + B, T_input, N = input_dict["decoder/input_action"].shape + assert num_decode_steps >= 1 + assert input_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + assert T_input >= num_prev_steps + + # Record "current" valid mask of input actions, we'll repeat it for each decoding step. + input_action_valid_mask = torch.clone( + input_dict["decoder/input_action_valid_mask"][:, num_prev_steps - 1:num_prev_steps] + ) + + # Discard future actions / mask + input_dict["decoder/input_action"] = input_dict["decoder/input_action"][:, :num_prev_steps] + input_dict["decoder/input_action_valid_mask"] = \ + input_dict["decoder/input_action_valid_mask"][:, :num_prev_steps] + + if self.config.MODEL.RELATIVE_PE_DECODER: + input_dict["decoder/modeled_agent_heading"] = input_dict["decoder/modeled_agent_heading"][:, :num_prev_steps + ] + input_dict["decoder/modeled_agent_position"] = input_dict["decoder/modeled_agent_position" + ][:, :num_prev_steps] + + tokenizer = get_tokenizer(config=self.config) + + original_data = { + "decoder/current_agent_position": input_dict["decoder/current_agent_position"].clone(), + "decoder/current_agent_heading": input_dict["decoder/current_agent_heading"].clone(), + "decoder/current_agent_velocity": input_dict["decoder/current_agent_velocity"].clone(), + } + + # Get scene embedding + input_step = torch.arange(num_decode_steps).to(input_dict["encoder/agent_position"].device) + + # ================ for language labels condition + try: + prompts = input_dict['decoder/text_label'] + print("text_label:", prompts) + + if isinstance(prompts, np.ndarray): + prompts = prompts.tolist() + + prompt_embedding = self.encode_prompt(prompts) + print("prompt embedding:", prompt_embedding) + input_dict['decoder/prompt_embedding'] = prompt_embedding + + except Exception as e: + import pdb + pdb.set_trace() + exit() + + # =============== + + input_dict = self.encode_scene(input_dict) + output_logit_list = [] + output_action_list = [] + input_dict["decoder/input_step"] = input_step[:1] + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + if not use_cache: + input_dict["decoder/input_step"] = input_step[:decode_step + 1] + + # Decode motion tokens + input_dict = self.decode_motion(input_dict, use_cache=use_cache) + + output_token = input_dict["decoder/output_logit"] + + if use_cache: + assert output_token.shape[:3] == (B, 1, N) + else: + assert output_token.shape[:3] == (B, decode_step + 1, N) + output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + + output_logit_list.append(output_token) + + # Sample the action + if sampling_method == "argmax": + selected_action = output_token.argmax(-1) + elif sampling_method == "softmax": + selected_action = torch.distributions.Categorical(logits=output_token / temperature).sample() + elif sampling_method == "topp": + selected_action = nucleus_sampling(logits=output_token / temperature, p=topp) + else: + raise ValueError("Unknown sampling method: {}".format(sampling_method)) + + output_action_list.append(selected_action) + + if use_cache: + # Discard the previous tokens whose key/value are cached. + input_dict["decoder/input_action"] = selected_action + input_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + input_dict["decoder/input_step"].fill_(decode_step + 1) + + if self.config.MODEL.RELATIVE_PE_DECODER: + reconstructed_pos, reconstructed_heading, reconstructed_vel = tokenizer.detokenize_for_step( + data_dict=input_dict, + action=selected_action, + ) + input_dict["decoder/current_agent_position"] = reconstructed_pos + input_dict["decoder/current_agent_heading"] = reconstructed_heading + input_dict["decoder/current_agent_velocity"] = reconstructed_vel + input_dict["decoder/modeled_agent_heading"] = reconstructed_heading.reshape(B, 1, N) + input_dict["decoder/modeled_agent_position"] = reconstructed_pos.reshape(B, 1, N, 2) + + else: + input_dict["decoder/input_action"] = torch.cat( + [input_dict["decoder/input_action"], selected_action], dim=1 + ) + input_dict["decoder/input_action_valid_mask"] = torch.cat( + [input_dict["decoder/input_action_valid_mask"], input_action_valid_mask], dim=1 + ) + + assert input_dict["decoder/input_action"].shape == input_dict["decoder/input_action_valid_mask"].shape + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + input_dict["decoder/output_logit"] = output_logit_list + input_dict["decoder/output_action"] = output_action_list + input_dict["decoder/output_score"] = calculate_trajectory_probabilities( + output_logit_list, output_action_list, mask=input_action_valid_mask + ) # (B, N) + + if self.config.MODEL.RELATIVE_PE_DECODER: + input_dict["decoder/current_agent_position"] = original_data["decoder/current_agent_position"] + input_dict["decoder/current_agent_heading"] = original_data["decoder/current_agent_heading"] + input_dict["decoder/current_agent_velocity"] = original_data["decoder/current_agent_velocity"] + + return input_dict diff --git a/scenestreamer/models/layers/__deprecated__initializer_predictor.py b/scenestreamer/models/layers/__deprecated__initializer_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..471e2fab5b7d0730462791d650aef85cd1968505 --- /dev/null +++ b/scenestreamer/models/layers/__deprecated__initializer_predictor.py @@ -0,0 +1,80 @@ +import torch.nn as nn + +from scenestreamer import utils +from scenestreamer.dataset import constants +from scenestreamer.models.layers import common_layers + + +class InitializerPredictor(nn.Module): + def __init__(self, d_model, num_modes): + super().__init__() + self.d_model = d_model + self.num_modes = num_modes + self.vel_head = common_layers.build_mlps( + c_in=self.d_model, + # mlp_channels=[self.num_modes * 2], + mlp_channels=[self.d_model, self.d_model, self.num_modes * 6 * 3], + ret_before_act=True, + ) + self.heading_head = common_layers.build_mlps( + c_in=self.d_model, + # mlp_channels=[self.num_modes * 1], + mlp_channels=[self.d_model, self.d_model, self.num_modes * 3 * 3], + ret_before_act=True, + ) + self.pos_head = common_layers.build_mlps( + c_in=self.d_model, + # mlp_channels=[self.num_modes * 2], + mlp_channels=[self.d_model, self.d_model, self.num_modes * 6 * 3], + ret_before_act=True, + ) + self.size_head = common_layers.build_mlps( + c_in=self.d_model, + # mlp_channels=[self.num_modes * 3], + mlp_channels=[self.d_model, self.d_model, self.num_modes * 7 * 3], + ret_before_act=True, + ) + self.score_head = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model, self.d_model, 1 * 3], + ret_before_act=True, + ) + + self.type_head = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model, self.d_model, 5], + ret_before_act=True, + ) + + def _get_dist(self, p): + return utils.get_distribution(p) + + def forward(self, start_tokens, start_token_valid_mask): + B, N, token_dim = start_tokens.shape + # feat = self.feat(start_tokens[start_token_valid_mask]) + feat = start_tokens[start_token_valid_mask] + num_modes = self.num_modes + + pred_vel = utils.unwrap(self.vel_head(feat), + start_token_valid_mask).reshape(B, N, constants.NUM_TYPES, num_modes, 6) + vel_dist = self._get_dist(pred_vel) + + pred_head = utils.unwrap(self.heading_head(feat), + start_token_valid_mask).reshape(B, N, constants.NUM_TYPES, num_modes, 3) + head_dist = self._get_dist(pred_head) + + pred_pos = utils.unwrap(self.pos_head(feat), + start_token_valid_mask).reshape(B, N, constants.NUM_TYPES, num_modes, 6) + pos_dist = self._get_dist(pred_pos) + + pred_size = utils.unwrap(self.size_head(feat), + start_token_valid_mask).reshape(B, N, constants.NUM_TYPES, num_modes, 7) + size_dist = self._get_dist(pred_size) + + map_feat_score = utils.unwrap( + self.score_head(feat), start_token_valid_mask, fill=float("-inf") + ).reshape(B, N, constants.NUM_TYPES, 1) + + actor_type = utils.unwrap(self.type_head(feat), start_token_valid_mask, fill=float("-inf")).reshape(B, N, 5) + + return pred_pos, pred_vel, pred_head, pred_size, map_feat_score, actor_type, pos_dist, vel_dist, head_dist, size_dist diff --git a/scenestreamer/models/layers/__deprecated__multi_head_attention_local.py b/scenestreamer/models/layers/__deprecated__multi_head_attention_local.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0226a3768a58919c2fa19c3afa351155a7a5f4 --- /dev/null +++ b/scenestreamer/models/layers/__deprecated__multi_head_attention_local.py @@ -0,0 +1,209 @@ +""" +Mostly copy-paste from https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/transformer/multi_head_attention.py +""" + +import torch +import torch.nn as nn +from torch.nn import Linear +from torch.nn import functional as F +from torch.nn.init import constant_ +from torch.nn.init import xavier_uniform_ +from torch.nn.parameter import Parameter + +import scenestreamer +from scenestreamer.models.ops.attention import attention_utils_v2 + + +class MultiheadAttentionLocal(nn.Module): + r"""Allows the layers to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + Args: + embed_dim: total dimension of the layers. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + kdim: total number of features in key. Default: None. + vdim: total number of features in key. Default: None. + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + def __init__(self, embed_dim, num_heads, dropout=0.0, without_weight=False, vdim=None): + super(MultiheadAttentionLocal, self).__init__() + self.embed_dim = embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" + + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + self.out_proj = Linear(self.vdim, self.vdim, bias=True) + + self.without_weight = without_weight + if self.without_weight: + self.in_proj_weight = self.in_proj_bias = None + constant_(self.out_proj.bias, 0.0) + else: + self._reset_parameters() + + def _reset_parameters(self): + xavier_uniform_(self.in_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + def _proj_qkv(self, t, start, end): + _w = self.in_proj_weight[start:end, :] + _b = self.in_proj_bias[start:end] + t = F.linear(t, _w, _b) + return t + + def forward( + self, + query, # total_q_num, c + key, # total_k_num, c + value, # total_k_num, c + index_pair, # total_q_num, max_memory_num + query_batch_cnt, # bs: query_amount of each batch + key_batch_cnt, # bs: key_amount of each batch. + index_pair_batch, # total_q_num, batch_index of each query. + attn_mask=None, # total_q_num, max_memory_num + vdim=None, + + # positional encoding setting. + relative_atten_weights=None, # total_q_num, max_memory_num, nhead + + # crpe module. + ctx_rpe_query=None, + ctx_rpe_key=None, + ctx_rpe_value=None, + rpe_distance=None, + **kwargs + ): + r""" To reduce memory cost in attention computation, use index to indicate attention pair. + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + index_pair: the associated key indices of each query for computing attention. + query_batch_cnt: indicate the query_amount in each batch. + key_batch_cnt: indicate the key / value amount in each batch. + index_pair_batch: the batch index of each query. + attn_mask: mask that prevents attention to certain positions. This is an additive mask + (i.e. the values will be added to the attention layer). + relative_atten_weights: Add relative positional encoding. + ctx_rpe_query / ctx_rpe_key / ctx_rpe_value: nn.Module for providing contextual relative positional + encoding given rpe_distance between query and keys. + Shape: + - Inputs: + - query: :math:`(N, C)` where N is the total query tokens length, C is + the embedding dimension. + - key: :math:`(M, C)`, where M is the total key tokens length, C is + the embedding dimension. + - value: :math:`(M, C)` where M is the total value tokens length (equals to ``key''), C is + the embedding dimension. + - index_pair: :math:`(N, L)` where N is the total query tokens length (equals to ``query''), + L is max_key_num for computing attention. + - query_batch_cnt: :math:`(B)` where B indicate batch_size. + - key_batch_cnt: :math:`(B)` where B indicate batch_size. + - index_pair_batch: :math:`(N)` where N is the total query tokens length (equals to ``query'') + - attn_mask: :math:`(N, L)` where N is the total query tokens length (equals to ``query''), + L is max_key_num for computing attention. + - relative_atten_weights: :math:`(N, L, H)` where N is the total query tokens length (equals to ``query''), + L is max_key_num for computing attention, H is head_num for computing attention. + - rpe_distance: :math:`(N, L, 3)` where N is the total query tokens length (equals to ``query''), + L is max_key_num for computing attention. + - Outputs: + - attn_output: :math:`(N, C)` where N is the total query tokens length, + C is the embedding dimension. + - attn_output_weights: :math:`(N, L, H)` where N is the total query tokens length (equals to ``query''), + L is max_key_num for computing attention, H is head_num for computing attention. + """ + total_query_len, embed_dim = query.size() + max_memory_len = index_pair.shape[1] + + if vdim is None: + assert key.size() == value.size() + vdim = embed_dim + v_head_dim = self.head_dim + else: + v_head_dim = vdim // self.num_heads + assert v_head_dim * self.num_heads == vdim + + scaling = float(self.head_dim)**-0.5 + + # generate qkv features. + if not self.without_weight: + q = self._proj_qkv(query, 0, embed_dim) + q = q * scaling + k = self._proj_qkv(key, embed_dim, embed_dim * 2) + v = self._proj_qkv(value, embed_dim * 2, embed_dim * 3) + else: + q = query * scaling + k, v = key, value + + # -1 in index_pair means this key not joining attention computation. + used_attn_mask = (index_pair == -1) # Ignore the -1 pair. + if attn_mask is not None: + # attn_mask should have a shape as [total_query_size, max_memory_size] + attn_mask = attn_mask.to(torch.bool) + used_attn_mask = torch.logical_or(used_attn_mask, attn_mask) + + q = q.contiguous().view(total_query_len, self.num_heads, self.head_dim) + k = k.contiguous().view(-1, self.num_heads, self.head_dim) + v = v.contiguous().view(-1, self.num_heads, v_head_dim) + + # compute attention weight. + attn_output_weights = attention_utils_v2.attention_weight_computation( + query_batch_cnt, key_batch_cnt, index_pair_batch, index_pair, q, k + ) # total_query_len, max_memory_len, num_heads + assert list(attn_output_weights.size()) == [total_query_len, max_memory_len, self.num_heads] + + if ctx_rpe_key is not None: + rpe_attn_weight = ctx_rpe_key( + rpe_distance, k, scaling, query_batch_cnt, key_batch_cnt, index_pair_batch, index_pair + ) + attn_output_weights = attn_output_weights + rpe_attn_weight + if ctx_rpe_query is not None: + rpe_attn_weight = ctx_rpe_query(rpe_distance, q, 1.0, query_batch_cnt) + attn_output_weights = attn_output_weights + rpe_attn_weight + + if relative_atten_weights is not None: + # relative_atten_weights: A float tensor with shape [total_query_num, max_memory_num, nhead] + attn_output_weights = attn_output_weights + relative_atten_weights + + # attn_output_weights: [total_query_num, max_memory_num, nhead] + used_attn_mask = used_attn_mask.unsqueeze(-1).repeat(1, 1, self.num_heads).contiguous() + attn_output_weights.masked_fill_(used_attn_mask, float("-inf")) + + attn_output_weights = F.softmax(attn_output_weights, dim=1) + attn_output_weights = attn_output_weights.to(q) + + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + if ctx_rpe_value is not None: + attn_output = ctx_rpe_value( + rpe_distance, attn_output_weights, v, query_batch_cnt, key_batch_cnt, index_pair_batch, index_pair + ) + else: + attn_output = scenestreamer.models.ops.attention.attention_utils_v2.attention_value_computation( + query_batch_cnt, key_batch_cnt, index_pair_batch, index_pair, attn_output_weights, v + ) + assert list(attn_output.size()) == [total_query_len, self.num_heads, v_head_dim] + + attn_output = attn_output.view(total_query_len, vdim) + + if self.out_proj is not None: + attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) + + return attn_output, attn_output_weights.sum(dim=-1) / self.num_heads diff --git a/scenestreamer/models/layers/__deprecated__our_decoder_layer.py b/scenestreamer/models/layers/__deprecated__our_decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e535a8b6c86cc694c04254d80775fca44d3f3f38 --- /dev/null +++ b/scenestreamer/models/layers/__deprecated__our_decoder_layer.py @@ -0,0 +1,195 @@ +""" +Modified from https://github.com/IDEA-opensource/DAB-DETR/blob/main/models/DAB_DETR/transformer.py + +TODO: Why this decoder layer does not has self-attention? It's pretty weird!!!! +""" + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from .multi_head_attention_local import MultiheadAttentionLocal + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class RelativePE(nn.Module): + """ + Credit: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L122 + """ + def __init__(self, nhead): + super().__init__() + self.relative_position_bias_table = nn.Parameter(torch.zeros(401 * 401, nhead)) # 2*Wh-1 * 2*Ww-1, nH + self.invalid = nn.Parameter(torch.zeros(nhead)) + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + torch.nn.init.trunc_normal_(self.invalid, std=0.02) + + def forward(self, pos, index_pair): + assert pos.ndim == 3 + shape = index_pair.shape + index_pair = index_pair.clone() + mask = index_pair == -1 + index_pair[mask] = 0 + index_pair = index_pair.unsqueeze(-1).repeat(1, 1, 2) + pos = torch.gather(pos, index=index_pair.long(), dim=1) + ind = torch.floor(pos).clamp(-200, 200).int() + 200 + ind = ind[..., 0] + ind[..., 1] * 401 + assert ind.max() < 401 * 401 + ret = self.relative_position_bias_table[ind] + ret = ret.reshape(*shape, ret.shape[-1]) + ret[mask] = self.invalid + return ret + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + relative_pe, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + keep_query_pos=False, + rm_self_attn_decoder=False, + use_local_attn=False, + is_first=False, + ): + super().__init__() + # Decoder Cross-Attention + self.ca_qcontent_proj = nn.Linear(d_model, d_model) + + if is_first: + self.sa_qpos_proj = nn.Linear(2, d_model) + self.ca_qpos_proj = nn.Linear(d_model, d_model) + + self.ca_kcontent_proj = nn.Linear(d_model, d_model) + self.ca_kpos_proj = nn.Linear(d_model, d_model) + self.ca_v_proj = nn.Linear(d_model, d_model) + self.ca_qpos_sine_proj = nn.Linear(d_model, d_model) + + self.use_local_attn = use_local_attn + + self.cross_attn = MultiheadAttentionLocal(d_model, nhead, dropout=dropout, vdim=d_model, without_weight=True) + + self.nhead = nhead + # self.rm_self_attn_decoder = rm_self_attn_decoder + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + self.keep_query_pos = keep_query_pos + self.d_model = d_model + self.relative_pe = relative_pe + + if relative_pe: + self.pe = RelativePE(self.nhead) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward( + self, + *, + tgt, + query_pos: Optional[Tensor] = None, + query_sine_embed=None, + memory, + memory_pos_emb, + memory_pos, + is_first=False, + key_batch_cnt=None, + index_pair=None, + index_pair_batch=None + ): + assert index_pair_batch.max() + 1 == key_batch_cnt.shape[0] + + cross_query = self.ca_qcontent_proj(tgt) + + k_content_valid = self.ca_kcontent_proj(memory) + cross_key = k_content_valid + + valid_pos = memory_pos_emb + k_pos_valid = self.ca_kpos_proj(valid_pos) + cross_key_position = k_pos_valid + + v_valid = self.ca_v_proj(memory) + cross_value = v_valid + + # TODO: I remove the query pos here. Double check. + # For the first decoder layer, we concatenate the positional embedding predicted from + # the object query (the positional embedding) into the original query (key) in DETR. + # if is_first or self.keep_query_pos: + # cross_query_position = self.ca_qpos_proj(self.sa_qpos_proj(query_pos)) + # cross_query = cross_query + cross_query_position + # cross_key = cross_key + cross_key_position + + assert self.use_local_attn + num_q_all, n_model = cross_query.shape + + cross_query = cross_query.view(num_q_all, self.nhead, n_model // self.nhead) + + # TODO: Query PE is removed now. Double check. + # query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) + # query_sine_embed = query_sine_embed.view(num_q_all, self.nhead, n_model // self.nhead) + # cross_query += query_sine_embed + + cross_query = cross_query.view(num_q_all, n_model) + + num_valid_key = cross_key.shape[0] + cross_key = cross_key.view(num_valid_key, self.nhead, n_model // self.nhead) + cross_key_position = cross_key_position.view(num_valid_key, self.nhead, n_model // self.nhead) + cross_key += cross_key_position + cross_key = cross_key.view(num_valid_key, n_model) + + # TODO: Relative PE is not working. + # [num Q, num K, 2] + # if self.relative_pe: + # raise ValueError() + # relative_pos = (memory_pos.unsqueeze(0) - query_pos.unsqueeze(1)) + # relative_pe = self.pe(relative_pos, index_pair) + # else: + relative_pe = None + + tgt2 = self.cross_attn( + query=cross_query, # [num valid objects, 2 * d_model] + key=cross_key, + value=cross_value, + index_pair=index_pair, + query_batch_cnt=key_batch_cnt, + key_batch_cnt=key_batch_cnt, + index_pair_batch=index_pair_batch, + attn_mask=None, + relative_atten_weights=relative_pe, + vdim=n_model + )[0] + + # ========== End of Cross-Attention ============= + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt diff --git a/scenestreamer/models/layers/__init__.py b/scenestreamer/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/models/layers/common_layers.py b/scenestreamer/models/layers/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..24a9db3c18cde2549dfd9098dd52cc175a609df8 --- /dev/null +++ b/scenestreamer/models/layers/common_layers.py @@ -0,0 +1,169 @@ +import collections.abc +from functools import partial +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + + +class Tokenizer(nn.Module): + def __init__(self, num_actions, d_model, add_one_more_action=True): + super(Tokenizer, self).__init__() + if add_one_more_action: + self.tokens = nn.Embedding( + num_actions + 1, d_model + ) # The last token is used for the dummy token at step=0. + else: + self.tokens = nn.Embedding(num_actions, d_model) + self.num_actions = num_actions + self.add_one_more_action = add_one_more_action + + def forward(self, actions): + new_actions = actions.clone() + if self.add_one_more_action: + new_actions[actions == -1] = self.num_actions + else: + new_actions[actions == -1] = self.num_actions - 1 + if new_actions.numel() > 0: + assert new_actions.max() < self.tokens.num_embeddings + assert new_actions.min() >= 0 + return self.tokens(new_actions) + + +# def get_activation_fn(activation): +# """Return an activation function given a string""" +# if activation == "relu": +# return F.relu +# if activation == "gelu": +# return F.gelu +# if activation == "glu": +# return F.glu +# raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + +# class SquaredReLU(nn.Module): +# def forward(self, x): +# return torch.square(F.relu(x)) + + +def build_mlps(c_in, mlp_channels, ret_before_act=False, without_norm=False, is_v7=None, zero_init=False): + layers = [] + num_layers = len(mlp_channels) + + for k in range(num_layers): + if k + 1 == num_layers and ret_before_act: + layers.append(nn.Linear(c_in, mlp_channels[k])) + else: + if without_norm: + layers.extend([nn.Linear(c_in, mlp_channels[k]), nn.ReLU()]) + else: + layers.extend( + [ + nn.Linear(c_in, mlp_channels[k], bias=False), + # nn.BatchNorm1d(mlp_channels[k]), + nn.LayerNorm(mlp_channels[k]), + nn.ReLU(), + ] + ) + c_in = mlp_channels[k] + + return nn.Sequential(*layers) + + +class Mlp(nn.Module): + """Copied from https://github.com/huggingface/pytorch-image-models/blob/4d4bdd64a996bf7b5919ec62f20af4a1c07d5848/timm/layers/mlp.py#L13""" + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + is_v7=False + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # if is_v7: + # bias = False + + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + # self.use_squared_relu = is_v7 + # if is_v7: + # Use relu: + # self.act = nn.ReLU() + # else: + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + # if is_v7: + # self.fc2.weight.data.zero_() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # if self.use_squared_relu: + # x = torch.square(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + + +class AdaLayerNorm(nn.Module): + def __init__(self, hidden_size, conditioning_dim, batch_first, eps=1e-5): + super().__init__() + self.eps = eps + # These projections output a modulation (scale and bias) for each feature. + self.gamma_proj = nn.Linear(conditioning_dim, hidden_size) + self.beta_proj = nn.Linear(conditioning_dim, hidden_size) + # We disable affine parameters inside the LayerNorm since they will be provided by z. + self.ln = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=False) + self.batch_first = batch_first + + def forward(self, x, z): + """ + x: Tensor of shape (..., hidden_size) to be normalized. + z: Conditioning tensor of shape (B, conditioning_dim) if x is [B, seq_len, hidden_size] + or shape (B, conditioning_dim) when x is [B, hidden_size]. + """ + normalized = self.ln(x) + # If x is 3D (B, seq_len, hidden_size), unsqueeze z along seq_len dimension. + gamma = self.gamma_proj(z) # [B, hidden_size] + beta = self.beta_proj(z) + if normalized.ndim != gamma.ndim: + assert normalized.ndim == 3 + assert gamma.ndim == 2 + if self.batch_first: + gamma = gamma.unsqueeze(1) + beta = beta.unsqueeze(1) + else: + gamma = gamma.unsqueeze(0) + beta = beta.unsqueeze(0) + return normalized * (1 + gamma) + beta diff --git a/scenestreamer/models/layers/decoder_layer.py b/scenestreamer/models/layers/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc83d4ea3e4f8c9b4cd1d4804b53709ec29b1ace --- /dev/null +++ b/scenestreamer/models/layers/decoder_layer.py @@ -0,0 +1,710 @@ +import copy +import numbers +from typing import Optional, Callable, List +from typing import Union, Tuple + +import torch +from torch import Tensor, Size +from torch import nn +from torch.nn import Module +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules.transformer import Dropout, Linear +from torch.nn.parameter import Parameter + +from scenestreamer.models import relation +from scenestreamer.models.layers.multi_head_attention import MultiheadAttention + +_shape_t = Union[int, List[int], Size] + + +def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: + if src.is_nested: + return None + else: + src_size = src.size() + if len(src_size) == 2: + # unbatched: S, E + return src_size[0] + else: + # batched: B, S, E if batch_first else S, B, E + seq_len_pos = 1 if batch_first else 0 + return src_size[seq_len_pos] + + +def _get_clones(module, N): + # FIXME: copy.deepcopy() is not defined on nn.module + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}") + + +def _generate_square_subsequent_mask( + sz: int, + device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'), + dtype: torch.dtype = torch.get_default_dtype(), +) -> Tensor: + return torch.triu( + torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), + diagonal=1, + ) + + +def _detect_is_causal_mask( + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, +) -> bool: + # Prevent type refinement + make_causal = (is_causal is True) + + if is_causal is None and mask is not None: + sz = size if size is not None else mask.size(-2) + causal_comparison = _generate_square_subsequent_mask(sz, device=mask.device, dtype=mask.dtype) + + # Do not use `torch.equal` so we handle batched masks by + # broadcasting the comparison. + if mask.size() == causal_comparison.size(): + make_causal = bool((mask == causal_comparison).all()) + else: + make_causal = False + + return make_causal + + +class LayerNorm(Module): + r"""Applies Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated over the last `D` dimensions, where `D` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over + the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``). + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, Layer Normalization applies per-element scale and + bias with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + + Attributes: + weight: the learnable weights of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 1. + bias: the learnable bias of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 0. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> # NLP Example + >>> batch, sentence_length, embedding_dim = 20, 5, 10 + >>> embedding = torch.randn(batch, sentence_length, embedding_dim) + >>> layer_norm = nn.LayerNorm(embedding_dim) + >>> # Activate module + >>> layer_norm(embedding) + >>> + >>> # Image Example + >>> N, C, H, W = 20, 5, 10, 10 + >>> input = torch.randn(N, C, H, W) + >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) + >>> # as shown in the image below + >>> layer_norm = nn.LayerNorm([C, H, W]) + >>> output = layer_norm(input) + + .. image:: ../_static/img/nn/layer_norm.jpg + :scale: 50 % + + """ + + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape, ) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('bias', None) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + init.ones_(self.weight) + if self.bias is not None: + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + + +class TransformerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers. + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + """ + + __constants__ = ['norm'] + + def __init__( + self, + decoder_layer, + num_layers, + d_model, + self_attention_knn, + cross_attention_knn, + norm=None, + relative_pe=False + ): + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.relative_pe = relative_pe + self.d_model = d_model + self.self_attention_knn = self_attention_knn + self.cross_attention_knn = cross_attention_knn + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_pos: Optional[Tensor] = None, + memory_pos: Optional[Tensor] = None, + tgt_heading: Optional[Tensor] = None, + memory_heading: Optional[Tensor] = None, + full_tgt_pos: Optional[Tensor] = None, + full_tgt_heading: Optional[Tensor] = None, + full_tgt_mask: Optional[Tensor] = None, + # full_tgt_causal_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + past_key_value=None, + use_cache=False + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in Transformer class. + """ + rel_pe_self, rel_mask_self, rel_indices_self = None, None, None + rel_pe_cross, rel_mask_cross, rel_indices_cross = None, None, None + if self.relative_pe: + rel_pe_cross, rel_mask_cross, rel_indices_cross = relation.compute_relation( + query_pos=tgt_pos, + query_heading=tgt_heading, + query_mask=tgt_key_padding_mask, + key_pos=memory_pos, + key_heading=memory_heading, + key_mask=memory_key_padding_mask, + hidden_dim=self.d_model, + causal_mask=None, + knn=self.cross_attention_knn + ) + + # PZH: This is very vulnerable to bugs as the query should attend to the q at this moment as well as all + # history queries. + if full_tgt_pos is None: + causal_mask = tgt_mask + else: + # No need to consider causal mask if in autoregressive decoding. + causal_mask = None + rel_pe_self, rel_mask_self, rel_indices_self = relation.compute_relation( + query_pos=tgt_pos, + query_heading=tgt_heading, + query_mask=tgt_key_padding_mask, + key_pos=full_tgt_pos if full_tgt_pos is not None else tgt_pos, + key_heading=full_tgt_heading if full_tgt_heading is not None else tgt_heading, + key_mask=full_tgt_mask if full_tgt_mask is not None else tgt_key_padding_mask, + hidden_dim=self.d_model, + causal_mask=causal_mask, + knn=self.self_attention_knn + ) + # print("RELATION PE SIZE: ", rel_pe_self.shape) + + output = tgt + + seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) + tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) + + new_past_key_value = None + if use_cache: + new_past_key_value = () + tgt_is_causal = False + assert memory_is_causal is False + + for layer_idx, mod in enumerate(self.layers): + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + # past_key_value=past_key_value[layer_idx] if past_key_value else None + past_key_value=past_key_value[layer_idx] if past_key_value else None, + use_cache=use_cache, + rel_pe_self=rel_pe_self, + rel_mask_self=rel_mask_self, + rel_indices_self=rel_indices_self, + rel_pe_cross=rel_pe_cross, + rel_mask_cross=rel_mask_cross, + rel_indices_cross=rel_indices_cross + ) + + if use_cache: + output, new_past_key_value_layer = output + new_past_key_value += (new_past_key_value_layer, ) + + if self.norm is not None: + output = self.norm(output) + + if use_cache: + return output, new_past_key_value + return output + + +class TransformerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, layer norm is done prior to self attention, multihead + attention and feedforward operations, respectively. Otherwise it's done after. + Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + """ + + __constants__ = ['norm_first'] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + pre_projection: bool = False, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + relative_pe=False, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + disable_projection=pre_projection, + batch_first=batch_first, + bias=bias, + **factory_kwargs + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + disable_projection=pre_projection, + batch_first=batch_first, + bias=bias, + **factory_kwargs + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + self.pre_projection = pre_projection + if pre_projection: + self.sa_qcontent_proj = nn.Linear(d_model, d_model) + # self.sa_qpos_proj = nn.Linear(d_model, d_model) + self.sa_kcontent_proj = nn.Linear(d_model, d_model) + # self.sa_kpos_proj = nn.Linear(d_model, d_model) + self.sa_v_proj = nn.Linear(d_model, d_model) + + self.ca_qcontent_proj = nn.Linear(d_model, d_model) + # self.ca_qpos_proj = nn.Linear(d_model, d_model) + self.ca_kcontent_proj = nn.Linear(d_model, d_model) + # self.ca_kpos_proj = nn.Linear(d_model, d_model) + self.ca_v_proj = nn.Linear(d_model, d_model) + # self.ca_qpos_sine_proj = nn.Linear(d_model, d_model) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + self.relative_pe = relative_pe + if relative_pe: + assert pre_projection is False, "Relative positional encoding is not supported with pre_projection" + self.sa_relation_k = nn.Linear(3 * d_model, d_model) + self.sa_relation_v = nn.Linear(3 * d_model, d_model) + self.ca_relation_k = nn.Linear(3 * d_model, d_model) + self.ca_relation_v = nn.Linear(3 * d_model, d_model) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + past_key_value=None, + use_cache=False, + rel_pe_self=None, + rel_mask_self=None, + rel_indices_self=None, + rel_pe_cross=None, + rel_mask_cross=None, + rel_indices_cross=None + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``False``. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in Transformer class. + """ + # Split past key and value states for self-attention and multi-head attention + # past_self_key_value = past_key_value[0] if past_key_value is not None else None + # past_cross_key_value = past_key_value[1] if past_key_value is not None else None + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), + tgt_mask, + tgt_key_padding_mask, + tgt_is_causal, + past_key_value=past_key_value, + use_cache=use_cache, + rel_pe_self=rel_pe_self, + rel_mask_self=rel_mask_self, + rel_indices_self=rel_indices_self, + ) + x = x + self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + past_key_value=None, + use_cache=False, + rel_pe_cross=rel_pe_cross, + rel_mask_cross=rel_mask_cross, + rel_indices_cross=rel_indices_cross + ) + x = x + self._ff_block(self.norm3(x)) + else: + sa_out = self._sa_block( + x, + tgt_mask, + tgt_key_padding_mask, + tgt_is_causal, + past_key_value=past_key_value, + use_cache=use_cache, + rel_pe_self=rel_pe_self, + rel_mask_self=rel_mask_self, + rel_indices_self=rel_indices_self, + ) + x = self.norm1(x + sa_out) + x = self.norm2( + x + self._mha_block( + x, + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + past_key_value=None, + use_cache=False, + rel_pe_cross=rel_pe_cross, + rel_mask_cross=rel_mask_cross, + rel_indices_cross=rel_indices_cross + ) + ) + x = self.norm3(x + self._ff_block(x)) + + if use_cache: + return x, self._new_self_key_value # , self._new_cross_key_value) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + pos: Optional[Tensor] = None, + is_causal: bool = False, + past_key_value=None, + use_cache=False, + rel_pe_self=None, + rel_mask_self=None, + rel_indices_self=None, + ) -> Tensor: + + if self.pre_projection: + q = self.sa_qcontent_proj(x) + # qpos = self.sa_qpos_proj(pos) + # q = torch.cat([qcontent, qpos], dim=-1) + + k = self.sa_kcontent_proj(x) + # kpos = self.sa_kpos_proj(pos) + # k = torch.cat([kcontent, kpos], dim=-1) + + v = self.sa_v_proj(x) + + else: + q = x + k = x + v = x + + if self.relative_pe: + assert self.pre_projection is False + relation_k = self.sa_relation_k(rel_pe_self) + relation_v = self.sa_relation_v(rel_pe_self) + else: + relation_k = None + relation_v = None + + x, _, new_key_value = self.self_attn( + q, + k, + v, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + past_key_value=past_key_value, + use_cache=use_cache, + disable_projection=self.pre_projection, + relation_k=relation_k if self.relative_pe else None, + relation_v=relation_v if self.relative_pe else None, + relation_mask=rel_mask_self if self.relative_pe else None, + relation_indices=rel_indices_self if self.relative_pe else None, + ) + self._new_self_key_value = new_key_value + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + pos: Optional[Tensor] = None, + is_causal: bool = False, + past_key_value=None, + use_cache=False, + rel_pe_cross=None, + rel_mask_cross=None, + rel_indices_cross=None + ) -> Tensor: + + if self.pre_projection: + q = self.ca_qcontent_proj(x) + # qpos = self.sa_qpos_proj(pos) + # q = torch.cat([qcontent, qpos], dim=-1) + + k = self.ca_kcontent_proj(mem) + # kpos = self.sa_kpos_proj(pos) + # k = torch.cat([kcontent, kpos], dim=-1) + + v = self.ca_v_proj(mem) + + else: + q = x + k = mem + v = mem + + if self.relative_pe: + assert self.pre_projection is False + relation_k = self.sa_relation_k(rel_pe_cross) + relation_v = self.sa_relation_v(rel_pe_cross) + else: + relation_k = None + relation_v = None + + x, _, new_key_value = self.multihead_attn( + q, + k, + v, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + past_key_value=past_key_value, + use_cache=use_cache, + disable_projection=self.pre_projection, + relation_k=relation_k if self.relative_pe else None, + relation_v=relation_v if self.relative_pe else None, + relation_mask=rel_mask_cross if self.relative_pe else None, + relation_indices=rel_indices_cross if self.relative_pe else None, + ) + self._new_cross_key_value = new_key_value + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) diff --git a/scenestreamer/models/layers/encoder_layer.py b/scenestreamer/models/layers/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c74937d70c4d7a0b495c342f105c07528c633d8 --- /dev/null +++ b/scenestreamer/models/layers/encoder_layer.py @@ -0,0 +1,233 @@ +from typing import Optional, Callable +from typing import Union + +from torch import Tensor +from torch import nn +from torch.nn import Module +from torch.nn import functional as F +from torch.nn.modules.transformer import Dropout, Linear + +from scenestreamer.models.layers.decoder_layer import LayerNorm, _get_activation_fn +from scenestreamer.models.layers.multi_head_attention import MultiheadAttention + + +class TransformerEncoderLayer(Module): + __constants__ = ['norm_first'] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + pre_projection: bool = False, + relative_pe: bool = False, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + + # if pre_projection: + # kdim = d_model * 2 + # else: + # kdim = d_model + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + kdim=d_model, + vdim=d_model, + disable_projection=pre_projection, + **factory_kwargs + ) + # self.multihead_attn = MultiheadAttention( + # d_model, nhead, dropout=dropout, batch_first=batch_first, bias=bias, **factory_kwargs + # ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + self.pre_projection = pre_projection + if pre_projection: + self.sa_qcontent_proj = nn.Linear(d_model, d_model) + self.sa_qpos_proj = nn.Linear(d_model, d_model) + self.sa_kcontent_proj = nn.Linear(d_model, d_model) + self.sa_kpos_proj = nn.Linear(d_model, d_model) + self.sa_v_proj = nn.Linear(d_model, d_model) + self.sa_vpos_proj = nn.Linear(d_model, d_model) + + self.relative_pe = relative_pe + if relative_pe: + assert pre_projection is False, "Relative positional encoding is not supported with pre_projection" + self.sa_relation_k = nn.Linear(3 * d_model, d_model) + self.sa_relation_v = nn.Linear(3 * d_model, d_model) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward( + self, + tgt: Tensor, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: bool = False, + pos: Optional[Tensor] = None, + relation: Optional[Tensor] = None, + relation_mask: Optional[Tensor] = None, + relation_indices: Optional[Tensor] = None, + # past_key_value=None, + # use_cache=False + ) -> Tensor: + # Split past key and value states for self-attention and multi-head attention + # past_self_key_value = past_key_value[0] if past_key_value is not None else None + # past_cross_key_value = past_key_value[1] if past_key_value is not None else None + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), + tgt_mask, + tgt_key_padding_mask, + tgt_is_causal, + pos=pos, + relation=relation, + relation_mask=relation_mask, + # past_key_value=past_key_value, + # use_cache=use_cache + ) + # x = x + self._mha_block( + # self.norm2(x), + # memory, + # memory_mask, + # memory_key_padding_mask, + # memory_is_causal, + # past_key_value=None, + # use_cache=False + # ) + x = x + self._ff_block(self.norm2(x)) + else: + sa_out = self._sa_block( + x, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + is_causal=tgt_is_causal, + pos=pos, + relation=relation, + relation_mask=relation_mask, + relation_indices=relation_indices, + # past_key_value=past_key_value, + # use_cache=use_cache + ) + x = self.norm1(x + sa_out) + # x = self.norm2( + # x + self._mha_block( + # x, + # memory, + # memory_mask, + # memory_key_padding_mask, + # memory_is_causal, + # past_key_value=None, + # use_cache=False + # ) + # ) + x = self.norm2(x + self._ff_block(x)) + + # if use_cache: + # return x, self._new_self_key_value # , self._new_cross_key_value) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + pos: Optional[Tensor] = None, + relation: Optional[Tensor] = None, + relation_mask: Optional[Tensor] = None, + relation_indices: Optional[Tensor] = None, + is_causal: bool = False, + # past_key_value=None, + # use_cache=False + ) -> Tensor: + + if self.pre_projection: + q = self.sa_qcontent_proj(x) + k = self.sa_kcontent_proj(x) + v = self.sa_v_proj(x) + qpos = self.sa_qpos_proj(pos) + kpos = self.sa_kpos_proj(pos) + vpos = self.sa_vpos_proj(pos) + + else: + q = x + k = x + v = x + qpos = None + kpos = None + vpos = None + + if self.relative_pe: + assert self.pre_projection is False + relation_k = self.sa_relation_k(relation) + relation_v = self.sa_relation_v(relation) + # relation_v = relation_k = relation + else: + relation_k = None + relation_v = None + + # B, L, D = q.shape + # if attn_mask is None: + # attn_mask = q.new_zeros((B, L, L)) + + x, _, new_key_value = self.self_attn( + q, + k, + v, + query_pos=qpos, + key_pos=kpos, + value_pos=vpos, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + relation_k=relation_k if self.relative_pe else None, + relation_v=relation_v if self.relative_pe else None, + relation_mask=relation_mask if self.relative_pe else None, + relation_indices=relation_indices if self.relative_pe else None, + need_weights=True if relation is not None else False, + disable_projection=True if self.pre_projection else False, + # past_key_value=past_key_value, + # use_cache=use_cache + ) + # self._new_self_key_value = new_key_value + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) diff --git a/scenestreamer/models/layers/fourier_embedding.py b/scenestreamer/models/layers/fourier_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6212db146b4e7540d052943dc8040cdb9d283aa6 --- /dev/null +++ b/scenestreamer/models/layers/fourier_embedding.py @@ -0,0 +1,64 @@ +""" +Credit: https://github.com/rainmaker22/SMART/blob/a329361b63082359be56c9bfaa7e76336c19115f/smart/layers/fourier_embedding.py +""" +import math +from typing import List, Optional + +import torch +import torch.nn as nn + + +class SquaredReLU(nn.Module): + def forward(self, x): + return torch.square(nn.functional.relu(x)) + + +class FourierEmbedding(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, num_freq_bands: int, is_v7=None) -> None: + super(FourierEmbedding, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None + self.mlps = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(num_freq_bands * 2 + 1, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) for _ in range(input_dim) + ] + ) + self.to_out = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), + ) + + def forward( + self, + continuous_inputs: Optional[torch.Tensor] = None, + categorical_embs: Optional[List[torch.Tensor]] = None + ) -> torch.Tensor: + if continuous_inputs is None: + if categorical_embs is not None: + x = sum(categorical_embs) + else: + raise ValueError('Both continuous_inputs and categorical_embs are None') + else: + x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi + # Warning: if your data are noisy, don't use learnable sinusoidal embedding + x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) + # continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim + continuous_embs = None + for i in range(self.input_dim): + out = self.mlps[i](x[:, i]) + if continuous_embs is None: + continuous_embs = out + else: + continuous_embs += out + x = continuous_embs + if categorical_embs is not None: + x = x + sum(categorical_embs) + return self.to_out(x) diff --git a/scenestreamer/models/layers/gpt_decoder_layer.py b/scenestreamer/models/layers/gpt_decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..182f0f0db395e2bc6c36a747bc0f254ff6947a92 --- /dev/null +++ b/scenestreamer/models/layers/gpt_decoder_layer.py @@ -0,0 +1,737 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Module +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import softmax + +from scenestreamer.models.layers import common_layers +from scenestreamer.models.layers.decoder_layer import _get_clones +from scenestreamer.utils import utils + +# from torch.nn.attention.flex_attention import ( +# _DEFAULT_SPARSE_BLOCK_SIZE, +# create_block_mask, +# create_mask, +# flex_attention, +# _round_up_to_multiple +# ) +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +# flex_attention = torch.compile(flex_attention) +# create_block_mask = torch.compile(create_block_mask, dynamic=False) +# create_block_mask = torch.compile(create_block_mask) + + +class MultiCrossAttTransformerDecoder(Module): + __constants__ = ['norm'] + + def __init__( + self, + decoder_layer, + num_layers, + d_model, + norm=None, + ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.d_model = d_model + + def forward( + self, + *, + agent_token, + scene_token, + a2a_info=None, + a2t_info=None, + a2s_info=None, + condition_token=None, + past_key_value_list=None, + use_cache=False + ): + new_past_key_value_list = [] + output = agent_token + for layer_idx, mod in enumerate(self.layers): + cache = past_key_value_list[layer_idx] if past_key_value_list is not None else None + output, past_key_value = mod( + agent_token=output, + scene_token=scene_token, + a2a_info=a2a_info, + a2t_info=a2t_info, + a2s_info=a2s_info, + condition_token=condition_token, + use_cache=use_cache, + past_key_value=cache, + ) + if use_cache: + new_past_key_value_list.append(past_key_value) + if self.norm is not None: + output = self.norm(output) + if use_cache: + return output, new_past_key_value_list + return output + + +class MultiheadAttentionLayer(MessagePassing): + def __init__( + self, + d_model, + n_heads, + dropout=0.0, + simple_relation=False, + simple_relation_factor=2, + is_v7=False, + update_relation=False, + add_relation_to_v=None + ): + super(MultiheadAttentionLayer, self).__init__(aggr='add', node_dim=0) # Aggregation method 'add' + self.n_heads = n_heads + self.head_dim = d_model // n_heads + assert dropout == 0.0, "dropout is not supported" + self.dropout = nn.Dropout(dropout) + self.simple_relation = simple_relation + if is_v7: + if simple_relation: + self.relation_head_dim = self.head_dim // simple_relation_factor + self.to_q_relation = nn.Linear(d_model, d_model) + self.to_k_r = nn.Linear(d_model // simple_relation_factor, d_model) + self.to_v_r = nn.Linear(d_model // simple_relation_factor, d_model) + self.to_k = nn.Linear(d_model, d_model) + self.to_q = nn.Linear(d_model, d_model) + self.to_v = nn.Linear(d_model, d_model) + self.out = nn.Linear(d_model, d_model) + # self.out.weight.data.zero_() + else: + raise ValueError() + if simple_relation: + self.relation_head_dim = self.head_dim // simple_relation_factor + self.to_q_relation = nn.Linear(d_model, d_model // simple_relation_factor) + self.to_k = nn.Linear(d_model, d_model) + self.to_q = nn.Linear(d_model, d_model) + self.to_v = nn.Linear(d_model, d_model) + self.is_v7 = is_v7 + self.update_relation = update_relation + assert update_relation is False + assert add_relation_to_v is not None, "add_relation_to_v is required." + self.add_relation_to_v = add_relation_to_v + # self.out_rel = nn.Linear(d_model, d_model, bias=False) + # self.out_rel.weight.data.zero_() + + def forward( + self, + q, + k, + edge_index, + edge_features, + edge_features_v=None, + use_cache=False, + cache=None, #Relation=None + ): + B, Lq, D = q.shape + _, Lk, _ = k.shape + + # Compute linear projections + x_dst = q + x_src = k + Q = self.to_q(x_dst).reshape(-1, self.n_heads * self.head_dim) + K = self.to_k(x_src).reshape(-1, self.n_heads * self.head_dim) + V = self.to_v(x_src).reshape(-1, self.n_heads * self.head_dim) + + if cache is not None: + past_key = cache[0] + past_value = cache[1] + key_B, key_T = cache[2] + + K = K.reshape(key_B, -1, self.n_heads * self.head_dim) + past_key = past_key.reshape(key_B, key_T, self.n_heads * self.head_dim) + K = torch.cat((past_key, K), dim=1) + K = K.reshape(-1, self.n_heads * self.head_dim) + + V = V.reshape(key_B, -1, self.n_heads * self.head_dim) + past_value = past_value.reshape(key_B, key_T, self.n_heads * self.head_dim) + V = torch.cat((past_value, V), dim=1) + V = V.reshape(-1, self.n_heads * self.head_dim) + + assert edge_index[0].max() < K.shape[0], f"{edge_index[0].max()} >= {K.shape[0]}" + assert edge_index[1].max() < Q.shape[0], f"{edge_index[1].max()} >= {Q.shape[0]}" + + if use_cache: + new_cache = [K, V] + else: + new_cache = None + + if self.simple_relation: + Q_relation = self.to_q_relation(x_dst).reshape(-1, self.n_heads * self.head_dim) + Q = torch.cat([Q, Q_relation], dim=-1) + + if self.is_v7: + if self.add_relation_to_v: + assert edge_features_v is not None + else: + assert edge_features_v is None + edge_features_v = edge_features + edge_features = self.to_k_r(edge_features) + edge_features_v = self.to_v_r(edge_features_v) + + # Propagate messages using edge_index + out, new_edge_features = self.propagate( + edge_index=edge_index, + # x_dst=x_dst.reshape(-1, self.n_heads * self.head_dim), + q=Q, + k=K, + v=V, + edge_features=edge_features, + edge_features_v=edge_features_v, + ) + + # Project the output back to original dimension + out = out.reshape(B, Lq, D) + if new_edge_features is not None: + new_edge_features = new_edge_features.reshape(-1, D) + if self.is_v7: + out = self.out(out) + # new_edge_features = self.out_rel(new_edge_features) + return out, new_cache, new_edge_features #, edge_features, edge_features_v + + return out, new_cache + + def message( + self, q_i, k_j, v_j, edge_features, edge_features_v, index, ptr, edge_index, edge_index_i, edge_index_j, + relation + ): + k_j = k_j.reshape(-1, self.n_heads, self.head_dim) + v_j = v_j.reshape(-1, self.n_heads, self.head_dim) + + if edge_features is not None and not self.simple_relation: + raise ValueError() + edge_features = edge_features.reshape(-1, self.n_heads, self.head_dim) + + if self.is_v7: + raise ValueError() + + # Compute relative positional encoding if enabled + k_j = k_j + edge_features # Add relative position embedding to Key + + if self.simple_relation: + q_i, q_relation = q_i[:, :self.n_heads * self.head_dim], q_i[:, self.n_heads * self.head_dim:] + # Compute attention scores + q_i = q_i.reshape(-1, self.n_heads, self.head_dim) + q_relation = q_relation.reshape(-1, self.n_heads, self.head_dim) + + edge_features = edge_features.reshape(-1, self.n_heads, self.head_dim) + + # if self.is_v7: + # + # # Do the so-call QK norm here. + # # q_i = nn.functional.rms_norm(q_i, normalized_shape=(q_i.shape[-1], )) + # # q_relation = nn.functional.rms_norm(q_relation, normalized_shape=(q_relation.shape[-1], )) + # # k_j = nn.functional.rms_norm(k_j, normalized_shape=(k_j.shape[-1], )) + # # edge_features = nn.functional.rms_norm(edge_features, normalized_shape=(edge_features.shape[-1], )) + + attn_scores = (q_i * k_j).sum(dim=-1) / self.head_dim**0.5 # Scaled dot-product + attn_scores_relation = (q_relation * edge_features).sum(dim=-1) / self.head_dim**0.5 + attn_scores = attn_scores + attn_scores_relation + + else: + q_i = q_i.reshape(-1, self.n_heads, self.head_dim) + # Compute attention scores + attn_scores = (q_i * k_j).sum(dim=-1) / self.head_dim**0.5 # Scaled dot-product + + attn_weights = softmax(attn_scores, index=index, ptr=ptr) + attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights + + if edge_features_v is not None: + edge_features_v = edge_features_v.reshape(-1, self.n_heads, self.head_dim) + + v_j = v_j + edge_features_v + + if self.update_relation: + new_edge_features = edge_features + edge_features_v + else: + new_edge_features = None + + attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights + + return v_j * attn_weights.unsqueeze(-1), new_edge_features + + def aggregate( + self, + inputs: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + ) -> Tensor: + raw_inputs, new_edge_features = inputs + inputs = super().aggregate(raw_inputs, index, ptr, dim_size) + if new_edge_features is not None: + new_edge_features = new_edge_features + raw_inputs + return inputs, new_edge_features + + +class MultiheadAttentionLayerWithFlex(MessagePassing): + def __init__(self, d_model, n_heads, dropout=0.1, simple_relation=False, simple_relation_factor=2, is_v7=False): + super(MultiheadAttentionLayerWithFlex, self).__init__(aggr='add', node_dim=0) # Aggregation method 'add' + self.n_heads = n_heads + self.head_dim = d_model // n_heads + self.dropout = nn.Dropout(dropout) + self.relation_head_dim = self.head_dim // simple_relation_factor + self.to_q_relation = nn.Linear(d_model, d_model, bias=False) + self.to_k_r = nn.Linear(d_model // simple_relation_factor, d_model, bias=False) + self.to_v_r = nn.Linear(d_model // simple_relation_factor, d_model, bias=False) + self.to_k = nn.Linear(d_model, d_model, bias=False) + self.to_q = nn.Linear(d_model, d_model, bias=False) + self.to_v = nn.Linear(d_model, d_model, bias=False) + self.out = nn.Linear(d_model, d_model, bias=False) + self.out.weight.data.zero_() + + def forward(self, q, k, qk_valid_mask, relation, block_mask=None, relation_v=None, use_cache=False, cache=None): + + # TODO: assert relation shape + + B, Lq, D = q.shape + _, Lk, _ = k.shape + + # Compute linear projections + x_dst = q + x_src = k + Q = self.to_q(x_dst).reshape(-1, self.n_heads * self.head_dim) + K = self.to_k(x_src).reshape(-1, self.n_heads * self.head_dim) + V = self.to_v(x_src).reshape(-1, self.n_heads * self.head_dim) + + if cache is not None: + past_key = cache[0] + past_value = cache[1] + key_B, key_T = cache[2] + + K = K.reshape(key_B, -1, self.n_heads * self.head_dim) + past_key = past_key.reshape(key_B, key_T, self.n_heads * self.head_dim) + K = torch.cat((past_key, K), dim=1) + K = K.reshape(-1, self.n_heads * self.head_dim) + + V = V.reshape(key_B, -1, self.n_heads * self.head_dim) + past_value = past_value.reshape(key_B, key_T, self.n_heads * self.head_dim) + V = torch.cat((past_value, V), dim=1) + V = V.reshape(-1, self.n_heads * self.head_dim) + + # assert edge_index[0].max() < K.shape[0], f"{edge_index[0].max()} >= {K.shape[0]}" + # assert edge_index[1].max() < Q.shape[0], f"{edge_index[1].max()} >= {Q.shape[0]}" + + if use_cache: + new_cache = [K, V] + else: + new_cache = None + + # newB, newLq, _ = x.shape + # qk_valid_mask = a2t_info["attn_valid_mask"] + + _, _, newLk = qk_valid_mask.shape + + # key = out + # value = out + + K = K.reshape(B, 1, Lk, D) + # rel = a2t_info['relation'] # newB, newLq, newLk, 128 + K = K + relation + + V = V.reshape(B, 1, newLk, D) + # rel_v = a2t_info['relation_v'] # newB, newLq, newLk, 128 + # value = value + rel_v + assert relation_v is None + relation_v = relation + V = V + relation_v + + Lk_new = Lk * Lq + + # TODO: in future update the swapaxes. + K = K.reshape(B, Lk_new, self.n_heads, self.head_dim).swapaxes(1, 2) + V = V.reshape(B, Lk_new, self.n_heads, self.head_dim).swapaxes(1, 2) + Q = Q.reshape(B, Lq, self.n_heads, self.head_dim).swapaxes(1, 2) + + if block_mask is None: + qk_valid_mask = qk_valid_mask.reshape(B, Lq, 1, newLk).expand(B, Lq, Lq, newLk).reshape(B, Lq, Lk_new) + + # TODO: How to select? + block_size = _DEFAULT_SPARSE_BLOCK_SIZE + # block_size = 4 + + Lq_padded = _round_up_to_multiple(Lq, block_size) + Lk_padded = _round_up_to_multiple(Lk_new, block_size) + new_valid_mask = qk_valid_mask.new_zeros(B, Lq_padded, Lk_padded) + new_valid_mask[:, :Lq, :Lk_new] = qk_valid_mask + + # TODO: Make the mask before! + # TODO: Can implement the sliding window here. + def mask_mod(b, h, q_idx, kv_idx): + realq = kv_idx // newLk + m1 = q_idx == realq + # FIXME + # FIXME + m3 = new_valid_mask[b, q_idx, kv_idx] + return m1 & m3 + # return m1 + + # res = [] + # import numpy as np + # for q in range(Lq_padded): + # res.append([mask_mod(0, 0, q, v).item() for v in range(Lk_padded)]) + # res = np.array(res) + + block_mask = create_block_mask( + mask_mod=mask_mod, + B=B, + H=self.n_heads, + Q_LEN=Lq_padded, + KV_LEN=Lk_padded, + device=Q.device, + BLOCK_SIZE=block_size, + # _compile=True + ) + + flex_out = flex_attention( + query=Q, + key=K, + value=V, + block_mask=block_mask, + ) + + # # if self.simple_relation: + # Q_relation = self.to_q_relation(x_dst).reshape(-1, self.n_heads * self.head_dim) + # Q = torch.cat([Q, Q_relation], dim=-1) + # + # + # # Propagate messages using edge_index + # out, new_edge_features = self.propagate( + # edge_index=edge_index, + # x_dst=x_dst.reshape(-1, self.n_heads * self.head_dim), + # q=Q, + # k=K, + # v=V, + # edge_features=edge_features, + # edge_features_v=edge_features_v, + # ) + + # Project the output back to original dimension + out = flex_out.reshape(B, Lq, D) + # new_edge_features = new_edge_features.reshape(-1, D) + # if self.is_v7: + out = self.out(out) + # new_edge_features = self.out_rel(new_edge_features) + new_edge_features = None + return out, new_cache, new_edge_features, block_mask #, edge_features, edge_features_v + + # return out, new_cache + + # def message( + # self, q_i, k_j, v_j, edge_features, edge_features_v, index, ptr, edge_index, edge_index_i, edge_index_j, + # relation + # ): + # k_j = k_j.reshape(-1, self.n_heads, self.head_dim) + # v_j = v_j.reshape(-1, self.n_heads, self.head_dim) + # + # q_i, q_relation = q_i[:, :self.n_heads * self.head_dim], q_i[:, self.n_heads * self.head_dim:] + # # Compute attention scores + # q_i = q_i.reshape(-1, self.n_heads, self.head_dim) + # q_relation = q_relation.reshape(-1, self.n_heads, self.head_dim) + # + # + # edge_features = edge_features.reshape(-1, self.n_heads, self.head_dim) + # + # if self.is_v7: + # + # # Do the so-call QK norm here. + # q_i = nn.functional.rms_norm(q_i, normalized_shape=(q_i.shape[-1], )) + # q_relation = nn.functional.rms_norm(q_relation, normalized_shape=(q_relation.shape[-1], )) + # k_j = nn.functional.rms_norm(k_j, normalized_shape=(k_j.shape[-1], )) + # edge_features = nn.functional.rms_norm(edge_features, normalized_shape=(edge_features.shape[-1], )) + # + # attn_scores = (q_i * k_j + q_relation * edge_features).sum(dim=-1) / self.head_dim**0.5 # Scaled dot-product + # # attn_scores_relation = (q_relation * edge_features).sum(dim=-1) / self.head_dim**0.5 + # # attn_scores = attn_scores + attn_scores_relation + # + # + # attn_weights = softmax(attn_scores, index=index, ptr=ptr) + # attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights + # + # if edge_features_v is not None: + # edge_features_v = edge_features_v.reshape(-1, self.n_heads, self.head_dim) + # + # v_j = v_j + edge_features_v + # + # new_edge_features = edge_features + edge_features_v + # + # return v_j * attn_weights.unsqueeze(-1), new_edge_features + # + # def aggregate( + # self, + # inputs: Tensor, + # index: Tensor, + # ptr: Optional[Tensor] = None, + # dim_size: Optional[int] = None, + # ) -> Tensor: + # raw_inputs, new_edge_features = inputs + # inputs = super().aggregate(raw_inputs, index, ptr, dim_size) + # new_edge_features = new_edge_features + raw_inputs + # return inputs, new_edge_features + + +class MultiCrossAttTransformerDecoderLayer(Module): + __constants__ = ['norm_first'] + + def __init__( + self, + d_model: int, + nhead: int, + dropout: float = 0.1, + use_adaln=False, + simple_relation=False, + simple_relation_factor=None, + is_v7=False, + update_relation=False, + add_relation_to_v=None, + remove_rel_norm=None, + ) -> None: + super().__init__() + self.cross_a2t = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=is_v7, + update_relation=update_relation, + add_relation_to_v=add_relation_to_v, + ) + self.cross_a2a = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=is_v7, + update_relation=update_relation, + add_relation_to_v=add_relation_to_v, + ) + self.cross_a2s = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=is_v7, + update_relation=update_relation, + add_relation_to_v=add_relation_to_v, + ) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = common_layers.Mlp( + in_features=d_model, hidden_features=4 * d_model, act_layer=approx_gelu, drop=dropout, is_v7=is_v7 + ) + + self.use_adaln = use_adaln + if use_adaln: + # self.a2t_adaln_norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=1e-6) + # self.a2a_adaln_norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=1e-6) + # self.a2s_adaln_norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=1e-6) + # self.mlp_adaln_prenorm = nn.LayerNorm(d_model, elementwise_affine=False, eps=1e-6) + # # https://github.com/facebookresearch/DiT/blob/ed81ce2229091fd4ecc9a223645f95cf379d582b/models.py#L113 + # self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(d_model, 12 * d_model, bias=True)) + # # self.adaLN_modulation_gate = nn.Sequential(nn.SiLU(), nn.Linear(d_model, 4 * d_model, bias=True)) + + self.a2s_norm = common_layers.AdaLayerNorm(d_model, conditioning_dim=d_model, batch_first=True) + self.a2t_norm = common_layers.AdaLayerNorm(d_model, conditioning_dim=d_model, batch_first=True) + self.a2a_norm = common_layers.AdaLayerNorm(d_model, conditioning_dim=d_model, batch_first=True) + self.mlp_prenorm = common_layers.AdaLayerNorm(d_model, conditioning_dim=d_model, batch_first=True) + else: + self.a2s_norm = nn.LayerNorm(d_model) + self.a2t_norm = nn.LayerNorm(d_model) + self.a2a_norm = nn.LayerNorm(d_model) + self.mlp_prenorm = nn.LayerNorm(d_model) + + self.remove_rel_norm = remove_rel_norm + assert remove_rel_norm is not None + if not remove_rel_norm: + self.a2t_norm_rel = nn.LayerNorm(d_model) + self.a2a_norm_rel = nn.LayerNorm(d_model) + self.a2s_norm_rel = nn.LayerNorm(d_model) + + self.update_relation = update_relation + self.add_relation_to_v = add_relation_to_v + assert add_relation_to_v is not None + if add_relation_to_v and (not remove_rel_norm): + assert update_relation is False + self.a2t_norm_rel_v = nn.LayerNorm(d_model) + self.a2a_norm_rel_v = nn.LayerNorm(d_model) + self.a2s_norm_rel_v = nn.LayerNorm(d_model) + if update_relation: + assert add_relation_to_v is False + + # def __setstate__(self, state): + # # if 'activation' not in state: + # # state['activation'] = F.relu + # super().__setstate__(state) + + def forward( + self, + *, + agent_token, + scene_token, + a2a_info, + a2t_info, + a2s_info, + condition_token, + use_cache=False, + past_key_value=None + ): + B, T, N, D = agent_token.shape + x = agent_token + + # if self.use_adaln: + # assert condition_token.ndim == agent_token.ndim # (B, T, N, D) + # adaln_params = self.adaln_modulation(condition_token) + # adaln_params = adaln_params.expand(B, T, N, -1) + # adaln_params = adaln_params.chunk(12, dim=-1) + # shift_a2t, scale_a2t, gate_a2t = adaln_params[:3] + # shift_a2a, scale_a2a, gate_a2a = adaln_params[3:6] + # shift_a2s, scale_a2s, gate_a2s = adaln_params[6:9] + # shift_ff, scale_ff, gate_ff = adaln_params[9:12] + + # === agent-temporal attention === + # B,T,N,D -> BN, T, D + x = x.swapaxes(1, 2).flatten(0, 1) + out = x + if self.use_adaln: + out = self.a2t_norm(out, z=condition_token.swapaxes(1, 2).flatten(0, 1)) + # out = utils.modulate(out, shift_a2t.swapaxes(1, 2).flatten(0, 1), scale_a2t.swapaxes(1, 2).flatten(0, 1)) + else: + out = self.a2t_norm(out) + + a2t_rel = a2t_info['edge_features'] + if self.remove_rel_norm: + a2t_rel_out = a2t_rel + a2t_rel_out_v = a2t_info['edge_features_v'] if self.add_relation_to_v else None + else: + a2t_rel_out = self.a2t_norm_rel(a2t_rel) + a2t_rel_out_v = self.a2t_norm_rel_v(a2t_info['edge_features_v']) if self.add_relation_to_v else None + # if "block_mask" not in a2t_info: + # a2t_info["block_mask"] = None + # out, past_key_value_a2t, a2t_rel_out, a2t_block_mask = self.cross_a2t( + out, past_key_value_a2t, a2t_rel_out = self.cross_a2t( + q=out, + k=out, + edge_features=a2t_rel_out, + edge_features_v=a2t_rel_out_v, + edge_index=a2t_info['edge_index'], + # qk_valid_mask=a2t_info["attn_valid_mask"], + # relation=a2t_info["relation"], + use_cache=use_cache, + cache=past_key_value, + # block_mask=a2t_info["block_mask"], + # Relation=a2t_info["relation"] + ) + assert out.shape == (B * N, T, D) + # if self.use_adaln: + # out = self.a2t_norm(out, z=condition_token.swapaxes(1, 2).flatten(0, 1)) + x = x + out + x = x.reshape(B, N, T, D).swapaxes(1, 2) + if self.update_relation: + a2t_rel_out = a2t_rel_out + a2t_rel + a2t_info['edge_features'] = a2t_rel_out + assert self.add_relation_to_v is False + + # === agent-agent attention === + x = x.reshape(B * T, N, D) + out = x + if self.use_adaln: + out = self.a2a_norm(out, z=condition_token.reshape(B * T, N, D)) + # out = self.a2a_adaln_norm(out) + # out = utils.modulate(out, shift_a2a.reshape(B * T, N, D), scale_a2a.reshape(B * T, N, D)) + else: + out = self.a2a_norm(out) + + a2a_rel = a2a_info['edge_features'] + if self.remove_rel_norm: + a2a_rel_out = a2a_rel + a2a_rel_out_v = a2a_info['edge_features_v'] if self.add_relation_to_v else None + else: + a2a_rel_out = self.a2a_norm_rel(a2a_rel) + a2a_rel_out_v = self.a2a_norm_rel_v(a2a_info['edge_features_v']) if self.add_relation_to_v else None + + if "block_mask" not in a2a_info: + a2a_info["block_mask"] = None + # out, _, a2a_rel_out, a2a_block_mask = self.cross_a2a( + out, _, a2a_rel_out = self.cross_a2a( + q=out, + k=out, + # qk_valid_mask=a2a_info["attn_valid_mask"], + # relation=a2a_info["relation"], + # block_mask=a2a_info["block_mask"], + edge_features=a2a_rel_out, + edge_features_v=a2a_rel_out_v, + edge_index=a2a_info['edge_index'], + ) + # a2a_info["block_mask"] = a2a_block_mask + # if self.use_adaln: + # out = out * gate_a2a.reshape(B * T, N, D) + x = x + out + x = x.reshape(B, T, N, D) + if self.update_relation: + a2a_rel_out = a2a_rel_out + a2a_rel + a2a_info['edge_features'] = a2a_rel_out + + # === agent-scene attention === + x = x.reshape(B, T * N, D) + out = x + if self.use_adaln: + out = self.a2s_norm(out, z=condition_token.reshape(B, T*N, D)) + # out = self.a2s_adaln_norm(out) + # out = utils.modulate(out, shift_a2s.reshape(B, T * N, D), scale_a2s.reshape(B, T * N, D)) + else: + out = self.a2s_norm(out) + + a2s_rel = a2s_info['edge_features'] + if self.remove_rel_norm: + a2s_rel_out = a2s_rel + a2s_rel_out_v = a2s_info['edge_features_v'] if self.add_relation_to_v else None + else: + a2s_rel_out = self.a2s_norm_rel(a2s_rel) + a2s_rel_out_v = self.a2s_norm_rel_v(a2s_info['edge_features_v']) if self.add_relation_to_v else None + + if "block_mask" not in a2s_info: + a2s_info["block_mask"] = None + # out, _, a2s_rel_out, a2s_block_mask = self.cross_a2s( + out, _, a2s_rel_out = self.cross_a2s( + q=out, + k=scene_token, + # qk_valid_mask=a2s_info["attn_valid_mask"], + # relation=a2s_info["relation"], + # block_mask=a2s_info["block_mask"], + edge_features=a2s_rel_out, + edge_features_v=a2s_rel_out_v, + edge_index=a2s_info['edge_index'], + ) + # a2s_info["block_mask"] = a2s_block_mask + # if self.use_adaln: + # out = out * gate_a2s.reshape(B, T * N, D) + x = x + out + x = x.reshape(B, T, N, D) + if self.update_relation: + a2s_rel_out = a2s_rel_out + a2s_rel + a2s_info['edge_features'] = a2s_rel_out + + # Print to make sure overwriting dict is valid. + # print("a2s_rel_out", a2s_rel.mean().item(), a2s_rel.std().item()) + + # === Feed-forward layer === + out = x + if self.use_adaln: + out = self.mlp_prenorm(out, z=condition_token) + else: + out = self.mlp_prenorm(out) + out = self.mlp(out) + # if self.use_adaln: + # out = out * gate_ff + x = x + out + + return x, past_key_value_a2t diff --git a/scenestreamer/models/layers/gpt_encoder_layer.py b/scenestreamer/models/layers/gpt_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..53e25c0e9025652cec015b0eca344145cc74240d --- /dev/null +++ b/scenestreamer/models/layers/gpt_encoder_layer.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +from torch.nn import Module + +from scenestreamer.models.layers import common_layers +from scenestreamer.models.layers.decoder_layer import _get_clones +from scenestreamer.models.layers.gpt_decoder_layer import MultiheadAttentionLayer + + +class SelfAttTransformerEncoder(Module): + def __init__(self, decoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.norm = nn.LayerNorm(decoder_layer.d_model) + + def forward(self, scene_tokens, scene_info, edge_features, edge_features_v=None, block_mask=None): + output = scene_tokens + for layer_idx, mod in enumerate(self.layers): + output, new_cache, edge_features, block_mask = mod( + output, scene_info, edge_features, edge_features_v=edge_features_v, block_mask=block_mask + ) + output = self.norm(output) + return output + + +class SelfAttTransformerEncoderLayer(Module): + def __init__( + self, + d_model: int, + nhead: int, + simple_relation=False, + simple_relation_factor=1, + dropout=0.0, + is_v7=False, + update_relation=False, + add_relation_to_v=None, + remove_rel_norm=None + ) -> None: + super().__init__() + self.d_model = d_model + + self.s2s_norm = nn.LayerNorm(d_model) + + self.remove_rel_norm = remove_rel_norm + if not remove_rel_norm: + self.s2s_norm_rel = nn.LayerNorm(d_model) + if add_relation_to_v: + self.s2s_norm_rel_v = nn.LayerNorm(d_model) + + self.cross_s2s = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=True, + add_relation_to_v=add_relation_to_v, + ) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = common_layers.Mlp( + in_features=d_model, + hidden_features=4 * d_model, + act_layer=approx_gelu, + drop=dropout, + ) + self.mlp_prenorm = nn.LayerNorm(d_model) + self.update_relation = update_relation + self.add_relation_to_v = add_relation_to_v + + # self.mlp_rel = common_layers.Mlp(in_features=d_model, hidden_features=4 * d_model, act_layer=approx_gelu, drop=0, is_v7=is_v7) + # self.mlp_rel_prenorm = nn.LayerNorm(d_model) + + def forward(self, scene_tokens, scene_info, edge_features, edge_features_v=None, block_mask=None): + x = self.s2s_norm(scene_tokens) + out, cache, edge_features_out = self.cross_s2s( + q=x, + k=x, + edge_index=scene_info['edge_index'], + edge_features=self.s2s_norm_rel(edge_features) if not self.remove_rel_norm else edge_features, + edge_features_v=(self.s2s_norm_rel_v(edge_features_v) if not self.remove_rel_norm else edge_features_v) + if edge_features_v is not None else None, + ) + scene_tokens = scene_tokens + out + out = self.mlp(self.mlp_prenorm(scene_tokens)) + scene_tokens = scene_tokens + out + + if self.update_relation: + assert self.add_relation_to_v is False + edge_features_out = edge_features_out + edge_features + else: + edge_features_out = edge_features + return scene_tokens, cache, edge_features_out, block_mask # , edge_features_v_out diff --git a/scenestreamer/models/layers/multi_head_attention.py b/scenestreamer/models/layers/multi_head_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0f11a39a22bdd8b00302bc74fd58790cf7eac77a --- /dev/null +++ b/scenestreamer/models/layers/multi_head_attention.py @@ -0,0 +1,1133 @@ +import math +import math +import warnings +from typing import List, Optional, Tuple + +from torch.nn import functional as F +from torch.nn.modules.activation import Module, Tensor, Parameter, NonDynamicallyQuantizableLinear, \ + xavier_uniform_, xavier_normal_, constant_ + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +from torch.overrides import has_torch_function, handle_torch_function +from torch.nn.functional import pad, linear, softmax, dropout, scaled_dot_product_attention + +import torch +# A workaround to support both TorchScript and MyPy: +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from torch.types import _dtype as DType +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + + +def _check_arg_device(x: Optional[torch.Tensor]) -> bool: + if x is not None: + return x.device.type in ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] + return True + + +def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: + if x is not None: + return x.requires_grad + return False + + +def _is_make_fx_tracing(): + if not torch.jit.is_scripting(): + torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack() + return any( + type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack + ) + else: + return False + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``nn.MultiHeadAttention`` will use the optimized implementations of + ``scaled_dot_product_attention()`` when possible. + + In addition to support for the new ``scaled_dot_product_attention()`` + function, for speeding up Inference, MHA will use + fastpath inference with support for Nested Tensors, iff: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). + - inputs are batched (3D) with ``batch_first==True`` + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + - autocast is disabled + + If the optimized inference fastpath implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ['batch_first'] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + disable_projection=False, + ) -> None: + if embed_dim <= 0 or num_heads <= 0: + raise ValueError( + f"embed_dim and num_heads must be greater than 0," + f" got embed_dim={embed_dim} and num_heads={num_heads} instead" + ) + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if disable_projection: + self.q_proj_weight = self.in_proj_weight = self.k_proj_weight = self.v_proj_weight = self.in_proj_bias = None + + else: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) + self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) + self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) + else: + self.register_parameter('in_proj_bias', None) + + self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self.in_proj_weight is not None: + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super().__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + query_pos=None, + key_pos=None, + value_pos=None, + relation_k=None, + relation_v=None, + relation_mask=None, + relation_indices=None, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + past_key_value=None, + use_cache=False, + disable_projection=False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor, Tensor]]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and float masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` + and achieve the best performance for MHA. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + If both attn_mask and key_padding_mask are supplied, their types should match. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``attn_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + + why_not_fast_path = '' + if ((attn_mask is not None and torch.is_floating_point(attn_mask)) + or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)): + why_not_fast_path = "floating-point masks are not supported for fast path." + + is_batched = query.dim() == 3 + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif self.in_proj_weight is None: + why_not_fast_path = "in_proj_weight was None" + elif query.dtype != self.in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif (self.num_heads % 2) != 0: + why_not_fast_path = "self.num_heads is not even" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): + why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif _is_make_fx_tracing(): + why_not_fast_path = "we are running make_fx tracing" + elif not all(_check_arg_device(x) for x in tensor_args): + why_not_fast_path = ( + "some Tensor argument's device is neither one of " + f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" + ) + elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query) + + # if self.in_proj_bias is not None and self.in_proj_weight is not None: + # raise ValueError() + # return torch._native_multi_head_attention( + # query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, + # self.out_proj.weight, self.out_proj.bias, merged_mask, need_weights, average_attn_weights, + # mask_type + # ), (key, value) if use_cache else None + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key + else: + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) + + # PZH: Deal with heads + if attn_mask is not None and attn_mask.ndim == 3: + if attn_mask.shape[0] != query.shape[1] * self.num_heads: + bsz, src_len, _ = attn_mask.shape + attn_mask = attn_mask.view(bsz, 1, src_len, + src_len).expand(-1, self.num_heads, -1, + -1).reshape(bsz * self.num_heads, src_len, src_len) + + # Zhenghao: Merge key value cache here! + if past_key_value is not None: + # Unpack past keys and values + past_key, past_value, past_key_padding_mask, _ = past_key_value + # Concatenate past keys and values with the current ones + key = torch.cat([past_key, key], dim=0) + value = torch.cat([past_value, value], dim=0) + key_padding_mask = torch.cat([past_key_padding_mask, key_padding_mask], dim=-1) + # attn_mask = torch.cat([past_attn_mask, attn_mask], dim=-1) + + _, past_L = past_key_padding_mask.shape + + if attn_mask.ndim == 2: + B, new_L = attn_mask.shape + past_attn_mask = attn_mask.new_zeros([new_L, past_L]) + + else: + B, _, new_L = attn_mask.shape + past_attn_mask = attn_mask.new_zeros([B, new_L, past_L]) + + attn_mask = torch.cat([past_attn_mask, attn_mask], dim=-1) + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + # relation=relation, + relation_k=relation_k, + relation_v=relation_v, + relation_mask=relation_mask, + relation_indices=relation_indices, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + query_pos=query_pos, + key_pos=key_pos, + disable_projection=disable_projection, + value_pos=value_pos, + ) + else: + attn_output, attn_output_weights = multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + # relation=relation, + relation_k=relation_k, + relation_v=relation_v, + relation_mask=relation_mask, + relation_indices=relation_indices, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + query_pos=query_pos, + key_pos=key_pos, + disable_projection=disable_projection, + value_pos=value_pos, + ) + if self.batch_first and is_batched: + return attn_output.transpose( + 1, 0 + ), attn_output_weights, (key, value, key_padding_mask, None) if past_key_value else None + else: + # print("QUERY: ", query.shape, query[:, 0, 0], "USE_CACHE: ", use_cache) + # print("KEY: ", key.shape, key[:, 0, 0], "USE_CACHE: ", use_cache) + # print("VALUE: ", value.shape, value[:, 0, 0], "USE_CACHE: ", use_cache) + return attn_output, attn_output_weights, (key, value, key_padding_mask, None) if use_cache else None + + def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], + query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]: + r""" + Determine mask type and combine masks if necessary. If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + + if attn_mask is not None: + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + + # Always expands attn_mask to 4D + if attn_mask.dim() == 3: + attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) + else: # attn_mask.dim() == 2: + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + + if key_padding_mask is not None: + key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, + seq_len).expand(-1, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + key_padding_mask_expanded + + # no attn_mask and no key_padding_mask, returns None, None + return merged_mask, mask_type + + +def _mha_shape_check( + query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], + num_heads: int +): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + + # Shape check. + if query.dim() == 3: + # Batched Inputs + is_batched = True + assert key.dim() == 3 and value.dim() == 3, \ + ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + if key_padding_mask is not None: + assert key_padding_mask.dim() == 2, \ + ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead") + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), \ + ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead") + elif query.dim() == 2: + # Unbatched Inputs + is_batched = False + assert key.dim() == 2 and value.dim() == 2, \ + ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + + if key_padding_mask is not None: + assert key_padding_mask.dim() == 1, \ + ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead") + + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), \ + ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead") + if attn_mask.dim() == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert attn_mask.shape == expected_shape, \ + (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" + ) + + return is_batched + + +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, +) -> Optional[Tensor]: + + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = torch.is_floating_point(mask) + if _mask_dtype != torch.bool and not _mask_is_float: + raise AssertionError(f"only bool and floating types of {mask_name} are supported") + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + mask = (torch.zeros_like(mask, dtype=target_type).masked_fill_(mask, float("-inf"))) + return mask + + +def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: + if input is None: + return None + elif isinstance(input, torch.Tensor): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + + +def _in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + b: Optional[Tensor] = None, +) -> List[Tensor]: + r"""Perform the in-projection step of the attention operation, using packed weights. + + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. For self-attention, + these are typically the same tensor; for encoder-decoder attention, + k and v are typically the same tensor. (We take advantage of these + identities for performance if they are present.) Regardless, q, k and v + must share a common embedding dimension; otherwise their shapes may vary. + w: projection weights for q, k and v, packed into a single tensor. Weights + are packed along dimension 0, in q, k, v order. + b: optional projection biases for q, k and v, packed into a single tensor + in q, k, v order. + + Shape: + Inputs: + - q: :math:`(..., E)` where E is the embedding dimension + - k: :math:`(..., E)` where E is the embedding dimension + - v: :math:`(..., E)` where E is the embedding dimension + - w: :math:`(E * 3, E)` where E is the embedding dimension + - b: :math:`E * 3` where E is the embedding dimension + + Output: + - in output list :math:`[q', k', v']`, each output tensor will have the + same shape as the corresponding input tensor. + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Perform the in-projection step of the attention operation. + + This is simply a triple of linear projections, + with shape constraints on the weights which + ensure embedding dimension uniformity in the projected outputs. + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. + w_q, w_k, w_v: weights for q, k and v, respectively. + b_q, b_k, b_v: optional biases for q, k and v, respectively. + + Shape: + Inputs: + - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any + number of leading dimensions. + - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any + number of leading dimensions. + - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any + number of leading dimensions. + - w_q: :math:`(Eq, Eq)` + - w_k: :math:`(Eq, Ek)` + - w_v: :math:`(Eq, Ev)` + - b_q: :math:`(Eq)` + - b_k: :math:`(Eq)` + - b_v: :math:`(Eq)` + + Output: in output triple :math:`(q', k', v')`, + - q': :math:`[Qdims..., Eq]` + - k': :math:`[Kdims..., Eq]` + - v': :math:`[Vdims..., Eq]` + + """ + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == (Eq, ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == (Eq, ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == (Eq, ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + query_pos: Optional[Tensor] = None, + key_pos: Optional[Tensor] = None, + value_pos: Optional[Tensor] = None, + relation_k: Optional[Tensor] = None, + relation_v: Optional[Tensor] = None, + relation_mask: Optional[Tensor] = None, + relation_indices: Optional[Tensor] = None, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + disable_projection: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + + # print(f"BEFORE PROJ {query.shape=}, {key.shape=}, {value.shape=}. Q {query.mean(-1)[:, 0]}, K {key.mean(-1)[:, 0]}, v {value.mean(-1)[:, 0]}") + + assert key_padding_mask.dtype in [torch.float32, torch.bfloat16] + if attn_mask is not None: + assert attn_mask.dtype in [torch.float32, torch.bfloat16] + + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + raise ValueError() # PZH + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + is_causal=is_causal, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + ) + + is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # PZH: + relative_pe = relation_k is not None + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + raise ValueError() + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert embed_dim == embed_dim_to_check, \ + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], \ + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + # PZH: Our modification + if disable_projection: + q = query + k = key + v = value + else: + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + # q_head_dim = q.shape[-1] + + if query_pos is None: + q = q.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + else: + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + q_pos = query_pos.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + q = torch.cat([q, q_pos], dim=-1) + + if static_k is None: + + if key_pos is None: + k = k.reshape(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + k = k.contiguous().view(src_len, bsz * num_heads, head_dim).transpose(0, 1) + k_pos = key_pos.contiguous().view(src_len, bsz * num_heads, head_dim).transpose(0, 1) + k = torch.cat([k, k_pos], dim=-1) + + else: + raise ValueError() # PZH + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.size(0) == bsz * num_heads, \ + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, \ + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + + # Create relation_v and relation_k + + if static_v is None: + if value_pos is None: + v = v.reshape(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + v = v.contiguous().view(src_len, bsz * num_heads, head_dim).transpose(0, 1) + v_pos = value_pos.contiguous().view(src_len, bsz * num_heads, head_dim).transpose(0, 1) + v = v + v_pos + + else: + raise ValueError() # PZH + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.size(0) == bsz * num_heads, \ + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, \ + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + raise ValueError() + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, src_len), \ + f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + if relative_pe: + assert relation_k.ndim == 4 + B, numq, numk, D = relation_k.shape + relation_k = relation_k.permute(1, 2, 0, 3) + relation_k = relation_k.reshape(numq, numk, bsz * num_heads, head_dim) + relation_k = relation_k.permute(2, 0, 1, 3) + + if relation_v is not None: + relation_v = relation_v.permute(1, 2, 0, 3) + relation_v = relation_v.reshape(numq, numk, bsz * num_heads, head_dim) + relation_v = relation_v.permute(2, 0, 1, 3) + + relation_mask = relation_mask.view(bsz, 1, numq, numk).expand(-1, num_heads, -1, + -1).reshape(bsz * num_heads, numq, numk) + relation_indices = relation_indices.view(bsz, 1, numq, numk).expand(-1, num_heads, -1, + -1).reshape(bsz * num_heads, numq, numk) + + # if attn_mask.shape[1] != numq: + # attn_mask = attn_mask.expand(-1, numq, -1) + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + B, Nt, E = q.shape + q_scaled = q / math.sqrt(E) + + assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + + if relative_pe: + k = k.reshape(bsz * num_heads, src_len, 1, head_dim).expand(-1, -1, numq, -1) + k = torch.gather(k, -2, relation_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)) + k += relation_k + attn_output_weights = torch.einsum("bnd,bnmd->bnm", q_scaled, k) + # Now become (B, Q, K) + + else: + if attn_mask is not None: + attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + else: + attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + if relative_pe: + attn_output_weights = attn_output_weights.masked_fill(~relation_mask, 0) + v = v.reshape(bsz * num_heads, src_len, 1, head_dim).expand(-1, -1, numq, -1) + if relation_v is None: + v = torch.gather(v, -2, relation_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)) + relation_k + else: + v = torch.gather(v, -2, relation_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)) + relation_v + attn_output = torch.einsum("bij,bijd->bid", attn_output_weights, v) + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, numk) + + else: + attn_output = torch.bmm(attn_output_weights, v) + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # optionally average attention weights over heads + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, q.shape[-1]) + k = k.view(bsz, num_heads, src_len, k.shape[-1]) + v = v.view(bsz, num_heads, src_len, v.shape[-1]) + + # print(f"AFTER PROJ {q.shape=}, {k.shape=}, {v.shape=}. Q {q[0, 0, :, 0]}, K {k[0, 0, :, 0]}, v {v[0, 0, :, 0]}") + + attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + + # print(f"before {attn_output.shape=}, attn_output {attn_output[:, 0]}") + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + + # print(f"AFTER {attn_output.shape=}, attn_output {attn_output[:, 0, 0]}") + + return attn_output, None diff --git a/scenestreamer/models/layers/polyline_encoder.py b/scenestreamer/models/layers/polyline_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c80e3736abdea71c4b4a93596fd5ecc244eaa3c9 --- /dev/null +++ b/scenestreamer/models/layers/polyline_encoder.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn + +from scenestreamer.models.layers import common_layers + + +class PointNetPolylineEncoder(nn.Module): + def __init__(self, in_channels, hidden_dim, num_layers=3, num_pre_layers=1, out_channels=None, is_v7=False): + super().__init__() + self.pre_mlps = common_layers.build_mlps( + c_in=in_channels, + mlp_channels=[hidden_dim] * num_pre_layers, + ret_before_act=False, #is_v7=is_v7 + ) + self.mlps = common_layers.build_mlps( + c_in=hidden_dim * 2, + mlp_channels=[hidden_dim] * (num_layers - num_pre_layers), + ret_before_act=False, + # is_v7=is_v7 + ) + + if out_channels is not None: + self.out_mlps = common_layers.build_mlps( + c_in=hidden_dim, + mlp_channels=[hidden_dim, out_channels], + ret_before_act=True, + without_norm=True, + # is_v7=is_v7 + ) + else: + self.out_mlps = None + + def forward(self, polylines, polylines_mask): + """ + Args: + polylines (batch_size, num_polylines, num_points_each_polylines, C): + polylines_mask (batch_size, num_polylines, num_points_each_polylines): + + Returns: + """ + batch_size, num_polylines, num_points_each_polylines, C = polylines.shape + + polylines_feature_valid = self.pre_mlps(polylines[polylines_mask]) # (N, C) + polylines_feature = polylines_feature_valid.new_zeros( + batch_size, num_polylines, num_points_each_polylines, polylines_feature_valid.shape[-1] + ) + polylines_feature[polylines_mask] = polylines_feature_valid + + # get global feature + pooled_feature = polylines_feature.max(dim=2)[0] + polylines_feature = torch.cat( + (polylines_feature, pooled_feature[:, :, None, :].repeat(1, 1, num_points_each_polylines, 1)), dim=-1 + ) + + polylines_feature_valid = self.mlps(polylines_feature[polylines_mask]) + feature_buffers = polylines_feature.new_zeros( + batch_size, num_polylines, num_points_each_polylines, polylines_feature_valid.shape[-1] + ) + feature_buffers[polylines_mask] = polylines_feature_valid + + # max-pooling + feature_buffers = feature_buffers.max(dim=2)[0] # (batch_size, num_polylines, C) + + # out-mlp + if self.out_mlps is not None: + valid_mask = (polylines_mask.sum(dim=-1) > 0) + + feature_buffers_valid = self.out_mlps(feature_buffers[valid_mask]) # (N, C) + feature_buffers = feature_buffers.new_zeros(batch_size, num_polylines, feature_buffers_valid.shape[-1]) + feature_buffers[valid_mask] = feature_buffers_valid.to(polylines_feature.dtype) + + return feature_buffers diff --git a/scenestreamer/models/layers/position_encoding_utils.py b/scenestreamer/models/layers/position_encoding_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92d45f48eb79deb5b6de6ea81171d66e7a5f2ed7 --- /dev/null +++ b/scenestreamer/models/layers/position_encoding_utils.py @@ -0,0 +1,88 @@ +import math + +import torch + + +def gen_sineembed_for_position(pos_tensor, hidden_dim=256): + """Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/ + """ + assert pos_tensor.ndim == 3 + + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + half_hidden_dim = hidden_dim // 2 + scale = 2 * math.pi + dim_t = torch.arange(half_hidden_dim, dtype=pos_tensor.dtype, device=pos_tensor.device) + dim_t = 10000**(2 * (dim_t // 2) / half_hidden_dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +def gen_sineembed_for_relation(pos_tensor, heading_tensor, hidden_dim=256): + """Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/ + """ + # assert pos_tensor.ndim == 4 # (B, N, N, 2) + + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + + # sliced_hidden_dim = hidden_dim // 3 # Devided by 3 now + + scale = 2 * math.pi + dim_t = torch.arange(hidden_dim, dtype=heading_tensor.dtype, device=heading_tensor.device) + dim_t = 10000**(2 * (dim_t // 2) / hidden_dim) + x_embed = pos_tensor[..., 0] * scale + y_embed = pos_tensor[..., 1] * scale + pos_x = x_embed[..., None] / dim_t + pos_y = y_embed[..., None] / dim_t + pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) + pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) + + h_embed = heading_tensor * scale + pos_h = h_embed[..., None] / dim_t + pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), dim=-1).flatten(-2) + + pe = torch.cat((pos_x, pos_y, pos_h), dim=-1) + + # print(111) + + # Concatenate position and heading tensors + # combined_tensor = torch.cat((pos_tensor, heading_tensor), dim=-1) + # + # B, N, _, _ = combined_tensor.shape + # half_hidden_dim = hidden_dim // 3 # Divided by 3 because we now have x, y, and heading components + # scale = 2 * math.pi + # + # # Create a tensor of dimension indices scaled according to their position + # dim_t = torch.arange(half_hidden_dim, dtype=combined_tensor.dtype, device=combined_tensor.device) + # dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim) + # + # # Scale and embed each component separately + # combined_embed = combined_tensor * scale + # combined_embed = combined_embed / dim_t + # + # # Apply sine and cosine alternately across the last dimension + # embed_sin = combined_embed[..., 0::2].sin() + # embed_cos = combined_embed[..., 1::2].cos() + # combined_embed = torch.stack((embed_sin, embed_cos), dim=-1).flatten(-2) + + return pe diff --git a/scenestreamer/models/motion_decoder.py b/scenestreamer/models/motion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..86a0aeba220497742d02062fbc209efd44942d09 --- /dev/null +++ b/scenestreamer/models/motion_decoder.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn + +from scenestreamer.dataset.preprocess_action_label import TurnAction, AccelerationAction, SafetyAction +from scenestreamer.models.layers import common_layers, position_encoding_utils +from scenestreamer.models.layers.decoder_layer import TransformerDecoder, TransformerDecoderLayer +from scenestreamer.tokenization import get_action_dim, get_tokenizer, START_ACTION +from scenestreamer.utils import unwrap + + +def create_causal_mask(T, N, is_valid_mask=False): + """ Create the causal mask for a flattened token sequence. Tokens will not attend to future ids. Tokens for the + agents in the same step can attend to each other. + + row: a query + col: a key + + So for mask[100] it should see more keys than mask[0]. + + Note that all +1 positions will be filled -inf. + + Args: + T: Number of steps + N: Number of agents (padded to fit different batches) + + Returns: + Causal mask in shape: (T*N, T*N), wherein 1s represent the ids to be ignored. + """ + block = torch.ones(N, N, dtype=torch.bool) + causal_mask = torch.kron(torch.tril(torch.ones(T, T, dtype=torch.bool)), block) + if is_valid_mask: + return causal_mask + else: + return ~causal_mask + + +class MotionDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + self.add_pe_for_static_features = self.config.MODEL.get('ADD_PE_FOR_STATIC_FEATURE', False) + assert self.add_pe_for_static_features is False + self.num_actions = get_action_dim(self.config) + + num_pred_steps = 16 + 1 + + dropout = self.config.MODEL['DROPOUT_OF_ATTN'] + + pre_projection = self.config.MODEL['PRE_PROJECTION'] + + # TODO: better name + self.relative_pe = self.config.MODEL['RELATIVE_PE_DECODER'] + + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + self.decoder = TransformerDecoder( + decoder_layer=TransformerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dim_feedforward=d_model * 4, + dropout=dropout, + activation="relu", + pre_projection=pre_projection + ), + num_layers=num_decoder_layers, + relative_pe=self.relative_pe, + d_model=d_model, + self_attention_knn=self.config.MODEL['SELF_ATTN_KNN'], + cross_attention_knn=self.config.MODEL['CROSS_ATTN_KNN'], + ) + self.prediction_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, d_model, self.num_actions], ret_before_act=True + ) + + self.step_pe = nn.Embedding(num_pred_steps, d_model) + + self.add_pe_for_token = self.config.MODEL.get('ADD_PE_FOR_TOKEN', False) + self.tokenizer = common_layers.Tokenizer(num_actions=self.num_actions, d_model=d_model) + if self.add_pe_for_token: + tokenizer = get_tokenizer(self.config) + # pe.shape = (num_actions, d_model) + pe = position_encoding_utils.gen_sineembed_for_position( + torch.from_numpy(tokenizer.bin_centers_flat)[None], hidden_dim=d_model + )[0].float() + + self.tokenizer.tokens.weight = nn.Parameter(torch.cat([pe, self.tokenizer.tokens.weight[-1:]])) + self.tokenizer.tokens.requires_grad_(False) + + self.tokenizer_mlp = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, d_model], ret_before_act=True + ) + else: + self.tokenizer_mlp = None + + self.use_action_label = config.ACTION_LABEL.USE_ACTION_LABEL + if self.use_action_label: + self.action_label_tokenizer_turn = common_layers.Tokenizer( + num_actions=TurnAction.num_actions, d_model=d_model + ) + self.action_label_tokenizer_accel = common_layers.Tokenizer( + num_actions=AccelerationAction.num_actions, d_model=d_model + ) + self.use_safety_label = config.ACTION_LABEL.USE_SAFETY_LABEL + if self.use_safety_label: + self.action_label_tokenizer_safety = common_layers.Tokenizer( + num_actions=SafetyAction.num_actions, d_model=d_model + ) + self.use_condition = self.use_safety_label or self.use_action_label + + def forward(self, input_dict, use_cache=False): + # === Process scene embedding === + scene_token = input_dict["encoder/scenario_token"] + scenario_valid_mask = input_dict["encoder/scenario_valid_mask"] + modeled_agent_pe = input_dict["encoder/modeled_agent_pe"] + scene_padding_mask = ~scenario_valid_mask + + # === Process action embedding === + input_action = input_dict["decoder/input_action"] + B, T_skipped, N = input_action.shape + action_valid_mask = input_dict["decoder/input_action_valid_mask"] + assert action_valid_mask.shape == (B, T_skipped, N) + + input_action[input_action == START_ACTION] = -1 + + action_token = self.tokenizer(input_action) # (B, T_skipped, N, D) + if self.add_pe_for_token: + action_token = self.tokenizer_mlp(action_token) + assert action_token.shape == (B, T_skipped, N, self.d_model) + + # Add PE to input action + if "decoder/input_step" not in input_dict: + input_dict["decoder/input_step"] = torch.arange(T_skipped).to(action_token.device) + assert input_dict["decoder/input_step"].ndim == 1 + step_pe = self.step_pe(input_dict["decoder/input_step"]) + # print('input_dict["decoder/input_step"]', input_dict["decoder/input_step"]) + action_token += step_pe.reshape(1, T_skipped, 1, self.d_model) + + assert action_token.shape == (B, T_skipped, N, self.d_model) + assert modeled_agent_pe.shape == (B, N, self.d_model), modeled_agent_pe.shape + action_token += modeled_agent_pe[:, None] + + if self.add_pe_for_static_features: + action_token += input_dict["encoder/modeled_agent_type_pe"][:, None] + + if self.use_action_label: + action_label_turn = self.action_label_tokenizer_turn(input_dict["decoder/label_turning"]) + action_label_accel = self.action_label_tokenizer_accel(input_dict["decoder/label_acceleration"]) + action_token += action_label_turn[:, None] + action_token += action_label_accel[:, None] + + if self.use_safety_label: + action_label_safety = self.action_label_tokenizer_safety(input_dict["decoder/label_safety"]) + action_token += action_label_safety[:, None] + + action_casual_mask = create_causal_mask( + T=T_skipped, N=N + ).to(action_token.device) # (B, T_skipped*N, T_skipped*N) + + # Just remove invalid actions + action_token = action_token * action_valid_mask[..., None] + + action_padding_mask = ~action_valid_mask # (T_skipped, N) + # Flatten action token from (B, T_skipped, N, D) to (B, T_skipped*N, D) + action_token = action_token.flatten(1, 2) + # Flatten action token from (B, T_skipped, N) to (B, T_skipped*N) + action_padding_mask = action_padding_mask.flatten(1, 2) + + # Cache from last rollout + past_key_value = None + if "decoder/cache" in input_dict: + past_key_value = input_dict["decoder/cache"] + + if self.relative_pe: + if use_cache: + if "decoder/modeled_agent_position_history" not in input_dict: + input_dict["decoder/modeled_agent_position_history"] = input_dict["decoder/modeled_agent_position" + ].flatten(1, 2) + input_dict["decoder/modeled_agent_heading_history"] = input_dict["decoder/modeled_agent_heading" + ].flatten(1, 2) + input_dict["decoder/modeled_agent_valid_mask_history"] = action_padding_mask + else: + input_dict["decoder/modeled_agent_position_history"] = torch.cat( + [ + input_dict["decoder/modeled_agent_position_history"], + input_dict["decoder/modeled_agent_position"].flatten(1, 2), + ], + dim=1 + ) + input_dict["decoder/modeled_agent_heading_history"] = torch.cat( + [ + input_dict["decoder/modeled_agent_heading_history"], + input_dict["decoder/modeled_agent_heading"].flatten(1, 2), + ], + dim=1 + ) + input_dict["decoder/modeled_agent_valid_mask_history"] = torch.cat( + [ + input_dict["decoder/modeled_agent_valid_mask_history"], + action_padding_mask, + ], dim=1 + ) + full_tgt_pos = input_dict["decoder/modeled_agent_position_history"] + full_tgt_heading = input_dict["decoder/modeled_agent_heading_history"] + full_tgt_mask = input_dict["decoder/modeled_agent_valid_mask_history"] + # full_tgt_causal_mask = input_dict["decoder/modeled_agent_causal_mask_history"] + else: + full_tgt_pos = input_dict["decoder/modeled_agent_position"].flatten(1, 2) + full_tgt_heading = input_dict["decoder/modeled_agent_heading"].flatten(1, 2) + full_tgt_mask = action_padding_mask + # full_tgt_causal_mask = action_casual_mask + else: + full_tgt_pos = None + full_tgt_heading = None + full_tgt_mask = None + # full_tgt_causal_mask = None + + # === Call models === + decoded_tokens = self.decoder( + tgt=action_token.swapaxes(0, 1), + tgt_mask=action_casual_mask, # swapaxes(0, 1), + tgt_key_padding_mask=action_padding_mask, + tgt_is_causal=True, + tgt_pos=input_dict["decoder/modeled_agent_position"].flatten(1, 2), + tgt_heading=input_dict["decoder/modeled_agent_heading"].flatten(1, 2), + full_tgt_pos=full_tgt_pos, + full_tgt_heading=full_tgt_heading, + full_tgt_mask=full_tgt_mask, + # full_tgt_causal_mask=full_tgt_causal_mask, + memory=scene_token.swapaxes(0, 1), + memory_mask=None, # The casual mask for memory + memory_key_padding_mask=scene_padding_mask, + memory_is_causal=False, + memory_pos=input_dict["encoder/scenario_position"], + memory_heading=input_dict["encoder/scenario_heading"], + past_key_value=past_key_value, + use_cache=use_cache + ) + + if use_cache: + decoded_tokens, past_key_value = decoded_tokens + input_dict["decoder/cache"] = past_key_value + + decoded_tokens = decoded_tokens.swapaxes(0, 1) + logits = unwrap(self.prediction_head(decoded_tokens[~action_padding_mask]), ~action_padding_mask) + logits = logits.reshape(B, T_skipped, N, self.num_actions) + + # print("Input", input_action.shape) + # # print("DECODE : ", decoded_tokens[~action_padding_mask][-62:].mean(0)[:5]) + # print("DECODED0", decoded_tokens.shape, decoded_tokens[0, :62].mean(-1)[:5]) + # print("DECODED-1", decoded_tokens.shape, decoded_tokens[0, -62:].mean(-1)[:5]) + # print("LOGIT:", logits[0, -1].mean(-1)[:5]) + # print("====") + + input_dict["decoder/output_logit"] = logits + + return input_dict + + +if __name__ == '__main__': + from scenestreamer.utils import debug_tools + from scenestreamer.models.scene_encoder import SceneEncoder + + config = debug_tools.get_debug_config() + enc = SceneEncoder(config) + dec = MotionDecoder(config) + input_dict = debug_tools.get_debug_data() + out = dec(enc(input_dict)) + print(out) diff --git a/scenestreamer/models/motion_decoder_gpt.py b/scenestreamer/models/motion_decoder_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..a455ff7d2f28aaecd5eb815c4bdaed28e9c6be87 --- /dev/null +++ b/scenestreamer/models/motion_decoder_gpt.py @@ -0,0 +1,688 @@ +import torch +import torch.nn as nn +from torch_geometric.utils import dense_to_sparse + +from scenestreamer.dataset import constants +from scenestreamer.dataset.preprocess_action_label import SafetyAction +from scenestreamer.models import relation +from scenestreamer.models.layers import common_layers, fourier_embedding +from scenestreamer.models.layers.gpt_decoder_layer import MultiCrossAttTransformerDecoderLayer, MultiCrossAttTransformerDecoder +from scenestreamer.models.motion_decoder import create_causal_mask +from scenestreamer.models.scene_encoder import mode_agent_id +from scenestreamer.tokenization import get_action_dim, get_tokenizer, START_ACTION, END_ACTION +from scenestreamer.utils import utils + + +def get_edge_info_new(*, q_k_valid_mask, q_k_relation, relation_model, relation_model_v, require_relation_pairwise=None): + B, Lq, Lk = q_k_valid_mask.shape + edge_index, _ = dense_to_sparse(q_k_valid_mask.swapaxes(1, 2).contiguous()) + assert edge_index.numel() > 0, (edge_index.shape, q_k_valid_mask.sum()) + assert edge_index[0].max() < B * Lk, f"{edge_index[0].max()} >= {B * Lk}" + assert edge_index[1].max() < B * Lq, f"{edge_index[1].max()} >= {B * Lq}" + + batch_ind = edge_index[1] // Lq + q_ind = edge_index[1] % Lq + batch_ind_k = edge_index[0] // Lk + k_ind = edge_index[0] % Lk + assert torch.all(batch_ind == batch_ind_k) + edge_features = q_k_relation[batch_ind, q_ind, k_ind] + + if relation_model_v is not None: + edge_features_v = relation_model_v(edge_features) + else: + edge_features_v = None + + if relation_model is not None: + if require_relation_pairwise is not None: + require_relation = require_relation_pairwise[batch_ind, q_ind, k_ind] + edge_features = utils.unwrap(relation_model(edge_features[require_relation]), require_relation) + else: + edge_features = relation_model(edge_features) + + return { + "edge_index": edge_index, + "edge_features": edge_features, + "edge_features_v": edge_features_v, + } + + +def get_edge_info(attn_valid_mask, rel_pe_cross, rel_pe_cross_v=None): + B, Lq, Lk = attn_valid_mask.shape + edge_index, _ = dense_to_sparse(attn_valid_mask.swapaxes(1, 2).contiguous()) + assert edge_index.numel() > 0, (edge_index.shape, attn_valid_mask.sum()) + assert edge_index[0].max() < B * Lk, f"{edge_index[0].max()} >= {B * Lk}" + assert edge_index[1].max() < B * Lq, f"{edge_index[1].max()} >= {B * Lq}" + + if rel_pe_cross is not None: + batch_ind = edge_index[1] // Lq + q_ind = edge_index[1] % Lq + batch_ind_k = edge_index[0] // Lk + k_ind = edge_index[0] % Lk + assert torch.all(batch_ind == batch_ind_k) + edge_features = rel_pe_cross[batch_ind, q_ind, k_ind] + else: + edge_features = None + + if rel_pe_cross_v is not None: + assert rel_pe_cross is not None + edge_features_v = rel_pe_cross_v[batch_ind, q_ind, k_ind] + else: + edge_features_v = None + + return { + "edge_index": edge_index, + "edge_features": edge_features, + "edge_features_v": edge_features_v, + # "attn_valid_mask": attn_valid_mask, + # "relation": rel_pe_cross, + # "relation_v": rel_pe_cross_v, + } # "relation": rel_pe_cross, "attn_valid_mask": attn_valid_mask} + + +class MotionDecoderGPT(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + self.num_actions = get_action_dim(self.config) + dropout = self.config.MODEL.DROPOUT + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + # use_condition = self.config.ACTION_LABEL.USE_ACTION_LABEL or self.config.ACTION_LABEL.USE_SAFETY_LABEL + # self.use_condition = use_condition + assert self.config.MODEL.NAME in ['gpt'] + self.add_pe_for_token = self.config.MODEL.get('ADD_PE_FOR_TOKEN', False) + assert self.add_pe_for_token + + self.use_destination = self.config.USE_DESTINATION + + simple_relation = self.config.SIMPLE_RELATION + simple_relation_factor = self.config.SIMPLE_RELATION_FACTOR + is_v7 = self.config.MODEL.IS_V7 + self.is_v7 = is_v7 + self.decoder = MultiCrossAttTransformerDecoder( + decoder_layer=MultiCrossAttTransformerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dropout=dropout, + use_adaln=self.use_destination, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=is_v7, + update_relation=self.config.UPDATE_RELATION, + add_relation_to_v=self.config.MODEL.ADD_RELATION_TO_V, + remove_rel_norm=self.config.REMOVE_REL_NORM + ), + num_layers=num_decoder_layers, + d_model=d_model, + ) + self.prediction_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, self.num_actions], ret_before_act=True, is_v7=is_v7, zero_init=is_v7 + ) + if self.use_destination: + self.prediction_prenorm = common_layers.AdaLayerNorm(d_model, conditioning_dim=d_model, batch_first=True) + + self.map_id_embed = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_MAP_FEATURES, d_model=d_model, add_one_more_action=True + ) + + else: + self.prediction_prenorm = nn.LayerNorm(d_model) + + # if self.config.BACKWARD_PREDICTION: + # if is_v7: + # raise ValueError() + # self.prediction_backward_head = common_layers.build_mlps( + # c_in=d_model, mlp_channels=[d_model, d_model, self.num_actions], ret_before_act=True + # ) + # self.prediction_backward_prenorm = nn.LayerNorm(d_model) + + if self.config.ADD_CONTOUR_RELATION: + + if self.config.SIMPLE_RELATION: + relation_d_model = d_model // simple_relation_factor + + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2t = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + + if self.config.MODEL.ADD_RELATION_TO_V: + self.relation_embed_a2a_v = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2t_v = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2s_v = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + + else: + relation_d_model = d_model + + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=13, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2t = fourier_embedding.FourierEmbedding( + input_dim=13, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding( + input_dim=13, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + else: + assert self.config.SIMPLE_RELATION is False + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=5, hidden_dim=d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2t = fourier_embedding.FourierEmbedding( + input_dim=5, hidden_dim=d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding( + input_dim=5, hidden_dim=d_model, num_freq_bands=64, is_v7=is_v7 + ) + + self.type_embed = common_layers.Tokenizer( + num_actions=constants.NUM_TYPES, d_model=d_model, add_one_more_action=False + ) + self.action_embed = common_layers.Tokenizer( + num_actions=self.num_actions, d_model=d_model, add_one_more_action=True + ) + self.shape_embed = common_layers.build_mlps( + c_in=3, mlp_channels=[d_model, d_model], ret_before_act=True, is_v7=is_v7 + ) + + if self.config.REMOVE_AGENT_FROM_SCENE_ENCODER: + self.agent_id_embed = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_AGENTS, d_model=self.d_model, add_one_more_action=False + ) + + self.motion_embed = fourier_embedding.FourierEmbedding( + input_dim=6, hidden_dim=d_model, num_freq_bands=64, is_v7=is_v7 + ) + + tokenizer = get_tokenizer(self.config) + motion_features = tokenizer.get_motion_feature() + if tokenizer.use_type_specific_bins: + motion_features = torch.cat([motion_features, torch.zeros(1, 3, 4)], dim=0) + else: + motion_features = torch.cat([motion_features, torch.zeros(1, 4)], dim=0) + self.tokenizer = tokenizer + self.register_buffer("motion_features", motion_features) + + # is start token? is end token (if any)? is padding token? is masked token? + self.special_token_embed = common_layers.Tokenizer( + num_actions=4, d_model=self.d_model, add_one_more_action=False + ) + + if self.config.BACKWARD_PREDICTION: + self.in_backward_prediction_embed = common_layers.Tokenizer( + num_actions=2, d_model=self.d_model, add_one_more_action=False + ) + + # self.use_action_label = config.ACTION_LABEL.USE_ACTION_LABEL or config.ACTION_LABEL.USE_SAFETY_LABEL + if config.ACTION_LABEL.USE_ACTION_LABEL: + raise ValueError("Not implemented") + # self.action_label_tokenizer_turn = common_layers.Tokenizer( + # num_actions=TurnAction.num_actions, d_model=d_model, add_one_more_action=True + # ) + # self.action_label_tokenizer_accel = common_layers.Tokenizer( + # num_actions=AccelerationAction.num_actions, d_model=d_model, add_one_more_action=True + # ) + # if config.ACTION_LABEL.USE_SAFETY_LABEL: + # self.action_label_tokenizer_safety = common_layers.Tokenizer( + # num_actions=SafetyAction.num_actions, d_model=d_model, add_one_more_action=True + # ) + # if self.use_adaln: + # self.initialize_weights_for_adaln() + + # if self.is_v7: + # self.prediction_head[-1].weight.data.fill_(0) + + # def initialize_weights_for_adaln(self): + # # Zero-out adaLN modulation layers in DiT blocks: + # for block in self.decoder.layers: + # nn.init.constant_(block.adaln_modulation[-1].weight, 0) + # nn.init.constant_(block.adaln_modulation[-1].bias, 0) + # nn.init.constant_(self.adaln_modulation[-1].weight, 0) + # nn.init.constant_(self.adaln_modulation[-1].bias, 0) + + def randomize_modeled_agent_id(self, data_dict, clip_agent_id=False): + modeled_agent_id = data_dict["decoder/agent_id"] + # batch_index = data_dict.get("batch_idx", None) + if not self.config.MODEL.RANDOMIZE_AGENT_ID: + if clip_agent_id: + modeled_agent_id = mode_agent_id( + modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True + ) + return modeled_agent_id.long() + + # assert batch_index is not None, "Need batch index to randomize agent id!" + # batch_to_unique = {} + # for i, b in enumerate(batch_index): + # b = b.item() + # if b not in batch_to_unique: + # batch_to_unique[b] = len(batch_to_unique) + + if clip_agent_id: + modeled_agent_id = mode_agent_id( + modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True + ) + B, N = modeled_agent_id.shape + weights = torch.ones(self.config.PREPROCESSING.MAX_AGENTS).expand(B, -1) + if N > self.config.PREPROCESSING.MAX_AGENTS: + num_samples = self.config.PREPROCESSING.MAX_AGENTS + new_modeled_agent_id = torch.full_like(modeled_agent_id, num_samples - 1) + new_modeled_agent_id[:, :num_samples] = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(modeled_agent_id) + new_modeled_agent_id[modeled_agent_id == -1] = -1 + else: + num_samples = N + new_modeled_agent_id = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(modeled_agent_id) + new_modeled_agent_id[modeled_agent_id == -1] = -1 + + # Allocate same agent id to the same batch + # return_modeled_agent_id = torch.full_like(modeled_agent_id, -1) + # for i, b in enumerate(batch_index): + # b = b.item() + # return_modeled_agent_id[i] = new_modeled_agent_id[batch_to_unique[b]] + # return return_modeled_agent_id + return new_modeled_agent_id.long() + + def forward(self, input_dict, use_cache=False, a2a_knn=None, a2t_knn=None, a2s_knn=None): + in_evaluation = input_dict["in_evaluation"][0].item() + + # num_heads = self.num_heads + # === Process scene embedding === + scene_token = input_dict["encoder/scenario_token"] + scenario_valid_mask = input_dict["encoder/scenario_valid_mask"] + + # === Process action embedding === + input_action = input_dict["decoder/input_action"] + modeled_agent_delta = input_dict["decoder/modeled_agent_delta"] + B, T_skipped, N = input_action.shape[:3] + + if self.config.REMOVE_AGENT_FROM_SCENE_ENCODER: + if in_evaluation: + assert "decoder/randomized_modeled_agent_id" in input_dict, "Need to provide randomized modeled agent id for evaluation! Please call randomize_modeled_agent_id()" + new_modeled_agent_id = input_dict["decoder/randomized_modeled_agent_id"] + else: + new_modeled_agent_id = self.randomize_modeled_agent_id(input_dict, clip_agent_id=False) + modeled_agent_pe = self.agent_id_embed(new_modeled_agent_id) + + # print("modeled_agent_pe", new_modeled_agent_id[0]) + else: + modeled_agent_pe = input_dict["encoder/modeled_agent_pe"] + + assert modeled_agent_pe.shape == (B, N, self.d_model), (B, N, self.d_model, modeled_agent_pe.shape) + modeled_agent_pe = modeled_agent_pe[:, None].expand(B, T_skipped, N, self.d_model) + + action_valid_mask = input_dict["decoder/input_action_valid_mask"] + assert action_valid_mask.shape == (B, T_skipped, N), (action_valid_mask.shape, (B, T_skipped, N)) + agent_pos = input_dict["decoder/modeled_agent_position"] + agent_heading = input_dict["decoder/modeled_agent_heading"] + # agent_vel = input_dict["decoder/modeled_agent_velocity"] + + # ===== Prepare input tokens ===== + if "decoder/input_step" not in input_dict: + input_dict["decoder/input_step"] = torch.arange(T_skipped).to(input_action.device) + agent_step = input_dict["decoder/input_step"].reshape(1, T_skipped, 1).expand(B, T_skipped, N) + + # Shape embedding and type embedding + type_emb = self.type_embed(input_dict["decoder/agent_type"])[:, None].expand(B, T_skipped, N, self.d_model) + shape_emb = self.shape_embed(input_dict["decoder/current_agent_shape"] + )[:, None].expand(B, T_skipped, N, self.d_model) + + valid_actions = input_action[action_valid_mask] + is_start_actions = valid_actions == START_ACTION + special_tok = torch.full_like(valid_actions, 0).int() + special_tok[is_start_actions] = 1 + valid_actions[is_start_actions] = -1 + if self.config.BACKWARD_PREDICTION: + is_end_actions = valid_actions == END_ACTION + special_tok[is_end_actions] = 2 + valid_actions[is_end_actions] = -1 + special_tok_emb = self.special_token_embed(special_tok) + if self.config.BACKWARD_PREDICTION: + if "in_backward_prediction" not in input_dict: + input_dict["in_backward_prediction"] = valid_actions.new_zeros(B, T_skipped, N) + in_backward_full = input_dict["in_backward_prediction"].reshape(B, 1, 1).expand(B, T_skipped, N) + in_backward = in_backward_full[action_valid_mask] + in_backward = in_backward.int() + in_backward_prediction_embed = self.in_backward_prediction_embed(in_backward) + special_tok_emb = special_tok_emb + in_backward_prediction_embed + action_emb = self.action_embed(valid_actions) + + # agent_type = agent_type[:, None].expand(B, T_skipped, N) + if self.tokenizer.use_type_specific_bins: + agent_type = input_dict["decoder/agent_type"] + agent_type = agent_type - 1 + agent_type[agent_type < 0] = 0 + agent_type = agent_type.reshape(B, 1, N).expand(B, T_skipped, N) + agent_type = agent_type[action_valid_mask] # Already flattened + agent_type = agent_type.reshape(-1, 1, 1, 1).expand(-1, self.motion_features.shape[0], 1, 4) + motion_feat = self.motion_features.reshape(1, -1, 3, 4).expand(agent_type.shape[0], -1, 3, 4) + motion_feat = torch.gather(motion_feat, dim=-2, index=agent_type).squeeze(-2) + else: + motion_feat = self.motion_features.reshape(1, -1, 4).expand(valid_actions.shape[0], -1, 4) + valid_actions[valid_actions < 0] = self.num_actions + valid_actions = valid_actions.reshape(-1, 1, 1).expand(-1, 1, 4) + motion_feat = torch.gather(motion_feat, dim=-2, index=valid_actions).squeeze(-2) + + motion_feat = torch.cat([motion_feat, modeled_agent_delta[action_valid_mask]], dim=-1) + + action_token = self.motion_embed( + continuous_inputs=motion_feat, + categorical_embs=[ + special_tok_emb, modeled_agent_pe[action_valid_mask], type_emb[action_valid_mask], + shape_emb[action_valid_mask], action_emb + ] + ) + action_token = utils.unwrap(action_token, action_valid_mask) + assert action_token.shape == (B, T_skipped, N, self.d_model) + assert action_valid_mask.shape == (B, T_skipped, N) + + # ===== Get agent-condition relation ===== + # condition_token = None + # if self.config.ACTION_LABEL.USE_SAFETY_LABEL: + # action_label_safety = self.action_label_tokenizer_safety(input_dict["decoder/label_safety"]) + # condition_token = action_label_safety[:, None] + # if self.use_adaln: + # pass + # else: + # action_token = action_token + condition_token + + condition_token = None + if self.use_destination: + condition_token = utils.unwrap( + self.map_id_embed(input_dict["decoder/dest_map_index"][action_valid_mask]), + action_valid_mask + ) + B, M, _ = input_dict["encoder/map_position"].shape + S = scene_token.shape[1] + map_id = torch.zeros([B, S], dtype=torch.long, device=scene_token.device) + map_id[:, :M] = torch.arange(M, device=map_id.device).unsqueeze(0) + map_id[~scenario_valid_mask] = -1 + map_id_pe = self.map_id_embed(map_id) + # We don't add map feat ID pe in SceneEncoder, so we add it here. (same code in TG dec too) + scene_token = scene_token + map_id_pe + + # ===== Get agent-temporal relation ===== + # BTND -> BNTD + agent_pos_bntd = torch.permute(agent_pos, [0, 2, 1, 3]) + agent_heading_bnt = torch.permute(agent_heading, [0, 2, 1]) + agent_mask_bnt = torch.permute(action_valid_mask, [0, 2, 1]) + agent_step_bnt = torch.permute(agent_step, [0, 2, 1]) + # agent_vel_bnt = torch.permute(agent_vel, [0, 2, 1, 3]) + if use_cache: + self.update_cache(input_dict) + + agent_pos_with_history = input_dict["decoder/modeled_agent_position_history"] + agent_heading_with_history = input_dict["decoder/modeled_agent_heading_history"] + agent_mask_with_history = input_dict["decoder/modeled_agent_valid_mask_history"] + agent_step_with_history = input_dict["decoder/modeled_agent_step_history"] + # agent_vel_with_history = input_dict["decoder/modeled_agent_velocity_history"] + real_T = agent_mask_with_history.shape[1] + key_pos = torch.permute(agent_pos_with_history, [0, 2, 1, 3]).flatten(0, 1) + # key_vel = torch.permute(agent_vel_with_history, [0, 2, 1, 3]).flatten(0, 1) + key_heading = torch.permute(agent_heading_with_history, [0, 2, 1]).flatten(0, 1) + key_mask = torch.permute(agent_mask_with_history, [0, 2, 1]).flatten(0, 1) + causal_valid_mask = None + key_step = agent_step_with_history.reshape(1, 1, -1).expand(B, N, -1).flatten(0, 1) + else: + real_T = T_skipped + # key_vel = agent_vel_bnt.flatten(0, 1) + key_pos = agent_pos_bntd.flatten(0, 1) + key_heading = agent_heading_bnt.flatten(0, 1) + key_mask = agent_mask_bnt.flatten(0, 1) + key_step = agent_step_bnt.flatten(0, 1) + causal_valid_mask = create_causal_mask(T=real_T, N=1, is_valid_mask=True).to(action_token.device) + + assert agent_pos_bntd.shape == (B, N, T_skipped, 2) + + a2t_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + agent_shape_no_time = input_dict["decoder/current_agent_shape" + ] #.reshape(B, 1, N, 3).expand(B, real_T, N, 3) + agent_length = agent_shape_no_time[..., 0] + agent_width = agent_shape_no_time[..., 1] + a2t_kwargs = dict( + query_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + query_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + key_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + key_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + non_agent_relation=False, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + + if self.config.SIMPLE_RELATION: + relation_func = relation.compute_relation_simple_relation + else: + relation_func = relation.compute_relation + + a2t_rel_feat, a2t_mask, _ = relation_func( + query_pos=agent_pos_bntd.flatten(0, 1), # BN, T, D + query_heading=agent_heading_bnt.flatten(0, 1), + query_valid_mask=agent_mask_bnt.flatten(0, 1), + query_step=agent_step_bnt.flatten(0, 1), + key_pos=key_pos, # BN, T_full, D + key_heading=key_heading, + key_valid_mask=key_mask, + key_step=key_step, + # hidden_dim=self.d_model, + causal_valid_mask=causal_valid_mask, + knn=None, + max_distance=None, + # return_pe=False, + # key_vel=key_vel, + # query_vel=agent_vel_bnt.flatten(0, 1), + **a2t_kwargs + ) + # a2t_rel_pe = utils.unwrap(self.relation_embed_a2t(a2t_rel_feat[a2t_mask]), a2t_mask) + # a2t_rel_pe_v = None + # if self.config.MODEL.ADD_RELATION_TO_V: + # a2t_rel_pe_v = utils.unwrap(self.relation_embed_a2t_v(a2t_rel_feat[a2t_mask]), a2t_mask) + # a2t_info = get_edge_info(attn_valid_mask=a2t_mask, rel_pe_cross=a2t_rel_pe, rel_pe_cross_v=a2t_rel_pe_v) + + a2t_info = get_edge_info_new( + q_k_valid_mask=a2t_mask, + q_k_relation=a2t_rel_feat, + relation_model=self.relation_embed_a2t, + relation_model_v=self.relation_embed_a2t_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + # print("===") + # print("a2t_mask.shape", a2t_mask.shape, a2t_mask.sum(-1).float().mean(), a2t_mask.float().mean()) + + # ===== Get agent-agent relation ===== + a2a_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + a2a_kwargs = dict( + query_width=w, + query_length=l, + key_width=w, + key_length=l, + non_agent_relation=False, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + a2a_rel_feat, a2a_mask, _ = relation_func( + query_pos=agent_pos.flatten(0, 1), # BT, N, D + query_heading=agent_heading.flatten(0, 1), + query_valid_mask=action_valid_mask.flatten(0, 1), + query_step=agent_step.flatten(0, 1), + key_pos=agent_pos.flatten(0, 1), + key_heading=agent_heading.flatten(0, 1), + key_valid_mask=action_valid_mask.flatten(0, 1), + key_step=agent_step.flatten(0, 1), + # hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2a_knn if a2a_knn is not None else self.config.MODEL.A2A_KNN, + max_distance=self.config.MODEL.A2A_DISTANCE, + # return_pe=False, + # query_vel=agent_vel.flatten(0, 1), + # key_vel=agent_vel.flatten(0, 1), + **a2a_kwargs + ) + # a2a_rel_pe = utils.unwrap(self.relation_embed_a2a(a2a_rel_feat[a2a_mask]), a2a_mask) + # a2a_rel_pe_v = None + # if self.config.MODEL.ADD_RELATION_TO_V: + # a2a_rel_pe_v = utils.unwrap(self.relation_embed_a2a_v(a2a_rel_feat[a2a_mask]), a2a_mask) + # a2a_info = get_edge_info(attn_valid_mask=a2a_mask, rel_pe_cross=a2a_rel_pe, rel_pe_cross_v=a2a_rel_pe_v) + a2a_info = get_edge_info_new( + q_k_valid_mask=a2a_mask, + q_k_relation=a2a_rel_feat, + relation_model=self.relation_embed_a2a, + relation_model_v=self.relation_embed_a2a_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + # print("a2a_mask.shape", a2a_mask.shape, a2a_mask.sum(-1).float().mean(), a2a_mask.float().mean()) + + # ===== Get agent-scene relation ===== + a2s_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + kw = torch.zeros_like(input_dict["encoder/scenario_position"][..., 0]) + a2s_kwargs = dict( + query_width=w, + query_length=l, + key_width=kw, + key_length=kw, + non_agent_relation=True, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + a2s_rel_feat, a2s_mask, a2s_indices = relation_func( + query_pos=agent_pos.flatten(1, 2), # B, TN, D + query_heading=agent_heading.flatten(1, 2), + query_valid_mask=action_valid_mask.flatten(1, 2), + query_step=agent_step.flatten(1, 2), + key_pos=input_dict["encoder/scenario_position"], # [..., :2], + key_heading=input_dict["encoder/scenario_heading"], + key_valid_mask=scenario_valid_mask, + key_step=agent_pos.new_zeros(B, input_dict["encoder/scenario_position"].shape[1]), + # hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2s_knn if a2s_knn is not None else self.config.MODEL.A2S_KNN, + max_distance=self.config.MODEL.A2S_DISTANCE, + gather=False, + # return_pe=False, + **a2s_kwargs + ) + # a2s_rel_pe = utils.unwrap(self.relation_embed_a2s(a2s_rel_feat[a2s_mask]), a2s_mask) + # a2s_rel_pe_v = None + # if self.config.MODEL.ADD_RELATION_TO_V: + # a2s_rel_pe_v = utils.unwrap(self.relation_embed_a2s_v(a2s_rel_feat[a2s_mask]), a2s_mask) + # a2s_info = get_edge_info(attn_valid_mask=a2s_mask, rel_pe_cross=a2s_rel_pe, rel_pe_cross_v=a2s_rel_pe_v) + a2s_info = get_edge_info_new( + q_k_valid_mask=a2s_mask, + q_k_relation=a2s_rel_feat, + relation_model=self.relation_embed_a2s, + relation_model_v=self.relation_embed_a2s_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + # print("a2s_mask.shape", a2s_mask.shape, a2s_mask.sum(-1).float().mean(), a2s_mask.float().mean()) + + # === Call models === + past_key_value_list = None + if use_cache: + # Cache from last rollout + if "decoder/cache" in input_dict: + past_key_value_list = input_dict["decoder/cache"] + + decoded_tokens = self.decoder( + agent_token=action_token, + scene_token=scene_token, + a2a_info=a2a_info, + a2t_info=a2t_info, + a2s_info=a2s_info, + condition_token=condition_token, + use_cache=use_cache, # We don't need decoder to take care cache. + past_key_value_list=past_key_value_list + ) + + if use_cache: + decoded_tokens, past_key_value_list = decoded_tokens + for l in past_key_value_list: + if l: + l.append((B * N, real_T)) + input_dict["decoder/cache"] = past_key_value_list + + if self.config.USE_DIFFUSION: + input_dict["decoder/decoded_tokens"] = decoded_tokens + return input_dict + + if self.use_destination: + # output_tokens = self.prediction_adaln_norm(decoded_tokens[action_valid_mask]) + # shift, scale = self.adaln_modulation(output_tokens).chunk(2, dim=-1) + # output_tokens = utils.modulate(output_tokens, shift, scale) + + output_tokens = self.prediction_prenorm(decoded_tokens[action_valid_mask], z=condition_token[action_valid_mask]) + + else: + output_tokens = self.prediction_prenorm(decoded_tokens[action_valid_mask]) + logits = utils.unwrap(self.prediction_head(output_tokens), action_valid_mask) + + # if self.config.BACKWARD_PREDICTION: + # output_tokens_backward = self.prediction_backward_prenorm(decoded_tokens[action_valid_mask]) + # logits_backward = utils.unwrap(self.prediction_backward_head(output_tokens_backward), action_valid_mask) + # + # logits = torch.where( + # in_backward_full.unsqueeze(-1).expand(-1, -1, -1, logits_backward.shape[-1]), logits_backward, logits + # ) + + # if self.is_v7: + # logits = 30 * torch.tanh(logits / 30) + + assert logits.shape == (B, T_skipped, N, self.num_actions) + input_dict["decoder/output_logit"] = logits + + # from torch.cuda import memory_snapshot + # + # snapshot = memory_snapshot() + # # This will show a detailed report on allocations in JSON format + # print(snapshot) + + return input_dict + + def update_cache(self, input_dict): + assert self.config.EVALUATION.USE_CACHE + if "decoder/modeled_agent_position_history" not in input_dict: + input_dict["decoder/modeled_agent_position_history"] = input_dict["decoder/modeled_agent_position"].clone() + # input_dict["decoder/modeled_agent_velocity_history"] = input_dict["decoder/modeled_agent_velocity"].clone() + input_dict["decoder/modeled_agent_heading_history"] = input_dict["decoder/modeled_agent_heading"].clone() + input_dict["decoder/modeled_agent_valid_mask_history"] = input_dict["decoder/input_action_valid_mask" + ].clone() + input_dict["decoder/modeled_agent_step_history"] = input_dict["decoder/input_step"].clone() + else: + input_dict["decoder/modeled_agent_position_history"] = torch.cat( + [input_dict["decoder/modeled_agent_position_history"], input_dict["decoder/modeled_agent_position"]], + dim=1 + ) + # input_dict["decoder/modeled_agent_velocity_history"] = torch.cat( + # [input_dict["decoder/modeled_agent_velocity_history"], input_dict["decoder/modeled_agent_velocity"]], + # dim=1 + # ) + input_dict["decoder/modeled_agent_heading_history"] = torch.cat( + [input_dict["decoder/modeled_agent_heading_history"], input_dict["decoder/modeled_agent_heading"]], + dim=1 + ) + input_dict["decoder/modeled_agent_valid_mask_history"] = torch.cat( + [ + input_dict["decoder/modeled_agent_valid_mask_history"], + input_dict["decoder/input_action_valid_mask"], + ], + dim=1 + ) + input_dict["decoder/modeled_agent_step_history"] = torch.cat( + [input_dict["decoder/modeled_agent_step_history"], input_dict["decoder/input_step"]], dim=0 + ) diff --git a/scenestreamer/models/motion_decoder_gpt_cleaned.py b/scenestreamer/models/motion_decoder_gpt_cleaned.py new file mode 100644 index 0000000000000000000000000000000000000000..a0122e86c1614b883459446e14d66802b1c1de9d --- /dev/null +++ b/scenestreamer/models/motion_decoder_gpt_cleaned.py @@ -0,0 +1,650 @@ +from dataclasses import dataclass +from typing import Optional, Any + +import torch +import torch.nn as nn +from torch_geometric.utils import dense_to_sparse + +from scenestreamer.dataset import constants +from scenestreamer.dataset.preprocess_action_label import SafetyAction +from scenestreamer.models import relation +from scenestreamer.models.layers import common_layers, fourier_embedding +from scenestreamer.models.layers.gpt_decoder_layer import MultiCrossAttTransformerDecoderLayer, MultiCrossAttTransformerDecoder +from scenestreamer.models.motion_decoder import create_causal_mask +from scenestreamer.models.scene_encoder import mode_agent_id +from scenestreamer.tokenization import get_action_dim, get_tokenizer, START_ACTION, END_ACTION +from scenestreamer.utils import utils + + +def get_edge_info_new(*, q_k_valid_mask, q_k_relation, relation_model, relation_model_v): + B, Lq, Lk = q_k_valid_mask.shape + edge_index, _ = dense_to_sparse(q_k_valid_mask.swapaxes(1, 2).contiguous()) + assert edge_index.numel() > 0, (edge_index.shape, q_k_valid_mask.sum()) + assert edge_index[0].max() < B * Lk, f"{edge_index[0].max()} >= {B * Lk}" + assert edge_index[1].max() < B * Lq, f"{edge_index[1].max()} >= {B * Lq}" + + batch_ind = edge_index[1] // Lq + q_ind = edge_index[1] % Lq + batch_ind_k = edge_index[0] // Lk + k_ind = edge_index[0] % Lk + assert torch.all(batch_ind == batch_ind_k) + edge_features = q_k_relation[batch_ind, q_ind, k_ind] + + if relation_model_v is not None: + edge_features_v = relation_model_v(edge_features) + else: + edge_features_v = None + + if relation_model is not None: + edge_features = relation_model(edge_features) + + return { + "edge_index": edge_index, + "edge_features": edge_features, + "edge_features_v": edge_features_v, + } + + +############################################################################### +# New Data Structures +############################################################################### + +# @dataclass +# class EncoderOutput: +# # Formerly EncoderInput +# scenario_token: torch.Tensor +# scenario_valid_mask: torch.Tensor +# scenario_position: torch.Tensor +# scenario_heading: torch.Tensor +# modeled_agent_pe: Optional[torch.Tensor] = None + + +@dataclass +class DecoderInput: + # Encoder's output + map_token: torch.Tensor + map_valid_mask: torch.Tensor + map_position: torch.Tensor + map_heading: torch.Tensor + map_pe: torch.Tensor + tl_token: torch.Tensor + tl_position: torch.Tensor + tl_valid_mask: torch.Tensor + + # Decoder-side inputs + input_action: torch.Tensor # (B, T, N) + input_action_valid_mask: torch.Tensor # (B, T, N) + agent_delta: torch.Tensor # (B, T, N, D_delta) + agent_position: torch.Tensor # (B, N, D_pos) or (B, T, N, D_pos) + agent_heading: torch.Tensor # (B, N) or (B, T, N) + agent_velocity: torch.Tensor # (B, N, D_vel) or (B, T, N, D_vel) # TODO: Remove this? + agent_type: torch.Tensor # (B, N) or (B, T, N) + + agent_shape: torch.Tensor # (B, N, D_shape) + agent_pe: torch.Tensor # (B, N, D_model) + + # Optional fields + input_step: Optional[torch.Tensor] = None # (B, T) + in_backward_prediction: Optional[torch.Tensor] = None # (B, ) + + def sanity_check(self): + # Check shapes + B, T, N = self.input_action.shape + assert self.input_action_valid_mask.shape == ( + B, T, N + ), f"Expected shape {(B, T, N)}, got {self.input_action_valid_mask.shape}" + assert self.agent_delta.shape == ( + B, T, N, self.agent_delta.shape[-1] + ), f"Expected shape {(B, T, N, self.agent_delta.shape[-1])}, got {self.agent_delta.shape}" + assert self.agent_position.shape in [ + (B, N, self.agent_position.shape[-1]), (B, T, N, self.agent_position.shape[-1]) + ], f"Unexpected shape {self.agent_position.shape}" + assert self.agent_heading.shape in [(B, N), (B, T, N)], f"Unexpected shape {self.agent_heading.shape}" + assert self.agent_velocity.shape in [ + (B, N, self.agent_velocity.shape[-1]), (B, T, N, self.agent_velocity.shape[-1]) + ], f"Unexpected shape {self.agent_velocity.shape}" + assert self.agent_type.shape in [(B, N), (B, T, N)], f"Unexpected shape {self.agent_type.shape}" + assert self.agent_shape.shape == ( + B, N, self.agent_shape.shape[-1] + ), f"Expected shape {(B, N, self.agent_shape.shape[-1])}, got {self.agent_shape.shape}" + assert self.agent_pe.shape == ( + B, N, self.agent_pe.shape[-1] + ), f"Expected shape {(B, N, self.agent_pe.shape[-1])}, got {self.agent_pe.shape}" + + # Check optional fields + if self.input_step is not None: + assert self.input_step.shape == (B, T), f"Expected shape {(B, T)}, got {self.input_step.shape}" + if self.in_backward_prediction is not None: + assert self.in_backward_prediction.shape == ( + B, + ), f"Expected shape {(B,)}, got {self.in_backward_prediction.shape}" + + def __post_init__(self): + if self.input_step is None: + B, T, _ = self.input_action.shape + self.input_step = torch.arange(T, device=self.input_action.device).unsqueeze(0).expand(B, T) + + +# @dataclass +# class HistoryData: +# # History/cache for autoregressive decoding +# modeled_agent_position_history: Optional[torch.Tensor] = None +# modeled_agent_velocity_history: Optional[torch.Tensor] = None +# modeled_agent_heading_history: Optional[torch.Tensor] = None +# modeled_agent_valid_mask_history: Optional[torch.Tensor] = None +# modeled_agent_step_history: Optional[torch.Tensor] = None + +# @dataclass +# class MotionDecoderData: +# # Top-level container for all inputs to MotionDecoder +# in_evaluation: bool +# encoder: EncoderOutput +# decoder: DecoderInput +# history: Optional[HistoryData] = None +# cache: Optional[Any] = None # For past key/value caching or similar + +############################################################################### +# MotionDecoderGPT with the New Data Structures +############################################################################### + + +class MotionDecoderGPT(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + self.num_actions = get_action_dim(self.config) + dropout = self.config.MODEL.DROPOUT + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + + use_adaln = False + self.use_adaln = use_adaln + + # TODO: Remove there: + assert self.config.MODEL.IS_V7 is True + assert self.config.SIMPLE_RELATION is True + assert self.config.SIMPLE_RELATION_FACTOR == 1 + self.add_pe_for_token = self.config.MODEL.get('ADD_PE_FOR_TOKEN', False) + assert self.add_pe_for_token is True + assert self.config.MODEL.NAME in ['gpt'] + assert self.config.ADD_CONTOUR_RELATION + assert self.config.SIMPLE_RELATION is True + assert self.config.MODEL.ADD_RELATION_TO_V is False + assert self.config.REMOVE_AGENT_FROM_SCENE_ENCODER is True + + simple_relation = self.config.SIMPLE_RELATION + simple_relation_factor = self.config.SIMPLE_RELATION_FACTOR + self.decoder = MultiCrossAttTransformerDecoder( + decoder_layer=MultiCrossAttTransformerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dropout=dropout, + use_adaln=use_adaln, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=True, + update_relation=self.config.UPDATE_RELATION, + add_relation_to_v=self.config.MODEL.ADD_RELATION_TO_V, + remove_rel_norm=self.config.REMOVE_REL_NORM + ), + num_layers=num_decoder_layers, + d_model=d_model, + ) + self.prediction_head = common_layers.build_mlps( + c_in=d_model, + mlp_channels=[d_model, self.num_actions], + ret_before_act=True, + ) + if self.use_adaln: + self.prediction_adaln_norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=1e-6) + self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(d_model, 2 * d_model, bias=True)) + else: + self.prediction_prenorm = nn.LayerNorm(d_model) + + relation_d_model = d_model // simple_relation_factor + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=12, + hidden_dim=relation_d_model, + num_freq_bands=64, + ) + self.relation_embed_a2t = fourier_embedding.FourierEmbedding( + input_dim=12, + hidden_dim=relation_d_model, + num_freq_bands=64, + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding( + input_dim=3, + hidden_dim=relation_d_model, + num_freq_bands=64, + ) + + self.type_embed = common_layers.Tokenizer( + num_actions=constants.NUM_TYPES, d_model=d_model, add_one_more_action=False + ) + self.action_embed = common_layers.Tokenizer( + num_actions=self.num_actions, d_model=d_model, add_one_more_action=True + ) + self.shape_embed = common_layers.build_mlps( + c_in=3, + mlp_channels=[d_model, d_model], + ret_before_act=True, + ) + self.agent_id_embed = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_AGENTS, d_model=self.d_model, add_one_more_action=False + ) + + self.motion_embed = fourier_embedding.FourierEmbedding( + input_dim=6, + hidden_dim=d_model, + num_freq_bands=64, + ) + + tokenizer = get_tokenizer(self.config) + motion_features = tokenizer.get_motion_feature() + if tokenizer.use_type_specific_bins: + motion_features = torch.cat([motion_features, torch.zeros(1, 3, 4)], dim=0) + else: + motion_features = torch.cat([motion_features, torch.zeros(1, 4)], dim=0) + self.tokenizer = tokenizer + self.register_buffer("motion_features", motion_features) + + self.special_token_embed = common_layers.Tokenizer( + num_actions=4, d_model=self.d_model, add_one_more_action=False + ) + + if self.config.BACKWARD_PREDICTION: + self.in_backward_prediction_embed = common_layers.Tokenizer( + num_actions=2, d_model=self.d_model, add_one_more_action=False + ) + if self.use_adaln: + self.initialize_weights_for_adaln() + + def initialize_weights_for_adaln(self): + for block in self.decoder.layers: + nn.init.constant_(block.adaln_modulation[-1].weight, 0) + nn.init.constant_(block.adaln_modulation[-1].bias, 0) + nn.init.constant_(self.adaln_modulation[-1].weight, 0) + nn.init.constant_(self.adaln_modulation[-1].bias, 0) + + def randomize_modeled_agent_id(self, data: MotionDecoderData, clip_agent_id=False): + modeled_agent_id = data.decoder.agent_id # Was: input_dict["decoder/agent_id"] + if not self.config.MODEL.RANDOMIZE_AGENT_ID: + if clip_agent_id: + modeled_agent_id = mode_agent_id( + modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True + ) + return modeled_agent_id + + if clip_agent_id: + modeled_agent_id = mode_agent_id( + modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True + ) + B, N = modeled_agent_id.shape + weights = torch.ones(self.config.PREPROCESSING.MAX_AGENTS).expand(B, -1) + if N > self.config.PREPROCESSING.MAX_AGENTS: + num_samples = self.config.PREPROCESSING.MAX_AGENTS + new_modeled_agent_id = torch.full_like(modeled_agent_id, num_samples - 1) + new_modeled_agent_id[:, :num_samples] = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(modeled_agent_id) + new_modeled_agent_id[modeled_agent_id == -1] = -1 + else: + num_samples = N + new_modeled_agent_id = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(modeled_agent_id) + new_modeled_agent_id[modeled_agent_id == -1] = -1 + return new_modeled_agent_id + + def _legacy_data_dict_to_new_data_structure(self, data_dict): + encoder = EncoderOutput( + scenario_token=data_dict["encoder/scenario_token"], + scenario_valid_mask=data_dict["encoder/scenario_valid_mask"], + scenario_position=data_dict["encoder/scenario_position"], + scenario_heading=data_dict["encoder/scenario_heading"], + modeled_agent_pe=data_dict["encoder/modeled_agent_pe"] if "encoder/modeled_agent_pe" in data_dict else None, + ) + decoder = DecoderInput( + input_action=data_dict["decoder/input_action"], + modeled_agent_delta=data_dict["decoder/modeled_agent_delta"], + input_action_valid_mask=data_dict["decoder/input_action_valid_mask"], + modeled_agent_position=data_dict["decoder/modeled_agent_position"], + modeled_agent_heading=data_dict["decoder/modeled_agent_heading"], + modeled_agent_velocity=data_dict["decoder/modeled_agent_velocity"], + agent_type=data_dict["decoder/agent_type"], + current_agent_shape=data_dict["decoder/current_agent_shape"], + agent_id=data_dict["decoder/agent_id"], + input_step=data_dict["decoder/input_step"] if "decoder/input_step" in data_dict else None, + # label_safety=data_dict["decoder/label_safety"], + randomized_modeled_agent_id=data_dict["decoder/randomized_modeled_agent_id"] + if "decoder/randomized_modeled_agent_id" in data_dict else None, + in_backward_prediction=data_dict["decoder/in_backward_prediction"] + if "decoder/in_backward_prediction" in data_dict else None, + ) + # TODO: Better handle history and cache. + history = HistoryData( + modeled_agent_position_history=data_dict["history/modeled_agent_position_history"] + if "history/modeled_agent_position_history" in data_dict else None, + modeled_agent_velocity_history=data_dict["history/modeled_agent_velocity_history"] + if "history/modeled_agent_velocity_history" in data_dict else None, + modeled_agent_heading_history=data_dict["history/modeled_agent_heading_history"] + if "history/modeled_agent_heading_history" in data_dict else None, + modeled_agent_valid_mask_history=data_dict["history/modeled_agent_valid_mask_history"] + if "history/modeled_agent_valid_mask_history" in data_dict else None, + modeled_agent_step_history=data_dict["history/modeled_agent_step_history"] + if "history/modeled_agent_step_history" in data_dict else None, + ) + return MotionDecoderData( + in_evaluation=data_dict["in_evaluation"], + encoder=encoder, + decoder=decoder, + history=history, + cache=data_dict["cache"] if "cache" in data_dict else None, + ) + + def forward(self, data: MotionDecoderData, use_cache=False, a2a_knn=None, a2t_knn=None, a2s_knn=None): + + if isinstance(data, dict): + data = self._legacy_data_dict_to_new_data_structure(data) + + in_evaluation = data.in_evaluation + + # Process scene (encoder) embedding + scene_token = data.encoder.scenario_token + scenario_valid_mask = data.encoder.scenario_valid_mask + + # Process action (decoder) embedding + input_action = data.decoder.input_action + modeled_agent_delta = data.decoder.modeled_agent_delta + B, T_skipped, N = input_action.shape + + if in_evaluation: + assert data.decoder.randomized_modeled_agent_id is not None, \ + "Need to provide randomized modeled agent id for evaluation! Please call randomize_modeled_agent_id()" + new_modeled_agent_id = data.decoder.randomized_modeled_agent_id + else: + new_modeled_agent_id = self.randomize_modeled_agent_id(data, clip_agent_id=False) + modeled_agent_pe = self.agent_id_embed(new_modeled_agent_id) + + assert modeled_agent_pe.shape == (B, N, self.d_model), modeled_agent_pe.shape + modeled_agent_pe = modeled_agent_pe[:, None].expand(B, T_skipped, N, self.d_model) + + action_valid_mask = data.decoder.input_action_valid_mask + agent_pos = data.decoder.modeled_agent_position + agent_heading = data.decoder.modeled_agent_heading + + # input_step is set in __post_init__ if not provided. + agent_step = data.decoder.input_step.reshape(1, T_skipped, 1).expand(B, T_skipped, N) + + # Shape and type embeddings + type_emb = self.type_embed(data.decoder.agent_type)[:, None].expand(B, T_skipped, N, self.d_model) + shape_emb = self.shape_embed(data.decoder.current_agent_shape)[:, None].expand(B, T_skipped, N, self.d_model) + + valid_actions = input_action[action_valid_mask] + is_start_actions = valid_actions == START_ACTION + special_tok = torch.full_like(valid_actions, 0).int() + special_tok[is_start_actions] = 1 + valid_actions[is_start_actions] = -1 + if self.config.BACKWARD_PREDICTION: + is_end_actions = valid_actions == END_ACTION + special_tok[is_end_actions] = 2 + valid_actions[is_end_actions] = -1 + special_tok_emb = self.special_token_embed(special_tok) + if self.config.BACKWARD_PREDICTION: + if data.decoder.in_backward_prediction is None: + data.decoder.in_backward_prediction = valid_actions.new_zeros(B, T_skipped, N) + in_backward_full = data.decoder.in_backward_prediction.reshape(B, 1, 1).expand(B, T_skipped, N) + in_backward = in_backward_full[action_valid_mask].int() + in_backward_prediction_embed = self.in_backward_prediction_embed(in_backward) + special_tok_emb += in_backward_prediction_embed + action_emb = self.action_embed(valid_actions) + + if self.tokenizer.use_type_specific_bins: + agent_type = data.decoder.agent_type + agent_type = agent_type - 1 + agent_type[agent_type < 0] = 0 + agent_type = agent_type.reshape(B, 1, N).expand(B, T_skipped, N) + agent_type = agent_type[action_valid_mask] + agent_type = agent_type.reshape(-1, 1, 1, 1).expand(-1, self.motion_features.shape[0], 1, 4) + motion_feat = self.motion_features.reshape(1, -1, 3, 4).expand(agent_type.shape[0], -1, 3, 4) + motion_feat = torch.gather(motion_feat, dim=-2, index=agent_type).squeeze(-2) + else: + motion_feat = self.motion_features.reshape(1, -1, 4).expand(valid_actions.shape[0], -1, 4) + valid_actions[valid_actions < 0] = self.num_actions + valid_actions = valid_actions.reshape(-1, 1, 1).expand(-1, 1, 4) + motion_feat = torch.gather(motion_feat, dim=-2, index=valid_actions).squeeze(-2) + + motion_feat = torch.cat([motion_feat, modeled_agent_delta[action_valid_mask]], dim=-1) + + action_token = self.motion_embed( + continuous_inputs=motion_feat, + categorical_embs=[ + special_tok_emb, modeled_agent_pe[action_valid_mask], type_emb[action_valid_mask], + shape_emb[action_valid_mask], action_emb + ] + ) + action_token = utils.unwrap(action_token, action_valid_mask) + assert action_token.shape == (B, T_skipped, N, self.d_model) + assert action_valid_mask.shape == (B, T_skipped, N) + + condition_token = None + if self.config.ACTION_LABEL.USE_SAFETY_LABEL: + action_label_safety = self.action_label_tokenizer_safety(data.decoder.label_safety) + condition_token = action_label_safety[:, None] + if self.use_adaln: + pass + else: + action_token += condition_token + + # Prepare agent-temporal relation data (permute BTND -> BNTD etc.) + agent_pos_bntd = torch.permute(agent_pos, [0, 2, 1, 3]) + agent_heading_bnt = torch.permute(agent_heading, [0, 2, 1]) + agent_mask_bnt = torch.permute(action_valid_mask, [0, 2, 1]) + agent_step_bnt = torch.permute(agent_step, [0, 2, 1]) + + if use_cache: + self.update_cache(data) + agent_pos_with_history = data.history.modeled_agent_position_history + agent_heading_with_history = data.history.modeled_agent_heading_history + agent_mask_with_history = data.history.modeled_agent_valid_mask_history + agent_step_with_history = data.history.modeled_agent_step_history + real_T = agent_mask_with_history.shape[1] + key_pos = torch.permute(agent_pos_with_history, [0, 2, 1, 3]).flatten(0, 1) + key_heading = torch.permute(agent_heading_with_history, [0, 2, 1]).flatten(0, 1) + key_mask = torch.permute(agent_mask_with_history, [0, 2, 1]).flatten(0, 1) + causal_valid_mask = None + key_step = agent_step_with_history.reshape(1, 1, -1).expand(B, N, -1).flatten(0, 1) + else: + real_T = T_skipped + key_pos = agent_pos_bntd.flatten(0, 1) + key_heading = agent_heading_bnt.flatten(0, 1) + key_mask = agent_mask_bnt.flatten(0, 1) + key_step = agent_step_bnt.flatten(0, 1) + causal_valid_mask = create_causal_mask(T=real_T, N=1, is_valid_mask=True).to(action_token.device) + + assert agent_pos_bntd.shape == (B, N, T_skipped, 2) + + a2t_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + agent_shape_no_time = data.decoder.current_agent_shape + agent_length = agent_shape_no_time[..., 0] + agent_width = agent_shape_no_time[..., 1] + a2t_kwargs = dict( + include_contour=True, + query_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + query_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + key_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + key_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + non_agent_relation=False, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + + if self.config.SIMPLE_RELATION: + relation_func = relation.compute_relation_simple_relation + else: + relation_func = relation.compute_relation + + a2t_rel_feat, a2t_mask, _ = relation_func( + query_pos=agent_pos_bntd.flatten(0, 1), + query_heading=agent_heading_bnt.flatten(0, 1), + query_valid_mask=agent_mask_bnt.flatten(0, 1), + query_step=agent_step_bnt.flatten(0, 1), + key_pos=key_pos, + key_heading=key_heading, + key_valid_mask=key_mask, + key_step=key_step, + hidden_dim=self.d_model, + causal_valid_mask=causal_valid_mask, + knn=None, + max_distance=None, + return_pe=False, + **a2t_kwargs + ) + a2t_info = get_edge_info_new( + q_k_valid_mask=a2t_mask, + q_k_relation=a2t_rel_feat, + relation_model=self.relation_embed_a2t, + relation_model_v=self.relation_embed_a2t_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + a2a_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + a2a_kwargs = dict( + include_contour=True, + query_width=w, + query_length=l, + key_width=w, + key_length=l, + non_agent_relation=False, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + a2a_rel_feat, a2a_mask, _ = relation_func( + query_pos=agent_pos.flatten(0, 1), + query_heading=agent_heading.flatten(0, 1), + query_valid_mask=action_valid_mask.flatten(0, 1), + query_step=agent_step.flatten(0, 1), + key_pos=agent_pos.flatten(0, 1), + key_heading=agent_heading.flatten(0, 1), + key_valid_mask=action_valid_mask.flatten(0, 1), + key_step=agent_step.flatten(0, 1), + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2a_knn if a2a_knn is not None else self.config.MODEL.A2A_KNN, + max_distance=self.config.MODEL.A2A_DISTANCE, + return_pe=False, + **a2a_kwargs + ) + a2a_info = get_edge_info_new( + q_k_valid_mask=a2a_mask, + q_k_relation=a2a_rel_feat, + relation_model=self.relation_embed_a2a, + relation_model_v=self.relation_embed_a2a_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + a2s_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + kw = torch.zeros_like(data.encoder.scenario_position[..., 0]) + a2s_kwargs = dict( + include_contour=True, + query_width=w, + query_length=l, + key_width=kw, + key_length=kw, + non_agent_relation=True, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + a2s_rel_feat, a2s_mask, a2s_indices = relation_func( + query_pos=agent_pos.flatten(1, 2), + query_heading=agent_heading.flatten(1, 2), + query_valid_mask=action_valid_mask.flatten(1, 2), + query_step=agent_step.flatten(1, 2), + key_pos=data.encoder.scenario_position, + key_heading=data.encoder.scenario_heading, + key_valid_mask=scenario_valid_mask, + key_step=agent_pos.new_zeros(B, data.encoder.scenario_position.shape[1]), + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2s_knn if a2s_knn is not None else self.config.MODEL.A2S_KNN, + max_distance=self.config.MODEL.A2S_DISTANCE, + gather=False, + return_pe=False, + **a2s_kwargs + ) + a2s_info = get_edge_info_new( + q_k_valid_mask=a2s_mask, + q_k_relation=a2s_rel_feat, + relation_model=self.relation_embed_a2s, + relation_model_v=self.relation_embed_a2s_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + past_key_value_list = None + if use_cache: + past_key_value_list = data.cache + + decoded_tokens = self.decoder( + agent_token=action_token, + scene_token=scene_token, + a2a_info=a2a_info, + a2t_info=a2t_info, + a2s_info=a2s_info, + condition_token=condition_token if self.use_adaln else None, + use_cache=use_cache, + past_key_value_list=past_key_value_list + ) + + if use_cache: + decoded_tokens, past_key_value_list = decoded_tokens + for l in past_key_value_list: + if l: + l.append((B * N, real_T)) + data.cache = past_key_value_list + + if self.config.USE_DIFFUSION: + # Attach decoded tokens to the decoder part of our data structure. + data.decoder.decoded_tokens = decoded_tokens + return data + + if self.use_adaln: + output_tokens = self.prediction_adaln_norm(decoded_tokens[action_valid_mask]) + shift, scale = self.adaln_modulation(output_tokens).chunk(2, dim=-1) + output_tokens = utils.modulate(output_tokens, shift, scale) + else: + output_tokens = self.prediction_prenorm(decoded_tokens[action_valid_mask]) + logits = utils.unwrap(self.prediction_head(output_tokens), action_valid_mask) + + assert logits.shape == (B, T_skipped, N, self.num_actions) + data.decoder.output_logit = logits + + return data + + def update_cache(self, data: MotionDecoderData): + assert self.config.EVALUATION.USE_CACHE + if data.history is None: + data.history = HistoryData( + modeled_agent_position_history=data.decoder.modeled_agent_position.clone(), + modeled_agent_velocity_history=data.decoder.modeled_agent_velocity.clone(), + modeled_agent_heading_history=data.decoder.modeled_agent_heading.clone(), + modeled_agent_valid_mask_history=data.decoder.input_action_valid_mask.clone(), + modeled_agent_step_history=data.decoder.input_step.clone() + ) + else: + data.history.modeled_agent_position_history = torch.cat( + [data.history.modeled_agent_position_history, data.decoder.modeled_agent_position], dim=1 + ) + data.history.modeled_agent_velocity_history = torch.cat( + [data.history.modeled_agent_velocity_history, data.decoder.modeled_agent_velocity], dim=1 + ) + data.history.modeled_agent_heading_history = torch.cat( + [data.history.modeled_agent_heading_history, data.decoder.modeled_agent_heading], dim=1 + ) + data.history.modeled_agent_valid_mask_history = torch.cat( + [data.history.modeled_agent_valid_mask_history, data.decoder.input_action_valid_mask], dim=1 + ) + data.history.modeled_agent_step_history = torch.cat( + [data.history.modeled_agent_step_history, data.decoder.input_step], dim=0 + ) diff --git a/scenestreamer/models/motion_decoder_gpt_diffusion.py b/scenestreamer/models/motion_decoder_gpt_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b8aea99ed4b002ee34fba2285831ddd74e6644a6 --- /dev/null +++ b/scenestreamer/models/motion_decoder_gpt_diffusion.py @@ -0,0 +1,518 @@ +import torch +import torch.nn as nn + +from scenestreamer.dataset import constants +from scenestreamer.models import relation +from scenestreamer.models.layers import common_layers, fourier_embedding +from scenestreamer.models.layers.gpt_decoder_layer import MultiCrossAttTransformerDecoderLayer, MultiCrossAttTransformerDecoder +from scenestreamer.models.motion_decoder import create_causal_mask +from scenestreamer.models.motion_decoder_gpt import get_edge_info +from scenestreamer.models.scene_encoder import mode_agent_id +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.utils import utils + + +class MotionDecoderGPTDiffusion(nn.Module): + def __init__(self, config): + + # TODO: ADD_RELATION_TO_V is not implemented! + print("config.MODEL.ADD_RELATION_TO_V", config.MODEL.ADD_RELATION_TO_V) + print("config.MODEL.ADD_RELATION_TO_V", config.MODEL.ADD_RELATION_TO_V) + print("config.MODEL.ADD_RELATION_TO_V", config.MODEL.ADD_RELATION_TO_V) + print("config.MODEL.ADD_RELATION_TO_V", config.MODEL.ADD_RELATION_TO_V) + print("config.MODEL.ADD_RELATION_TO_V", config.MODEL.ADD_RELATION_TO_V) + print("config.MODEL.ADD_RELATION_TO_V", config.MODEL.ADD_RELATION_TO_V) + + super().__init__() + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + # self.num_actions = get_action_dim(self.config) + dropout = self.config.MODEL['DROPOUT_OF_ATTN'] + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + # use_condition = self.config.ACTION_LABEL.USE_ACTION_LABEL or self.config.ACTION_LABEL.USE_SAFETY_LABEL + # self.use_condition = use_condition + assert self.config.MODEL.NAME in ['gpt'] + self.add_pe_for_token = self.config.MODEL.get('ADD_PE_FOR_TOKEN', False) + assert self.add_pe_for_token + use_adaln = self.config.USE_ADALN + self.use_adaln = use_adaln + + simple_relation = self.config.SIMPLE_RELATION + simple_relation_factor = 1 + self.decoder = MultiCrossAttTransformerDecoder( + decoder_layer=MultiCrossAttTransformerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dropout=dropout, + use_adaln=use_adaln, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor + ), + num_layers=num_decoder_layers, + d_model=d_model, + self_attention_knn=self.config.MODEL['SELF_ATTN_KNN'], + cross_attention_knn=self.config.MODEL['CROSS_ATTN_KNN'], + ) + + assert self.config.BACKWARD_PREDICTION is False + + assert self.config.ADD_CONTOUR_RELATION is True + + assert self.config.SIMPLE_RELATION is True + relation_d_model = d_model // simple_relation_factor + + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64 + ) + self.relation_embed_a2t = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64 + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64 + ) + + self.type_embed = common_layers.Tokenizer( + num_actions=constants.NUM_TYPES, d_model=d_model, add_one_more_action=False + ) + # self.action_embed = common_layers.Tokenizer( + # num_actions=self.num_actions, d_model=d_model, add_one_more_action=True + # ) + self.shape_embed = common_layers.build_mlps(c_in=3, mlp_channels=[d_model, d_model], ret_before_act=True) + + if self.config.REMOVE_AGENT_FROM_SCENE_ENCODER: + self.agent_id_embed = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_AGENTS, d_model=self.d_model, add_one_more_action=False + ) + + tokenizer = get_tokenizer(self.config) + # motion_features = tokenizer.get_motion_feature() + # if tokenizer.use_type_specific_bins: + # motion_features = torch.cat([motion_features, torch.zeros(1, 3, 4)], dim=0) + # else: + # motion_features = torch.cat([motion_features, torch.zeros(1, 4)], dim=0) + self.tokenizer = tokenizer + # self.register_buffer("motion_features", motion_features) + + if self.tokenizer.use_delta_delta: + agent_motion_embed_dim = 5 * 2 + 2 + else: + agent_motion_embed_dim = 5 * 3 + 2 + + self.agent_motion_embed = fourier_embedding.FourierEmbedding( + input_dim=agent_motion_embed_dim, hidden_dim=d_model, num_freq_bands=64 + ) + + # Special tokens: Invalid, Valid, Start, Masked, Unused. + self.special_token_embed = common_layers.Tokenizer( + num_actions=5, d_model=self.d_model, add_one_more_action=False + ) + + # if self.config.BACKWARD_PREDICTION: + # self.in_backward_prediction_embed = common_layers.Tokenizer( + # num_actions=2, d_model=self.d_model, add_one_more_action=False + # ) + + from scenestreamer.diffusion.diffusion_loss import DiffLoss + # diffloss_w = 4 * self.config.MODEL.D_MODEL + diffloss_w = self.config.MODEL.D_MODEL + diffloss_d = 3 + grad_checkpointing = False + decoder_embed_dim = self.config.MODEL.D_MODEL + + if self.tokenizer.use_delta_delta: + token_embed_dim = 5 * 2 + else: + token_embed_dim = 5 * 3 + + predict_xstart = False + diffusion_steps = 100 + num_sampling_steps = '100' + self.diffusion_loss = DiffLoss( + target_channels=token_embed_dim, + z_channels=decoder_embed_dim, + width=diffloss_w, + depth=diffloss_d, + num_sampling_steps=num_sampling_steps, + grad_checkpointing=grad_checkpointing, + predict_xstart=predict_xstart, + diffusion_steps=diffusion_steps, + ) + + def randomize_modeled_agent_id(self, data_dict, clip_agent_id=False): + modeled_agent_id = data_dict["decoder/agent_id"] + # batch_index = data_dict.get("batch_idx", None) + if not self.config.MODEL.RANDOMIZE_AGENT_ID: + if clip_agent_id: + modeled_agent_id = mode_agent_id( + modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True + ) + return modeled_agent_id + + # assert batch_index is not None, "Need batch index to randomize agent id!" + # batch_to_unique = {} + # for i, b in enumerate(batch_index): + # b = b.item() + # if b not in batch_to_unique: + # batch_to_unique[b] = len(batch_to_unique) + + if clip_agent_id: + modeled_agent_id = mode_agent_id( + modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=True + ) + B, N = modeled_agent_id.shape + weights = torch.ones(self.config.PREPROCESSING.MAX_AGENTS).expand(B, -1) + if N > self.config.PREPROCESSING.MAX_AGENTS: + num_samples = self.config.PREPROCESSING.MAX_AGENTS + new_modeled_agent_id = torch.full_like(modeled_agent_id, num_samples - 1) + new_modeled_agent_id[:, :num_samples] = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(modeled_agent_id) + new_modeled_agent_id[modeled_agent_id == -1] = -1 + else: + num_samples = N + new_modeled_agent_id = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(modeled_agent_id) + new_modeled_agent_id[modeled_agent_id == -1] = -1 + + # Allocate same agent id to the same batch + # return_modeled_agent_id = torch.full_like(modeled_agent_id, -1) + # for i, b in enumerate(batch_index): + # b = b.item() + # return_modeled_agent_id[i] = new_modeled_agent_id[batch_to_unique[b]] + # return return_modeled_agent_id + return new_modeled_agent_id + + def forward(self, input_dict, use_cache=False, a2a_knn=None, a2t_knn=None, a2s_knn=None): + in_evaluation = input_dict["in_evaluation"][0].item() + + # === Process scene embedding === + scene_token = input_dict["encoder/scenario_token"] + scenario_valid_mask = input_dict["encoder/scenario_valid_mask"] + + # === Process action embedding === + # input_action = input_dict["decoder/input_action"] + modeled_agent_delta = input_dict["decoder/modeled_agent_delta"] + # B, T_skipped, N = input_action.shape + input_special_token = input_dict["decoder/input_action"] + B, T_skipped, N = input_special_token.shape + + if self.config.REMOVE_AGENT_FROM_SCENE_ENCODER: + if in_evaluation: + assert "decoder/randomized_modeled_agent_id" in input_dict, "Need to provide randomized modeled agent id for evaluation! Please call randomize_modeled_agent_id()" + new_modeled_agent_id = input_dict["decoder/randomized_modeled_agent_id"] + else: + new_modeled_agent_id = self.randomize_modeled_agent_id(input_dict, clip_agent_id=False) + modeled_agent_pe = self.agent_id_embed(new_modeled_agent_id) + + # print("modeled_agent_pe", new_modeled_agent_id[0]) + else: + modeled_agent_pe = input_dict["encoder/modeled_agent_pe"] + assert modeled_agent_pe.shape == (B, N, self.d_model), modeled_agent_pe.shape + modeled_agent_pe = modeled_agent_pe[:, None].expand(B, T_skipped, N, self.d_model) + + action_valid_mask = input_dict["decoder/input_action_valid_mask"] + assert action_valid_mask.shape == (B, T_skipped, N), (action_valid_mask.shape, (B, T_skipped, N)) + agent_pos = input_dict["decoder/modeled_agent_position"][..., :2] + agent_heading = input_dict["decoder/modeled_agent_heading"] + + # ===== Prepare input tokens ===== + if "decoder/input_step" not in input_dict: + input_dict["decoder/input_step"] = torch.arange(T_skipped).to(scene_token.device) + agent_step = input_dict["decoder/input_step"].reshape(1, T_skipped, 1).expand(B, T_skipped, N) + + # Shape embedding and type embedding + type_emb = self.type_embed(input_dict["decoder/agent_type"])[:, None].expand(B, T_skipped, N, self.d_model) + shape_emb = self.shape_embed(input_dict["decoder/current_agent_shape"] + )[:, None].expand(B, T_skipped, N, self.d_model) + special_tok_emb = self.special_token_embed(input_special_token) + + # The input token contains: + # 1. Special token (start, end, padding, masked) + # 2. Modeled agent id + # 3. Type embedding + # 4. Shape embedding + # 5. Last action (15-dim) + # 6. modeled_agent_delta + # No need to add modeled_agent_delta as the model will take care. + input_agent_motion = input_dict["decoder/input_agent_motion"] + cont_input = torch.cat([input_agent_motion[action_valid_mask], modeled_agent_delta[action_valid_mask]], dim=-1) + action_token = self.agent_motion_embed( + continuous_inputs=cont_input, + categorical_embs=[ + special_tok_emb[action_valid_mask], + modeled_agent_pe[action_valid_mask], + type_emb[action_valid_mask], + shape_emb[action_valid_mask], + ] + ) + action_token = utils.unwrap(action_token, action_valid_mask) + assert action_token.shape == (B, T_skipped, N, self.d_model) + assert action_valid_mask.shape == (B, T_skipped, N) + + # ===== Get agent-condition relation ===== + condition_token = None + # if self.config.ACTION_LABEL.USE_SAFETY_LABEL: + # action_label_safety = self.action_label_tokenizer_safety(input_dict["decoder/label_safety"]) + # condition_token = action_label_safety[:, None] + # if self.use_adaln: + # pass + # else: + # action_token += condition_token + + # ===== Get agent-temporal relation ===== + # BTND -> BNTD + agent_pos_bntd = torch.permute(agent_pos, [0, 2, 1, 3]) + agent_heading_bnt = torch.permute(agent_heading, [0, 2, 1]) + agent_mask_bnt = torch.permute(action_valid_mask, [0, 2, 1]) + agent_step_bnt = torch.permute(agent_step, [0, 2, 1]) + # agent_vel_bnt = torch.permute(agent_vel, [0, 2, 1, 3]) + if use_cache: + self.update_cache(input_dict) + + agent_pos_with_history = input_dict["decoder/modeled_agent_position_history"] + agent_heading_with_history = input_dict["decoder/modeled_agent_heading_history"] + agent_mask_with_history = input_dict["decoder/modeled_agent_valid_mask_history"] + agent_step_with_history = input_dict["decoder/modeled_agent_step_history"] + # agent_vel_with_history = input_dict["decoder/modeled_agent_velocity_history"] + real_T = agent_mask_with_history.shape[1] + key_pos = torch.permute(agent_pos_with_history, [0, 2, 1, 3]).flatten(0, 1) + # key_vel = torch.permute(agent_vel_with_history, [0, 2, 1, 3]).flatten(0, 1) + key_heading = torch.permute(agent_heading_with_history, [0, 2, 1]).flatten(0, 1) + key_mask = torch.permute(agent_mask_with_history, [0, 2, 1]).flatten(0, 1) + causal_valid_mask = None + key_step = agent_step_with_history.reshape(1, 1, -1).expand(B, N, -1).flatten(0, 1) + else: + real_T = T_skipped + # key_vel = agent_vel_bnt.flatten(0, 1) + key_pos = agent_pos_bntd.flatten(0, 1) + key_heading = agent_heading_bnt.flatten(0, 1) + key_mask = agent_mask_bnt.flatten(0, 1) + key_step = agent_step_bnt.flatten(0, 1) + causal_valid_mask = create_causal_mask(T=real_T, N=1, is_valid_mask=True).to(action_token.device) + + assert agent_pos_bntd.shape == (B, N, T_skipped, 2) + + a2t_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + agent_shape_no_time = input_dict["decoder/current_agent_shape" + ] # .reshape(B, 1, N, 3).expand(B, real_T, N, 3) + agent_length = agent_shape_no_time[..., 0] + agent_width = agent_shape_no_time[..., 1] + a2t_kwargs = dict( + include_contour=True, + query_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + query_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + key_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + key_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + ) + + if self.config.SIMPLE_RELATION: + relation_func = relation.compute_relation_simple_relation + else: + relation_func = relation.compute_relation + + a2t_rel_feat, a2t_mask, _ = relation_func( + query_pos=agent_pos_bntd.flatten(0, 1), # BN, T, D + query_heading=agent_heading_bnt.flatten(0, 1), + query_valid_mask=agent_mask_bnt.flatten(0, 1), + query_step=agent_step_bnt.flatten(0, 1), + key_pos=key_pos, # BN, T_full, D + key_heading=key_heading, + key_valid_mask=key_mask, + key_step=key_step, + hidden_dim=self.d_model, + causal_valid_mask=causal_valid_mask, + knn=None, + return_pe=False, + # key_vel=key_vel, + # query_vel=agent_vel_bnt.flatten(0, 1), + **a2t_kwargs + ) + a2t_rel_pe = utils.unwrap(self.relation_embed_a2t(a2t_rel_feat[a2t_mask]), a2t_mask) + a2t_info = get_edge_info(attn_valid_mask=a2t_mask, rel_pe_cross=a2t_rel_pe) + + # ===== Get agent-agent relation ===== + a2a_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + a2a_kwargs = dict( + include_contour=True, + query_width=w, + query_length=l, + key_width=w, + key_length=l, + ) + a2a_rel_feat, a2a_mask, _ = relation_func( + query_pos=agent_pos.flatten(0, 1), # BT, N, D + query_heading=agent_heading.flatten(0, 1), + query_valid_mask=action_valid_mask.flatten(0, 1), + query_step=agent_step.flatten(0, 1), + key_pos=agent_pos.flatten(0, 1), + key_heading=agent_heading.flatten(0, 1), + key_valid_mask=action_valid_mask.flatten(0, 1), + key_step=agent_step.flatten(0, 1), + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2a_knn if a2a_knn is not None else self.config.MODEL.A2A_KNN, + return_pe=False, + # query_vel=agent_vel.flatten(0, 1), + # key_vel=agent_vel.flatten(0, 1), + **a2a_kwargs + ) + a2a_rel_pe = utils.unwrap(self.relation_embed_a2a(a2a_rel_feat[a2a_mask]), a2a_mask) + a2a_info = get_edge_info(attn_valid_mask=a2a_mask, rel_pe_cross=a2a_rel_pe) + + # ===== Get agent-scene relation ===== + a2a_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + kw = torch.zeros_like(input_dict["encoder/scenario_position"][..., 0]) + a2a_kwargs = dict( + include_contour=True, + query_width=w, + query_length=l, + key_width=kw, + key_length=kw, + non_agent_relation=True + ) + a2s_rel_feat, a2s_mask, a2s_indices = relation_func( + query_pos=agent_pos.flatten(1, 2), # B, TN, D + query_heading=agent_heading.flatten(1, 2), + query_valid_mask=action_valid_mask.flatten(1, 2), + query_step=agent_step.flatten(1, 2), + key_pos=input_dict["encoder/scenario_position"], # [..., :2], + key_heading=input_dict["encoder/scenario_heading"], + key_valid_mask=scenario_valid_mask, + key_step=agent_pos.new_zeros(B, input_dict["encoder/scenario_position"].shape[1]), + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2s_knn if a2s_knn is not None else self.config.MODEL.A2S_KNN, + gather=False, + return_pe=False, + **a2a_kwargs + ) + a2s_rel_pe = utils.unwrap(self.relation_embed_a2s(a2s_rel_feat[a2s_mask]), a2s_mask) + a2s_info = get_edge_info(attn_valid_mask=a2s_mask, rel_pe_cross=a2s_rel_pe) + + # === Call models === + past_key_value_list = None + if use_cache: + # Cache from last rollout + if "decoder/cache" in input_dict: + past_key_value_list = input_dict["decoder/cache"] + + decoded_tokens = self.decoder( + agent_token=action_token, + scene_token=scene_token, + a2a_info=a2a_info, + a2t_info=a2t_info, + a2s_info=a2s_info, + condition_token=condition_token if self.use_adaln else None, + use_cache=use_cache, # We don't need decoder to take care cache. + past_key_value_list=past_key_value_list + ) + + if use_cache: + decoded_tokens, past_key_value_list = decoded_tokens + for l in past_key_value_list: + if l: + l.append((B * N, real_T)) + input_dict["decoder/cache"] = past_key_value_list + + input_dict["decoder/decoded_tokens"] = decoded_tokens + return input_dict + + def update_cache(self, input_dict): + # TODO: Do we have cache for diffusion? + + assert self.config.EVALUATION.USE_CACHE + if "decoder/modeled_agent_position_history" not in input_dict: + input_dict["decoder/modeled_agent_position_history"] = input_dict["decoder/modeled_agent_position"].clone() + input_dict["decoder/modeled_agent_velocity_history"] = input_dict["decoder/modeled_agent_velocity"].clone() + input_dict["decoder/modeled_agent_heading_history"] = input_dict["decoder/modeled_agent_heading"].clone() + input_dict["decoder/modeled_agent_valid_mask_history"] = input_dict["decoder/input_action_valid_mask" + ].clone() + input_dict["decoder/modeled_agent_step_history"] = input_dict["decoder/input_step"].clone() + else: + input_dict["decoder/modeled_agent_position_history"] = torch.cat( + [input_dict["decoder/modeled_agent_position_history"], input_dict["decoder/modeled_agent_position"]], + dim=1 + ) + input_dict["decoder/modeled_agent_velocity_history"] = torch.cat( + [input_dict["decoder/modeled_agent_velocity_history"], input_dict["decoder/modeled_agent_velocity"]], + dim=1 + ) + input_dict["decoder/modeled_agent_heading_history"] = torch.cat( + [input_dict["decoder/modeled_agent_heading_history"], input_dict["decoder/modeled_agent_heading"]], + dim=1 + ) + input_dict["decoder/modeled_agent_valid_mask_history"] = torch.cat( + [ + input_dict["decoder/modeled_agent_valid_mask_history"], + input_dict["decoder/input_action_valid_mask"], + ], + dim=1 + ) + input_dict["decoder/modeled_agent_step_history"] = torch.cat( + [input_dict["decoder/modeled_agent_step_history"], input_dict["decoder/input_step"]], dim=0 + ) + + def get_diffusion_loss(self, data_dict): + + target = data_dict['decoder/target_agent_motion'] + target_valid_mask = data_dict['decoder/target_action_valid_mask'] + t = target[target_valid_mask] + + out = data_dict["decoder/decoded_tokens"] + z = out[target_valid_mask] + + # TODO: Consider getting this back? + # self.diffusion_batch_mul = 4 + # z = z.repeat(self.diffusion_batch_mul, 1) + # t = t.repeat(self.diffusion_batch_mul, 1) + + loss_dict = self.diffusion_loss(z=z, target=t) + + mean = t.mean(0) + std = t.std(0) + # assert mean.shape[0] == 15 + for i in range(mean.shape[0]): + loss_dict[f"motion_stat/target_mean_{i}"] = mean[i] # .item() + loss_dict[f"motion_stat/target_std_{i}"] = std[i] # .item() + loss_dict[f"motion_stat/pred_mean_{i}"] = loss_dict["model_output"][i] # .item() + + loss_dict["toks"] = z.shape[0] + + # TODO: Remove stat in formal experiment. + print("\n==== Diffusion Loss ====") + print("TARG", [round(v.item(), 4) for v in mean.cpu().detach().numpy()]) + print("TARG_MAX", [round(v.item(), 4) for v in t.max(0).values.cpu().detach().numpy()]) + print("TARG_MIN", [round(v.item(), 4) for v in t.min(0).values.cpu().detach().numpy()]) + print("PRED", [round(v.item(), 4) for v in loss_dict["model_output"]]) + print("MSE", loss_dict["mse"].mean().item()) + print("==== Diffusion Loss END ====") + + loss_dict.pop("model_output") + return loss_dict + + def sample_diffusion(self, input_dict, use_cache): + + # TODO: Do we need to introduce the EMA model here? Or we can make another copy. + + out = self.forward(input_dict, use_cache) + tok = out["decoder/decoded_tokens"] + m = out["decoder/input_action_valid_mask"] + z = tok[m] + + temperature = 1.0 + cfg = 1.0 + predicted = self.diffusion_loss.sample(z, temperature, cfg) + + predicted = utils.unwrap(predicted, m) + + out["decoder/output_action"] = predicted + return out diff --git a/scenestreamer/models/motion_decoder_gpt_fast.py b/scenestreamer/models/motion_decoder_gpt_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..b10d704d17dc7046859efca0b9b8b15f6e786eeb --- /dev/null +++ b/scenestreamer/models/motion_decoder_gpt_fast.py @@ -0,0 +1,695 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scenestreamer.dataset import constants +from scenestreamer.models import relation +from scenestreamer.models.layers import common_layers, fourier_embedding +from scenestreamer.models.layers.gpt_decoder_layer import MultiCrossAttTransformerDecoderLayer, MultiCrossAttTransformerDecoder +from scenestreamer.models.motion_decoder import create_causal_mask +from scenestreamer.models.motion_decoder_gpt import MotionDecoderGPT as MotionDecoderGPTBase, get_edge_info_new +from scenestreamer.tokenization import get_action_dim, get_tokenizer, START_ACTION as MOTION_START_ACTION, END_ACTION as MOTION_END_ACTION +from scenestreamer.utils import utils + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization. + Applies standard layer normalization and then conditions the normalized output + on a latent vector z via learned affine parameters. + """ + def __init__(self, hidden_size, conditioning_dim, eps=1e-5): + super().__init__() + self.eps = eps + # These projections output a modulation (scale and bias) for each feature. + self.gamma_proj = nn.Linear(conditioning_dim, hidden_size) + self.beta_proj = nn.Linear(conditioning_dim, hidden_size) + # We disable affine parameters inside the LayerNorm since they will be provided by z. + self.ln = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=False) + + def forward(self, x, z): + """ + x: Tensor of shape (..., hidden_size) to be normalized. + z: Conditioning tensor of shape (B, conditioning_dim) if x is [B, seq_len, hidden_size] + or shape (B, conditioning_dim) when x is [B, hidden_size]. + """ + normalized = self.ln(x) + # If x is 3D (B, seq_len, hidden_size), unsqueeze z along seq_len dimension. + if normalized.dim() == 3: + assert z.dim() == 2 + gamma = self.gamma_proj(z).unsqueeze(0) # Note that input x is NOT batch first. + beta = self.beta_proj(z).unsqueeze(0) + elif normalized.dim() == 2: + gamma = self.gamma_proj(z) # [B, hidden_size] + beta = self.beta_proj(z) + else: + raise ValueError("Unsupported input tensor shape for AdaLayerNorm") + # Modulate normalized activations. + return normalized * (1 + gamma) + beta + + +class TransformerBlock(nn.Module): + """ + A single transformer block that uses adaptive layer normalization. + It includes a self-attention layer and a feed-forward network. + """ + def __init__(self, hidden_size, num_heads, conditioning_dim, dropout=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=False) + self.adaln1 = AdaLayerNorm(hidden_size, conditioning_dim) + self.adaln2 = AdaLayerNorm(hidden_size, conditioning_dim) + + # Simple feed-forward network. + self.ff = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), nn.ReLU(), nn.Linear(hidden_size * 4, hidden_size), + nn.Dropout(dropout) + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, z, attn_mask=None, key_padding_mask=None): + """ + x: Tensor of shape [seq_len, B, hidden_size] + z: Conditioning tensor of shape [B, conditioning_dim] + attn_mask: Optional attention mask for self-attention. + key_padding_mask: Optional mask for padded positions. + """ + # Self-attention with pre-normalization using AdaLN. + # We apply AdaLN to x before attention. + x_norm = self.adaln1(x, z) + assert attn_mask.dtype == key_padding_mask.dtype + attn_output, _ = self.self_attn( + x_norm, x_norm, x_norm, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=True + ) + x = x + self.dropout(attn_output) + + # Feed-forward network with pre-normalization. + x_norm = self.adaln2(x, z) + ff_output = self.ff(x_norm) + x = x + self.dropout(ff_output) + return x + + +class TransformerPredictionHead(nn.Module): + """ + Transformer-based prediction head with adaptive layer normalization. + + The forward function takes: + - x: LongTensor of shape [B, seq_len]. It may contain -1, which will be replaced by the pad token. + - z: FloatTensor of shape [B, conditioning_dim] used to condition the AdaLN layers. + + The autoregressive generate function handles generation with proper and tokens. + During training x is assumed not to include these special tokens. + """ + def __init__( + self, + vocab_size, + hidden_size, + num_heads, + num_layers, + conditioning_dim, + max_seq_len=512, + dropout=0.1, + pad_token=0, + sos_token=1, + eos_token=2 + ): + super().__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + self.pad_token = pad_token + self.sos_token = sos_token + self.eos_token = eos_token + self.num_actions = self.vocab_size + self.max_seq_len = max_seq_len + + # Token embedding. + self.token_embedding = common_layers.Tokenizer(vocab_size, hidden_size, add_one_more_action=False) + # Positional encoding: learnable embeddings. + self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len, hidden_size)) + + # Stack of transformer blocks. + self.layers = nn.ModuleList( + [TransformerBlock(hidden_size, num_heads, conditioning_dim, dropout) for _ in range(num_layers)] + ) + # Final normalization (can be standard LN). + self.ln_final = nn.LayerNorm(hidden_size) + # Project back to vocabulary logits. + self.output_layer = nn.Linear(hidden_size, vocab_size) + + def get_action_embedding(self, action): + action_emb = None + for i in range(action.shape[-1]): + v = self.token_embedding(action[..., i]) + v = torch.where((action[..., i] == -1)[..., None], 0, v) + if action_emb is None: + action_emb = v + else: + action_emb += v + mask = (action == -1).all(dim=-1) + action_emb[mask] = 0 + return action_emb + + def prepare_for_training(self, x): + """ + This function append start token before the sequence and it looks for the first -1 token and replace it with end token. + """ + assert x.ndim == 2 + B, T = x.shape + + newx = x.clone() + + #### DEBUG CODE: + # newx = torch.full_like(newx, -1) + # newx[:, 0] = 777 + # x = torch.full_like(x, -1) + # x[:, 0] = 777 + # newx = torch.where(newx != -1, 777, newx) + + first_neg1_ind = (x == -1).float().argmax(dim=-1) + all_invalid_mask = (x == -1).all(dim=-1) + all_valid_mask = (x != -1).all(dim=-1) + newx[torch.arange(B), first_neg1_ind] = self.eos_token + newx = torch.where(all_valid_mask[:, None], x, newx) + + start_token = torch.full((B, 1), self.sos_token, dtype=torch.long, device=x.device) + end_token = torch.full((B, 1), -1, dtype=torch.long, device=x.device) + newx = torch.cat([start_token, newx, end_token], dim=1) + newx[all_valid_mask, -1] = self.eos_token + + key_padding_valid_mask = newx != -1 + key_padding_valid_mask[all_invalid_mask] = False + newx[~key_padding_valid_mask] = self.pad_token + + # In padding_mask, True means the token will be ignored. + key_padding_mask = ~key_padding_valid_mask + + # assert (x[:, 0] == 1025).all() + assert (newx[newx[:, -1] != self.pad_token][:, -1] == self.eos_token).all() + + return newx, key_padding_mask + + def forward(self, x, z, key_padding_mask=None, prepare_for_training=True): + """ + x: LongTensor of shape [B, seq_len]. May contain -1 (which will be replaced by pad_token). + z: FloatTensor of shape [B, conditioning_dim]. + Returns logits of shape [B, seq_len, vocab_size]. + """ + assert x.dim() == 2, "Input tensor must have shape [B, seq_len]" + assert z.dim() == 2, "Conditioning tensor must have shape [B, conditioning_dim]" + + info = {} + if prepare_for_training: + x, key_padding_mask = self.prepare_for_training(x) + info["fast_input_token"] = x + info["fast_sos_token"] = self.sos_token + info["fast_eos_token"] = self.eos_token + info["fast_pad_token"] = self.pad_token + + # Compute token embeddings. + emb = self.token_embedding(x) # [B, seq_len, hidden_size] + seq_len = emb.size(1) + + # Add positional embeddings. + emb = emb + self.pos_embedding[:, :seq_len, :] + + # Transpose to [seq_len, B, hidden_size] for the PyTorch attention module. + h = emb.transpose(0, 1) + + # Optionally, one may create an attention mask to prevent attending to pad positions. + # key_padding_mask is assumed to be provided (or could be computed here based on x == pad_token) + attn_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(h.device) + attn_mask = attn_mask < 0 + for layer in self.layers: + h = layer(h, z, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + + # Final normalization. + h = self.ln_final(h) + # Transpose back to [B, seq_len, hidden_size] + h = h.transpose(0, 1) + # Compute logits. + logits = self.output_layer(h) + + + # # Compute the probability for GT tokens + # # TODO: remove this + # if x.shape[1] > 1: + # tmp = torch.where(x[:, 1:] == -1, self.pad_token, x[:, 1:]) + # log_prob = utils.masked_average( + # torch.distributions.Categorical(logits=logits[:, :-1]).log_prob(tmp),~key_padding_mask[:, :-1], dim=1).mean() + # print("log_prob:", log_prob.item()) + logits[(key_padding_mask == True).all(-1)] = 0 + return logits, info + + @torch.no_grad() + def generate(self, z, greedy=False): + """ + Autoregressively generate a sequence conditioned on latent vector z. + + z: FloatTensor of shape [B, conditioning_dim] + max_length: Maximum length to generate (including and tokens). + greedy: If True, use argmax sampling; otherwise, sample from the distribution. + + Returns: + generated: LongTensor of shape [B, generated_seq_len] (including the starting ). + """ + assert z.ndim == 2 + + B = z.size(0) + max_length = self.max_seq_len + device = z.device + # Start each sequence with the token. + generated = torch.full((B, 1), self.sos_token, dtype=torch.long, device=device) + + # key_padding_mask = torch.zeros(B, 1, dtype=torch.float32, device=device) + key_padding_valid_mask_bool = torch.ones(B, 1, dtype=torch.bool, device=device) + key_padding_mask = (~key_padding_valid_mask_bool).clone() + + for step in range(max_length - 1): # already have one token + # Compute logits for the current sequence. + logits, _ = self.forward( + generated, z, prepare_for_training=False, key_padding_mask=key_padding_mask + ) # [B, seq_len, vocab_size] + # Focus on the last time step. + last_logits = logits[:, -1, :] # [B, vocab_size] + + # Can't select sos token.. + last_logits[:, self.sos_token] = -float("inf") + + if step == 0: + # Don't allow you to select the end token... + last_logits[:, self.eos_token] = -float("inf") + + if greedy: + next_token = last_logits.argmax(dim=-1, keepdim=True) # [B, 1] + else: + probs = F.softmax(last_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + # Mask out the padding tokens + next_token[~key_padding_valid_mask_bool] = -1 + + # Append the predicted token. + generated = torch.cat([generated, next_token], dim=1) + + key_padding_valid_mask_bool = (next_token != self.pad_token) & key_padding_valid_mask_bool & ( + next_token != self.eos_token + ) & (next_token != -1) + + key_padding_mask = torch.cat([key_padding_mask, ~key_padding_valid_mask_bool], dim=1) + + # Check if all sequences have produced an token. + if not key_padding_valid_mask_bool.any(): + break + + out = generated[:, 1:] + assert (out != self.sos_token).all(), "Generated sequence should not contain start token" + return out + + +class MotionDecoderGPT(MotionDecoderGPTBase): + def __init__(self, config): + nn.Module.__init__(self) + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + + # self.num_actions = get_action_dim(self.config) + + dropout = self.config.MODEL.DROPOUT + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + assert self.config.MODEL.NAME in ['gpt'] + self.add_pe_for_token = self.config.MODEL.get('ADD_PE_FOR_TOKEN', False) + assert self.add_pe_for_token + + # TODO: Implement this + use_adaln = False + self.use_adaln = use_adaln + + simple_relation = self.config.SIMPLE_RELATION + simple_relation_factor = self.config.SIMPLE_RELATION_FACTOR + is_v7 = self.config.MODEL.IS_V7 + self.is_v7 = is_v7 + self.decoder = MultiCrossAttTransformerDecoder( + decoder_layer=MultiCrossAttTransformerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dropout=dropout, + use_adaln=use_adaln, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + is_v7=is_v7, + update_relation=self.config.UPDATE_RELATION, + add_relation_to_v=self.config.MODEL.ADD_RELATION_TO_V, + remove_rel_norm=self.config.REMOVE_REL_NORM + ), + num_layers=num_decoder_layers, + d_model=d_model, + ) + # self.prediction_head = common_layers.build_mlps( + # c_in=d_model, mlp_channels=[d_model, self.num_actions], ret_before_act=True, is_v7=is_v7, zero_init=is_v7 + # ) + # if self.use_adaln: + # self.prediction_adaln_norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=1e-6) + # self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(d_model, 2 * d_model, bias=True)) + # else: + self.prediction_prenorm = nn.LayerNorm(d_model) + + relation_d_model = d_model // simple_relation_factor + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2t = fourier_embedding.FourierEmbedding( + input_dim=12, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64, is_v7=is_v7 + ) + + self.type_embed = common_layers.Tokenizer( + num_actions=constants.NUM_TYPES, d_model=d_model, add_one_more_action=False + ) + # self.action_embed = common_layers.Tokenizer( + # num_actions=self.num_actions, d_model=d_model, add_one_more_action=True + # ) + self.shape_embed = common_layers.build_mlps( + c_in=3, mlp_channels=[d_model, d_model], ret_before_act=True, is_v7=is_v7 + ) + + # if self.config.REMOVE_AGENT_FROM_SCENE_ENCODER: + self.agent_id_embed = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_AGENTS, d_model=self.d_model, add_one_more_action=False + ) + + self.motion_embed = fourier_embedding.FourierEmbedding( + input_dim=2, hidden_dim=d_model, num_freq_bands=64, is_v7=is_v7 + ) + + tokenizer = get_tokenizer(self.config) + motion_features = tokenizer.get_motion_feature() + if tokenizer.use_type_specific_bins: + motion_features = torch.cat([motion_features, torch.zeros(1, 3, 4)], dim=0) + else: + motion_features = torch.cat([motion_features, torch.zeros(1, 4)], dim=0) + self.tokenizer = tokenizer + self.register_buffer("motion_features", motion_features) + + self.special_token_embed = common_layers.Tokenizer( + num_actions=4, d_model=self.d_model, add_one_more_action=False + ) + self.prediction_head = TransformerPredictionHead( + vocab_size=self.tokenizer.fast_tokenizer.vocab_size + 3, + hidden_size=self.d_model, + num_heads=4, + num_layers=3, + conditioning_dim=self.d_model, + max_seq_len=20, + pad_token=self.tokenizer.fast_tokenizer.vocab_size, + sos_token=self.tokenizer.fast_tokenizer.vocab_size + 1, + eos_token=self.tokenizer.fast_tokenizer.vocab_size + 2, + dropout=0.0 + ) + + def forward(self, input_dict, use_cache=False, a2a_knn=None, a2t_knn=None, a2s_knn=None): + in_evaluation = input_dict["in_evaluation"][0].item() + + # num_heads = self.num_heads + # === Process scene embedding === + scene_token = input_dict["encoder/scenario_token"] + scenario_valid_mask = input_dict["encoder/scenario_valid_mask"] + + # === Process action embedding === + input_action = input_dict["decoder/input_action"] + modeled_agent_delta = input_dict["decoder/modeled_agent_delta"] + B, T_skipped, N = input_action.shape[:3] + + if in_evaluation: + assert "decoder/randomized_modeled_agent_id" in input_dict, "Need to provide randomized modeled agent id for evaluation! Please call randomize_modeled_agent_id()" + new_modeled_agent_id = input_dict["decoder/randomized_modeled_agent_id"] + else: + new_modeled_agent_id = self.randomize_modeled_agent_id(input_dict, clip_agent_id=False) + modeled_agent_pe = self.agent_id_embed(new_modeled_agent_id) + + assert modeled_agent_pe.shape == (B, N, self.d_model), (B, N, self.d_model, modeled_agent_pe.shape) + modeled_agent_pe = modeled_agent_pe[:, None].expand(B, T_skipped, N, self.d_model) + + action_valid_mask = input_dict["decoder/input_action_valid_mask"] + assert action_valid_mask.shape == (B, T_skipped, N), (action_valid_mask.shape, (B, T_skipped, N)) + agent_pos = input_dict["decoder/modeled_agent_position"] + agent_heading = input_dict["decoder/modeled_agent_heading"] + + # ===== Prepare input tokens ===== + if "decoder/input_step" not in input_dict: + input_dict["decoder/input_step"] = torch.arange(T_skipped).to(input_action.device) + agent_step = input_dict["decoder/input_step"].reshape(1, T_skipped, 1).expand(B, T_skipped, N) + + # Shape embedding and type embedding + type_emb = self.type_embed(input_dict["decoder/agent_type"])[:, None].expand(B, T_skipped, N, self.d_model) + shape_emb = self.shape_embed(input_dict["decoder/current_agent_shape"] + )[:, None].expand(B, T_skipped, N, self.d_model) + + valid_actions = input_action[action_valid_mask] + + is_start_actions = valid_actions[..., 0] == MOTION_START_ACTION + + special_tok = torch.full([ + valid_actions.shape[0], + ], 0, device=valid_actions.device, dtype=torch.long) + special_tok[is_start_actions] = 1 + if self.config.BACKWARD_PREDICTION: + is_end_actions = valid_actions == MOTION_END_ACTION + special_tok[is_end_actions] = 2 + valid_actions[is_end_actions] = -1 + special_tok_emb = self.special_token_embed(special_tok) + # TODO: Can add more special tokens in future + + valid_actions[is_start_actions] = -1 + action_emb = self.prediction_head.get_action_embedding(valid_actions) + + motion_feat = torch.cat([modeled_agent_delta[action_valid_mask]], dim=-1) + + action_token = self.motion_embed( + continuous_inputs=motion_feat, + categorical_embs=[ + special_tok_emb, modeled_agent_pe[action_valid_mask], type_emb[action_valid_mask], + shape_emb[action_valid_mask], action_emb + ] + ) + action_token = utils.unwrap(action_token, action_valid_mask) + assert action_token.shape == (B, T_skipped, N, self.d_model) + assert action_valid_mask.shape == (B, T_skipped, N) + + # ===== Get agent-temporal relation ===== + # BTND -> BNTD + agent_pos_bntd = torch.permute(agent_pos, [0, 2, 1, 3]) + agent_heading_bnt = torch.permute(agent_heading, [0, 2, 1]) + agent_mask_bnt = torch.permute(action_valid_mask, [0, 2, 1]) + agent_step_bnt = torch.permute(agent_step, [0, 2, 1]) + # agent_vel_bnt = torch.permute(agent_vel, [0, 2, 1, 3]) + if use_cache: + self.update_cache(input_dict) + + agent_pos_with_history = input_dict["decoder/modeled_agent_position_history"] + agent_heading_with_history = input_dict["decoder/modeled_agent_heading_history"] + agent_mask_with_history = input_dict["decoder/modeled_agent_valid_mask_history"] + agent_step_with_history = input_dict["decoder/modeled_agent_step_history"] + # agent_vel_with_history = input_dict["decoder/modeled_agent_velocity_history"] + real_T = agent_mask_with_history.shape[1] + key_pos = torch.permute(agent_pos_with_history, [0, 2, 1, 3]).flatten(0, 1) + # key_vel = torch.permute(agent_vel_with_history, [0, 2, 1, 3]).flatten(0, 1) + key_heading = torch.permute(agent_heading_with_history, [0, 2, 1]).flatten(0, 1) + key_mask = torch.permute(agent_mask_with_history, [0, 2, 1]).flatten(0, 1) + causal_valid_mask = None + key_step = agent_step_with_history.reshape(1, 1, -1).expand(B, N, -1).flatten(0, 1) + else: + real_T = T_skipped + # key_vel = agent_vel_bnt.flatten(0, 1) + key_pos = agent_pos_bntd.flatten(0, 1) + key_heading = agent_heading_bnt.flatten(0, 1) + key_mask = agent_mask_bnt.flatten(0, 1) + key_step = agent_step_bnt.flatten(0, 1) + causal_valid_mask = create_causal_mask(T=real_T, N=1, is_valid_mask=True).to(action_token.device) + + assert agent_pos_bntd.shape == (B, N, T_skipped, 2) + + a2t_kwargs = {} + agent_shape_no_time = input_dict["decoder/current_agent_shape"] # .reshape(B, 1, N, 3).expand(B, real_T, N, 3) + agent_length = agent_shape_no_time[..., 0] + agent_width = agent_shape_no_time[..., 1] + a2t_kwargs = dict( + include_contour=True, + query_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + query_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, T_skipped), + key_width=agent_width.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + key_length=agent_length.flatten(0, 1).unsqueeze(1).expand(-1, real_T), + non_agent_relation=False, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + + relation_func = relation.compute_relation_simple_relation + a2t_rel_feat, a2t_mask, _ = relation_func( + query_pos=agent_pos_bntd.flatten(0, 1), # BN, T, D + query_heading=agent_heading_bnt.flatten(0, 1), + query_valid_mask=agent_mask_bnt.flatten(0, 1), + query_step=agent_step_bnt.flatten(0, 1), + key_pos=key_pos, # BN, T_full, D + key_heading=key_heading, + key_valid_mask=key_mask, + key_step=key_step, + hidden_dim=self.d_model, + causal_valid_mask=causal_valid_mask, + knn=None, + max_distance=None, + return_pe=False, + # key_vel=key_vel, + # query_vel=agent_vel_bnt.flatten(0, 1), + **a2t_kwargs + ) + a2t_info = get_edge_info_new( + q_k_valid_mask=a2t_mask, + q_k_relation=a2t_rel_feat, + relation_model=self.relation_embed_a2t, + relation_model_v=self.relation_embed_a2t_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + # ===== Get agent-agent relation ===== + a2a_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(0, 1) + a2a_kwargs = dict( + include_contour=True, + query_width=w, + query_length=l, + key_width=w, + key_length=l, + non_agent_relation=False, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + a2a_rel_feat, a2a_mask, _ = relation_func( + query_pos=agent_pos.flatten(0, 1), # BT, N, D + query_heading=agent_heading.flatten(0, 1), + query_valid_mask=action_valid_mask.flatten(0, 1), + query_step=agent_step.flatten(0, 1), + key_pos=agent_pos.flatten(0, 1), + key_heading=agent_heading.flatten(0, 1), + key_valid_mask=action_valid_mask.flatten(0, 1), + key_step=agent_step.flatten(0, 1), + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2a_knn if a2a_knn is not None else self.config.MODEL.A2A_KNN, + max_distance=self.config.MODEL.A2A_DISTANCE, + return_pe=False, + # query_vel=agent_vel.flatten(0, 1), + # key_vel=agent_vel.flatten(0, 1), + **a2a_kwargs + ) + a2a_info = get_edge_info_new( + q_k_valid_mask=a2a_mask, + q_k_relation=a2a_rel_feat, + relation_model=self.relation_embed_a2a, + relation_model_v=self.relation_embed_a2a_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + # ===== Get agent-scene relation ===== + a2s_kwargs = {} + if self.config.ADD_CONTOUR_RELATION: + w = agent_width.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + l = agent_length.unsqueeze(1).expand(B, T_skipped, N).flatten(1, 2) + kw = torch.zeros_like(input_dict["encoder/scenario_position"][..., 0]) + a2s_kwargs = dict( + include_contour=True, + query_width=w, + query_length=l, + key_width=kw, + key_length=kw, + non_agent_relation=True, + per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION + ) + a2s_rel_feat, a2s_mask, a2s_indices = relation_func( + query_pos=agent_pos.flatten(1, 2), # B, TN, D + query_heading=agent_heading.flatten(1, 2), + query_valid_mask=action_valid_mask.flatten(1, 2), + query_step=agent_step.flatten(1, 2), + key_pos=input_dict["encoder/scenario_position"], # [..., :2], + key_heading=input_dict["encoder/scenario_heading"], + key_valid_mask=scenario_valid_mask, + key_step=agent_pos.new_zeros(B, input_dict["encoder/scenario_position"].shape[1]), + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=a2s_knn if a2s_knn is not None else self.config.MODEL.A2S_KNN, + max_distance=self.config.MODEL.A2S_DISTANCE, + gather=False, + return_pe=False, + **a2s_kwargs + ) + a2s_info = get_edge_info_new( + q_k_valid_mask=a2s_mask, + q_k_relation=a2s_rel_feat, + relation_model=self.relation_embed_a2s, + relation_model_v=self.relation_embed_a2s_v if self.config.MODEL.ADD_RELATION_TO_V else None + ) + + # === Call models === + past_key_value_list = None + if use_cache: + # Cache from last rollout + if "decoder/cache" in input_dict: + past_key_value_list = input_dict["decoder/cache"] + + decoded_tokens = self.decoder( + agent_token=action_token, + scene_token=scene_token, + a2a_info=a2a_info, + a2t_info=a2t_info, + a2s_info=a2s_info, + condition_token=None, # TODO: Add condition token + use_cache=use_cache, # We don't need decoder to take care cache. + past_key_value_list=past_key_value_list + ) + + if use_cache: + decoded_tokens, past_key_value_list = decoded_tokens + for l in past_key_value_list: + if l: + l.append((B * N, real_T)) + input_dict["decoder/cache"] = past_key_value_list + + output_tokens = self.prediction_prenorm(decoded_tokens[action_valid_mask]) + + if in_evaluation: + pred_out = self.prediction_head.generate(z=output_tokens) + + pred_out = utils.unwrap(pred_out, action_valid_mask, fill=-1) + + input_dict["decoder/output_token"] = pred_out + input_dict["decoder/output_logit"] = None + + else: + + target_actions = input_dict["decoder/target_action"] + masked_target_actions = target_actions[action_valid_mask] + + valid_target_mask = (masked_target_actions != -1).any(dim=-1) + + pred_out, fast_info = self.prediction_head(masked_target_actions[valid_target_mask], z=output_tokens[valid_target_mask]) + + pred_out_new = pred_out.new_zeros(valid_target_mask.shape[0], pred_out.shape[1], pred_out.shape[2]) + pred_out_new[valid_target_mask] = pred_out + pred_out = pred_out_new + + fast_tok = fast_info["fast_input_token"].new_zeros(valid_target_mask.shape[0], pred_out.shape[1]) + fast_tok[valid_target_mask] = fast_info["fast_input_token"] + fast_info["fast_input_token"] = fast_tok + + input_dict.update(fast_info) + + logits = utils.unwrap(pred_out.flatten(1, 2), + action_valid_mask).reshape(B, T_skipped, N, -1, pred_out.shape[-1]) + + assert logits.shape == (B, T_skipped, N, pred_out.shape[1], self.prediction_head.num_actions), ( + logits.shape, (B, T_skipped, N, pred_out.shape[1], self.prediction_head.num_actions) + ) + input_dict["decoder/output_logit"] = logits + + return input_dict diff --git a/scenestreamer/models/motionlm.py b/scenestreamer/models/motionlm.py new file mode 100644 index 0000000000000000000000000000000000000000..775247e4ed0f2838dbf4d83e0d0d06ceb41f3baf --- /dev/null +++ b/scenestreamer/models/motionlm.py @@ -0,0 +1,1519 @@ +import copy +import logging + +import torch +import torch.nn as nn + +from scenestreamer.models.gpt_scene_encoder import SceneEncoderGPT +from scenestreamer.models.layers import common_layers +from scenestreamer.models.motion_decoder import MotionDecoder +from scenestreamer.models.motion_decoder_gpt import MotionDecoderGPT +from scenestreamer.models.motion_decoder_gpt_diffusion import MotionDecoderGPTDiffusion +from scenestreamer.models.scene_encoder import SceneEncoder +from scenestreamer.models.trafficgen_decoder import TrafficGenDecoder +from scenestreamer.tokenization import get_tokenizer, SPECIAL_VALID, SPECIAL_START, END_ACTION, START_ACTION +from scenestreamer.utils import calculate_trajectory_probabilities, utils + +logger = logging.getLogger(__file__) + + +def get_relative_velocity(vel, heading): + return utils.rotate(vel[..., 0], vel[..., 1], angle=-heading) + + +def _reconstruct_delta_pos_from_abs_vel(vel, heading, dt): + vel = utils.rotate(vel[..., 0], vel[..., 1], angle=-heading) + pos = vel * dt + return pos + + +def nucleus_sampling(logits, p=None, epsilon=1e-8): + p = p or 0.9 + + # logits = logits.clamp(-20, 20) + + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits).fill_(-1e9), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits).fill_(-1e9), logits) + + # Adding a small epsilon to logits to avoid log(0) + # logits = logits + epsilon + + # Convert logits to probabilities + probs = torch.softmax(logits, dim=-1) + + # Sort the probabilities to identify the top-p cutoff + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative probability above the threshold p + cutoff_index = cumulative_probs > p + # Shift the mask to the right to keep the first token above the threshold + cutoff_index[..., 1:] = cutoff_index[..., :-1].clone() + cutoff_index[..., 0] = False + + # Zero out the probabilities for tokens not in the top-p set + sorted_probs.masked_fill_(cutoff_index, 0) + + # Recover the original order of the probabilities + original_probs = torch.zeros_like(probs) + original_probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) + + # original_probs += epsilon + + # Sample from the adjusted probability distribution + # try: + sampled_token_index = torch.distributions.Categorical(probs=original_probs).sample() + # except ValueError: + # import ipdb; ipdb.set_trace() + # print(1111111) + + return sampled_token_index, {"cutoff_index": cutoff_index} + + +def sample_action(logits, sampling_method, temperature, topp): + # Sample the action + info = {} + if sampling_method == "argmax": + selected_action = logits.argmax(-1) + elif sampling_method == "softmax": + selected_action = torch.distributions.Categorical(logits=logits / temperature).sample() + elif sampling_method == "topp": + selected_action, info = nucleus_sampling(logits=logits / temperature, p=topp) + elif sampling_method == "topk": + candidates = logits.topk(5, dim=-1).indices + selected_action = torch.gather( + candidates, index=torch.randint(0, 5, size=candidates.shape[:-1])[..., None].to(candidates), dim=-1 + ).squeeze(-1) + else: + raise ValueError("Unknown sampling method: {}".format(sampling_method)) + return selected_action, info + + +class MotionLM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.tokenizer = get_tokenizer(config=self.config) + + if self.config.MODEL.NAME == "motionlm": + self.scene_encoder = SceneEncoder(config=self.config) + self.motion_decoder = MotionDecoder(config=self.config) + elif self.config.MODEL.NAME == "gpt": + self.scene_encoder = SceneEncoderGPT(config=self.config) + + if self.config.USE_TRAFFICGEN: + self.trafficgen_decoder = TrafficGenDecoder(config=self.config) + + if self.config.USE_MOTION: + # TODO: For simplicity, remove motion for now if we want to train TG. + if self.config.USE_DIFFUSION: + self.motion_decoder = MotionDecoderGPTDiffusion(config=self.config) + else: + + if self.config.TOKENIZATION.TOKENIZATION_METHOD == "fast": + from scenestreamer.models.motion_decoder_gpt_fast import MotionDecoderGPT as MotionDecoderGPTFast + self.motion_decoder = MotionDecoderGPTFast(config=self.config) + else: + self.motion_decoder = MotionDecoderGPT(config=self.config) + + assert (self.config.USE_TRAFFICGEN or self.config.USE_MOTION) + + else: + raise ValueError(f"Unknown model name: {self.config.MODEL.NAME}") + + if self.config.RECONSTRUCT_MAP: + d_model = self.scene_encoder.d_model + map_feat_dim = self.config.PREPROCESSING.MAX_VECTORS + self.map_recon_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, map_feat_dim * 2], ret_before_act=True + ) + self.map_recon_head_prenorm = nn.LayerNorm(d_model) + + def encode_scene(self, input_dict): + return self.scene_encoder(input_dict) + + def decode_motion(self, *args, **kwargs): + input_dict = self.motion_decoder(*args, **kwargs) + return input_dict + + # def decode_trafficgen(self, *args, **kwargs): + # input_dict = self.trafficgen_decoder(*args, **kwargs) + # return input_dict + # + # def decode_trafficgen_offset(self, *args, **kwargs): + # input_dict = self.trafficgen_decoder.forward_offset(*args, **kwargs) + # return input_dict + + def forward(self, input_dict): + input_dict = self.encode_scene(input_dict) + + if self.config.USE_MOTION: + input_dict = self.decode_motion(input_dict) + + return input_dict + + def autoregressive_rollout( + self, + data_dict, + # num_decode_steps, + num_decode_steps=None, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None, + autoregressive_start_step=2, + **kwargs + ): + + assert self.training is False, "This function is only for evaluation!" + + if "backward_prediction" in kwargs and kwargs["backward_prediction"]: + return self.autoregressive_rollout_backward_prediction( + data_dict=data_dict, + num_decode_steps=num_decode_steps, + use_cache=use_cache, + sampling_method=sampling_method, + temperature=temperature, + topp=topp, + num_modes_for_eval=num_modes_for_eval, + flip_heading_accordingly=kwargs.get("flip_heading_accordingly", True), + ) + + if self.config.USE_DIFFUSION: + return self.autoregressive_rollout_diffusion( + data_dict, + num_decode_steps=num_decode_steps, + use_cache=use_cache, + sampling_method=sampling_method, + temperature=temperature, + topp=topp, + num_modes_for_eval=num_modes_for_eval, + **kwargs + ) + + raw_data = data_dict + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + tokenizer = self.tokenizer + + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + B, T_input, N = data_dict["decoder/input_action"].shape + + if self.config.GPT_STYLE: + start_action_step = 0 + # assert T_input == 19 # Might not be True in waymo test set. + else: + start_action_step = 2 + assert T_input == 17 + + if num_decode_steps is None: + num_decode_steps = 19 + # assert start_action_step + T_input == num_decode_steps # Might not be True in waymo test set. + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"] #.clone() + agent_heading = data_dict["decoder/agent_heading"] #.clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] #.clone() + agent_velocity = data_dict["decoder/agent_velocity"] #.clone() + agent_shape = data_dict["decoder/current_agent_shape"] #.clone() + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + # assert T_full == 91 # Might not be True in waymo test set. + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::tokenizer.num_skipped_steps] + agent_heading = agent_heading[:, ::tokenizer.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::tokenizer.num_skipped_steps] + agent_velocity = agent_velocity[:, ::tokenizer.num_skipped_steps] + gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + # T_chunks = agent_pos.shape[1] + # assert T_chunks == 19 # Might not be True in waymo test set. + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + current_pos = data_dict["decoder/modeled_agent_position"][:, :1].clone() + current_heading = data_dict["decoder/modeled_agent_heading"][:, :1].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"][:, :1].clone() + current_valid_mask = data_dict["decoder/input_action_valid_mask"][:, :1].clone() + current_delta = data_dict["decoder/modeled_agent_delta"][:, :1].clone() + current_model_step = torch.arange(1).to(current_pos.device) # it's 0 + gt_input_action = data_dict["decoder/input_action"].clone() + if autoregressive_start_step > 0: + gt_target_action = data_dict["decoder/target_action"].clone() + current_input_action = gt_input_action[:, :1].clone() + + output_logit_list = [] + output_action_list = [] + input_action_valid_mask_list = [] + assert use_cache + + pos = [] + head = [] + vel = [] + + # Select correct bins: + agent_type = data_dict["decoder/agent_type"] + bin_centers = tokenizer.get_bin_centers(agent_type) + + if "encoder/scenario_token" not in data_dict: + data_dict = self.encode_scene(data_dict) + + data_dict["decoder/randomized_modeled_agent_id"] = self.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=True + ) + + detokenization_state = None + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + if decode_step < start_action_step: + # For non-gpt model, skip first 2 steps. + pos.append(agent_pos[:, decode_step:decode_step + 1, ..., :2]) + head.append(agent_heading[:, decode_step:decode_step + 1]) + vel.append(agent_velocity[:, decode_step:decode_step + 1]) + continue + + if decode_step == autoregressive_start_step: + assert ( + current_valid_mask == agent_valid_mask[:, autoregressive_start_step:autoregressive_start_step + 1] + ).all() + assert (current_valid_mask == data_dict["decoder/current_agent_valid_mask"][:, None]).all() + + # ===== Fill a lot of stuff ===== + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + input_action_valid_mask_list.append(current_valid_mask.clone()) + + assert not (current_input_action == END_ACTION).any() + + use_mcts = self.config.MCTS.USE_MCTS + if use_mcts: + from scenestreamer.mcts import mcts_search + selected_action, mcts_info = mcts_search( + self, + data_dict, + self.config, + start_steps=decode_step, + num_search_steps=self.config.MCTS.MCTS_DEPTH, # D + num_search_width=self.config.MCTS.MCTS_WIDTH, # W + bin_centers=bin_centers, + ) + + # Decode motion tokens + data_dict = self.decode_motion(data_dict, use_cache=use_cache) + + if "decoder/modeled_agent_position_history" in data_dict: + assert data_dict["decoder/modeled_agent_position_history"].shape[ + 1] == decode_step + 1 - start_action_step + + output_token = data_dict["decoder/output_logit"] + if use_cache: + assert output_token.shape[:3] == (B, 1, N) + else: + assert output_token.shape[:3] == (B, decode_step + 1, N) + output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + + if use_mcts: + pass + else: + selected_action, sampling_info = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + # avg_left_rate = utils.masked_average((~sampling_info["cutoff_index"]).float().mean(-1), current_valid_mask, dim=-1).mean() + # avg_left_num = avg_left_rate*output_token.shape[-1] + + # print("With TOPP {:.2f}, TEMPERATURE {:.2f}, AVG_LEFT_RATE {:.2f}, AVG_LEFT_NUM {:.2f}".format( + # topp, temperature, avg_left_rate, avg_left_num + # )) + + if decode_step < autoregressive_start_step: + # Overwrite the action by GT action + selected_action = gt_target_action[:, decode_step:decode_step + 1] + + # if self.config.MODEL.RELATIVE_PE_DECODER: + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=agent_shape, + bin_centers=bin_centers, + dt=tokenizer.dt, + flip_wrong_heading=self.config.TOKENIZATION.FLIP_WRONG_HEADING, + detokenization_state=detokenization_state + ) + detokenization_state = res + recon_next_pos, recon_next_heading, recon_next_vel, relative_delta_pos = res["pos"], res["heading"], res[ + "vel"], res["delta_pos"] + + # break + + # Just for fun, detect collision: + # + # + # from shapely.geometry import Polygon + # + # contours = utils.cal_polygon_contour_torch( + # x=recon_next_pos[..., 0], + # y=recon_next_pos[..., 1], + # theta=recon_next_heading, + # width=data_dict["decoder/current_agent_shape"][..., 1], + # length=data_dict["decoder/current_agent_shape"][..., 0], + # ) + # def detect_collision(contour_list, mask): + # collision_detected = torch.zeros_like(mask, dtype=torch.bool) + # for i in range(len(contour_list)): + # for j in range(i + 1, len(contour_list)): + # if mask[i] and mask[j]: + # poly1 = Polygon(contour_list[i].cpu().numpy()) + # poly2 = Polygon(contour_list[j].cpu().numpy()) + # if poly1.intersects(poly2): + # collision_detected[i] = True + # collision_detected[j] = True + # return collision_detected + # + # collision_detected = torch.stack( + # [detect_collision(contour_list=contours[b], mask=current_valid_mask.squeeze(1)[b]) for b in range(B)], + # dim=0 + # ) + # print("Iter:", iteration) + # if collision_detected.any(): + # print("Collision detected!") + # B, T, N, D_actions = output_token.shape + # + # # Create a one hot where selected action is 1 + # should_mask = torch.nn.functional.one_hot(selected_action, num_classes=output_token.shape[-1]).bool() + # should_mask = should_mask & collision_detected.reshape(B, 1, N, 1).bool() + # output_token = torch.where(should_mask, float("-inf") * torch.ones_like(output_token), output_token) + # continue + # else: + # break + + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + current_model_step = torch.full_like(current_model_step, decode_step + 1 - start_action_step) + current_input_action = selected_action + + # Overwrite the data FOR NEXT STEP by the GT data: + if decode_step < autoregressive_start_step: + newly_added = agent_valid_mask[:, decode_step + 1:decode_step + 2] & (~current_valid_mask) + if newly_added.any(): + current_pos[newly_added] = agent_pos[:, decode_step + 1:decode_step + 2, ..., :2][newly_added] + current_heading[newly_added] = agent_heading[:, decode_step + 1:decode_step + 2][newly_added] + current_vel[newly_added] = agent_velocity[:, decode_step + 1:decode_step + 2][newly_added] + current_valid_mask[newly_added] = agent_valid_mask[:, decode_step + 1:decode_step + 2][newly_added] + current_delta[newly_added] = gt_agent_delta[:, decode_step + 1:decode_step + 2][newly_added] + + # Overwrite the input action by GT action + current_input_action = gt_input_action[:, decode_step + 1:decode_step + 2] + output_token = torch.zeros_like(output_token) + + pos.append(current_pos.clone()) + head.append(current_heading.clone()) + vel.append(current_vel.clone()) + output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps - start_action_step, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + + # FIXME + # FIXME + # FIXME What is the score? + data_dict["decoder/output_score"] = utils.calculate_trajectory_probabilities( + output_logit_list, output_action_list, mask=current_valid_mask + ) # (B, N) + + input_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + + data_dict["decoder/debug_ar_pos"] = torch.cat(pos, dim=1) + data_dict["decoder/debug_ar_head"] = torch.cat(head, dim=1) + data_dict["decoder/debug_ar_vel"] = torch.cat(vel, dim=1) + + valid_output_action = output_action_list[input_action_valid_mask] + assert valid_output_action.max() < tokenizer.num_actions + assert valid_output_action.min() >= 0 + + # ===== Debug! rewrite output action by GT ===== + # input_dict["decoder/output_action"] = input_dict["decoder/target_action"].clone() + # fill_zero = ((input_dict["decoder/output_action"] == -1) & input_dict["decoder/input_action_valid_mask"]) + # input_dict["decoder/output_action"][fill_zero] = tokenizer.default_action + + return data_dict + + def autoregressive_rollout_with_replay( + self, + data_dict, + # num_decode_steps, + num_decode_steps=None, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None, + teacher_forcing_ids=None, + **kwargs + ): + + if "backward_prediction" in kwargs and kwargs["backward_prediction"]: + raise ValueError("Not implemented yet!") + + if self.config.USE_DIFFUSION: + raise ValueError("Not implemented yet!") + + raw_data = data_dict + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + tokenizer = self.tokenizer + + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + B, T_input, N = data_dict["decoder/input_action"].shape + + if self.config.GPT_STYLE: + start_action_step = 0 + assert T_input == 19 + + else: + start_action_step = 2 + assert T_input == 17 + autoregressive_start_step = 2 + + if num_decode_steps is None: + num_decode_steps = 19 + assert start_action_step + T_input == num_decode_steps + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + assert teacher_forcing_ids is not None + sdc_id = int(data_dict['decoder/sdc_index']) + sdc_adv_ids = torch.zeros_like(data_dict["decoder/agent_id"], dtype=torch.bool) + for i in teacher_forcing_ids: + sdc_adv_ids[:, i] = True + sdc_adv_ids = sdc_adv_ids.reshape(B, 1, N) + assert B == 1, "To avoid the confusion, we only support B=1" + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"] #.clone() + agent_heading = data_dict["decoder/agent_heading"] #.clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() #.clone() + agent_velocity = data_dict["decoder/agent_velocity"] #.clone() + agent_shape = data_dict["decoder/current_agent_shape"] #.clone() + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::tokenizer.num_skipped_steps] + agent_heading = agent_heading[:, ::tokenizer.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::tokenizer.num_skipped_steps] + agent_velocity = agent_velocity[:, ::tokenizer.num_skipped_steps] + gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + current_pos = data_dict["decoder/modeled_agent_position"][:, :1].clone() + current_heading = data_dict["decoder/modeled_agent_heading"][:, :1].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"][:, :1].clone() + current_valid_mask = data_dict["decoder/input_action_valid_mask"][:, :1].clone() + current_delta = data_dict["decoder/modeled_agent_delta"][:, :1].clone() + current_model_step = torch.arange(1).to(current_pos.device) # it's 0 + gt_input_action = data_dict["decoder/input_action"].clone() + gt_target_action = data_dict["decoder/target_action"].clone() + current_input_action = gt_input_action[:, :1].clone() + + output_logit_list = [] + output_action_list = [] + input_action_valid_mask_list = [] + assert use_cache + + pos = [] + head = [] + vel = [] + + # Select correct bins: + agent_type = data_dict["decoder/agent_type"] + bin_centers = tokenizer.get_bin_centers(agent_type) + + data_dict = self.encode_scene(data_dict) + data_dict["decoder/randomized_modeled_agent_id"] = self.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=True + ) + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + # if decode_step < start_action_step: + # # For non-gpt model, skip first 2 steps. + # pos.append(agent_pos[:, decode_step:decode_step + 1, ..., :2]) + # head.append(agent_heading[:, decode_step:decode_step + 1]) + # vel.append(agent_velocity[:, decode_step:decode_step + 1]) + # continue + + if decode_step == autoregressive_start_step: + assert ( + current_valid_mask == agent_valid_mask[:, autoregressive_start_step:autoregressive_start_step + 1] + ).all() + assert (current_valid_mask == data_dict["decoder/current_agent_valid_mask"][:, None]).all() + + # ===== Fill a lot of stuff ===== + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + input_action_valid_mask_list.append(current_valid_mask.clone()) + + assert not (current_input_action == END_ACTION).any() + + # Decode motion tokens + data_dict = self.decode_motion(data_dict, use_cache=use_cache) + + if "decoder/modeled_agent_position_history" in data_dict: + assert data_dict["decoder/modeled_agent_position_history"].shape[ + 1] == decode_step + 1 - start_action_step + + output_token = data_dict["decoder/output_logit"] + if use_cache: + assert output_token.shape[:3] == (B, 1, N) + else: + assert output_token.shape[:3] == (B, decode_step + 1, N) + output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + + # success = False + # iteration = 0 + # while True: + # iteration += 1 + selected_action = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + # We only overwrite SDC and ADV in Forward. + # TODO: We might want to allow SDC to be free. + selected_action[sdc_adv_ids] = gt_input_action[:, decode_step:decode_step + 1][sdc_adv_ids] + + # if decode_step < autoregressive_start_step: + # # Overwrite the action by GT action + # selected_action = gt_target_action[:, decode_step:decode_step + 1] + + # if self.config.MODEL.RELATIVE_PE_DECODER: + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=agent_shape, + bin_centers=bin_centers, + dt=tokenizer.dt, + flip_wrong_heading=True, # TODO: Dirty workaround only used in AR with Replay! + ) + recon_next_pos, recon_next_heading, recon_next_vel, relative_delta_pos = res["pos"], res["heading"], res[ + "vel"], res["delta_pos"] + + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + current_model_step = torch.full_like(current_model_step, decode_step + 1 - start_action_step) + current_input_action = selected_action + + # Overwrite the data FOR NEXT STEP by the GT data: + if decode_step < autoregressive_start_step: + newly_added = agent_valid_mask[:, decode_step + 1:decode_step + 2] & (~current_valid_mask) + if newly_added.any(): + current_pos[newly_added] = agent_pos[:, decode_step + 1:decode_step + 2, ..., :2][newly_added] + current_heading[newly_added] = agent_heading[:, decode_step + 1:decode_step + 2][newly_added] + current_vel[newly_added] = agent_velocity[:, decode_step + 1:decode_step + 2][newly_added] + current_valid_mask[newly_added] = agent_valid_mask[:, decode_step + 1:decode_step + 2][newly_added] + current_delta[newly_added] = gt_agent_delta[:, decode_step + 1:decode_step + 2][newly_added] + + # Overwrite the input action by GT action + current_input_action = gt_input_action[:, decode_step + 1:decode_step + 2] + output_token = torch.zeros_like(output_token) + + # ===== Teacher Forcing ===== + if decode_step < T_chunks - 1: + current_pos[sdc_adv_ids] = agent_pos[:, decode_step + 1:decode_step + 2, ..., :2][sdc_adv_ids] + current_heading[sdc_adv_ids] = agent_heading[:, decode_step + 1:decode_step + 2][sdc_adv_ids] + current_vel[sdc_adv_ids] = agent_velocity[:, decode_step + 1:decode_step + 2][sdc_adv_ids] + current_valid_mask[sdc_adv_ids] = agent_valid_mask[:, decode_step + 1:decode_step + 2][sdc_adv_ids] + current_delta[sdc_adv_ids] = gt_agent_delta[:, decode_step + 1:decode_step + 2][sdc_adv_ids] + + pos.append(current_pos.clone()) + head.append(current_heading.clone()) + vel.append(current_vel.clone()) + output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps - start_action_step, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + + # FIXME + # FIXME + # FIXME What is the score? + # data_dict["decoder/output_score"] = calculate_trajectory_probabilities( + # output_logit_list, output_action_list, mask=current_valid_mask + # ) # (B, N) + + input_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + + # invalid = output_action_list == -1 + # input_action_valid_mask[invalid] = False + # invalid = output_action_list == START_ACTION + # input_action_valid_mask[invalid] = False + + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + + data_dict["decoder/debug_ar_pos"] = torch.cat(pos, dim=1) + data_dict["decoder/debug_ar_head"] = torch.cat(head, dim=1) + data_dict["decoder/debug_ar_vel"] = torch.cat(vel, dim=1) + + valid_output_action = output_action_list[input_action_valid_mask] + + assert valid_output_action.max() <= START_ACTION + assert valid_output_action.min() >= 0 + + # ===== Debug! rewrite output action by GT ===== + # input_dict["decoder/output_action"] = input_dict["decoder/target_action"].clone() + # fill_zero = ((input_dict["decoder/output_action"] == -1) & input_dict["decoder/input_action_valid_mask"]) + # input_dict["decoder/output_action"][fill_zero] = tokenizer.default_action + + return data_dict + + def autoregressive_rollout_backward_prediction( + self, + data_dict, + # num_decode_steps, + num_decode_steps=None, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + flip_heading_accordingly=True, + num_modes_for_eval=None, + **kwargs + ): + + if self.config.USE_DIFFUSION: + raise ValueError() + + raw_data = data_dict + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + tokenizer = self.tokenizer + + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + B, T_input, N = data_dict["decoder/input_action"].shape + + assert self.config.GPT_STYLE + start_action_step = 0 + assert T_input == 19 + # else: + # start_action_step = 2 + # assert T_input == 17 + # autoregressive_start_step = 2 + + if num_decode_steps is None: + num_decode_steps = 19 + assert start_action_step + T_input == num_decode_steps + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"] #.clone() + agent_heading = data_dict["decoder/agent_heading"] #.clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] #.clone() + agent_velocity = data_dict["decoder/agent_velocity"] #.clone() + agent_shape = data_dict["decoder/current_agent_shape"] #.clone() + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::tokenizer.num_skipped_steps] + agent_heading = agent_heading[:, ::tokenizer.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::tokenizer.num_skipped_steps] + agent_velocity = agent_velocity[:, ::tokenizer.num_skipped_steps] + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + current_pos = data_dict["decoder/modeled_agent_position"][:, :1].clone() + current_heading = data_dict["decoder/modeled_agent_heading"][:, :1].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"][:, :1].clone() + current_valid_mask = data_dict["decoder/input_action_valid_mask"][:, :1].clone() + current_delta = data_dict["decoder/modeled_agent_delta"][:, :1].clone() + current_model_step = torch.arange(1).to(current_pos.device) # it's 0 + gt_input_action = data_dict["decoder/input_action"].clone() + gt_target_action = data_dict["decoder/target_action"].clone() + current_input_action = gt_input_action[:, :1].clone() + + output_logit_list = [] + output_action_list = [] + input_action_valid_mask_list = [] + assert use_cache + + import numpy as np + + pos = [] + head = [] + vel = [] + + # Select correct bins: + agent_type = data_dict["decoder/agent_type"] + bin_centers = tokenizer.get_bin_centers(agent_type) + + data_dict = self.encode_scene(data_dict) + data_dict["decoder/randomized_modeled_agent_id"] = self.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=True + ) + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + # TODO: put back the following code + # if decode_step == autoregressive_start_step: + # assert ( + # current_valid_mask == agent_valid_mask[:, autoregressive_start_step:autoregressive_start_step + 1] + # ).all() + # assert (current_valid_mask == data_dict["decoder/current_agent_valid_mask"][:, None]).all() + + # ===== Fill a lot of stuff ===== + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + input_action_valid_mask_list.append(current_valid_mask.clone()) + + assert not (current_input_action == START_ACTION).any() + + # Decode motion tokens + data_dict = self.decode_motion(data_dict, use_cache=use_cache) + + if "decoder/modeled_agent_position_history" in data_dict: + assert data_dict["decoder/modeled_agent_position_history"].shape[ + 1] == decode_step + 1 - start_action_step + + output_token = data_dict["decoder/output_logit"] + if use_cache: + assert output_token.shape[:3] == (B, 1, N) + else: + assert output_token.shape[:3] == (B, decode_step + 1, N) + output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + selected_action, info = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + # if decode_step < autoregressive_start_step: + # # Overwrite the action by GT action + # selected_action = gt_target_action[:, decode_step:decode_step + 1] + + # if self.config.MODEL.RELATIVE_PE_DECODER: + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=agent_shape, + bin_centers=bin_centers, + + # dt=tokenizer.dt, + dt=-tokenizer.dt, + flip_heading_accordingly=flip_heading_accordingly, + flip_wrong_heading=True + ) + recon_next_pos, recon_next_heading, recon_next_vel, relative_delta_pos = res["pos"], res["heading"], res[ + "vel"], res["delta_pos"] + + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + # current_model_step.fill_(decode_step + 1 - start_action_step) + current_model_step = torch.full_like(current_model_step, decode_step + 1 - start_action_step) + current_input_action = selected_action + + # Overwrite the data FOR NEXT STEP by the GT data: + # if decode_step < autoregressive_start_step: + # Always adding new agents + if True: + # decode_step = 0, ..., 18 + forward_current_step = T_chunks - decode_step - 1 + # forward_current_step = 18, ..., 0 + forward_next_step = forward_current_step - 1 + # forward_next_step = 17, ..., 0 + + newly_added = agent_valid_mask[:, forward_next_step:forward_next_step + 1] & (~current_valid_mask) + if newly_added.any(): + current_pos[newly_added] = agent_pos[:, forward_next_step:forward_next_step + 1, + ..., :2][newly_added] + current_heading[newly_added] = agent_heading[:, + forward_next_step:forward_next_step + 1][newly_added] + current_vel[newly_added] = agent_velocity[:, forward_next_step:forward_next_step + 1][newly_added] + current_valid_mask[newly_added] = agent_valid_mask[:, forward_next_step:forward_next_step + + 1][newly_added] + + if self.config.DELTA_POS_IS_VELOCITY: + + current_delta[newly_added] = get_relative_velocity( + agent_velocity[:, forward_next_step:forward_next_step + 1][newly_added], + agent_heading[:, forward_next_step:forward_next_step + 1][newly_added] + ) + + else: + current_delta[newly_added] = _reconstruct_delta_pos_from_abs_vel( + vel=current_vel[newly_added], + + # heading=current_heading[newly_added], + heading=current_heading[newly_added] + np.pi, + dt=tokenizer.dt + ) + + # Overwrite the input action by GT action + current_input_action[newly_added] = gt_input_action[:, decode_step + 1:decode_step + 2][newly_added] + output_token[newly_added] = 0.0 + + pos.append(current_pos.clone()) + head.append(current_heading.clone()) + vel.append(current_vel.clone()) + output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps - start_action_step, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + + # FIXME + # FIXME + # FIXME What is the score? + data_dict["decoder/output_score"] = calculate_trajectory_probabilities( + output_logit_list, output_action_list, mask=current_valid_mask + ) # (B, N) + + input_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + + data_dict["decoder/debug_ar_pos"] = torch.cat(pos, dim=1) + data_dict["decoder/debug_ar_head"] = torch.cat(head, dim=1) + data_dict["decoder/debug_ar_vel"] = torch.cat(vel, dim=1) + + valid_output_action = output_action_list[input_action_valid_mask] + assert valid_output_action.max() < tokenizer.num_actions + assert valid_output_action.min() >= 0 + + # ===== Debug! rewrite output action by GT ===== + # input_dict["decoder/output_action"] = input_dict["decoder/target_action"].clone() + # fill_zero = ((input_dict["decoder/output_action"] == -1) & input_dict["decoder/input_action_valid_mask"]) + # input_dict["decoder/output_action"][fill_zero] = tokenizer.default_action + + return data_dict + + def autoregressive_rollout_backward_prediction_with_replay( + self, + data_dict, + # num_decode_steps, + num_decode_steps=None, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + flip_heading_accordingly=True, + num_modes_for_eval=None, + not_teacher_forcing_ids=None, + **kwargs + ): + + if self.config.USE_DIFFUSION: + raise ValueError() + + raw_data = data_dict + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + tokenizer = self.tokenizer + + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + B, T_input, N = data_dict["decoder/input_action"].shape + + assert self.config.GPT_STYLE + start_action_step = 0 + assert T_input == 19 + # else: + # start_action_step = 2 + # assert T_input == 17 + # autoregressive_start_step = 2 + + if num_decode_steps is None: + num_decode_steps = 19 + assert start_action_step + T_input == num_decode_steps + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Get initial data ===== + + agent_pos = data_dict["decoder/agent_position"][..., :2] #.clone() + agent_heading = data_dict["decoder/agent_heading"] #.clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] #.clone() + agent_velocity = data_dict["decoder/agent_velocity"] #.clone() + agent_shape = data_dict["decoder/current_agent_shape"] #.clone() + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + + agent_pos = agent_pos[:, ::tokenizer.num_skipped_steps] # (B, 19, N, 3) + agent_heading = agent_heading[:, ::tokenizer.num_skipped_steps] # (B, 19, N) + agent_valid_mask = agent_valid_mask[:, ::tokenizer.num_skipped_steps] # (B, 19, N) + agent_velocity = agent_velocity[:, ::tokenizer.num_skipped_steps] # (B, 19, N, 2) + + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + assert not_teacher_forcing_ids is not None + sdc_adv_ids = torch.zeros_like(data_dict["decoder/agent_id"], dtype=torch.bool) + for aid in not_teacher_forcing_ids: + sdc_adv_ids[:, aid] = True + sdc_adv_ids = sdc_adv_ids.reshape(B, 1, N) + assert B == 1, "To avoid the confusion, we only support B=1" + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + current_pos = data_dict["decoder/modeled_agent_position"][:, :1].clone() + current_heading = data_dict["decoder/modeled_agent_heading"][:, :1].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"][:, :1].clone() + current_valid_mask = data_dict["decoder/input_action_valid_mask"][:, :1].clone() + current_delta = data_dict["decoder/modeled_agent_delta"][:, :1].clone() + current_model_step = torch.arange(1).to(current_pos.device) # it's 0 + gt_input_action = data_dict["decoder/input_action"].clone() + gt_target_action = data_dict["decoder/target_action"].clone() + gt_target_valid_mask = data_dict["decoder/target_action_valid_mask"].clone() + current_input_action = gt_input_action[:, :1].clone() + + output_logit_list = [] + output_action_list = [] + output_action_valid_mask = [] + assert use_cache + + import numpy as np + + pos = [] + head = [] + vel = [] + + # Select correct bins: + agent_type = data_dict["decoder/agent_type"] + bin_centers = tokenizer.get_bin_centers(agent_type) + + data_dict = self.encode_scene(data_dict) + data_dict["decoder/randomized_modeled_agent_id"] = self.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=True + ) + for decode_step in range(num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + # TODO: put back the following code + # if decode_step == autoregressive_start_step: + # assert ( + # current_valid_mask == agent_valid_mask[:, autoregressive_start_step:autoregressive_start_step + 1] + # ).all() + # assert (current_valid_mask == data_dict["decoder/current_agent_valid_mask"][:, None]).all() + + # ===== Fill a lot of stuff ===== + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + + assert not (current_input_action == START_ACTION).any() + + # Decode motion tokens + data_dict = self.decode_motion(data_dict, use_cache=use_cache) + + if "decoder/modeled_agent_position_history" in data_dict: + assert data_dict["decoder/modeled_agent_position_history"].shape[ + 1] == decode_step + 1 - start_action_step + + output_token = data_dict["decoder/output_logit"] + if use_cache: + assert output_token.shape[:3] == (B, 1, N) + else: + assert output_token.shape[:3] == (B, decode_step + 1, N) + output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + selected_action = sample_action( + logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + ) + + # ===== Teacher-forcing ===== + selected_action[~sdc_adv_ids] = gt_target_action[:, decode_step:decode_step + 1][~sdc_adv_ids] + current_valid_mask[~sdc_adv_ids] = gt_target_valid_mask[:, decode_step:decode_step + 1][~sdc_adv_ids] + + # if self.config.MODEL.RELATIVE_PE_DECODER: + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=data_dict["decoder/current_agent_shape"], + bin_centers=bin_centers, + # dt=tokenizer.dt, + dt=-tokenizer.dt, + flip_heading_accordingly=flip_heading_accordingly, + flip_wrong_heading=True, # TODO: This is a dirty workaround! + ) + recon_next_pos, recon_next_heading, recon_next_vel, relative_delta_pos = res["pos"], res["heading"], res[ + "vel"], res["delta_pos"] + + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + current_model_step = torch.full_like(current_model_step, decode_step + 1 - start_action_step) + current_input_action = selected_action + + # current_model_step.fill_(decode_step + 1 - start_action_step) + # current_input_action = tf_action + + # Overwrite the data FOR NEXT STEP by the GT data: + # if decode_step < autoregressive_start_step: + # Always adding new agents + + # decode_step = 0, ..., 18 + forward_current_step = T_chunks - decode_step - 1 + # forward_current_step = 18, ..., 0 + forward_next_step = forward_current_step - 1 + # forward_next_step = 17, ..., 0 + + if forward_next_step >= 0: + # ===== Teacher-forcing ===== + overwrite_mask = ~sdc_adv_ids + current_pos[overwrite_mask] = agent_pos[:, + forward_next_step:forward_next_step + 1, :, :2][overwrite_mask] + current_heading[overwrite_mask] = agent_heading[:, + forward_next_step:forward_next_step + 1][overwrite_mask] + current_vel[overwrite_mask] = agent_velocity[:, forward_next_step:forward_next_step + 1][overwrite_mask] + + if self.config.DELTA_POS_IS_VELOCITY: + current_delta[overwrite_mask] = tokenizer.get_relative_velocity( + vel=agent_velocity[:, forward_next_step:forward_next_step + 1][overwrite_mask], + heading=agent_heading[:, forward_next_step:forward_next_step + 1][overwrite_mask], + ) + + else: + current_delta[overwrite_mask] = _reconstruct_delta_pos_from_abs_vel( + vel=current_vel[overwrite_mask], + heading=current_heading[overwrite_mask] + np.pi, + dt=tokenizer.dt + ) + current_input_action[overwrite_mask] = gt_target_action[:, decode_step:decode_step + 1][overwrite_mask] + current_valid_mask[overwrite_mask] = gt_target_valid_mask[:, + decode_step:decode_step + 1][overwrite_mask] + output_token[overwrite_mask] = 0.0 + + # The output action valid mask should before adding new agents. + output_action_valid_mask.append(current_valid_mask.clone()) + + newly_added = agent_valid_mask[:, forward_next_step:forward_next_step + 1] & (~current_valid_mask) + if newly_added.any(): + current_pos[newly_added] = agent_pos[:, forward_next_step:forward_next_step + 1, ..., :2][newly_added] + current_heading[newly_added] = agent_heading[:, forward_next_step:forward_next_step + 1][newly_added] + current_vel[newly_added] = agent_velocity[:, forward_next_step:forward_next_step + 1][newly_added] + current_valid_mask[newly_added] = agent_valid_mask[:, + forward_next_step:forward_next_step + 1][newly_added] + + if self.config.DELTA_POS_IS_VELOCITY: + + current_delta[newly_added] = get_relative_velocity( + vel=agent_velocity[:, forward_next_step:forward_next_step + 1][newly_added], + heading=agent_heading[:, forward_next_step:forward_next_step + 1][newly_added], + ) + + else: + current_delta[newly_added] = _reconstruct_delta_pos_from_abs_vel( + vel=current_vel[newly_added], + + # heading=current_heading[newly_added], + heading=current_heading[newly_added] + np.pi, + dt=tokenizer.dt + ) + + # Overwrite the input action by GT action + current_input_action[newly_added] = gt_input_action[:, decode_step + 1:decode_step + 2][newly_added] + output_token[newly_added] = 0.0 + + pos.append(current_pos.clone()) + head.append(current_heading.clone()) + vel.append(current_vel.clone()) + output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + + output_action_list = torch.concatenate(output_action_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps - start_action_step, N) + + output_logit_list = torch.concatenate(output_logit_list, dim=1) + data_dict["decoder/output_logit"] = output_logit_list + data_dict["decoder/output_action"] = output_action_list + + # data_dict["decoder/output_score"] = + + # action_valid_mask = data_dict["decoder/target_action_valid_maks"].clone() + # print(111) + output_action_valid_mask = torch.cat(output_action_valid_mask, dim=1) + data_dict["decoder/input_action_valid_mask"] = output_action_valid_mask + + data_dict["decoder/debug_ar_pos"] = torch.cat(pos, dim=1) + data_dict["decoder/debug_ar_head"] = torch.cat(head, dim=1) + data_dict["decoder/debug_ar_vel"] = torch.cat(vel, dim=1) + + # output_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + + valid_output_action = output_action_list[output_action_valid_mask] + assert valid_output_action.max() < tokenizer.num_actions, valid_output_action.max() + assert valid_output_action.min() >= 0, valid_output_action.min() + + assert valid_output_action.max() < END_ACTION, valid_output_action.max() + assert valid_output_action.min() >= 0, valid_output_action.min() + + # ===== Debug! rewrite output action by GT ===== + # input_dict["decoder/output_action"] = input_dict["decoder/target_action"].clone() + # fill_zero = ((input_dict["decoder/output_action"] == -1) & input_dict["decoder/input_action_valid_mask"]) + # input_dict["decoder/output_action"][fill_zero] = tokenizer.default_action + + return data_dict + + def autoregressive_rollout_diffusion( + self, + data_dict, + # num_decode_steps, + num_decode_steps=None, + use_cache=True, + sampling_method="softmax", + temperature=None, + topp=None, + num_modes_for_eval=None, + **kwargs + ): + + assert not ("backward_prediction" in kwargs and kwargs["backward_prediction"]) + assert self.config.USE_DIFFUSION + + # raw_data = data_dict + # To avoid those overwriting operation. + data_dict = copy.deepcopy(data_dict) + + tokenizer = self.tokenizer + + if temperature is None: + temperature = self.config.SAMPLING.TEMPERATURE + if topp is None: + topp = self.config.SAMPLING.TOPP + + B, T_input, N = data_dict["decoder/input_action"].shape + + if self.config.GPT_STYLE: + start_action_step = 0 + assert T_input == 19 + else: + start_action_step = 2 + assert T_input == 17 + autoregressive_start_step = 2 + + if num_decode_steps is None: + num_decode_steps = 19 + assert start_action_step + T_input == num_decode_steps + assert num_decode_steps == 19 + assert data_dict["decoder/input_action_valid_mask"].shape == (B, T_input, N) + else: + print("WARNING: You are freely generating future trajectory! num_decode_steps (was 19) =", num_decode_steps) + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"] #.clone() + agent_heading = data_dict["decoder/agent_heading"] #.clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] #.clone() + agent_velocity = data_dict["decoder/agent_velocity"] #.clone() + agent_shape = data_dict["decoder/current_agent_shape"] #.clone() + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::tokenizer.num_skipped_steps] + agent_heading = agent_heading[:, ::tokenizer.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::tokenizer.num_skipped_steps] + agent_velocity = agent_velocity[:, ::tokenizer.num_skipped_steps] + # gt_agent_delta = data_dict["decoder/modeled_agent_delta"].clone() + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 + + # ===== Build up some variables ===== + # Should note that the modeled_agent_* is starting from t=0 (GPT) and t=10 (non-GPT). So using 0:1 to get the + # first step for decoder is correct. + current_pos = data_dict["decoder/modeled_agent_position"][:, :1, ..., :2].clone() + current_heading = data_dict["decoder/modeled_agent_heading"][:, :1].clone() + current_vel = data_dict["decoder/modeled_agent_velocity"][:, :1].clone() + current_valid_mask = data_dict["decoder/input_action_valid_mask"][:, :1].clone() + current_delta = data_dict["decoder/modeled_agent_delta"][:, :1].clone() + current_model_step = torch.arange(1).to(current_pos.device) # it's 0 + + current_input_agent_motion = data_dict["decoder/input_agent_motion"][:, :1].clone() + + gt_input_action = data_dict["decoder/input_action"].clone() + gt_input_agent_motion = data_dict["decoder/input_agent_motion"].clone() + gt_target_agent_motion = data_dict["decoder/target_agent_motion"].clone() + + current_input_action = gt_input_action[:, :1].clone() + + output_logit_list = [] + output_action_list = [] + output_motion_list = [] + input_action_valid_mask_list = [] + assert use_cache + + pos = [current_pos.clone()] + head = [current_heading.clone()] + vel = [current_vel.clone()] + + # Select correct bins: + agent_type = data_dict["decoder/agent_type"] + # bin_centers = tokenizer.get_bin_centers(agent_type) + + data_dict = self.encode_scene(data_dict) + data_dict["decoder/randomized_modeled_agent_id"] = self.motion_decoder.randomize_modeled_agent_id( + data_dict, clip_agent_id=True + ) + for decode_step in range(0, num_decode_steps): + logger.debug(f"======================= STEP {decode_step=} =======================") + + # if decode_step < start_action_step: + # # For non-gpt model, skip first 2 steps. + # pos.append(agent_pos[:, decode_step:decode_step + 1, ..., :2]) + # head.append(agent_heading[:, decode_step:decode_step + 1]) + # vel.append(agent_velocity[:, decode_step:decode_step + 1]) + # continue + + if decode_step == autoregressive_start_step: + assert ( + current_valid_mask == agent_valid_mask[:, autoregressive_start_step:autoregressive_start_step + 1] + ).all() + assert (current_valid_mask == data_dict["decoder/current_agent_valid_mask"][:, None]).all() + if decode_step == autoregressive_start_step + 1: + current_input_action[current_input_action == SPECIAL_START] = SPECIAL_VALID + + # ===== Fill a lot of stuff ===== + # Overwrite all necessary data: + data_dict["decoder/modeled_agent_position"] = current_pos + data_dict["decoder/modeled_agent_heading"] = current_heading + data_dict["decoder/modeled_agent_velocity"] = current_vel + data_dict["decoder/modeled_agent_valid_mask"] = current_valid_mask + data_dict["decoder/modeled_agent_delta"] = current_delta + data_dict["decoder/input_step"] = current_model_step + data_dict["decoder/input_action"] = current_input_action + data_dict["decoder/input_action_valid_mask"] = current_valid_mask + data_dict["decoder/input_agent_motion"] = current_input_agent_motion + + input_action_valid_mask_list.append(current_valid_mask.clone()) + + assert not (current_input_action == END_ACTION).any() + + assert not self.config.MCTS.USE_MCTS + + # Decode motion tokens + # data_dict = self.decode_motion(data_dict, use_cache=use_cache) + # + # if "decoder/modeled_agent_position_history" in data_dict: + # assert data_dict["decoder/modeled_agent_position_history"].shape[ + # 1] == decode_step + 1 - start_action_step + + # output_token = data_dict["decoder/decoded_tokens"] + # assert use_cache + # assert output_token.shape[:3] == (B, 1, N) + # else: + # assert output_token.shape[:3] == (B, decode_step + 1, N) + # output_token = output_token[:, -1:] # -> output_token.shape == (B, 1, N, #actions) + + # selected_action = sample_action( + # logits=output_token, sampling_method=sampling_method, temperature=temperature, topp=topp + # ) + + data_dict = self.motion_decoder.sample_diffusion(data_dict, use_cache=use_cache) + selected_action = data_dict["decoder/output_action"] + if decode_step < autoregressive_start_step: + # Overwrite the action by GT action + selected_action = gt_target_agent_motion[:, decode_step:decode_step + 1] + + # FIXME: DEBUG + # FIXME: DEBUG + # FIXME: DEBUG + # FIXME: DEBUG + # FIXME: DEBUG + # DEBUG + # selected_action = gt_target_agent_motion[:, decode_step:decode_step + 1] + + res = tokenizer.detokenize_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=selected_action, + agent_shape=data_dict["decoder/current_agent_shape"], + # bin_centers=bin_centers, + dt=tokenizer.dt, + flip_wrong_heading=self.config.TOKENIZATION.FLIP_WRONG_HEADING, + agent_type=agent_type, + ) + recon_next_pos = res["pos"] + recon_next_heading = res["heading"] + recon_traj = res["reconstructed_pos"] + recon_traj_heading = res["reconstructed_heading"] + recon_next_vel = res["vel"] + relative_delta_pos = res["delta_pos"] + + current_pos = recon_next_pos.reshape(B, 1, N, 2) + current_heading = recon_next_heading.reshape(B, 1, N) + current_vel = recon_next_vel.reshape(B, 1, N, 2) + current_delta = relative_delta_pos.reshape(B, 1, N, 2) + current_model_step = torch.full_like(current_model_step, decode_step + 1 - start_action_step) + # current_input_action = selected_action + + agent_motion = data_dict["decoder/output_action"] + assert current_input_agent_motion.shape == agent_motion.shape, ( + current_input_agent_motion.shape, agent_motion.shape + ) + current_input_agent_motion = agent_motion + + # Overwrite the data FOR NEXT STEP by the GT data: + if decode_step < autoregressive_start_step: + # current_input_action[current_input_action == SPECIAL_START] = SPECIAL_VALID + newly_added = agent_valid_mask[:, decode_step + 1:decode_step + 2] & (~current_valid_mask) + if newly_added.any(): + current_pos[newly_added] = agent_pos[:, decode_step + 1:decode_step + 2, ..., :2][newly_added] + current_heading[newly_added] = agent_heading[:, decode_step + 1:decode_step + 2][newly_added] + current_vel[newly_added] = agent_velocity[:, decode_step + 1:decode_step + 2][newly_added] + current_valid_mask[newly_added] = agent_valid_mask[:, decode_step + 1:decode_step + 2][newly_added] + + current_input_action = gt_input_action[:, decode_step + 1:decode_step + 2] + current_input_agent_motion = gt_input_agent_motion[:, decode_step + 1:decode_step + 2].clone() + + pos.append(recon_traj.clone().permute(0, 1, 3, 2, 4).squeeze(1)) + head.append(recon_traj_heading.clone().permute(0, 1, 3, 2).squeeze(1)) + vel.append(current_vel.clone()) + # output_logit_list.append(output_token.clone()) + output_action_list.append(current_input_action.clone()) + output_motion_list.append(current_input_agent_motion.clone()) + + output_action_list = torch.concatenate(output_action_list, dim=1) + output_motion_list = torch.concatenate(output_motion_list, dim=1) + assert output_action_list.shape == (B, num_decode_steps - start_action_step, N) + + # output_logit_list = torch.concatenate(output_logit_list, dim=1) + # data_dict["decoder/output_logit"] = output_logit_list + # data_dict["decoder/output_action"] = output_action_list + + # data_dict["decoder/output_score"] = calculate_trajectory_probabilities( + # output_logit_list, output_action_list, mask=current_valid_mask + # ) # (B, N) + + input_action_valid_mask = torch.cat(input_action_valid_mask_list, dim=1) + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + + data_dict["decoder/reconstructed_position"] = torch.cat(pos, dim=1) + data_dict["decoder/reconstructed_heading"] = torch.cat(head, dim=1) + data_dict["decoder/reconstructed_velocity"] = torch.cat(vel, dim=1) + + valid = input_action_valid_mask.reshape(B, -1, 1, N).expand(-1, -1, self.tokenizer.num_skipped_steps, + -1).reshape(B, -1, N) + valid = torch.cat([valid, input_action_valid_mask[:, -1:]], dim=1) + data_dict["decoder/reconstructed_valid_mask"] = valid + + # valid_output_action = output_action_list[input_action_valid_mask] + # assert valid_output_action.max() < tokenizer.num_actions + # assert valid_output_action.min() >= 0 + + return data_dict diff --git a/scenestreamer/models/motionlm_lightning.py b/scenestreamer/models/motionlm_lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..7394379a646ecf429c9f45c693b6c9b9aed3e4f0 --- /dev/null +++ b/scenestreamer/models/motionlm_lightning.py @@ -0,0 +1,863 @@ +import functools +import logging + +import lightning.pytorch as pl +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf + +from scenestreamer.infer.initial_state import generate_initial_state +from scenestreamer.models.gen_model import GenModel, SceneStreamerModel +from scenestreamer.dataset.preprocessor import NUM_TG_MULTI +from scenestreamer.models.language_motionlm import LanguageMotionLM +from scenestreamer.models.motionlm import MotionLM +from scenestreamer.tokenization import get_tokenizer +from scenestreamer.tokenization.trafficgen_tokenizers import TrafficGenTokenizer +from scenestreamer.utils import lr_schedule +from scenestreamer.utils import utils +from scenestreamer.dataset.preprocessor import slice_trafficgen_data +from scenestreamer.models.scenestreamer_model import get_edge_info_for_scenestreamer + +logger = logging.getLogger(__file__) + + +def update_ema(target_params, source_params, rate=0.99): + """ + PZH: From https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/engine_mar.py#L19 + + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def safe_entropy(logits, epsilon=1e-5): + """ + Computes the entropy of the given logits safely by replacing NaN and Inf values. + :param logits: Input logits tensor. + :param epsilon: A small value to add to the logits to avoid log(0) which results in NaN. + :return: Mean entropy of the logits. + """ + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits), logits) + + # Adding a small epsilon to logits to avoid log(0) + logits = logits + epsilon + + # Compute softmax to get probabilities + probs = F.softmax(logits, dim=-1) + + # Compute entropy + entropy = -(probs * torch.log(probs)).sum(-1) + + # Return the mean entropy + return entropy.mean() + + +class MotionLMLightning(pl.LightningModule): + def __init__(self, config): + if "SEED" in config: + pl.seed_everything(config.SEED) + print("Everything is seeded to: ", config.SEED) + super().__init__() + self.config = config + + if config.MODEL.NAME in ["motionlm", "gpt"]: + self.model = MotionLM(config=self.config) + elif config.MODEL.NAME == "gen": + self.model = GenModel(config=self.config) + elif config.MODEL.NAME == "scenestreamer": + # self.model = SceneStreamerModel(config=self.config) + from scenestreamer.models.scenestreamer_model import SceneStreamer + self.model = SceneStreamer(config=self.config) + + elif config.MODEL.NAME == "language_motionlm": + self.model = LanguageMotionLM(config=self.config) + else: + raise ValueError(f"Unknown model name: {config.MODEL.NAME}") + + if config.EVALUATION.NAME in ["waymo_motion_prediction", "waymo_prediction", "womd"]: + from scenestreamer.eval.waymo_motion_prediction_evaluator import WaymoMotionPredictionEvaluator + self.evaluator = WaymoMotionPredictionEvaluator(config=config) + # if self.config.EVALUATION.PREDICT_ALL_AGENTS is False: + # assert self.config.PREPROCESSING.ADD_SDC_TO_OBJECT_OF_INTEREST is False + if self.config.SUBMISSION.GENERATE_SUBMISSION: + assert self.config.EVALUATION.PREDICT_ALL_AGENTS is False + assert self.config.PREPROCESSING.ADD_SDC_TO_OBJECT_OF_INTEREST is False + + elif config.EVALUATION.NAME in ["wosac2023", "wosac2024", "sgen"]: + + # Let's overwrite some configs here + # Note that the WOSAC eval code will take care of tracks_to_predict + assert config.EVALUATION.PREDICT_ALL_AGENTS is True + # assert config.PREPROCESSING.ADD_SDC_TO_OBJECT_OF_INTEREST is True + # assert config.EVALUATION.NUM_MODES == 32 + # config.EVALUATION.MAXIMUM_BATCH_SIZE = min(config.EVALUATION.MAXIMUM_BATCH_SIZE, 16) + assert config.DATA.SD_PASSTHROUGH + # config.DATA.SD_PASSTHROUGH = True + + from scenestreamer.eval.waymo_sim_agent_evaluator import WaymoSimAgentEvaluator + self.evaluator = WaymoSimAgentEvaluator(config=config) + elif config.EVALUATION.NAME in ["lmdb"]: + from scenestreamer.eval.lmdb_evaluator import LMDBEvaluator + self.evaluator = LMDBEvaluator(config=config) + + elif config.EVALUATION.NAME in ["peng"]: + from scenestreamer.eval.peng_evaluator import PengEvaluator + self.evaluator = PengEvaluator(config=config) + + else: + raise ValueError(f"Unknown evaluation name: {config.EVALUATION.NAME}") + + self.save_hyperparameters(OmegaConf.to_container(self.config)) + + self._tokenizer = get_tokenizer(self.config) + # self.validation_outputs = [] + # self.validation_ground_truth = [] + + self.exp_name = None + + if self.config.USE_TRAFFICGEN: + self._trafficgen_tokenizer = TrafficGenTokenizer(config) + from scenestreamer.eval.test_trafficgen_eval import TrafficGenEvaluator + self._trafficgen_evaluator = TrafficGenEvaluator(config, device=self.device) + + self.rl_finetuner = None + + def forward(self, batch_dict): + return self.model(batch_dict) + + def get_diffusion_loss(self, data_dict): + loss_dict = self.model.motion_decoder.get_diffusion_loss(data_dict) + loss_dict = {k: (v.mean() if isinstance(v, torch.Tensor) else v) for k, v in loss_dict.items()} + loss = loss_dict["loss"] + try: + loss_dict["lr"] = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] + except RuntimeError: + # When debugging, the model might not be attached to a trainer. + pass + + return loss, loss_dict + + def get_loss(self, data_dict): + if self.config.USE_DIFFUSION: + return self.get_diffusion_loss(data_dict) + + loss_stat = {} + loss = 0.0 + + if self.config.USE_MOTION: + + # Get the decoder's output + output_logit = data_dict["decoder/output_logit"] # (B, T_skipped + 1, N, num_actions) + + # Get the GT actions + target_action = data_dict["decoder/target_action"] # (B, T_skipped, N) + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + assert output_logit.shape[:3] == target_action.shape[:3], (output_logit.shape, target_action.shape) + + + # Get loss + if self.config.OPTIMIZATION.USE_FOCAL_LOSS: + raise ValueError + from torchvision.ops import sigmoid_focal_loss + # Compute Focal Loss + alpha = 0.25 + gamma = 2 + target_onehot = F.one_hot(target_action, output_logit.shape[-1]).float() + loss = sigmoid_focal_loss( + inputs=output_logit, targets=target_onehot, alpha=alpha, gamma=gamma, reduction="none" + ) + else: + + if self.config.TOKENIZATION.TOKENIZATION_METHOD == "fast": + B, T_full, N, max_len, num_toks = output_logit.shape + + input_mask = data_dict["decoder/input_action_valid_mask"] + output_mask = data_dict["decoder/target_action_valid_mask"] + assert input_mask.shape == output_mask.shape + assert input_mask.shape == (B, T_full, N) + valid_gt_mask = input_mask & output_mask + + fast_input_token = ( + utils.unwrap(data_dict["fast_input_token"], data_dict["decoder/input_action_valid_mask"], fill=44444) + ) + fast_input_token[fast_input_token == data_dict["fast_pad_token"]] = -1 + # fast_input_valid_mask = data_dict["decoder/input_action_valid_mask"] + + # The input tokens should be in shape (B, T, N, max_len) + # At time t, it should already be the sequence of target actions! (this point is easily missed) + target_action = fast_input_token[valid_gt_mask] + masked_logit = output_logit[valid_gt_mask] + + # In fast tokenization, the first token is always the SOS token so remove them. + # Note that the target_action and output_logit are already in 2D/3D. + target_action = target_action[:, 1:] + masked_logit = masked_logit[:, :-1] + + assert not (target_action == 44444).any() + assert not (masked_logit == 0).all(-1).any() + + target_action_neg1_mask = target_action != -1 + target_action = target_action[target_action_neg1_mask] + + assert (target_action == data_dict["fast_eos_token"]).any() + + masked_logit = masked_logit[target_action_neg1_mask] + + # rate_777 = (masked_logit.argmax(dim=-1) == 777).float().mean() + # print("RATE 777: ", rate_777) + + loss = torch.nn.functional.cross_entropy(input=masked_logit, target=target_action, reduction="none") + + output_logit = masked_logit + + else: + + output_logit = output_logit[target_action_valid_mask] + target_action = target_action[target_action_valid_mask] + + loss = torch.nn.functional.cross_entropy(input=output_logit, target=target_action, reduction="none") + + original_loss = loss + loss = loss.mean() + + assert not np.isnan(loss.item()) + assert not np.isinf(loss.item()) + + with torch.no_grad(): + encodings = F.one_hot(output_logit.argmax(-1), + output_logit.shape[-1]).float().reshape(-1, output_logit.shape[-1]) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + + gt_onehot = F.one_hot(target_action, output_logit.shape[-1]).float() + gt_encodings = gt_onehot.reshape(-1, output_logit.shape[-1]) + gt_avg_probs = gt_encodings.mean(0) + gt_perplexity = (-(gt_avg_probs * torch.log(gt_avg_probs + 1e-10)).sum()).exp() + gt_cluster_use = torch.sum(gt_avg_probs > 0) + debug_gt_c_use = (gt_encodings.sum(0) > 0).sum() # .mean() + + pred_act = output_logit.argmax(-1) + acc = torch.sum(pred_act == target_action) / target_action.shape[0] + entropy = safe_entropy(output_logit) + pred_act = pred_act.float() + + rate_default_pred = (pred_act == self._tokenizer.default_action).float().mean() + rate_default_gt = (target_action == self._tokenizer.default_action).float().mean() + + num_trained_tokens = len(target_action) + num_trained_tokens_sum = self.trainer.world_size * num_trained_tokens + + # print("ACCURACY: ", acc, "ENTROPY: ", entropy.mean()) + + loss_stat.update( + { + "original_loss": original_loss.mean(), + "accuracy": acc, + "entropy": entropy.mean(), + "avg_action": pred_act.mean(), + "max_action": pred_act.max(), + "min_action": pred_act.min(), + "perplexity": perplexity, + "gt_perplexity": gt_perplexity, + "cluster_use": cluster_use, + "gt_cluster_use": gt_cluster_use, + "rate_84": rate_default_gt, + "rate_default_gt": rate_default_gt, + "rate_default_pred": rate_default_pred, + "num_trained_tokens": num_trained_tokens, + "num_trained_tokens_sum": num_trained_tokens_sum, + "toks": num_trained_tokens_sum, + } + ) + + if self.config.BACKWARD_PREDICTION: + in_back_mask = data_dict["in_backward_prediction"] + in_back_mask = in_back_mask.reshape(-1, 1, 1).expand(*target_action_valid_mask.shape) + in_back_mask = in_back_mask[target_action_valid_mask] + acc2 = (pred_act == target_action) + acc_in_back = (acc2 & in_back_mask).sum() / in_back_mask.sum() + acc_in_forward = (acc2 & ~in_back_mask).sum() / (~in_back_mask).sum() + loss_in_back = original_loss[in_back_mask].mean() + loss_in_forward = original_loss[~in_back_mask].mean() + entropy_in_back = safe_entropy(output_logit[in_back_mask]).mean() + entropy_in_forward = safe_entropy(output_logit[~in_back_mask]).mean() + loss_stat.update( + { + "accuracy_in_backward": acc_in_back, + "accuracy_in_forward": acc_in_forward, + "loss_in_backward": loss_in_back, + "loss_in_forward": loss_in_forward, + "entropy_in_backward": entropy_in_back, + "entropy_in_forward": entropy_in_forward, + "backward_ratio": in_back_mask.float().mean(), + } + ) + + if self.config.RECONSTRUCT_MAP: + gt_map_feat = data_dict["encoder/map_feature"] + map_feat_valid_mask = data_dict["encoder/map_valid_mask"] + polypoint_valid_mask = data_dict["encoder/map_feature_valid_mask"] + polypoint_valid_mask = polypoint_valid_mask[map_feat_valid_mask] # (valid points, 128) + map_feat = gt_map_feat[map_feat_valid_mask] # (num_valid_map_features, 128, 27) + polypoint = map_feat[:, :, :2] # (valid map feat, 128, 2) + num_points = polypoint.shape[1] + gt_valid_mask = polypoint_valid_mask.unsqueeze(-1).expand_as(polypoint) + gt = torch.where(gt_valid_mask, polypoint, torch.zeros_like(polypoint)) + gt_valid_mask = gt_valid_mask.reshape(-1, num_points * 2) + gt = gt.reshape(-1, num_points * 2) + map_token = data_dict["encoder/map_token"] + out = self.model.map_recon_head(self.model.map_recon_head_prenorm(map_token[map_feat_valid_mask])) + + # out.shape = (num_valid_map_features, 128 * 2) + map_recon_loss = torch.nn.functional.mse_loss(out, gt, reduction="none") + map_recon_loss = map_recon_loss[gt_valid_mask] + map_recon_loss = map_recon_loss.mean() + + loss += map_recon_loss + loss_stat["map_recon_loss"] = map_recon_loss + loss_stat["map_recon_mask_rate"] = gt_valid_mask.float().mean() + + # DEBUG CODE to find unused parameters: + # gs = torch.autograd.grad(loss.mean(), self.parameters(), allow_unused=True, retain_graph=True) + # ns = [n for n, v in self.named_parameters()] + # printed = False + # for c, g in enumerate(gs): + # if g is None: + # print(ns[c]) + # printed = True + # if not printed: + # print("No unused parameters found.") + + if (self.config.USE_TRAFFICGEN and (self.config.TRAIN_TRAFFICGEN is True or self.config.TRAIN_TRAFFICGEN is None)): + data_dict = self.model.trafficgen_decoder.forward(data_dict) + + tg_gt_action = data_dict["decoder/input_action_for_trafficgen"][:, 1:] + + tg_gt_mask = data_dict["decoder/input_action_valid_mask_for_trafficgen"][:, 1:] + tg_gt = tg_gt_action[tg_gt_mask] + tg_logit = data_dict["decoder/output_logit_for_trafficgen"][:, :-1][tg_gt_mask] + tg_loss = torch.nn.functional.cross_entropy(input=tg_logit, target=tg_gt, reduction="none") + + tg_loss = tg_loss.mean() + loss += tg_loss + tg_accuracy = torch.sum(tg_logit.argmax(-1) == tg_gt) / tg_gt.shape[0] + loss_stat.update( + { + "trafficgen_loss": tg_loss, + "trafficgen_accuracy": tg_accuracy, + "trafficgen_entropy": safe_entropy(tg_logit).mean(), + } + ) + + # current_input_action[:, 0] is the START_ACTION, so need to skip it. + tg_gt_offset_mask = tg_gt_mask & (tg_gt_action != self._trafficgen_tokenizer.start_action_id + ) & (tg_gt_action != self._trafficgen_tokenizer.end_action_id) + + gt_agent_type = data_dict["decoder/agent_type_for_trafficgen"][:, 1:] + agent_type_output = self.model.trafficgen_decoder.forward_agent_type(data_dict, action=tg_gt_action) + agent_type_loss = torch.nn.functional.cross_entropy( + input=agent_type_output[tg_gt_offset_mask], target=gt_agent_type[tg_gt_offset_mask], reduction="mean" + ) + loss += agent_type_loss + + offset_output = self.model.trafficgen_decoder.forward_offset( + data_dict, action=tg_gt_action, agent_type=gt_agent_type + ) + + for kid, k in enumerate(["position_x", "position_y", "heading", "velocity_x", "velocity_y", "length", + "width", "height"]): + tg_gt_offset = data_dict["decoder/target_offset_for_trafficgen"][:, :, kid] + tg_logit = offset_output[k] + + tg_logit = tg_logit[tg_gt_offset_mask] + tg_gt_offset = tg_gt_offset[tg_gt_offset_mask] + tg_offset_loss = torch.nn.functional.cross_entropy( + input=tg_logit, target=tg_gt_offset, reduction="mean" + ) + loss += tg_offset_loss + tg_accuracy = torch.sum(tg_logit.argmax(-1) == tg_gt_offset) / tg_gt_offset.shape[0] + loss_stat.update({ + f"trafficgen_loss_{k}": tg_loss, + f"trafficgen_accuracy_{k}": tg_accuracy, + }) + + loss_stat["total_loss"] = loss + try: + loss_stat["lr"] = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] + except RuntimeError: + # When debugging, the model might not be attached to a trainer. + pass + return loss, loss_stat + + def get_loss_for_scenestreamer(self, data_dict): + + def _safe_cross_entropy(input, target, reduction="mean"): + assert input.shape[:-1] == target.shape, (input.shape, target.shape) + assert target.min() >= 0, (target.min(), target.max()) + assert input.shape[-1] > target.max(), (input.shape, target.shape) + return torch.nn.functional.cross_entropy(input=input, target=target, reduction=reduction) + + def _accuracy(input, target): + assert input.ndim == 2 + assert target.ndim == 1 + pred_act = input.argmax(-1) + acc = torch.sum(pred_act == target) / target.shape[0] + return acc + + loss_stat = {} + loss = 0.0 + + + # ===== motion loss ===== + output_logit = data_dict["model/motion_logit"] # (B, T_skipped + 1, N, num_actions) + # Get the GT actions + target_action = data_dict["decoder/target_action"] # (B, T_skipped, N) + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + assert output_logit.shape[:3] == target_action.shape[:3], (output_logit.shape, target_action.shape) + output_logit = output_logit[target_action_valid_mask] + target_action = target_action[target_action_valid_mask] + motion_loss = _safe_cross_entropy(input=output_logit, target=target_action) + loss += motion_loss + loss_stat["motion_loss"] = motion_loss + motion_accuracy = _accuracy(input=output_logit, target=target_action) + loss_stat["motion_accuracy"] = motion_accuracy + + # ===== trafficgen loss ===== + tg_gt_action = data_dict["decoder/input_action_for_trafficgen"] + # tg_gt_mask = data_dict["decoder/input_action_valid_mask_for_trafficgen"] + B, T, N = data_dict["decoder/target_action"].shape + tg_gt_action = slice_trafficgen_data(tg_gt_action[:, :, 1:-1].reshape(B, T, N, NUM_TG_MULTI), dim=1) + agent_valid_mask = data_dict["decoder/input_action_valid_mask"] + + if self.model.no_tg: + pass + + else: + tg_agent_valid_mask = slice_trafficgen_data(agent_valid_mask, dim=1) + agent_type_logit = data_dict["model/trafficgen_agent_type_logit"] + agent_type_gt = tg_gt_action[..., 1] + agent_type_gt = agent_type_gt[tg_agent_valid_mask] + assert agent_type_gt.min() != -1 + agent_type_gt[agent_type_gt == self.model.veh_id] = 0 + agent_type_gt[agent_type_gt == self.model.ped_id] = 1 + agent_type_gt[agent_type_gt == self.model.cyc_id] = 2 + agent_type_input = agent_type_logit[tg_agent_valid_mask] + agent_type_loss = _safe_cross_entropy( + input=agent_type_input, + target=agent_type_gt, + ) + loss += agent_type_loss + loss_stat["trafficgen_agent_type_loss"] = agent_type_loss + loss_stat["trafficgen_agent_type_accuracy"] = _accuracy(input=agent_type_input, target=agent_type_gt) + + map_id_logit = data_dict["model/trafficgen_map_id_logit"] + map_id_gt = tg_gt_action[:, :, :, 2] + assert map_id_gt.shape[:3] == map_id_logit.shape[:3] + assert map_id_gt[tg_agent_valid_mask].min()>=0 + map_id_input = map_id_logit[tg_agent_valid_mask] + map_id_target = map_id_gt[tg_agent_valid_mask] + map_id_loss = _safe_cross_entropy( + input=map_id_input, + target=map_id_target + ) + loss += map_id_loss + loss_stat["trafficgen_map_id_loss"] = map_id_loss + loss_stat["trafficgen_map_id_accuracy"] = _accuracy(input=map_id_input, target=map_id_target) + + agent_state_logit = data_dict["model/trafficgen_agent_state_logit"][..., :-1, :] + agent_state_gt = slice_trafficgen_data(data_dict["decoder/target_offset_for_trafficgen"], dim=1) + assert agent_state_logit.shape[:4] == agent_state_gt.shape + agent_state_loss = _safe_cross_entropy( + input=agent_state_logit[tg_agent_valid_mask].flatten(0, 1), + target=agent_state_gt[tg_agent_valid_mask].flatten(), + ) + loss += agent_state_loss + loss_stat["trafficgen_agent_state_loss"] = agent_state_loss + + # dest_id_logit = data_dict["model/trafficgen_dest_id_logit"] + # dest_valid_mask = slice_trafficgen_data(data_dict["decoder/dest_map_index_valid_mask"], dim=1) + # # raise error if some agent_valid_mask is False but dest_valid_mask is True + # assert (dest_valid_mask & ~tg_agent_valid_mask).sum() == 0 + # dest_valid_mask = dest_valid_mask & tg_agent_valid_mask + # dest_id_gt = slice_trafficgen_data(data_dict["decoder/dest_map_index_gt"], dim=1) + # # assert (dest_id_gt == tg_gt_action[..., 4]).all() # It's normal that they are not aligned. + # assert dest_id_gt.shape[:3] == dest_id_logit.shape[:3] + # dest_id_gt = dest_id_gt[dest_valid_mask] + # assert dest_id_gt.min() >= 0 + # dest_id_logit_input = dest_id_logit[dest_valid_mask] + # dest_id_loss = _safe_cross_entropy( + # input=dest_id_logit_input, + # target=dest_id_gt, + # ) + # if self.config.PREPROCESSING.DEST_DROPOUT >= 1.0: + # loss += dest_id_loss * 0.0 + # else: + # loss += dest_id_loss + # loss_stat["trafficgen_dest_id_loss"] = dest_id_loss + # loss_stat["trafficgen_dest_id_accuracy"] = _accuracy(input=dest_id_logit_input, target=dest_id_gt) + # no_pad_mask = dest_id_gt != self.model.trafficgen_sequence_pad_id + # loss_stat["trafficgen_dest_id_accuracy_no_pad"] = _accuracy(input=dest_id_logit_input[no_pad_mask], target=dest_id_gt[no_pad_mask]) + + # ===== traffic light loss ===== + traffic_light_logit = data_dict["model/traffic_light_logit"] + traffic_light_gt = data_dict["encoder/traffic_light_state"] + traffic_light_mask = data_dict["encoder/traffic_light_valid_mask"] + if traffic_light_mask.any(): + traffic_light_loss = _safe_cross_entropy( + input=traffic_light_logit[traffic_light_mask], + target=traffic_light_gt[traffic_light_mask], + ) + loss += traffic_light_loss + traffic_light_accuracy = _accuracy( + input=traffic_light_logit[traffic_light_mask], + target=traffic_light_gt[traffic_light_mask], + ) + loss_stat["traffic_light_accuracy"] = traffic_light_accuracy + else: + traffic_light_loss = _safe_cross_entropy( + input=traffic_light_logit.flatten(0, 2)[:1], + target=traffic_light_gt.flatten()[:1], + reduction="mean" + ) * 0.0 + loss += traffic_light_loss + loss_stat["traffic_light_accuracy"] = np.nan + loss_stat["traffic_light_loss"] = traffic_light_loss + + # DEBUG CODE to find unused parameters: + # gs = torch.autograd.grad(loss.mean(), self.parameters(), allow_unused=True, retain_graph=True) + # ns = [n for n, v in self.named_parameters()] + # printed = False + # for c, g in enumerate(gs): + # if g is None: + # print(ns[c]) + # printed = True + # if not printed: + # print("No unused parameters found.") + + # List parameter name and the gradient: + # gs = torch.autograd.grad(loss.mean(), self.parameters(), allow_unused=True, retain_graph=True) + # gs = {name: g for (name, _), g in zip(self.named_parameters(), gs)} + + loss_stat["total_loss"] = loss + try: + loss_stat["lr"] = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] + except RuntimeError: + # When debugging, the model might not be attached to a trainer. + pass + return loss, loss_stat + + def training_step_rl_finetuning(self, data_dict, batch_idx): + assert self.config.SCENESTREAMER_NO_TG is True + assert self.config.USE_RL_FINETUNING is True + + if self.rl_finetuner is None: + from scenestreamer.rl_finetuning import RLFinetuner + self.rl_finetuner = RLFinetuner(model=self.model, all_gather=self.all_gather) + + loss, loss_stat = self.rl_finetuner.get_loss(data_dict) + try: + loss_stat["lr"] = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] + except RuntimeError: + # When debugging, the model might not be attached to a trainer. + pass + + pbar_keys = ("total_loss", "toks", "lr") + motion_stat = {k: v for k, v in loss_stat.items() if k.startswith("motion_stat")} + loss_stat = {k: v for k, v in loss_stat.items() if not k.startswith("motion_stat")} + self.log_dict( + {f"{k}": float(v) + for k, v in loss_stat.items() if k in pbar_keys}, + batch_size=data_dict["encoder/map_feature"].shape[0], + prog_bar=True, + ) + if motion_stat: + self.log_dict( + {f"{k}": float(v) + for k, v in motion_stat.items()}, + batch_size=data_dict["encoder/map_feature"].shape[0], + prog_bar=False, + ) + self.log_dict( + {f"train/{k}": float(v) + for k, v in loss_stat.items()}, + batch_size=data_dict["encoder/map_feature"].shape[0], + # on_epoch=True, + prog_bar=False, + ) + self.log('monitoring_step', float(self.global_step)) + return loss + + def training_step(self, data_dict, batch_idx): + + if self.config.USE_RL_FINETUNING: + return self.training_step_rl_finetuning(data_dict, batch_idx) + + # For profiling GPU usage. + # torch.cuda.empty_cache() + + # print("RANK {} SCENARIO ID {} START".format(self.global_rank, data_dict["scenario_id"])) + + data_dict = self(data_dict) + + if self.config.MODEL.NAME == "scenestreamer": + loss, loss_stat = self.get_loss_for_scenestreamer(data_dict) + + else: + loss, loss_stat = self.get_loss(data_dict) + + pbar_keys = ("total_loss", "toks", "lr") + + motion_stat = {k: v for k, v in loss_stat.items() if k.startswith("motion_stat")} + loss_stat = {k: v for k, v in loss_stat.items() if not k.startswith("motion_stat")} + + self.log_dict( + {f"{k}": float(v) + for k, v in loss_stat.items() if k in pbar_keys}, + batch_size=data_dict["encoder/map_feature"].shape[0], + prog_bar=True, + ) + if motion_stat: + self.log_dict( + {f"{k}": float(v) + for k, v in motion_stat.items()}, + batch_size=data_dict["encoder/map_feature"].shape[0], + prog_bar=False, + ) + self.log_dict( + {f"train/{k}": float(v) + for k, v in loss_stat.items()}, + batch_size=data_dict["encoder/map_feature"].shape[0], + # on_epoch=True, + prog_bar=False, + ) + self.log('monitoring_step', float(self.global_step)) + + # print("RANK {} SCENARIO ID {} END".format(self.global_rank, data_dict["scenario_id"])) + + return loss + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + # do something on_after_optimizer_step + + # if self.config.USE_DIFFUSION: + # self.model.motion_decoder.update_diffusion_step() + + def on_validation_start(self): + torch.cuda.empty_cache() + + def validation_step(self, data_dict, batch_idx): + + if self.config.EVAL_MOTION: + + if data_dict["encoder/map_valid_mask"].shape[1] == 0: + sid = data_dict["scenario_id"] + print("Warning: Empty map_valid_mask found for scenario: ", sid) + logger.error(f"Empty map_valid_mask found for scenario: {sid}") + return + + try: + self.evaluator.validation_step( + data_dict, + batch_idx, + model=self.model, + global_rank=self.global_rank, + trainer=self.trainer, + logger=self.logger, + log_func=self.log, + log_dict_func=self.log_dict, + print_func=self.print, + lightning_model=self, + ) + except Exception as error: + scenario_ids = data_dict["scenario_id"] + rank = self.global_rank + msg = f"Error in validation_step: {batch_idx=}, {scenario_ids=}, {rank=}, {error=}" + print(msg) + raise RuntimeError(msg) from error + + if self.config.EVAL_TRAFFICGEN: + + if self.config.MODEL.NAME == "scenestreamer": + if not hasattr(self, "scenestreamer_generator"): + from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator + self.scenestreamer_generator = SceneStreamerGenerator(model=self.model, device=self.device) + with torch.no_grad(): + self.scenestreamer_generator.reset(new_data_dict=data_dict) + output_data_dict = self.scenestreamer_generator.generate_scenestreamer_initial_state(progress_bar=True) + data_dict.update(output_data_dict) + stat = {} + + else: + assert self.config.USE_TRAFFICGEN + + # data_dict = self.model.encode_scene(data_dict) + # data_dict, stat = self.model.trafficgen_decoder.autoregressive_rollout_trafficgen(data_dict) + + data_dict, stat = generate_initial_state(model=self.model, data_dict=data_dict, force_add=True) + + # import matplotlib.pyplot as plt + # pos_pred = data_dict["decoder/modeled_agent_position_for_trafficgen"][0][1:] + # pred_mask = data_dict["decoder/input_action_valid_mask_for_trafficgen"][0][1:] + # pos_target = data_dict["decoder/agent_position"][0, 0] + # gt_mask = data_dict["decoder/agent_valid_mask"][0, 0] + # plt.figure() + # plt.scatter(pos_pred[pred_mask][:, 0].cpu().numpy(), pos_pred[pred_mask][:, 1].cpu().numpy(), c='r') + # plt.scatter(pos_target[gt_mask][:, 0].cpu().numpy(), pos_target[gt_mask][:, 1].cpu().numpy(), c='b') + # from scenestreamer.gradio_ui.plot import _plot_map + # _plot_map({k: v[0].cpu().numpy() for k, v in data_dict.items() if isinstance(v, torch.Tensor)}, plt.gca()) + # plt.gca().set_aspect('equal', adjustable='box') + # # plt.title(f"mmd_pos={mmd_pos.item()}") + # plt.show() + + self._trafficgen_evaluator.validation_step( + data_dict, + stat, + model=self.model, + global_rank=self.global_rank, + trainer=self.trainer, + logger=self.logger, + log_func=functools.partial(self.log, sync_dist=False), + log_dict_func=self.log_dict, + print_func=self.print, + lightning_model=self, + ) + + def on_validation_epoch_end(self): + """ + This function gathers intermediate evaluation result and pass them to the Waymo + evaluation pipeline together and log the final results. + """ + if self.config.EVAL_MOTION: + self.log("monitoring_step", float(self.global_step)) + self.evaluator.on_validation_epoch_end( + global_rank=self.global_rank, + trainer=self.trainer, + logger=self.logger, + log_func=self.log, + log_dict_func=self.log_dict, + print_func=self.print, + exp_name=self.exp_name, + ) + + # if self.config.EVAL_TRAFFICGEN: + # import functools + # self._trafficgen_evaluator.on_validation_epoch_end( + # global_rank=self.global_rank, + # trainer=self.trainer, + # logger=self.logger, + # log_func=functools.partial(self.log, sync_dist=True), + # log_dict_func=self.log_dict, + # print_func=self.print, + # exp_name=self.exp_name, + # ) + + def configure_optimizers(self): + """Required by Lightning.""" + opt_cfg = self.config.OPTIMIZATION + + if opt_cfg.OPTIMIZER == 'Adam': + # optimizer = torch.optim.Adam( + # [each[1] for each in self.named_parameters()], + # lr=opt_cfg.LR, + # weight_decay=opt_cfg.get('WEIGHT_DECAY', 0) + # ) + raise ValueError() + elif opt_cfg.OPTIMIZER == 'AdamW': + optimizer = torch.optim.AdamW( + self.parameters(), + lr=opt_cfg.LR, + weight_decay=opt_cfg.get('WEIGHT_DECAY', 0), + betas=(0.9, 0.95), + eps=1e-5 + ) + else: + assert False + + if opt_cfg.get('SCHEDULER', None) == 'cosine': + + utils.rank_zero_print("=====================================") + if self.trainer.train_dataloader is not None: + num_steps_per_epoch = len(self.trainer.train_dataloader) + elif self.trainer.datamodule is not None and self.trainer.datamodule.train_dataset is not None: + utils.rank_zero_print( + "Finding num_steps_per_epoch from datamodule...", len(self.trainer.datamodule.train_dataset), + self.trainer.datamodule.train_batch_size, self.trainer.world_size + ) + num_steps_per_epoch = len(self.trainer.datamodule.train_dataset + ) // (self.trainer.datamodule.train_batch_size * self.trainer.world_size) + else: + raise ValueError("Can't find num_steps_per_epoch") + + num_epochs = self.config.epochs + total_steps = num_steps_per_epoch * num_epochs + utils.rank_zero_print("Configuring cosine scheduler") + utils.rank_zero_print("Num Steps per epoch: ", num_steps_per_epoch) + utils.rank_zero_print("Num Epochs: ", num_epochs) + utils.rank_zero_print("Total Steps: ", total_steps) + utils.rank_zero_print("=====================================") + + scheduler = lr_schedule.get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=opt_cfg.WARMUP_STEPS, + num_training_steps=total_steps, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step" + }, + } + + elif opt_cfg.get('SCHEDULER', None) == 'lambdaLR': + raise ValueError() + # def lr_lbmd(cur_epoch): + # cur_decay = 1 + # for decay_step in opt_cfg.get('DECAY_STEP_LIST', [5, 10, 15, 20]): + # if cur_epoch >= decay_step: + # cur_decay = cur_decay * opt_cfg.LR_DECAY + # return max(cur_decay, opt_cfg.LR_CLIP / opt_cfg.LR) + # + # scheduler = LambdaLR(optimizer, lr_lbmd) + + elif opt_cfg.get('SCHEDULER', None) == 'linear': + raise ValueError() + scheduler = lr_schedule.get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=opt_cfg.WARMUP_STEPS, + num_training_steps=opt_cfg.TRAINING_STEPS, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step" + }, + } + + elif opt_cfg.get('SCHEDULER', None) == 'inverse_sqrt': + scheduler = lr_schedule.get_inverse_sqrt_schedule( + optimizer, + num_warmup_steps=opt_cfg.WARMUP_STEPS, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step" + }, + } + + else: + raise ValueError() diff --git a/scenestreamer/models/ops/__init__.py b/scenestreamer/models/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/models/ops/attention/__init__.py b/scenestreamer/models/ops/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0da7f60527041412c1528deed6fd2ca006f08ff --- /dev/null +++ b/scenestreamer/models/ops/attention/__init__.py @@ -0,0 +1,9 @@ +""" +Mostly copy-paste from https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/ops/attention +""" + +from scenestreamer.models.ops.attention import attention_utils_v2 + +__all__ = { + 'v2': attention_utils_v2, +} \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/attention_utils_v2.py b/scenestreamer/models/ops/attention/attention_utils_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b75b761af7f75d39cdf4f8e54933167725bcbe --- /dev/null +++ b/scenestreamer/models/ops/attention/attention_utils_v2.py @@ -0,0 +1,381 @@ +""" +Mostly copy-paste from https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/ops/attention/attention_utils_v2.py +""" + +import torch +from torch.autograd import Function, Variable + +from . import attention_cuda +""" Attention computation code v2.""" + + +class AttentionWeightComputation(Function): + """ + Generate the attention weight matrix based on: + * the generated attention pair index (total_query_num, local_size); + * query features (total_query_num, nhead, hdim) + * key features (total_key_num, nhead, hdim) + Generate the attention weight matrix. + * (total_query_num, local_size) + """ + @staticmethod + def forward( + ctx, query_batch_cnt: torch.Tensor, key_batch_cnt: torch.Tensor, index_pair_batch: torch.Tensor, + index_pair: torch.Tensor, query_features: torch.Tensor, key_features: torch.Tensor + ): + """ + :param ctx: + :param query_batch_cnt: A integer tensor with shape [bs], indicating the query amount for each batch. + :param key_batch_cnt: A integer tensor with shape [bs], indicating the key amount of each batch. + :param index_pair_batch: A integer tensor with shape [total_query_num], indicating the batch + index of each query. + :param index_pair: A integer tensor with shape [total_query_num, local_size] + We ignore those index whose value is -1. + :param query_features: A float tensor with shape [total_query_num, nhead, hdim] + :param key_features: A float tensor with shape [total_key_num, nhead, hdim] + :return: + output: A float tensor with shape [total_query_num, local_size, nhead] + """ + assert query_batch_cnt.is_contiguous() + assert key_batch_cnt.is_contiguous() + assert index_pair_batch.is_contiguous() + assert index_pair.is_contiguous() + assert query_features.is_contiguous() + assert key_features.is_contiguous() + + b = query_batch_cnt.shape[0] + total_query_num, local_size = index_pair.size() + total_key_num, nhead, hdim = key_features.size() + + # Need to ensure that every tensor in query features have an output. + assert total_query_num == query_features.shape[0] + + # output = torch.cuda.FloatTensor(total_query_num, local_size, nhead).zero_() + output = torch.zeros([total_query_num, local_size, nhead], dtype=torch.float32, device=query_features.device) + + if query_features.dtype == torch.bfloat16: + # attention_cuda.attention_weight_computation_wrapper_v2_half( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, query_features, key_features, output + # ) + attention_cuda.attention_weight_computation_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, query_features.type(torch.float32), + key_features.type(torch.float32), output + ) + elif query_features.dtype == torch.float16: + # attention_cuda.attention_weight_computation_wrapper_v2_fp16( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, query_features, key_features, output + # ) + attention_cuda.attention_weight_computation_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, query_features.type(torch.float32), + key_features.type(torch.float32), output + ) + else: + + # For debug only. + # b = 1 + # total_query_num = 1 + # local_size = 1 + # total_key_num = 1 + # nhead = 1 + # hdim = 128 + # query_batch_cnt = torch.ones([1, ]).to(query_batch_cnt) + # key_batch_cnt = torch.ones([1, ]).to(query_batch_cnt) + # index_pair_batch = torch.zeros([1, ]).to(query_batch_cnt) + # index_pair = torch.zeros([1, 1]).to(query_batch_cnt) + # query_features = torch.ones([1, 1, hdim]).to(query_features) * 3 + # key_features = torch.ones([1, 1, hdim]).to(query_features) * 7 + # output = torch.zeros([1, 1, 1]).to(query_features) + + attention_cuda.attention_weight_computation_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, query_features, key_features, output + ) + ctx.for_backwards = ( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, query_features, key_features + ) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + """ + Args: + ctx: + grad_out: [total_query_num, local_size, nhead] + Returns: + grad_query_features: [total_query_num, nhead, hdim] + grad_key_features: [total_key_num, nhead, hdim] + """ + + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + + ( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, query_features, key_features + ) = ctx.for_backwards + + # grad_query_features = Variable(torch.cuda.FloatTensor(total_query_num, nhead, hdim).zero_()) + # grad_key_features = Variable(torch.cuda.FloatTensor(total_key_num, nhead, hdim).zero_()) + + # grad_query_features = query_features.new_zeros([total_query_num, nhead, hdim]) + # grad_key_features = query_features.new_zeros([total_key_num, nhead, hdim]) + + if query_features.dtype == torch.float16: + # grad_out_data = grad_out.data.contiguous() + # grad_query_features = Variable( + # query_features.new_zeros([total_query_num, nhead, hdim], dtype=torch.float16)) + # grad_key_features = Variable(query_features.new_zeros([total_key_num, nhead, hdim], dtype=torch.float16)) + # attention_cuda.attention_weight_computation_grad_wrapper_v2_fp16( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, query_features, key_features, grad_out_data, grad_query_features.data, + # grad_key_features.data + # ) + grad_query_features = Variable(torch.cuda.FloatTensor(total_query_num, nhead, hdim).zero_()) + grad_key_features = Variable(torch.cuda.FloatTensor(total_key_num, nhead, hdim).zero_()) + grad_out = grad_out.type(torch.float32) + grad_out_data = grad_out.data.contiguous() + attention_cuda.attention_weight_computation_grad_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, query_features.type(torch.float32), + key_features.type(torch.float32), grad_out_data, grad_query_features.data, grad_key_features.data + ) + + elif query_features.dtype == torch.bfloat16: + # grad_out_data = grad_out.data.contiguous() + # grad_query_features = Variable( + # query_features.new_zeros([total_query_num, nhead, hdim], dtype=torch.bfloat16)) + # grad_key_features = Variable(query_features.new_zeros([total_key_num, nhead, hdim], dtype=torch.bfloat16)) + # attention_cuda.attention_weight_computation_grad_wrapper_v2_half( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, query_features, key_features, grad_out_data, grad_query_features.data, + # grad_key_features.data + # ) + + grad_query_features = Variable(torch.cuda.FloatTensor(total_query_num, nhead, hdim).zero_()) + grad_key_features = Variable(torch.cuda.FloatTensor(total_key_num, nhead, hdim).zero_()) + grad_out = grad_out.type(torch.float32) + grad_out_data = grad_out.data.contiguous() + attention_cuda.attention_weight_computation_grad_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, query_features.type(torch.float32), + key_features.type(torch.float32), grad_out_data, grad_query_features.data, grad_key_features.data + ) + + # grad_query_features = Variable(query_features.new_zeros([total_query_num, nhead, hdim])) + # grad_key_features = Variable(query_features.new_zeros([total_key_num, nhead, hdim])) + # grad_out_data = grad_out.data.contiguous() + # attention_cuda.attention_weight_computation_grad_wrapper_v2_half( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, query_features, key_features, grad_out_data, grad_query_features.data, + # grad_key_features.data + # ) + + else: + + grad_query_features = Variable( + query_features.new_zeros([total_query_num, nhead, hdim], dtype=torch.float32) + ) + grad_key_features = Variable(query_features.new_zeros([total_key_num, nhead, hdim], dtype=torch.float32)) + + # grad_query_features = Variable(query_features.new_zeros([total_query_num, nhead, hdim])) + # grad_key_features = Variable(query_features.new_zeros([total_key_num, nhead, hdim])) + grad_out_data = grad_out.data.contiguous() + attention_cuda.attention_weight_computation_grad_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, query_features, key_features, grad_out_data, grad_query_features.data, + grad_key_features.data + ) + + return None, None, None, None, grad_query_features, grad_key_features + + +attention_weight_computation = AttentionWeightComputation.apply + + +class AttentionValueComputation(Function): + """ + Generate the attention result based on: + * the generated attention pair index (total_query_num, local_size); + * value features (total_key_num, nhead, hdim) + * attn_weight (total_query_num, local_size, nhead) + Generate the attention result. + * (total_query_num, nhead, hdim) + """ + @staticmethod + def forward( + ctx, query_batch_cnt: torch.Tensor, key_batch_cnt: torch.Tensor, index_pair_batch: torch.Tensor, + index_pair: torch.Tensor, attn_weight: torch.Tensor, value_features: torch.Tensor + ): + """ + :param ctx: + :param query_batch_cnt: A integer tensor with shape [bs], indicating the query amount for each batch. + :param key_batch_cnt: A integer tensor with shape [bs], indicating the key amount of each batch. + :param index_pair_batch: A integer tensor with shape [total_query_num], indicating the batch + index of each query. + :param index_pair: A integer tensor with shape [total_query_num, local_size] + We ignore those index whose value is -1. + :param attn_weight: A float tensor with shape [total_query_num, local_size, nhead] + :param value_features: A float tensor with shape [total_key_num, nhead, hdim] + :return: + output: A float tensor with shape [total_query_num, nhead, hdim] + """ + assert query_batch_cnt.is_contiguous() + assert key_batch_cnt.is_contiguous() + assert index_pair_batch.is_contiguous() + assert index_pair.is_contiguous() + assert attn_weight.is_contiguous() + assert value_features.is_contiguous() + + b = query_batch_cnt.shape[0] + total_query_num, local_size = index_pair.size() + total_key_num, nhead, hdim = value_features.size() + + # Need to ensure that every tensor in query features have an output. + assert total_query_num == attn_weight.shape[0] + + # output = torch.cuda.FloatTensor(total_query_num, nhead, hdim).zero_() + # output = value_features.new_zeros([total_query_num, nhead, hdim]) + output = torch.zeros([total_query_num, nhead, hdim], dtype=torch.float32, device=value_features.device) + + # half_precision = int(value_features.dtype == torch.bfloat16) + + if value_features.dtype == torch.bfloat16: + # attention_cuda.attention_value_computation_wrapper_v2_half( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, attn_weight, value_features, output + # ) + attention_cuda.attention_value_computation_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, attn_weight.type(torch.float32), + value_features.type(torch.float32), output + ) + elif value_features.dtype == torch.float16: + # attention_cuda.attention_value_computation_wrapper_v2_fp16( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, attn_weight, value_features, output + # ) + attention_cuda.attention_value_computation_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, attn_weight.type(torch.float32), + value_features.type(torch.float32), output + ) + else: + attention_cuda.attention_value_computation_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, attn_weight, value_features, output + ) + + ctx.for_backwards = ( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, attn_weight, value_features + ) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + """ + Args: + ctx: + grad_out: [total_query_num, nhead, hdim] + Returns: + grad_attn_weight: [total_query_num, local_size, nhead] + grad_value_features: [total_key_num, nhead, hdim] + """ + + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + + ( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, attn_weight, value_features + ) = ctx.for_backwards + + grad_attn_weight = Variable( + torch.zeros((total_query_num, local_size, nhead), dtype=torch.float32, device=attn_weight.device) + ) + grad_value_features = Variable( + torch.zeros((total_key_num, nhead, hdim), dtype=torch.float32, device=attn_weight.device) + ) + + if value_features.dtype == torch.bfloat16: + + # grad_attn_weight = Variable( + # torch.zeros((total_query_num, local_size, nhead), dtype=torch.bfloat16, device=attn_weight.device)) + # grad_value_features = Variable( + # torch.zeros((total_key_num, nhead, hdim), dtype=torch.bfloat16, device=attn_weight.device)) + + # grad_attn_weight = Variable(torch.cuda.FloatTensor(total_query_num, local_size, nhead).zero_()) + # grad_value_features = Variable(torch.cuda.FloatTensor(total_key_num, nhead, hdim).zero_()) + # grad_out_data = grad_out.data.contiguous() + # attention_cuda.attention_value_computation_grad_wrapper_v2_half( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, attn_weight, value_features, grad_out_data, grad_attn_weight.data, + # grad_value_features.data + # ) + + grad_out = grad_out.type(torch.float32) + grad_out_data = grad_out.data.contiguous() + attention_cuda.attention_value_computation_grad_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, attn_weight.type(torch.float32), + value_features.type(torch.float32), grad_out_data, grad_attn_weight.data, grad_value_features.data + ) + + elif value_features.dtype == torch.float16: + + # grad_attn_weight = Variable( + # torch.zeros((total_query_num, local_size, nhead), dtype=torch.float16, device=attn_weight.device)) + # grad_value_features = Variable( + # torch.zeros((total_key_num, nhead, hdim), dtype=torch.float16, device=attn_weight.device)) + + # grad_attn_weight = Variable(torch.cuda.FloatTensor(total_query_num, local_size, nhead).zero_()) + # grad_value_features = Variable(torch.cuda.FloatTensor(total_key_num, nhead, hdim).zero_()) + # grad_out_data = grad_out.data.contiguous() + # attention_cuda.attention_value_computation_grad_wrapper_v2_fp16( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, attn_weight, value_features, grad_out_data, + # grad_attn_weight.data, + # grad_value_features.data + # ) + + # grad_out_data = grad_out.data.contiguous() + # grad_attn_weight = Variable(attn_weight.new_zeros([total_query_num, local_size, nhead])) + # grad_value_features = Variable(attn_weight.new_zeros([total_key_num, nhead, hdim])) + # attention_cuda.attention_value_computation_grad_wrapper_v2_half( + # b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + # index_pair_batch, index_pair, attn_weight, value_features, grad_out_data, grad_attn_weight.data, + # grad_value_features.data + # ) + grad_out = grad_out.type(torch.float32) + grad_out_data = grad_out.data.contiguous() + attention_cuda.attention_value_computation_grad_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, + key_batch_cnt, index_pair_batch, index_pair, attn_weight.type(torch.float32), + value_features.type(torch.float32), grad_out_data, grad_attn_weight.data, grad_value_features.data + ) + else: + + # grad_attn_weight = Variable( + # torch.zeros((total_query_num, local_size, nhead), dtype=torch.float32, device=attn_weight.device)) + # grad_value_features = Variable( + # torch.zeros((total_key_num, nhead, hdim), dtype=torch.float32, device=attn_weight.device)) + + grad_out_data = grad_out.data.contiguous() + # grad_attn_weight = Variable(attn_weight.new_zeros([total_query_num, local_size, nhead])) + # grad_value_features = Variable(attn_weight.new_zeros([total_key_num, nhead, hdim])) + + attention_cuda.attention_value_computation_grad_wrapper_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, query_batch_cnt, key_batch_cnt, + index_pair_batch, index_pair, attn_weight, value_features, grad_out_data, grad_attn_weight.data, + grad_value_features.data + ) + + return None, None, None, None, grad_attn_weight, grad_value_features + + +attention_value_computation = AttentionValueComputation.apply diff --git a/scenestreamer/models/ops/attention/src/attention_api.cpp b/scenestreamer/models/ops/attention/src/attention_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ea32e967cba761a398ad88c0d6d718b477848c1 --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_api.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include + +#include "attention_func_v2.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +// m.def("attention_weight_computation_wrapper", &attention_weight_computation_wrapper, +// "attention weight computation forward."); +// m.def("attention_weight_computation_grad_wrapper", &attention_weight_computation_grad_wrapper, +// "attention weight computation backward."); +// m.def("attention_value_computation_wrapper", &attention_value_computation_wrapper, +// "attention result computation forward."); +// m.def("attention_value_computation_grad_wrapper", &attention_value_computation_grad_wrapper, +// "attention result computation backward."); + + m.def("attention_weight_computation_wrapper_v2", &attention_weight_computation_wrapper_v2, + "attention weight computation forward."); + m.def("attention_weight_computation_wrapper_v2_half", &attention_weight_computation_wrapper_v2_half, + "attention weight computation forward."); + m.def("attention_weight_computation_grad_wrapper_v2", &attention_weight_computation_grad_wrapper_v2, + "attention weight computation backward."); + m.def("attention_weight_computation_grad_wrapper_v2_half", &attention_weight_computation_grad_wrapper_v2_half, + "attention weight computation backward."); + m.def("attention_value_computation_wrapper_v2", &attention_value_computation_wrapper_v2, + "attention result computation forward."); + m.def("attention_value_computation_wrapper_v2_half", &attention_value_computation_wrapper_v2_half, + "attention result computation forward."); + m.def("attention_value_computation_grad_wrapper_v2", &attention_value_computation_grad_wrapper_v2, + "attention result computation backward."); + m.def("attention_value_computation_grad_wrapper_v2_half", &attention_value_computation_grad_wrapper_v2_half, + "attention result computation backward."); + m.def("attention_value_computation_grad_wrapper_v2_fp16", &attention_value_computation_grad_wrapper_v2_fp16, + "attention result computation backward."); + m.def("attention_value_computation_wrapper_v2_fp16", &attention_value_computation_wrapper_v2_fp16, + "attention result computation forward."); + m.def("attention_weight_computation_grad_wrapper_v2_fp16", &attention_weight_computation_grad_wrapper_v2_fp16, + "attention weight computation backward."); + m.def("attention_weight_computation_wrapper_v2_fp16", &attention_weight_computation_wrapper_v2_fp16, + "attention weight computation forward."); +} \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/src/attention_func_v2.cpp b/scenestreamer/models/ops/attention/src/attention_func_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5de79be66693bc5c0cfba3fa37c46864fb770b8a --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_func_v2.cpp @@ -0,0 +1,549 @@ +#include +#include +#include +#include + +#include "attention_func_v2.h" + +#define CHECK_CUDA(x) do { \ + if (!x.type().is_cuda()) { \ + fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ + exit(-1); \ + } \ +} while (0) +#define CHECK_CONTIGUOUS(x) do { \ + if (!x.is_contiguous()) { \ + fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ + exit(-1); \ + } \ +} while (0) +#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) + + +int attention_weight_computation_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(query_features); + CHECK_INPUT(key_features); + + CHECK_INPUT(output); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const float *query_features_data = query_features.data_ptr(); + const float *key_features_data = key_features.data_ptr(); + + float *output_data = output.data_ptr(); + + attention_weight_computation_launcher_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, query_features_data, key_features_data, + output_data); + + return 1; +} + + + +int attention_weight_computation_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(query_features); + CHECK_INPUT(key_features); + + CHECK_INPUT(output); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::BFloat16 *query_features_data = query_features.data_ptr(); + const at::BFloat16 *key_features_data = key_features.data_ptr(); + + at::BFloat16 *output_data = output.data_ptr(); + + attention_weight_computation_launcher_v2_half( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, query_features_data, key_features_data, + output_data); + + return 1; +} + + +int attention_weight_computation_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(query_features); + CHECK_INPUT(key_features); + + CHECK_INPUT(output); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::Half *query_features_data = query_features.data_ptr(); + const at::Half *key_features_data = key_features.data_ptr(); + + at::Half *output_data = output.data_ptr(); + + attention_weight_computation_launcher_v2_fp16( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, query_features_data, key_features_data, + output_data); + + return 1; +} + + +int attention_weight_computation_grad_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(query_features); + CHECK_INPUT(key_features); + + CHECK_INPUT(grad_out); + CHECK_INPUT(grad_query_features); + CHECK_INPUT(grad_key_features); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const float *query_features_data = query_features.data_ptr(); + const float *key_features_data = key_features.data_ptr(); + + float *grad_out_data = grad_out.data_ptr(); + float *grad_query_features_data = grad_query_features.data_ptr(); + float *grad_key_features_data = grad_key_features.data_ptr(); + + attention_weight_computation_grad_launcher_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, query_features_data, key_features_data, + grad_out_data, grad_query_features_data, grad_key_features_data); + + return 1; +} + + +int attention_weight_computation_grad_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(query_features); + CHECK_INPUT(key_features); + + CHECK_INPUT(grad_out); + CHECK_INPUT(grad_query_features); + CHECK_INPUT(grad_key_features); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::BFloat16 *query_features_data = query_features.data_ptr(); + const at::BFloat16 *key_features_data = key_features.data_ptr(); + + at::BFloat16 *grad_out_data = grad_out.data_ptr(); + at::BFloat16 *grad_query_features_data = grad_query_features.data_ptr(); + at::BFloat16 *grad_key_features_data = grad_key_features.data_ptr(); + + attention_weight_computation_grad_launcher_v2_half( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, query_features_data, key_features_data, + grad_out_data, grad_query_features_data, grad_key_features_data); + + return 1; +} + + +int attention_weight_computation_grad_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(query_features); + CHECK_INPUT(key_features); + + CHECK_INPUT(grad_out); + CHECK_INPUT(grad_query_features); + CHECK_INPUT(grad_key_features); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::Half *query_features_data = query_features.data_ptr(); + const at::Half *key_features_data = key_features.data_ptr(); + + at::Half *grad_out_data = grad_out.data_ptr(); + at::Half *grad_query_features_data = grad_query_features.data_ptr(); + at::Half *grad_key_features_data = grad_key_features.data_ptr(); + + attention_weight_computation_grad_launcher_v2_fp16( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, query_features_data, key_features_data, + grad_out_data, grad_query_features_data, grad_key_features_data); + + return 1; +} + + +int attention_value_computation_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(attn_weight); + CHECK_INPUT(value_features); + + CHECK_INPUT(output); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const float *attn_weight_data = attn_weight.data_ptr(); + const float *value_features_data = value_features.data_ptr(); + + float *output_data = output.data_ptr(); + + attention_value_computation_launcher_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, attn_weight_data, value_features_data, + output_data); + + return 1; +} + + +int attention_value_computation_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(attn_weight); + CHECK_INPUT(value_features); + + CHECK_INPUT(output); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::BFloat16 *attn_weight_data = attn_weight.data_ptr(); + const at::BFloat16 *value_features_data = value_features.data_ptr(); + + at::BFloat16 *output_data = output.data_ptr(); + + attention_value_computation_launcher_v2_half( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, attn_weight_data, value_features_data, + output_data); + + return 1; +} + + + +int attention_value_computation_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(attn_weight); + CHECK_INPUT(value_features); + + CHECK_INPUT(output); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::Half *attn_weight_data = attn_weight.data_ptr(); + const at::Half *value_features_data = value_features.data_ptr(); + + at::Half *output_data = output.data_ptr(); + + attention_value_computation_launcher_v2_fp16( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, attn_weight_data, value_features_data, + output_data); + + return 1; +} + + +int attention_value_computation_grad_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(attn_weight); + CHECK_INPUT(value_features); + + CHECK_INPUT(grad_out); + CHECK_INPUT(grad_attn_weight); + CHECK_INPUT(grad_value_features); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const float *attn_weight_data = attn_weight.data_ptr(); + const float *value_features_data = value_features.data_ptr(); + + float *grad_out_data = grad_out.data_ptr(); + float *grad_attn_weight_data = grad_attn_weight.data_ptr(); + float *grad_value_features_data = grad_value_features.data_ptr(); + + attention_value_computation_grad_launcher_v2( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, attn_weight_data, value_features_data, + grad_out_data, grad_attn_weight_data, grad_value_features_data); + + return 1; +} + +int attention_value_computation_grad_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(attn_weight); + CHECK_INPUT(value_features); + + CHECK_INPUT(grad_out); + CHECK_INPUT(grad_attn_weight); + CHECK_INPUT(grad_value_features); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::BFloat16 *attn_weight_data = attn_weight.data_ptr(); + const at::BFloat16 *value_features_data = value_features.data_ptr(); + + at::BFloat16 *grad_out_data = grad_out.data_ptr(); + at::BFloat16 *grad_attn_weight_data = grad_attn_weight.data_ptr(); + at::BFloat16 *grad_value_features_data = grad_value_features.data_ptr(); + + attention_value_computation_grad_launcher_v2_half( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, attn_weight_data, value_features_data, + grad_out_data, grad_attn_weight_data, grad_value_features_data); + + return 1; +} + + +int attention_value_computation_grad_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + CHECK_INPUT(query_batch_cnt); + CHECK_INPUT(key_batch_cnt); + CHECK_INPUT(index_pair_batch); + CHECK_INPUT(index_pair); + CHECK_INPUT(attn_weight); + CHECK_INPUT(value_features); + + CHECK_INPUT(grad_out); + CHECK_INPUT(grad_attn_weight); + CHECK_INPUT(grad_value_features); + + const int *query_batch_cnt_data = query_batch_cnt.data_ptr(); + const int *key_batch_cnt_data = key_batch_cnt.data_ptr(); + const int *index_pair_batch_data = index_pair_batch.data_ptr(); + const int *index_pair_data = index_pair.data_ptr(); + + const at::Half *attn_weight_data = attn_weight.data_ptr(); + const at::Half *value_features_data = value_features.data_ptr(); + + at::Half *grad_out_data = grad_out.data_ptr(); + at::Half *grad_attn_weight_data = grad_attn_weight.data_ptr(); + at::Half *grad_value_features_data = grad_value_features.data_ptr(); + + attention_value_computation_grad_launcher_v2_fp16( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, + index_pair_data, attn_weight_data, value_features_data, + grad_out_data, grad_attn_weight_data, grad_value_features_data); + + return 1; +} \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/src/attention_func_v2.h b/scenestreamer/models/ops/attention/src/attention_func_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..fd96b7aab153a5e436990e9360588c2ec9827c8e --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_func_v2.h @@ -0,0 +1,199 @@ +#ifndef _ATTENTION_FUNC_V2_H +#define _ATTENTION_FUNC_V2_H + +#include +#include +#include +#include +#include +#include +#include +#include + +void attention_weight_computation_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *query_features, const float* key_features, + float *output); + + +int attention_weight_computation_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor output); + +int attention_weight_computation_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor output); + + +int attention_weight_computation_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor output); + + +void attention_weight_computation_grad_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *query_features, const float* key_features, + float *grad_out, float* grad_query_features, float* grad_key_features); + + +int attention_weight_computation_grad_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features); + +int attention_weight_computation_grad_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features); + +int attention_weight_computation_grad_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, + at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features); + + +void attention_value_computation_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *attn_weight, const float* value_features, + float *output); + + +int attention_value_computation_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor output); + +int attention_value_computation_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor output); + + +void attention_value_computation_grad_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *attn_weight, const float* value_features, + float *grad_out, float* grad_attn_weight, float* grad_value_features); + + +int attention_value_computation_grad_wrapper_v2( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features); + + +int attention_value_computation_grad_wrapper_v2_half( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features); + + +void attention_weight_computation_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *query_features, const at::BFloat16* key_features, + at::BFloat16 *output); + + +void attention_weight_computation_grad_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *query_features, const at::BFloat16* key_features, + at::BFloat16 *grad_out, at::BFloat16* grad_query_features, at::BFloat16* grad_key_features); + + +void attention_value_computation_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *attn_weight, const at::BFloat16* value_features, + at::BFloat16 *output); + +void attention_value_computation_grad_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *attn_weight, const at::BFloat16* value_features, + at::BFloat16 *grad_out, at::BFloat16* grad_attn_weight, at::BFloat16* grad_value_features); + + + +int attention_value_computation_grad_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features); + + +void attention_weight_computation_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *query_features, const at::Half* key_features, + at::Half *output); + + +void attention_weight_computation_grad_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *query_features, const at::Half* key_features, + at::Half *grad_out, at::Half* grad_query_features, at::Half* grad_key_features); + + +void attention_value_computation_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *attn_weight, const at::Half* value_features, + at::Half *output); + +void attention_value_computation_grad_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *attn_weight, const at::Half* value_features, + at::Half *grad_out, at::Half* grad_attn_weight, at::Half* grad_value_features); + + +int attention_value_computation_wrapper_v2_fp16( + int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, + at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, + at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, + at::Tensor output); + +#endif \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2.cu b/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8277e901a00c557fe15b8342e7e9b77585eb558 --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2.cu @@ -0,0 +1,292 @@ +/* +Transformer function helper function. +Written by tomztyang, +2021/08/23 +*/ + +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +// #define DEBUG + + +template +__global__ void attention_value_computation_forward_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *attn_weight, const float* value_features, + float *output) { + // dim3 blocks(total_query_num, nhead); dim3 threads(hdim); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int hdim_idx = threadIdx.x; + if (query_idx >= total_query_num || + head_idx >= nhead || + hdim_idx >= hdim) return; + + // get key_start_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + int cur_key_idx; + + // get shared variables. + __shared__ float shared_attn_weight[d]; // d == local_size + __shared__ int shared_value_indices[d]; + for (int i = hdim_idx; i < local_size; i += blockDim.x){ + shared_attn_weight[i] = attn_weight[ + query_idx * local_size * nhead + i * nhead + head_idx]; + + cur_key_idx = index_pair[query_idx * local_size + i]; + if (cur_key_idx == -1){ + shared_value_indices[i] = -1; + continue; + } + cur_key_idx += key_start_idx; + shared_value_indices[i] = cur_key_idx; + } + __syncthreads(); + + output += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; + + float attn_result = 0; + for (int i = 0; i < local_size; i++){ + if (shared_value_indices[i] == -1) continue; + attn_result += shared_attn_weight[i] * value_features[ + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx]; + } + output[0] = attn_result; +} + + +void attention_value_computation_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *attn_weight, const float* value_features, + float *output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + dim3 blocks(total_query_num, nhead); + dim3 threads(hdim); + if (local_size > 512){ + throw "local_size should be <= 512."; + } + + switch (local_size){ + case 16: + attention_value_computation_forward_v2<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 32: + attention_value_computation_forward_v2<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 64: + attention_value_computation_forward_v2<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 128: + attention_value_computation_forward_v2<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 320: + attention_value_computation_forward_v2<320><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 384: + attention_value_computation_forward_v2<384><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + default: + attention_value_computation_forward_v2<512><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + } +} + + +template // d == local_size +__global__ void attention_value_computation_backward_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *attn_weight, const float* value_features, + float *grad_out, float * grad_attn_weight, float * grad_value_features) { + // dim3 blocks(total_query_num, nhead); dim3 threads(hdim); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int hdim_idx = threadIdx.x; + if (query_idx >= total_query_num || + head_idx >= nhead || + hdim_idx >= hdim) return; + + // get key_start_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + int cur_key_idx; + + // get shared variables. + __shared__ float shared_attn_weight[d], shared_grad_attn_weight[d]; // d == local_size + __shared__ int shared_value_indices[d]; + for (int i = hdim_idx; i < local_size; i += blockDim.x){ + shared_attn_weight[i] = attn_weight[ + query_idx * local_size * nhead + i * nhead + head_idx]; + shared_grad_attn_weight[i] = 0; + + cur_key_idx = index_pair[query_idx * local_size + i]; + if (cur_key_idx == -1){ + shared_value_indices[i] = -1; + continue; + } + cur_key_idx += key_start_idx; + shared_value_indices[i] = cur_key_idx; + } + __syncthreads(); + + float gradient = grad_out[query_idx * nhead * hdim + head_idx * hdim + hdim_idx]; + for (int i = 0; i < local_size; i++){ + if (shared_value_indices[i] == -1) continue; + atomicAdd( + shared_grad_attn_weight + i, + gradient * value_features[shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx]); + atomicAdd( + grad_value_features + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx, + gradient * shared_attn_weight[i]); + } + __syncthreads(); + + for (int i = hdim_idx; i < local_size; i+=blockDim.x){ + grad_attn_weight[query_idx * local_size * nhead + i * nhead + head_idx] = shared_grad_attn_weight[i]; + } +} + + +void attention_value_computation_grad_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *attn_weight, const float* value_features, + float *grad_out, float* grad_attn_weight, float* grad_value_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + dim3 blocks(total_query_num, nhead); + dim3 threads(hdim); + if (local_size > 512){ + throw "local_size should be <= 512."; + } + + switch(local_size){ + case 16: + attention_value_computation_backward_v2<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 32: + attention_value_computation_backward_v2<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 64: + attention_value_computation_backward_v2<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 128: + attention_value_computation_backward_v2<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 320: + attention_value_computation_backward_v2<320><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 384: + attention_value_computation_backward_v2<384><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + default: + attention_value_computation_backward_v2<512><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + } +} \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2_half.cu b/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2_half.cu new file mode 100644 index 0000000000000000000000000000000000000000..08467bdb5033b6d39e3d1d2115dbb1c8c97eb09e --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2_half.cu @@ -0,0 +1,315 @@ +/* +Transformer function helper function. +Written by tomztyang, +2021/08/23 +*/ + +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +// #define DEBUG + + +template +__global__ void attention_value_computation_forward_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *attn_weight, const at::BFloat16* value_features, + at::BFloat16 *output) { + // dim3 blocks(total_query_num, nhead); dim3 threads(hdim); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int hdim_idx = threadIdx.x; + if (query_idx >= total_query_num || + head_idx >= nhead || + hdim_idx >= hdim) return; + + // get key_start_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + int cur_key_idx; + + // get shared variables. + __shared__ at::BFloat16 shared_attn_weight[d]; // d == local_size + __shared__ int shared_value_indices[d]; + for (int i = hdim_idx; i < local_size; i += blockDim.x){ + shared_attn_weight[i] = attn_weight[ + query_idx * local_size * nhead + i * nhead + head_idx]; + + cur_key_idx = index_pair[query_idx * local_size + i]; + if (cur_key_idx == -1){ + shared_value_indices[i] = -1; + continue; + } + cur_key_idx += key_start_idx; + shared_value_indices[i] = cur_key_idx; + } + __syncthreads(); + + output += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; + + at::BFloat16 attn_result = __int2bfloat16_rn(0); + for (int i = 0; i < local_size; i++){ + if (shared_value_indices[i] == -1) continue; + attn_result += shared_attn_weight[i] * value_features[ + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx]; + } + output[0] = attn_result; +} + + +void attention_value_computation_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *attn_weight, const at::BFloat16* value_features, + at::BFloat16 *output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + dim3 blocks(total_query_num, nhead); + dim3 threads(hdim); + if (local_size > 512){ + throw "local_size should be <= 512."; + } + + switch (local_size){ + case 16: + attention_value_computation_forward_v2_half<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 32: + attention_value_computation_forward_v2_half<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 64: + attention_value_computation_forward_v2_half<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 128: + attention_value_computation_forward_v2_half<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 320: + attention_value_computation_forward_v2_half<320><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 384: + attention_value_computation_forward_v2_half<384><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + default: + attention_value_computation_forward_v2_half<512><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + } +} + + +template // d == local_size +__global__ void attention_value_computation_backward_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *attn_weight, const at::BFloat16* value_features, + at::BFloat16 *grad_out, at::BFloat16 * grad_attn_weight, at::BFloat16 * grad_value_features) { + // dim3 blocks(total_query_num, nhead); dim3 threads(hdim); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int hdim_idx = threadIdx.x; + if (query_idx >= total_query_num || + head_idx >= nhead || + hdim_idx >= hdim) return; + + // get key_start_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + int cur_key_idx; + + // get shared variables. + __shared__ at::BFloat16 shared_attn_weight[d], shared_grad_attn_weight[d]; // d == local_size + __shared__ int shared_value_indices[d]; + for (int i = hdim_idx; i < local_size; i += blockDim.x){ + shared_attn_weight[i] = attn_weight[ + query_idx * local_size * nhead + i * nhead + head_idx]; + shared_grad_attn_weight[i] = __int2bfloat16_rn(0); + + cur_key_idx = index_pair[query_idx * local_size + i]; + if (cur_key_idx == -1){ + shared_value_indices[i] = -1; + continue; + } + cur_key_idx += key_start_idx; + shared_value_indices[i] = cur_key_idx; + } + __syncthreads(); + + at::BFloat16 gradient = grad_out[query_idx * nhead * hdim + head_idx * hdim + hdim_idx]; + for (int i = 0; i < local_size; i++){ + if (shared_value_indices[i] == -1) continue; + +// atomicAdd( +// shared_grad_attn_weight + i, +// __hmul(gradient, value_features[shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx]) +// ); +// atomicAdd( +// grad_value_features + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx, +// __hmul(gradient, shared_attn_weight[i]) +// ); + + // PZH NOTE: AtomicAdd is extremely slow for FP16 or BF16. We can accelerate code by: + // PZH NOTE: According to https://github.com/BBuf/how-to-optim-algorithm-in-cuda + at::native::fastAtomicAdd( + shared_grad_attn_weight + i, + 0, 0, + gradient * value_features[shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx], + true + ); + at::native::fastAtomicAdd( + grad_value_features + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx, + 0, 0, + gradient * shared_attn_weight[i], + true + ); + +// fastAtomicAdd(gradInput.data(), index, gradInput_numel, val, true); + } + __syncthreads(); + + for (int i = hdim_idx; i < local_size; i+=blockDim.x){ + grad_attn_weight[query_idx * local_size * nhead + i * nhead + head_idx] = shared_grad_attn_weight[i]; + } +} + + +void attention_value_computation_grad_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *attn_weight, const at::BFloat16* value_features, + at::BFloat16 *grad_out, at::BFloat16* grad_attn_weight, at::BFloat16* grad_value_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + dim3 blocks(total_query_num, nhead); + dim3 threads(hdim); + if (local_size > 512){ + throw "local_size should be <= 512."; + } + + switch(local_size){ + case 16: + attention_value_computation_backward_v2_half<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 32: + attention_value_computation_backward_v2_half<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 64: + attention_value_computation_backward_v2_half<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 128: + attention_value_computation_backward_v2_half<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 320: + attention_value_computation_backward_v2_half<320><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 384: + attention_value_computation_backward_v2_half<384><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + default: + attention_value_computation_backward_v2_half<512><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + } +} diff --git a/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2_half_fp16.cu b/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2_half_fp16.cu new file mode 100644 index 0000000000000000000000000000000000000000..a15c2a42ebf89ba184168db263190a06870161db --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_value_computation_kernel_v2_half_fp16.cu @@ -0,0 +1,315 @@ +/* +Transformer function helper function. +Written by tomztyang, +2021/08/23 +*/ + +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +// #define DEBUG + + +template +__global__ void attention_value_computation_forward_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *attn_weight, const at::Half* value_features, + at::Half *output) { + // dim3 blocks(total_query_num, nhead); dim3 threads(hdim); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int hdim_idx = threadIdx.x; + if (query_idx >= total_query_num || + head_idx >= nhead || + hdim_idx >= hdim) return; + + // get key_start_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + int cur_key_idx; + + // get shared variables. + __shared__ at::Half shared_attn_weight[d]; // d == local_size + __shared__ int shared_value_indices[d]; + for (int i = hdim_idx; i < local_size; i += blockDim.x){ + shared_attn_weight[i] = attn_weight[ + query_idx * local_size * nhead + i * nhead + head_idx]; + + cur_key_idx = index_pair[query_idx * local_size + i]; + if (cur_key_idx == -1){ + shared_value_indices[i] = -1; + continue; + } + cur_key_idx += key_start_idx; + shared_value_indices[i] = cur_key_idx; + } + __syncthreads(); + + output += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; + + at::Half attn_result = __int2half_rn(0); + for (int i = 0; i < local_size; i++){ + if (shared_value_indices[i] == -1) continue; + attn_result = __hadd(attn_result, __hmul(shared_attn_weight[i], value_features[ + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx])); + } + output[0] = attn_result; +} + + +void attention_value_computation_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *attn_weight, const at::Half* value_features, + at::Half *output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, nhead, hdim] + dim3 blocks(total_query_num, nhead); + dim3 threads(hdim); + if (local_size > 512){ + throw "local_size should be <= 512."; + } + + switch (local_size){ + case 16: + attention_value_computation_forward_v2_fp16<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 32: + attention_value_computation_forward_v2_fp16<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 64: + attention_value_computation_forward_v2_fp16<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 128: + attention_value_computation_forward_v2_fp16<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 320: + attention_value_computation_forward_v2_fp16<320><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + case 384: + attention_value_computation_forward_v2_fp16<384><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + default: + attention_value_computation_forward_v2_fp16<512><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + output); + break; + } +} + + +template // d == local_size +__global__ void attention_value_computation_backward_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *attn_weight, const at::Half* value_features, + at::Half *grad_out, at::Half * grad_attn_weight, at::Half * grad_value_features) { + // dim3 blocks(total_query_num, nhead); dim3 threads(hdim); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int hdim_idx = threadIdx.x; + if (query_idx >= total_query_num || + head_idx >= nhead || + hdim_idx >= hdim) return; + + // get key_start_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + int cur_key_idx; + + // get shared variables. + __shared__ at::Half shared_attn_weight[d], shared_grad_attn_weight[d]; // d == local_size + __shared__ int shared_value_indices[d]; + for (int i = hdim_idx; i < local_size; i += blockDim.x){ + shared_attn_weight[i] = attn_weight[ + query_idx * local_size * nhead + i * nhead + head_idx]; + shared_grad_attn_weight[i] = __int2half_rn(0); + + cur_key_idx = index_pair[query_idx * local_size + i]; + if (cur_key_idx == -1){ + shared_value_indices[i] = -1; + continue; + } + cur_key_idx += key_start_idx; + shared_value_indices[i] = cur_key_idx; + } + __syncthreads(); + + at::Half gradient = grad_out[query_idx * nhead * hdim + head_idx * hdim + hdim_idx]; + for (int i = 0; i < local_size; i++){ + if (shared_value_indices[i] == -1) continue; + +// atomicAdd( +// shared_grad_attn_weight + i, +// __hmul(gradient, value_features[shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx]) +// ); +// atomicAdd( +// grad_value_features + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx, +// __hmul(gradient, shared_attn_weight[i]) +// ); + + // PZH NOTE: AtomicAdd is extremely slow for FP16 or BF16. We can accelerate code by: + // PZH NOTE: According to https://github.com/BBuf/how-to-optim-algorithm-in-cuda + at::native::fastAtomicAdd( + shared_grad_attn_weight + i, + 0, 0, + gradient * value_features[shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx], + true + ); + at::native::fastAtomicAdd( + grad_value_features + shared_value_indices[i] * nhead * hdim + head_idx * hdim + hdim_idx, + 0, 0, + gradient * shared_attn_weight[i], + true + ); + +// fastAtomicAdd(gradInput.data(), index, gradInput_numel, val, true); + } + __syncthreads(); + + for (int i = hdim_idx; i < local_size; i+=blockDim.x){ + grad_attn_weight[query_idx * local_size * nhead + i * nhead + head_idx] = shared_grad_attn_weight[i]; + } +} + + +void attention_value_computation_grad_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *attn_weight, const at::Half* value_features, + at::Half *grad_out, at::Half* grad_attn_weight, at::Half* grad_value_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params attn_weight: [total_query_num, local_size, nhead] + // params value_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, nhead, hdim] + // params grad_attn_weight: [total_query_num, local_size, nhead] + // params grad_value_features: [total_key_num, nhead, hdim] + dim3 blocks(total_query_num, nhead); + dim3 threads(hdim); + if (local_size > 512){ + throw "local_size should be <= 512."; + } + + switch(local_size){ + case 16: + attention_value_computation_backward_v2_fp16<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 32: + attention_value_computation_backward_v2_fp16<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 64: + attention_value_computation_backward_v2_fp16<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 128: + attention_value_computation_backward_v2_fp16<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 320: + attention_value_computation_backward_v2_fp16<320><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + case 384: + attention_value_computation_backward_v2_fp16<384><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + default: + attention_value_computation_backward_v2_fp16<512><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, attn_weight, value_features, + grad_out, grad_attn_weight, grad_value_features); + break; + } +} diff --git a/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2.cu b/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2.cu new file mode 100644 index 0000000000000000000000000000000000000000..d411fd466964fbba76a99ff80bd94c0bcadec80d --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2.cu @@ -0,0 +1,290 @@ +/* +Transformer function helper function. +Written by tomztyang, +2021/08/23 +*/ + +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +// #define DEBUG + + +template +__global__ void attention_weight_computation_forward_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *query_features, const float* key_features, + float *output) { + // dim3 blocks(total_query_num, nhead); dim3 threads(local_size); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int local_key_idx = threadIdx.x; + + int index = query_idx * local_size + local_key_idx; + if (query_idx >= total_query_num || + head_idx >= nhead || + local_key_idx >= local_size) return; + + // build shared query features. + __shared__ float shared_query_features[d]; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + shared_query_features[i] = query_features[ + query_idx * nhead * hdim + head_idx * hdim + i]; + } + __syncthreads(); + + if (index_pair[index] == -1){ + // Ignore index. + return; + } + + // get real key_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + key_start_idx += index_pair[index]; + + // get key features. + key_features += key_start_idx * nhead * hdim + head_idx * hdim; + output += index * nhead + head_idx; + + float attn_weight = 0; + for (int i = 0; i < hdim; i++){ + attn_weight += key_features[i] * shared_query_features[i]; + } + output[0] = attn_weight; +} + + +void attention_weight_computation_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *query_features, const float* key_features, + float *output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + if (hdim > 150){ + throw "hdim should be <= 150."; + } + + dim3 blocks(total_query_num, nhead); + dim3 threads(local_size); + switch(hdim){ // switch hdim for utilizing different shared vectors. + case 16: + attention_weight_computation_forward_v2<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 24: + attention_weight_computation_forward_v2<24><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 32: + attention_weight_computation_forward_v2<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 48: + attention_weight_computation_forward_v2<48><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 64: + attention_weight_computation_forward_v2<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 128: + attention_weight_computation_forward_v2<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + default: + attention_weight_computation_forward_v2<150><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + } +} + + +template +__global__ void attention_weight_computation_backward_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *query_features, const float* key_features, + float *grad_out, float * grad_query_features, float * grad_key_features) { + // dim3 blocks(total_query_num, nhead); dim3 threads(local_size); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int local_key_idx = threadIdx.x; + int index = query_idx * local_size + local_key_idx; + + if (query_idx >= total_query_num || + head_idx >= nhead || + local_key_idx >= local_size) return; + + // build shared query features. + __shared__ float shared_query_features[d]; + __shared__ float shared_grad_query_features[d]; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + shared_query_features[i] = query_features[ + query_idx * nhead * hdim + head_idx * hdim + i]; + shared_grad_query_features[i] = 0; + } + __syncthreads(); + + if (index_pair[index] != -1){ + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + key_start_idx += index_pair[index]; + + key_features += key_start_idx * nhead * hdim + head_idx * hdim; + grad_key_features += key_start_idx * nhead * hdim + head_idx * hdim; + + float gradient = grad_out[index * nhead + head_idx]; + for (int i = 0; i < hdim; i++){ + atomicAdd( + shared_grad_query_features + i, + gradient * key_features[i]); + atomicAdd( + grad_key_features + i, + gradient * shared_query_features[i]); + } + } + __syncthreads(); + + grad_query_features += query_idx * nhead * hdim + head_idx * hdim; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + grad_query_features[i] = shared_grad_query_features[i]; + } +} + + +void attention_weight_computation_grad_launcher_v2( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const float *query_features, const float* key_features, + float *grad_out, float* grad_query_features, float* grad_key_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + if (hdim > 150){ + throw "hdim should be <= 150."; + } + + dim3 blocks(total_query_num, nhead); + dim3 threads(local_size); + + switch(hdim){ + case 16: + attention_weight_computation_backward_v2<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 24: + attention_weight_computation_backward_v2<24><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 32: + attention_weight_computation_backward_v2<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 48: + attention_weight_computation_backward_v2<48><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 64: + attention_weight_computation_backward_v2<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 128: + attention_weight_computation_backward_v2<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + default: + attention_weight_computation_backward_v2<150><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + } +} \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2_half.cu b/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2_half.cu new file mode 100644 index 0000000000000000000000000000000000000000..c7179d7ce213559a495cf01856daa27c0297c7b0 --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2_half.cu @@ -0,0 +1,300 @@ +/* +Transformer function helper function. +Written by tomztyang, +2021/08/23 +*/ + +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +// #define DEBUG + + +template +__global__ void attention_weight_computation_forward_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *query_features, const at::BFloat16* key_features, + at::BFloat16 *output) { + // dim3 blocks(total_query_num, nhead); dim3 threads(local_size); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int local_key_idx = threadIdx.x; + + int index = query_idx * local_size + local_key_idx; + if (query_idx >= total_query_num || + head_idx >= nhead || + local_key_idx >= local_size) return; + + // build shared query features. + __shared__ at::BFloat16 shared_query_features[d]; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + shared_query_features[i] = query_features[ + query_idx * nhead * hdim + head_idx * hdim + i]; + } + __syncthreads(); + + if (index_pair[index] == -1){ + // Ignore index. + return; + } + + // get real key_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + key_start_idx += index_pair[index]; + + // get key features. + key_features += key_start_idx * nhead * hdim + head_idx * hdim; + output += index * nhead + head_idx; + + at::BFloat16 attn_weight = __int2bfloat16_rn(0); + for (int i = 0; i < hdim; i++){ + attn_weight += key_features[i] * shared_query_features[i]; + } + output[0] = attn_weight; +} + + +void attention_weight_computation_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *query_features, const at::BFloat16* key_features, + at::BFloat16 *output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + if (hdim > 150){ + throw "hdim should be <= 150."; + } + + dim3 blocks(total_query_num, nhead); + dim3 threads(local_size); + switch(hdim){ // switch hdim for utilizing different shared vectors. + case 16: + attention_weight_computation_forward_v2_half<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 24: + attention_weight_computation_forward_v2_half<24><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 32: + attention_weight_computation_forward_v2_half<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 48: + attention_weight_computation_forward_v2_half<48><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 64: + attention_weight_computation_forward_v2_half<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 128: + attention_weight_computation_forward_v2_half<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + default: + attention_weight_computation_forward_v2_half<150><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + } +} + + +template +__global__ void attention_weight_computation_backward_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *query_features, const at::BFloat16* key_features, + at::BFloat16 *grad_out, at::BFloat16 * grad_query_features, at::BFloat16 * grad_key_features) { + // dim3 blocks(total_query_num, nhead); dim3 threads(local_size); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int local_key_idx = threadIdx.x; + int index = query_idx * local_size + local_key_idx; + + if (query_idx >= total_query_num || + head_idx >= nhead || + local_key_idx >= local_size) return; + + // build shared query features. + __shared__ at::BFloat16 shared_query_features[d]; + __shared__ at::BFloat16 shared_grad_query_features[d]; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + shared_query_features[i] = query_features[ + query_idx * nhead * hdim + head_idx * hdim + i]; + shared_grad_query_features[i] = __int2bfloat16_rn(0); + } + __syncthreads(); + + if (index_pair[index] != -1){ + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + key_start_idx += index_pair[index]; + + key_features += key_start_idx * nhead * hdim + head_idx * hdim; + grad_key_features += key_start_idx * nhead * hdim + head_idx * hdim; + + at::BFloat16 gradient = grad_out[index * nhead + head_idx]; + for (int i = 0; i < hdim; i++){ +// atomicAdd( +// shared_grad_query_features + i, +// gradient * key_features[i]); +// atomicAdd( +// grad_key_features + i, +// gradient * shared_query_features[i]); + at::native::fastAtomicAdd( + shared_grad_query_features + i, + 0, 0, + gradient * key_features[i], true); + at::native::fastAtomicAdd( + grad_key_features + i, 0, 0, + gradient * shared_query_features[i], true); + } + } + __syncthreads(); + + grad_query_features += query_idx * nhead * hdim + head_idx * hdim; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + grad_query_features[i] = shared_grad_query_features[i]; + } +} + + +void attention_weight_computation_grad_launcher_v2_half( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::BFloat16 *query_features, const at::BFloat16* key_features, + at::BFloat16 *grad_out, at::BFloat16* grad_query_features, at::BFloat16* grad_key_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + if (hdim > 150){ + throw "hdim should be <= 150."; + } + + dim3 blocks(total_query_num, nhead); + dim3 threads(local_size); + + switch(hdim){ + case 16: + attention_weight_computation_backward_v2_half<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 24: + attention_weight_computation_backward_v2_half<24><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 32: + attention_weight_computation_backward_v2_half<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 48: + attention_weight_computation_backward_v2_half<48><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 64: + attention_weight_computation_backward_v2_half<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 128: + attention_weight_computation_backward_v2_half<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + default: + attention_weight_computation_backward_v2_half<150><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + } +} \ No newline at end of file diff --git a/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2_half_fp16.cu b/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2_half_fp16.cu new file mode 100644 index 0000000000000000000000000000000000000000..6848d1c76efcc9207cab6e624cc46d25a46fa125 --- /dev/null +++ b/scenestreamer/models/ops/attention/src/attention_weight_computation_kernel_v2_half_fp16.cu @@ -0,0 +1,300 @@ +/* +Transformer function helper function. +Written by tomztyang, +2021/08/23 +*/ + +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +// #define DEBUG + + +template +__global__ void attention_weight_computation_forward_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *query_features, const at::Half* key_features, + at::Half *output) { + // dim3 blocks(total_query_num, nhead); dim3 threads(local_size); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int local_key_idx = threadIdx.x; + + int index = query_idx * local_size + local_key_idx; + if (query_idx >= total_query_num || + head_idx >= nhead || + local_key_idx >= local_size) return; + + // build shared query features. + __shared__ at::Half shared_query_features[d]; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + shared_query_features[i] = query_features[ + query_idx * nhead * hdim + head_idx * hdim + i]; + } + __syncthreads(); + + if (index_pair[index] == -1){ + // Ignore index. + return; + } + + // get real key_idx. + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + key_start_idx += index_pair[index]; + + // get key features. + key_features += key_start_idx * nhead * hdim + head_idx * hdim; + output += index * nhead + head_idx; + + at::Half attn_weight = __int2half_rn(0); + for (int i = 0; i < hdim; i++){ + attn_weight = __hadd(attn_weight, __hmul(key_features[i], shared_query_features[i])); + } + output[0] = attn_weight; +} + + +void attention_weight_computation_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *query_features, const at::Half* key_features, + at::Half *output){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params output: [total_query_num, local_size, nhead] + if (hdim > 150){ + throw "hdim should be <= 150."; + } + + dim3 blocks(total_query_num, nhead); + dim3 threads(local_size); + switch(hdim){ // switch hdim for utilizing different shared vectors. + case 16: + attention_weight_computation_forward_v2_fp16<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 24: + attention_weight_computation_forward_v2_fp16<24><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 32: + attention_weight_computation_forward_v2_fp16<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 48: + attention_weight_computation_forward_v2_fp16<48><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 64: + attention_weight_computation_forward_v2_fp16<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + case 128: + attention_weight_computation_forward_v2_fp16<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + default: + attention_weight_computation_forward_v2_fp16<150><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + output); + break; + } +} + + +template +__global__ void attention_weight_computation_backward_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *query_features, const at::Half* key_features, + at::Half *grad_out, at::Half * grad_query_features, at::Half * grad_key_features) { + // dim3 blocks(total_query_num, nhead); dim3 threads(local_size); + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + + int query_idx = blockIdx.x; + int head_idx = blockIdx.y; + int local_key_idx = threadIdx.x; + int index = query_idx * local_size + local_key_idx; + + if (query_idx >= total_query_num || + head_idx >= nhead || + local_key_idx >= local_size) return; + + // build shared query features. + __shared__ at::Half shared_query_features[d]; + __shared__ at::Half shared_grad_query_features[d]; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + shared_query_features[i] = query_features[ + query_idx * nhead * hdim + head_idx * hdim + i]; + shared_grad_query_features[i] = __int2half_rn(0); + } + __syncthreads(); + + if (index_pair[index] != -1){ + int batch_idx = index_pair_batch[query_idx]; + int key_start_idx = 0; + for (int i = 0; i < batch_idx; i++){ + key_start_idx += key_batch_cnt[i]; + } + key_start_idx += index_pair[index]; + + key_features += key_start_idx * nhead * hdim + head_idx * hdim; + grad_key_features += key_start_idx * nhead * hdim + head_idx * hdim; + + at::Half gradient = grad_out[index * nhead + head_idx]; + for (int i = 0; i < hdim; i++){ +// atomicAdd( +// shared_grad_query_features + i, +// gradient * key_features[i]); +// atomicAdd( +// grad_key_features + i, +// gradient * shared_query_features[i]); + at::native::fastAtomicAdd( + shared_grad_query_features + i, + 0, 0, + gradient * key_features[i], true); + at::native::fastAtomicAdd( + grad_key_features + i, 0, 0, + gradient * shared_query_features[i], true); + } + } + __syncthreads(); + + grad_query_features += query_idx * nhead * hdim + head_idx * hdim; + for (int i = local_key_idx; i < hdim; i += blockDim.x){ + grad_query_features[i] = shared_grad_query_features[i]; + } +} + + +void attention_weight_computation_grad_launcher_v2_fp16( + int b, int total_query_num, int local_size, + int total_key_num, int nhead, int hdim, + const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, + const int *index_pair, + const at::Half *query_features, const at::Half* key_features, + at::Half *grad_out, at::Half* grad_query_features, at::Half* grad_key_features){ + // params query_batch_cnt: [b] + // params key_batch_cnt: [b] + // params index_pair_batch: [total_query_num] + // params index_pair: [total_query_num, local_size] + // params query_features: [total_query_num, nhead, hdim] + // params key_features: [total_key_num, nhead, hdim] + // params grad_out: [total_query_num, local_size, nhead] + // params grad_query_features: [total_query_num, nhead, hdim] + // params grad_key_features: [total_key_num, nhead, hdim] + if (hdim > 150){ + throw "hdim should be <= 150."; + } + + dim3 blocks(total_query_num, nhead); + dim3 threads(local_size); + + switch(hdim){ + case 16: + attention_weight_computation_backward_v2_fp16<16><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 24: + attention_weight_computation_backward_v2_fp16<24><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 32: + attention_weight_computation_backward_v2_fp16<32><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 48: + attention_weight_computation_backward_v2_fp16<48><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 64: + attention_weight_computation_backward_v2_fp16<64><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + case 128: + attention_weight_computation_backward_v2_fp16<128><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + default: + attention_weight_computation_backward_v2_fp16<150><<>>( + b, total_query_num, local_size, total_key_num, nhead, hdim, + query_batch_cnt, key_batch_cnt, index_pair_batch, + index_pair, query_features, key_features, + grad_out, grad_query_features, grad_key_features); + break; + } +} \ No newline at end of file diff --git a/scenestreamer/models/ops/collapse_time.py b/scenestreamer/models/ops/collapse_time.py new file mode 100644 index 0000000000000000000000000000000000000000..46d5e5b56a334d2c313b11d28b6889dab652a1b3 --- /dev/null +++ b/scenestreamer/models/ops/collapse_time.py @@ -0,0 +1,7 @@ +def collapse_time(tensor): + if tensor.ndim == 4: + B, T, N, D = tensor.shape + tensor = tensor.swapaxes(1, 2).reshape(B, N, T * D) + else: + raise ValueError(f"Unknown tensor shape: {tensor.shape}") + return tensor diff --git a/scenestreamer/models/ops/knn/__init__.py b/scenestreamer/models/ops/knn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scenestreamer/models/ops/knn/knn_utils.py b/scenestreamer/models/ops/knn/knn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e99730ddb9951d4cbed23bb562b11af2f19127c8 --- /dev/null +++ b/scenestreamer/models/ops/knn/knn_utils.py @@ -0,0 +1,102 @@ +# Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 +# Published at NeurIPS 2022 +# Written by Li Jiang, Shaoshuai Shi +# All Rights Reserved + +import torch +from torch.autograd import Function + +from . import knn_cuda + + +class KNNBatch(Function): + @staticmethod + def forward(ctx, xyz, query_xyz, batch_idxs, query_batch_offsets, k): + ''' + :param ctx: + :param xyz: (n, 3) float + :param query_xyz: (m, 3), float + :param batch_idxs: (n) int + :param query_batch_offsets: (B+1) int, offsets[-1] = m + :param k: int + :return: idx (n, k) + ''' + + n = xyz.size(0) + m = query_xyz.size(0) + assert k <= m + assert xyz.is_contiguous() and xyz.is_cuda, (xyz.is_contiguous(), xyz.is_cuda) + assert query_xyz.is_contiguous() and query_xyz.is_cuda, (query_xyz.is_contiguous(), query_xyz.is_cuda) + assert batch_idxs.is_contiguous() and batch_idxs.is_cuda, (batch_idxs.is_contiguous(), batch_idxs.is_cuda) + assert query_batch_offsets.is_contiguous() and query_batch_offsets.is_cuda, \ + (query_batch_offsets.is_contiguous(), query_batch_offsets.is_cuda) + + # idx = torch.cuda.IntTensor(n, k).zero_() + idx = torch.zeros([n, k], device=xyz.device, dtype=torch.int) + + knn_cuda.knn_batch(xyz, query_xyz, batch_idxs, query_batch_offsets, idx, n, m, k) + + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None, None + + +knn_batch = KNNBatch.apply + + +class KNNBatchMlogK(Function): + @staticmethod + def forward(ctx, xyz, query_xyz, batch_idxs, query_batch_offsets, k): + ''' + :param ctx: + :param xyz: (n, 3) float + :param query_xyz: (m, 3), float + :param batch_idxs: (n) int + :param query_batch_offsets: (B+1) int, offsets[-1] = m + :param k: int + :return: idx (n, k) + ''' + assert xyz.shape[-1] == 3 + assert query_xyz.shape[-1] == 3 + + n = xyz.size(0) + m = query_xyz.size(0) + # assert k <= m + assert xyz.is_contiguous() and xyz.is_cuda, (xyz.is_contiguous(), xyz.is_cuda) + assert query_xyz.is_contiguous() and query_xyz.is_cuda, (query_xyz.is_contiguous(), query_xyz.is_cuda) + assert batch_idxs.is_contiguous() and batch_idxs.is_cuda, (batch_idxs.is_contiguous(), batch_idxs.is_cuda) + assert query_batch_offsets.is_contiguous() and query_batch_offsets.is_cuda, \ + (query_batch_offsets.is_contiguous(), query_batch_offsets.is_cuda) + assert k <= 128 + + assert query_batch_offsets.max() == query_batch_offsets[-1] == query_xyz.shape[0] + assert query_batch_offsets[0] == 0 + assert query_batch_offsets.shape[0] == batch_idxs.max() + 1 + 1 + assert batch_idxs.shape[0] == n + + # idx = torch.cuda.IntTensor(n, k).zero_() + idx = torch.zeros([n, k], device=xyz.device, dtype=torch.int) + + # half_precision = int(query_xyz.dtype == torch.bfloat16) + + if query_xyz.dtype == torch.bfloat16: + knn_cuda.knn_batch_mlogk( + xyz.type(torch.float32), query_xyz.type(torch.float32), batch_idxs, query_batch_offsets, idx, n, m, k + ) + elif query_xyz.dtype == torch.float16: + knn_cuda.knn_batch_mlogk( + xyz.type(torch.float32), query_xyz.type(torch.float32), batch_idxs, query_batch_offsets, idx, n, m, k + ) + else: + knn_cuda.knn_batch_mlogk(xyz, query_xyz, batch_idxs, query_batch_offsets, idx, n, m, k) + + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None, None + + +knn_batch_mlogk = KNNBatchMlogK.apply diff --git a/scenestreamer/models/ops/knn/src/knn.cpp b/scenestreamer/models/ops/knn/src/knn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..03e9581d6133f45a83e64db89e480dc759b7f5fb --- /dev/null +++ b/scenestreamer/models/ops/knn/src/knn.cpp @@ -0,0 +1,68 @@ +// Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 +// Published at NeurIPS 2022 +// Written by Li Jiang, Shaoshuai Shi +// All Rights Reserved + + +#include "knn_gpu.h" + +// input xyz: (n, 3), float +// input query_xyz: (m, 3), float +// input batch_idxs: (n), int +// input query_batch_offsets: (B + 1), int, offsets[-1] = m +// output idx: (n, k), int +void knn_batch(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k){ + + const float *query_xyz = query_xyz_tensor.data_ptr(); + const float *xyz = xyz_tensor.data_ptr(); + const int *batch_idxs = batch_idxs_tensor.data_ptr(); + const int *query_batch_offsets = query_batch_offsets_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + knn_batch_cuda(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx, stream); +} + + +void knn_batch_mlogk(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k){ + + const float *query_xyz = query_xyz_tensor.data_ptr(); + const float *xyz = xyz_tensor.data_ptr(); + const int *batch_idxs = batch_idxs_tensor.data_ptr(); + const int *query_batch_offsets = query_batch_offsets_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + knn_batch_mlogk_cuda(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx, stream); +} + + + +void knn_batch_mlogk_half(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k){ + + const at::BFloat16 *query_xyz = query_xyz_tensor.data_ptr(); + const at::BFloat16 *xyz = xyz_tensor.data_ptr(); + const int *batch_idxs = batch_idxs_tensor.data_ptr(); + const int *query_batch_offsets = query_batch_offsets_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + knn_batch_mlogk_cuda_half(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx, stream); +} + + +void knn_batch_mlogk_half_fp16(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k){ + + const at::Half *query_xyz = query_xyz_tensor.data_ptr(); + const at::Half *xyz = xyz_tensor.data_ptr(); + const int *batch_idxs = batch_idxs_tensor.data_ptr(); + const int *query_batch_offsets = query_batch_offsets_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + knn_batch_mlogk_cuda_half_fp16(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx, stream); +} diff --git a/scenestreamer/models/ops/knn/src/knn_api.cpp b/scenestreamer/models/ops/knn/src/knn_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1851034bda4131138da3fe14909a7cfc98460708 --- /dev/null +++ b/scenestreamer/models/ops/knn/src/knn_api.cpp @@ -0,0 +1,17 @@ +// Motion Transformer (MTR): Motion Forecasting Transformer with Global Intention Localization and Local Movement Refinement +// Written by Shaoshuai Shi +// All Rights Reserved + + +#include +#include + +#include "knn_gpu.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("knn_batch", &knn_batch, "knn_batch"); + m.def("knn_batch_mlogk", &knn_batch_mlogk, "knn_batch_mlogk"); + m.def("knn_batch_mlogk_half", &knn_batch_mlogk_half, "knn_batch_mlogk_half"); + m.def("knn_batch_mlogk_half_fp16", &knn_batch_mlogk_half_fp16, "knn_batch_mlogk_half_fp16"); +} diff --git a/scenestreamer/models/ops/knn/src/knn_gpu.cu b/scenestreamer/models/ops/knn/src/knn_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..503cae012ee54df9256f76236793411214513533 --- /dev/null +++ b/scenestreamer/models/ops/knn/src/knn_gpu.cu @@ -0,0 +1,392 @@ +// Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 +// Published at NeurIPS 2022 +// Written by Li Jiang, Shaoshuai Shi +// All Rights Reserved + + +#include "knn_gpu.h" + +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) + +__global__ void knn_batch_cuda_(int n, int m, int k, const float *__restrict__ xyz, const float *__restrict__ query_xyz, const int *__restrict__ batch_idxs, const int *__restrict__ query_batch_offsets, int *__restrict__ idx) { + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= n) return; + + xyz += pt_idx * 3; + idx += pt_idx * k; + + float ox = xyz[0]; + float oy = xyz[1]; + float oz = xyz[2]; + + float best[100]; + int besti[100]; + for(int i = 0; i < k; i++){ + best[i] = 1e20; + besti[i] = -1; + } + + int batch_idx = batch_idxs[pt_idx]; + int start = query_batch_offsets[batch_idx]; + int end = query_batch_offsets[batch_idx + 1]; + + for (int i = start; i < end; ++i) { + float x = query_xyz[i * 3 + 0]; + float y = query_xyz[i * 3 + 1]; + float z = query_xyz[i * 3 + 2]; + float d2 = (ox - x) * (ox - x) + (oy - y) * (oy - y) + (oz - z) * (oz - z); + for(int p = 0; p < k; p++){ + if(d2 < best[p]){ + for(int q = k - 1; q > p; q--){ + best[q] = best[q - 1]; + besti[q] = besti[q - 1]; + } + best[p] = d2; + besti[p] = i - start; + break; + } + } + } + + for(int i = 0; i < k; i++){ + idx[i] = besti[i]; + } +} + + +__global__ void knn_batch_mlogk_cuda_(int n, int m, int k, const float *__restrict__ xyz, const float *__restrict__ query_xyz, const int *__restrict__ batch_idxs, const int *__restrict__ query_batch_offsets, int *__restrict__ idx) { + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= n) return; + + xyz += pt_idx * 3; + idx += pt_idx * k; + + float ox = xyz[0]; + float oy = xyz[1]; + float oz = xyz[2]; + + float best[150]; + int besti[150]; + + int heap_len = 0; + + for(int i = 0; i <= k; i++){ + best[i] = std::numeric_limits::infinity(); + besti[i] = -1; + } + + int batch_idx = batch_idxs[pt_idx]; + int start = query_batch_offsets[batch_idx]; + int end = query_batch_offsets[batch_idx + 1]; + int temp_i; + float temp_f; + + for (int i = start; i < end; ++i) { + float x = query_xyz[i * 3 + 0]; + float y = query_xyz[i * 3 + 1]; + float z = query_xyz[i * 3 + 2]; + float d2 = (ox - x) * (ox - x) + (oy - y) * (oy - y) + (oz - z) * (oz - z); + + if (heap_len < k){ + heap_len++; + best[heap_len] = d2; + besti[heap_len] = i - start; + int cur_idx = heap_len, fa_idx = cur_idx >> 1; + + while (fa_idx > 0){ + if (best[cur_idx] < best[fa_idx]) break; + + temp_i = besti[cur_idx]; besti[cur_idx] = besti[fa_idx]; besti[fa_idx] = temp_i; + temp_f = best[cur_idx]; best[cur_idx] = best[fa_idx]; best[fa_idx] = temp_f; + cur_idx = fa_idx; + fa_idx = cur_idx >> 1; + } + } + else{ + if (d2 > best[1]) continue; + best[1] = d2; besti[1] = i - start; + + int cur_idx = 1, son_idx; + while (cur_idx <= k){ + son_idx = cur_idx << 1; + if (son_idx > k) break; + if (son_idx + 1 <= k && best[son_idx] < best[son_idx + 1]){ + son_idx++; + } + + if (son_idx <= k && best[cur_idx] < best[son_idx]){ + temp_i = besti[cur_idx]; besti[cur_idx] = besti[son_idx]; besti[son_idx] = temp_i; + temp_f = best[cur_idx]; best[cur_idx] = best[son_idx]; best[son_idx] = temp_f; + } + else break; + cur_idx = son_idx; + } + } + } + + for(int i = 1; i <= k; i++){ + idx[i - 1] = besti[i]; + } + // delete [] best; + // delete [] besti; +} + + + +__global__ void knn_batch_mlogk_cuda_half_(int n, int m, int k, const at::BFloat16 *__restrict__ xyz, const at::BFloat16 *__restrict__ query_xyz, const int *__restrict__ batch_idxs, const int *__restrict__ query_batch_offsets, int *__restrict__ idx) { + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= n) return; + + xyz += pt_idx * 3; + idx += pt_idx * k; + + at::BFloat16 ox = xyz[0]; + at::BFloat16 oy = xyz[1]; + at::BFloat16 oz = xyz[2]; + + at::BFloat16 best[150]; + int besti[150]; + + int heap_len = 0; + + for(int i = 0; i <= k; i++){ + best[i] = __float2bfloat16(std::numeric_limits::infinity()); + besti[i] = -1; + } + + int batch_idx = batch_idxs[pt_idx]; + int start = query_batch_offsets[batch_idx]; + int end = query_batch_offsets[batch_idx + 1]; + int temp_i; + at::BFloat16 temp_f; + + for (int i = start; i < end; ++i) { + at::BFloat16 x = query_xyz[i * 3 + 0]; + at::BFloat16 y = query_xyz[i * 3 + 1]; + at::BFloat16 z = query_xyz[i * 3 + 2]; + at::BFloat16 d2 = (ox - x) * (ox - x) + (oy - y) * (oy - y) + (oz - z) * (oz - z); + + if (heap_len < k){ + heap_len++; + best[heap_len] = d2; + besti[heap_len] = i - start; + int cur_idx = heap_len, fa_idx = cur_idx >> 1; + + while (fa_idx > 0){ + if (best[cur_idx] < best[fa_idx]) break; + + temp_i = besti[cur_idx]; besti[cur_idx] = besti[fa_idx]; besti[fa_idx] = temp_i; + temp_f = best[cur_idx]; best[cur_idx] = best[fa_idx]; best[fa_idx] = temp_f; + cur_idx = fa_idx; + fa_idx = cur_idx >> 1; + } + } + else{ + if (d2 > best[1]) continue; + best[1] = d2; besti[1] = i - start; + + int cur_idx = 1, son_idx; + while (cur_idx <= k){ + son_idx = cur_idx << 1; + if (son_idx > k) break; + if (son_idx + 1 <= k && best[son_idx] < best[son_idx + 1]){ + son_idx++; + } + + if (son_idx <= k && best[cur_idx] < best[son_idx]){ + temp_i = besti[cur_idx]; besti[cur_idx] = besti[son_idx]; besti[son_idx] = temp_i; + temp_f = best[cur_idx]; best[cur_idx] = best[son_idx]; best[son_idx] = temp_f; + } + else break; + cur_idx = son_idx; + } + } + } + + for(int i = 1; i <= k; i++){ + idx[i - 1] = besti[i]; + } + // delete [] best; + // delete [] besti; +} + + + + +__global__ void knn_batch_mlogk_cuda_half_fp16_(int n, int m, int k, const at::Half *__restrict__ xyz, const at::Half *__restrict__ query_xyz, const int *__restrict__ batch_idxs, const int *__restrict__ query_batch_offsets, int *__restrict__ idx) { + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= n) return; + + xyz += pt_idx * 3; + idx += pt_idx * k; + + at::Half ox = xyz[0]; + at::Half oy = xyz[1]; + at::Half oz = xyz[2]; + + at::Half best[150]; + int besti[150]; + + int heap_len = 0; + + for(int i = 0; i <= k; i++){ + best[i] = __float2half(std::numeric_limits::infinity()); + besti[i] = -1; + } + + int batch_idx = batch_idxs[pt_idx]; + int start = query_batch_offsets[batch_idx]; + int end = query_batch_offsets[batch_idx + 1]; + int temp_i; + at::Half temp_f; + + for (int i = start; i < end; ++i) { + at::Half x = query_xyz[i * 3 + 0]; + at::Half y = query_xyz[i * 3 + 1]; + at::Half z = query_xyz[i * 3 + 2]; + at::Half d2 = (ox - x) * (ox - x) + (oy - y) * (oy - y) + (oz - z) * (oz - z); + + if (heap_len < k){ + heap_len++; + best[heap_len] = d2; + besti[heap_len] = i - start; + int cur_idx = heap_len, fa_idx = cur_idx >> 1; + + while (fa_idx > 0){ + if (best[cur_idx] < best[fa_idx]) break; + + temp_i = besti[cur_idx]; besti[cur_idx] = besti[fa_idx]; besti[fa_idx] = temp_i; + temp_f = best[cur_idx]; best[cur_idx] = best[fa_idx]; best[fa_idx] = temp_f; + cur_idx = fa_idx; + fa_idx = cur_idx >> 1; + } + } + else{ + if (d2 > best[1]) continue; + best[1] = d2; besti[1] = i - start; + + int cur_idx = 1, son_idx; + while (cur_idx <= k){ + son_idx = cur_idx << 1; + if (son_idx > k) break; + if (son_idx + 1 <= k && best[son_idx] < best[son_idx + 1]){ + son_idx++; + } + + if (son_idx <= k && best[cur_idx] < best[son_idx]){ + temp_i = besti[cur_idx]; besti[cur_idx] = besti[son_idx]; besti[son_idx] = temp_i; + temp_f = best[cur_idx]; best[cur_idx] = best[son_idx]; best[son_idx] = temp_f; + } + else break; + cur_idx = son_idx; + } + } + } + + for(int i = 1; i <= k; i++){ + idx[i - 1] = besti[i]; + } + // delete [] best; + // delete [] besti; +} + + + + +void knn_batch_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream) { + // param xyz: (n, 3), float + // param query_xyz: (m, 3), float + // param batch_idxs: (n), int + // param query_batch_offsets: (B + 1), int, offsets[-1] = m + // param idx: (n, k), int + + cudaError_t err; + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + knn_batch_cuda_<<>>(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +void knn_batch_mlogk_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream) { + // param xyz: (n, 3), float + // param query_xyz: (m, 3), float + // param batch_idxs: (n), int + // param query_batch_offsets: (B + 1), int, offsets[-1] = m + // param idx: (n, k), int + + cudaError_t err; + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + knn_batch_mlogk_cuda_<<>>(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +void knn_batch_mlogk_cuda_half(int n, int m, int k, const at::BFloat16 *xyz, const at::BFloat16 *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream) { + // param xyz: (n, 3), at::BFloat16 + // param query_xyz: (m, 3), at::BFloat16 + // param batch_idxs: (n), int + // param query_batch_offsets: (B + 1), int, offsets[-1] = m + // param idx: (n, k), int + + cudaError_t err; + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + knn_batch_mlogk_cuda_half_<<>>(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +void knn_batch_mlogk_cuda_half_fp16(int n, int m, int k, const at::Half *xyz, const at::Half *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream) { + // param xyz: (n, 3), at::Half + // param query_xyz: (m, 3), at::Half + // param batch_idxs: (n), int + // param query_batch_offsets: (B + 1), int, offsets[-1] = m + // param idx: (n, k), int + + cudaError_t err; + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + knn_batch_mlogk_cuda_half_fp16_<<>>(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/scenestreamer/models/ops/knn/src/knn_gpu.h b/scenestreamer/models/ops/knn/src/knn_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..c758972e978b6a50da73ac52c9f03b955257819d --- /dev/null +++ b/scenestreamer/models/ops/knn/src/knn_gpu.h @@ -0,0 +1,32 @@ +// Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 +// Published at NeurIPS 2022 +// Written by Li Jiang, Shaoshuai Shi +// All Rights Reserved + + +#ifndef KNN_H +#define KNN_H +#include +#include +#include +#include +#include +#include +#include +#include +// #include + + +void knn_batch(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k); +void knn_batch_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream); + +void knn_batch_mlogk(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k); +void knn_batch_mlogk_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream); + +void knn_batch_mlogk_half(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k); +void knn_batch_mlogk_cuda_half(int n, int m, int k, const at::BFloat16 *xyz, const at::BFloat16 *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream); + +void knn_batch_mlogk_half_fp16(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k); +void knn_batch_mlogk_cuda_half_fp16(int n, int m, int k, const at::Half *xyz, const at::Half *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream); + +#endif diff --git a/scenestreamer/models/ops/search_knn_indices.py b/scenestreamer/models/ops/search_knn_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..04b61496bda86a722cd9ecf12c7f12ffe0c05591 --- /dev/null +++ b/scenestreamer/models/ops/search_knn_indices.py @@ -0,0 +1,215 @@ +import torch +import torch.nn.functional as F + +from scenestreamer.models.ops.knn import knn_utils + + +def search_k_nearest_object_indices( + ego_position_full, # The "ego object" position + ego_valid_mask, + neighbor_position_full, # The "other object" position + neighbor_valid_mask, + num_neighbors +): + """ + ego_position_full: [B, T, max_ego_objects, 2] + ego_valid_mask: [B, T, max_ego_objects] + neighbor_position_full: [B, T, max_neighbor_objects, 2] + neighbor_valid_mask: [B, T, max_neighbor_objects] + num_neighbors: int + """ + assert ego_position_full.ndim == 4 + B, T, max_ego_objects, pos_dim = ego_position_full.shape + # assert pos_dim == 3 + assert ego_valid_mask.shape == (B, T, max_ego_objects) + assert neighbor_position_full.shape[:2] == (B, T) + _, _, max_neighbor_objects, pos_dim = neighbor_position_full.shape + assert neighbor_valid_mask.shape == (B, T, max_neighbor_objects) + + ego_position_full = ego_position_full.clone() + neighbor_position_full = neighbor_position_full.clone() + + effective_batch_size = B * T + + valid_ego_position = ego_position_full[ego_valid_mask] # [num valid ego, 2] + if valid_ego_position.shape[-1] == 2: + valid_ego_position = F.pad(valid_ego_position, (0, 1)) # [num valid ego, 3] + + # Build a lookup table that translate the index of a valid ego object to the "batch index". + # The batch index should in range [0, B*T]. See below for more discussion. + batch_index = torch.arange(0, effective_batch_size, device=ego_position_full.device, dtype=torch.int) # [B*T,] + batch_index = batch_index.reshape(B, T, 1) # [B, T, 1] + batch_index = batch_index.repeat(1, 1, max_ego_objects) # [B, T, max_ego_objects] + valid_batch_index = batch_index[ego_valid_mask] # [num valid ego,] + + neighbor_position_full[~neighbor_valid_mask] = 100000 + neighbor_position_flat = neighbor_position_full.flatten(start_dim=0, end_dim=2) # [B*T*max_neighbors, 2] + # neighbor_position_flat = neighbor_position_full[neighbor_valid_mask] # [num valid neighbor, 2] + if neighbor_position_flat.shape[-1] == 2: + neighbor_position_flat = F.pad(neighbor_position_flat, (0, 1)) # [num valid ego, 3] + + # neighbor_valid_mask is in [B, T, N] + # neighbor_batch_index = neighbor_valid_mask.sum(-1) # [B, T] + # neighbor_batch_index = neighbor_batch_index.reshape(-1) # [B * T] + # neighbor_batch_index = neighbor_batch_index.cumsum(-1).int() # [B * T] + # neighbor_batch_index = F.pad(neighbor_batch_index, (1, 0)) # [1 + B*T] + + # traffic_light_offsets is in shape [] + neighbor_offsets = \ + max_neighbor_objects * torch.arange(0, effective_batch_size + 1, device=ego_position_full.device, + dtype=torch.int) + + assert len(neighbor_offsets) - 1 == valid_batch_index.max() + 1 + assert neighbor_position_flat.shape[0] == neighbor_offsets.max() + assert neighbor_position_flat.shape[-1] == valid_ego_position.shape[-1] == 3 + + # Output is in range: [0, max_neighbors] + # (an alternative is to fall into [0, num valids neighbor], which is deprecated) + # print(f"Searching near {num_neighbors}, {neighbor_offsets}, {valid_batch_index}") + k_nearest_neighbor_index = knn_utils.knn_batch_mlogk( + valid_ego_position, # position of "ego" object + neighbor_position_flat, # position of "other objects" + valid_batch_index, # the batch index of each ego object, telling ego belongs to which batch + neighbor_offsets, # the index offset. For ego objects in (b, t), the offset will be (b*t-1)*max_neighbors + # neighbor_batch_index, + num_neighbors + ) + # print("Finish searching.") + + # It is possible that at some (batch: b, time: t), there are no valid neighbor objects at all! + # We will do postprocessing to tell that for those ego objects in (b, t), they have no neighbors since + # all neighbors at that (b, t) are invalid. + this_batch_has_no_neighbor = (~neighbor_valid_mask).all(dim=-1).flatten(0, 1) # after flatten the shape is [B*T] + + i_have_no_neighbor = this_batch_has_no_neighbor[valid_batch_index] # [num_valid,] + + k_nearest_neighbor_index[i_have_no_neighbor] = -1 + + # return k_nearest_neighbor_index + + ret = torch.empty((B, T, max_ego_objects, num_neighbors)).to(batch_index) + ret.fill_(-1) + ret[ego_valid_mask] = k_nearest_neighbor_index + + return ret + + +def search_k_nearest_map_feature_indicies( + ego_position_full, # The "ego object" position + ego_valid_mask, + neighbor_position_full, # The "other object" position + neighbor_valid_mask, + num_neighbors +): + assert ego_position_full.ndim == 4 + B, T, max_ego_objects, pos_dim = ego_position_full.shape + assert ego_valid_mask.shape == (B, T, max_ego_objects) + # assert neighbor_position_full.shape[:2] == (B, T) + _, max_map_feats, pos_dim = neighbor_position_full.shape + assert neighbor_valid_mask.shape == (B, max_map_feats) + + # effective_batch_size = B * T + + valid_ego_position = ego_position_full[ego_valid_mask] # [num valid ego, 2] + if valid_ego_position.shape[-1] == 2: + valid_ego_position = F.pad(valid_ego_position, (0, 1)) # [num valid ego, 3] + + # Build a lookup table that translate the index of a valid ego object to the "batch index". + # The batch index should in range [0, B]. + batch_index = torch.arange(0, B, device=ego_position_full.device, dtype=torch.int) # [B,] + batch_index = batch_index.reshape(B, 1, 1) # [B, 1, 1] + batch_index = batch_index.repeat(1, T, max_ego_objects) # [B, T, max_ego_objects] + valid_batch_index = batch_index[ego_valid_mask] # [num valid ego,] + + # neighbor_position_full[~neighbor_valid_mask] = float("+inf") + # neighbor_position_flat = neighbor_position_full.flatten(start_dim=0, end_dim=1) # [B*T*max_neighbors, 2] + + # traffic_light_offsets is in shape [] + # neighbor_offsets = \ + # max_map_feats * torch.arange(0, B + 1, device=ego_position_full.device, dtype=torch.int) + + neighbor_position_flat = neighbor_position_full[neighbor_valid_mask] # [num valid neighbor, 2] + if neighbor_position_flat.shape[-1] == 2: + neighbor_position_flat = F.pad(neighbor_position_flat, (0, 1)) # [num valid ego, 3] + + # neighbor_valid_mask is in [B, M] + neighbor_batch_index = neighbor_valid_mask.sum(-1) # [B,] + neighbor_batch_index = neighbor_batch_index.cumsum(-1).int() # [B,] + neighbor_batch_index = F.pad(neighbor_batch_index, (1, 0)) # [1+B] + + assert len(neighbor_batch_index) - 1 == valid_batch_index.max() + 1 + assert neighbor_position_flat.shape[0] == neighbor_batch_index.max() + assert neighbor_position_flat.shape[-1] == valid_ego_position.shape[-1] == 3 + + # Output will be in shape [num valid ego objects, K] + k_nearest_neighbor_index = knn_utils.knn_batch_mlogk( + valid_ego_position, # position of "ego" object + neighbor_position_flat, # position of "other objects" + valid_batch_index, # the batch index of each ego object, telling ego belongs to which batch + # neighbor_offsets, # the index offset. For ego objects in (b, t), the offset will be (b*t-1)*max_neighbors + neighbor_batch_index, + num_neighbors + ) + + # return k_nearest_neighbor_index + + ret = torch.empty((B, T, max_ego_objects, num_neighbors)).to(batch_index) + ret.fill_(-1) + ret[ego_valid_mask] = k_nearest_neighbor_index + + return ret + + +def search_k_nearest_map_feature_indicies_for_map( + ego_position_full, # The "ego object" position + ego_valid_mask, + num_neighbors +): + B, max_map_feats, pos_dim = ego_position_full.shape + assert ego_valid_mask.shape == (B, max_map_feats) + + valid_ego_position = ego_position_full[ego_valid_mask] # [num valid ego, 2] + if valid_ego_position.shape[-1] == 2: + valid_ego_position = F.pad(valid_ego_position, (0, 1)) # [num valid ego, 3] + + # Build a lookup table that translate the index of a valid map feat to the "batch index". + # The batch index should in range [0, B]. + batch_index = torch.arange(0, B, device=ego_position_full.device, dtype=torch.int) # [B,] + batch_index = batch_index.reshape(B, 1) # [B, 1, 1] + batch_index = batch_index.repeat(1, max_map_feats) # [B, max_ego_objects] + valid_batch_index = batch_index[ego_valid_mask] # [num valid map feat,] + + # ego_position_full[~ego_valid_mask] = float("+inf") + # neighbor_position_flat = ego_position_full.flatten(start_dim=0, end_dim=1) # [B*max_neighbors, 2] + neighbor_position_flat = ego_position_full[ego_valid_mask] + if neighbor_position_flat.shape[-1] == 2: + neighbor_position_flat = F.pad(neighbor_position_flat, (0, 1)) # [num valid ego, 3] + + # traffic_light_offsets is in shape [] + # neighbor_offsets = max_map_feats * torch.arange(0, B + 1, device=ego_position_full.device, dtype=torch.int) + + # neighbor_valid_mask is in [B, M] + neighbor_batch_index = ego_valid_mask.sum(-1) # [B,] + neighbor_batch_index = neighbor_batch_index.cumsum(-1).int() # [B,] + neighbor_batch_index = F.pad(neighbor_batch_index, (1, 0)) # [1+B] + + assert len(neighbor_batch_index) - 1 == valid_batch_index.max() + 1 + assert neighbor_position_flat.shape[0] == neighbor_batch_index.max() + assert neighbor_position_flat.shape[-1] == valid_ego_position.shape[-1] == 3 + + # Output will be in shape [num valid ego objects, K] + k_nearest_neighbor_index = knn_utils.knn_batch_mlogk( + valid_ego_position, # position of "ego" object + neighbor_position_flat, # position of "other objects" + valid_batch_index, # the batch index of each ego object, telling ego belongs to which batch + # neighbor_offsets, # the index offset. For ego objects in (b, t), the offset will be (b*t-1)*max_neighbors + neighbor_batch_index, # the index offset. For ego objects in (b, t), the offset will be (b*t-1)*max_neighbors + num_neighbors + ) + + # return k_nearest_neighbor_index + + ret = torch.empty((B, max_map_feats, num_neighbors)).to(batch_index) + ret.fill_(-1) + ret[ego_valid_mask] = k_nearest_neighbor_index + return ret diff --git a/scenestreamer/models/relation.py b/scenestreamer/models/relation.py new file mode 100644 index 0000000000000000000000000000000000000000..8cce5a1ac1d1e7e33822f9a3fa9ff1a1a82cb3bb --- /dev/null +++ b/scenestreamer/models/relation.py @@ -0,0 +1,855 @@ +import numpy as np +import torch + +# from torch.nn.modules.transformer import TransformerEncoderLayer as NativeTransformerEncoderLayer +from scenestreamer.dataset import constants +from scenestreamer.models.layers import position_encoding_utils +from scenestreamer.utils import rotate, utils + +# def pairwise_mask(mask): +# """ +# input mask is in shape (B, N), we need to prepare a pairwise mask in shape (B, N, N). +# It's not correct to naively expand the mask. We need to maintain the symmetry of the mask. +# """ +# B, N = mask.shape +# mask = mask.unsqueeze(1).expand(B, N, N) +# mask = mask & mask.transpose(1, 2) +# return mask + + +def pairwise_mask(mask_a, mask_b): + assert mask_a.ndim == mask_b.ndim == 2 + mask_a = mask_a.unsqueeze(-1) + mask_b = mask_b.unsqueeze(-2) + mask = torch.logical_and(mask_a, mask_b) + return mask + + +def pairwise_relative_diff(positions_a, positions_b): + """ + Compute pairwise relative diffs for a batch of objects. + For the ouput [b, i, j, :], it means the relative differences of [b, j] - [b, i], + which is the pos of j in i's coordinate system. + + Parameters: + - positions: A PyTorch tensor of shape (B, N, 2) + + Returns: + - A PyTorch tensor of shape (B, N, N, 2) containing pairwise relative positions. + """ + assert positions_a.ndim == positions_b.ndim + # assert positions_a.ndim == 3 or positions_a.ndim == 2 + # Expand dimensions to get tensors of shapes (B, N, 1, ...) and (B, 1, N, ...) + positions_expanded_a = positions_a.unsqueeze(2) # Shape: (B, N, 1, ...) + positions_expanded_b = positions_b.unsqueeze(1) # Shape: (B, 1, N, ...) + + # Compute the pairwise relative positions by subtraction + relative_positions = positions_expanded_b - positions_expanded_a # Shape: (B, N, N, ...) + + return relative_positions + + +def compute_relation( + query_pos, + query_heading, + query_valid_mask, + key_pos, + key_heading, + key_valid_mask, + hidden_dim, + causal_valid_mask, + knn=128, + max_distance=None, + gather=True, + return_pe=True, + query_step=None, + query_vel=None, + key_step=None, + key_vel=None, + include_contour=False, + query_width=None, + query_length=None, + key_width=None, + key_length=None, + non_agent_relation=False, +): + """ + Compute the relation encoding for the transformer encoder. + """ + assert max_distance is None, "Not implemented" + assert query_pos.ndim == key_pos.ndim == 3 + assert query_heading.ndim == key_heading.ndim == 2 + assert query_valid_mask.ndim == key_valid_mask.ndim == 2 + + pairwise_heading = pairwise_relative_diff(query_heading, key_heading) + + heading_fill_0_mask = pairwise_mask( + query_heading == constants.HEADING_PLACEHOLDER, key_heading == constants.HEADING_PLACEHOLDER + ) + pairwise_heading[heading_fill_0_mask] = 0 + + rel_pos = pairwise_relative_diff(query_pos[..., :2], key_pos[..., :2]) + + rel_vel = None + if query_vel is not None: + assert key_vel is not None + rel_vel = pairwise_relative_diff(query_vel, key_vel) + + B, Q = query_heading.shape + K = key_heading.shape[1] + + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global = query_heading.reshape(B, Q, 1).expand(B, Q, K) + i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 + rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) + + if rel_vel is not None: + rotated_vel = rotate(rel_vel[..., 0], rel_vel[..., 1], angle=-i_local_x_wrt_global) + + valid_mask = pairwise_mask(query_valid_mask, key_valid_mask) + + if include_contour: + contour_q = utils.cal_polygon_contour_torch( + x=query_pos[..., 0], y=query_pos[..., 1], theta=query_heading, width=query_width, length=query_length + ) + contour_k = utils.cal_polygon_contour_torch( + x=key_pos[..., 0], y=key_pos[..., 1], theta=key_heading, width=key_width, length=key_length + ) + contour_diff = pairwise_relative_diff(contour_q, contour_k) + contour_diff = rotate( + contour_diff[..., 0], contour_diff[..., 1], angle=-i_local_x_wrt_global.unsqueeze(-1).expand(-1, -1, -1, 4) + ) + + # THRESHOLD = 100 + dist = rel_pos.norm(dim=-1) + if causal_valid_mask is not None: + if causal_valid_mask.ndim == 2: + # the causal mask is not batched + causal_valid_mask = causal_valid_mask.unsqueeze(0).expand(B, -1, -1) + dist = dist.masked_fill_(~causal_valid_mask, float("+inf")) + valid_mask = valid_mask & causal_valid_mask + + else: + raise ValueError() # TODO + + # dist_mask = dist < THRESHOLD + # rel_mask = torch.logical_and(mask, dist_mask) + rel_mask = valid_mask + + if query_step is not None: + step_diff = pairwise_relative_diff(query_step, key_step) + + indices = None + if knn: + dist = dist.masked_fill_(~valid_mask, float("+inf")) + indices = dist.argsort(dim=-1)[..., :knn] + + if gather: + rotated_pos = torch.gather(rotated_pos, dim=-2, index=indices.unsqueeze(-1).expand(-1, -1, -1, 2)) + if rel_vel is not None: + rotated_vel = torch.gather(rotated_vel, dim=-2, index=indices.unsqueeze(-1).expand(-1, -1, -1, 2)) + pairwise_heading = torch.gather(pairwise_heading, dim=-1, index=indices) + rel_mask = torch.gather(rel_mask, dim=-1, index=indices) + if query_step is not None: + step_diff = torch.gather(step_diff, dim=-1, index=indices) + if include_contour: + contour_diff = torch.gather( + contour_diff, dim=-3, index=indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, 4, 2) + ) + + else: + + # Create a new mask with the same shape as rel_mask, initially set to False + original_valid_mask = torch.zeros_like(valid_mask, dtype=torch.bool) + + # Use advanced indexing to set True for indices selected by KNN + batch_indices = torch.arange(B).view(B, 1, 1) + query_indices = torch.arange(Q).view(1, Q, 1) + assert original_valid_mask.shape[0] == B, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, B + ) + assert original_valid_mask.shape[1] == Q, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, Q + ) + assert original_valid_mask.shape[2] == K, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, K + ) + original_valid_mask[batch_indices, query_indices, indices] = True + + # Update rel_mask to only include indices selected by KNN + rel_mask = rel_mask & original_valid_mask + + # Just pass them: + # rotated_pos + # pairwise_heading + if return_pe: + pos_pe = position_encoding_utils.gen_sineembed_for_relation( + rotated_pos[rel_mask], pairwise_heading[rel_mask], hidden_dim=hidden_dim + ) + pos_pe = utils.unwrap(pos_pe, rel_mask) + assert query_step is None + assert query_vel is None + assert not include_contour + return pos_pe, rel_mask, indices + else: + distance = torch.norm(rotated_pos, p=2, dim=-1) + ret = [rotated_pos, distance[..., None], pairwise_heading[..., None]] + if query_step is not None: + assert key_step is not None + ret.append(step_diff[..., None]) + if query_vel is not None: + ret.append(rotated_vel) + if include_contour: + ret.append(contour_diff.flatten(-2, -1)) + ret = torch.cat(ret, dim=-1) + ret[~rel_mask] = 0 + return ret, rel_mask, indices + + +def compute_relation_simple_relation( + *, + query_pos, + query_heading, + query_valid_mask, + key_pos, + key_heading, + key_valid_mask, + causal_valid_mask, + knn=128, + max_distance=None, + gather=True, + query_step=None, + key_step=None, + query_width=None, + query_length=None, + key_width=None, + key_length=None, + non_agent_relation=False, + per_contour_point_relation=None, + hidden_dim=None, # Useless + return_pe=None, # Useless +): + """ + Compute the relation encoding for the transformer encoder. + """ + assert per_contour_point_relation is not None, "Not implemented" + assert query_pos.ndim == key_pos.ndim == 3 + assert query_heading.ndim == key_heading.ndim == 2 + assert query_valid_mask.ndim == key_valid_mask.ndim == 2 + + pairwise_heading = pairwise_relative_diff(query_heading, key_heading) + + heading_fill_0_mask = pairwise_mask( + query_heading == constants.HEADING_PLACEHOLDER, key_heading == constants.HEADING_PLACEHOLDER + ) + pairwise_heading[heading_fill_0_mask] = 0 + + rel_pos = pairwise_relative_diff(query_pos[..., :2], key_pos[..., :2]) + + B, Q = query_heading.shape + K = key_heading.shape[1] + + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global = query_heading.reshape(B, Q, 1).expand(B, Q, K) + i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 + + rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) + + # if rel_vel is not None: + # rotated_vel = rotate(rel_vel[..., 0], rel_vel[..., 1], angle=-i_local_x_wrt_global) + + valid_mask = pairwise_mask(query_valid_mask, key_valid_mask) + + if not non_agent_relation: + + if per_contour_point_relation: + contour_q_center = utils.cal_polygon_contour_torch( + x=query_pos[..., 0], + y=query_pos[..., 1], + theta=query_heading, + + # Note that set width and length to zeros so that the contour is a point. + # There is no need to compute per-contour-point relation. + width=torch.zeros_like(query_pos[..., 0]), + length=torch.zeros_like(query_pos[..., 0]) + ) + contour_k = utils.cal_polygon_contour_torch( + x=key_pos[..., 0], + y=key_pos[..., 1], + theta=key_heading, + width=key_width if key_width is not None else torch.zeros_like(key_pos[..., 0]), + length=key_length if key_length is not None else torch.zeros_like(key_pos[..., 0]) + ) + contour_q = utils.cal_polygon_contour_torch( + x=query_pos[..., 0], + y=query_pos[..., 1], + theta=query_heading, + width=query_width if query_width is not None else torch.zeros_like(query_pos[..., 0]), + length=query_length if query_length is not None else torch.zeros_like(query_pos[..., 0]) + ) + contour_k_center = utils.cal_polygon_contour_torch( + x=key_pos[..., 0], + y=key_pos[..., 1], + theta=key_heading, + width=torch.zeros_like(key_pos[..., 0]), + length=torch.zeros_like(key_pos[..., 0]) + ) + contour_diff_in_q = pairwise_relative_diff(contour_q_center, contour_k) + # contour_diff_in_q = rotate( + # contour_diff_in_q[..., 0], + # contour_diff_in_q[..., 1], + # angle=-i_local_x_wrt_global.unsqueeze(-1).expand(-1, -1, -1, 4) + # ) + # contour_info = contour_diff_in_q.reshape(B, Q, K, 8) + contour_diff_in_q_min = contour_diff_in_q.min(dim=-2).values + contour_diff_in_q_max = contour_diff_in_q.max(dim=-2).values + contour_diff_in_k = pairwise_relative_diff(contour_k_center, contour_q) + + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global_key = key_heading.reshape(B, K, 1).expand(B, K, Q) + i_local_x_wrt_global_key = i_local_y_wrt_global_key - np.pi / 2 + contour_diff_in_k = rotate( + contour_diff_in_k[..., 0], + contour_diff_in_k[..., 1], + angle=-i_local_x_wrt_global_key.unsqueeze(-1).expand(-1, -1, -1, 4) + ) + contour_diff_in_k_min = contour_diff_in_k.min(dim=-2).values + contour_diff_in_k_max = contour_diff_in_k.max(dim=-2).values + contour_diff_in_k_min = contour_diff_in_k_min.permute(0, 2, 1, 3) + contour_diff_in_k_max = contour_diff_in_k_max.permute(0, 2, 1, 3) + + contour_info = torch.cat( + [contour_diff_in_q_min, contour_diff_in_q_max, contour_diff_in_k_min, contour_diff_in_k_max], dim=-1 + ) + + else: + contour_q_center = utils.cal_polygon_contour_torch( + x=query_pos[..., 0], + y=query_pos[..., 1], + theta=query_heading, + + # Note that set width and length to zeros so that the contour is a point. + # There is no need to compute per-contour-point relation. + width=torch.zeros_like(query_pos[..., 0]), + length=torch.zeros_like(query_pos[..., 0]) + ) + contour_k = utils.cal_polygon_contour_torch( + x=key_pos[..., 0], + y=key_pos[..., 1], + theta=key_heading, + width=key_width if key_width is not None else torch.zeros_like(key_pos[..., 0]), + length=key_length if key_length is not None else torch.zeros_like(key_pos[..., 0]) + ) + contour_diff_in_q = pairwise_relative_diff(contour_q_center, contour_k) + contour_diff_in_q = rotate( + contour_diff_in_q[..., 0], + contour_diff_in_q[..., 1], + angle=-i_local_x_wrt_global.unsqueeze(-1).expand(-1, -1, -1, 4) + ) + contour_info = contour_diff_in_q.reshape(B, Q, K, 8) + + # THRESHOLD = 100 + dist = rel_pos.norm(dim=-1) + if causal_valid_mask is not None: + if causal_valid_mask.ndim == 2: + # the causal mask is not batched + causal_valid_mask = causal_valid_mask.unsqueeze(0).expand(B, -1, -1) + dist = dist.masked_fill_(~causal_valid_mask, float("+inf")) + valid_mask = valid_mask & causal_valid_mask + + else: + raise ValueError() # TODO + + # dist_mask = dist < THRESHOLD + # rel_mask = torch.logical_and(mask, dist_mask) + dist_argsort = dist.argsort(dim=-1) + if max_distance is not None: + within_dist = dist < max_distance # Shape (B, Q, K) + # Allow at least 8 neighbors... + closest = dist_argsort[..., :8] + # fill in True for these 8 neighbors + within_dist[torch.arange(B).view(B, 1, 1), torch.arange(Q).view(1, Q, 1), closest] = True + + valid_mask = valid_mask & within_dist + + # rel_mask = valid_mask + + if not non_agent_relation and query_step is not None: + step_diff = pairwise_relative_diff(query_step, key_step) + else: + step_diff = None + + indices = None + if knn: + dist = dist.masked_fill_(~valid_mask, float("+inf")) + indices = dist_argsort[..., :knn] + + if gather: + rotated_pos = torch.gather(rotated_pos, dim=-2, index=indices.unsqueeze(-1).expand(-1, -1, -1, 2)) + # if rel_vel is not None: + # rotated_vel = torch.gather(rotated_vel, dim=-2, index=indices.unsqueeze(-1).expand(-1, -1, -1, 2)) + pairwise_heading = torch.gather(pairwise_heading, dim=-1, index=indices) + valid_mask = torch.gather(valid_mask, dim=-1, index=indices) + if query_step is not None: + step_diff = torch.gather(step_diff, dim=-1, index=indices) + # if include_contour: + + if not non_agent_relation: + contour_info = torch.gather(contour_info, dim=-2, index=indices.unsqueeze(-1).expand(-1, -1, -1, 8)) + + else: + + # Create a new mask with the same shape as rel_mask, initially set to False + original_valid_mask = torch.zeros_like(valid_mask, dtype=torch.bool) + + # Use advanced indexing to set True for indices selected by KNN + batch_indices = torch.arange(B).view(B, 1, 1) + query_indices = torch.arange(Q).view(1, Q, 1) + assert original_valid_mask.shape[0] == B, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, B + ) + assert original_valid_mask.shape[1] == Q, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, Q + ) + assert original_valid_mask.shape[2] == K, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, K + ) + original_valid_mask[batch_indices, query_indices, indices] = True + + # Update rel_mask to only include indices selected by KNN + valid_mask = valid_mask & original_valid_mask + + # Just pass them: + # rotated_pos + # pairwise_heading + # if return_pe: + # pos_pe = position_encoding_utils.gen_sineembed_for_relation( + # rotated_pos[rel_mask], pairwise_heading[rel_mask], hidden_dim=hidden_dim + # ) + # pos_pe = utils.unwrap(pos_pe, rel_mask) + # assert query_step is None + # assert query_vel is None + # assert not include_contour + # return pos_pe, rel_mask, indices + # else: + distance = torch.norm(rotated_pos, p=2, dim=-1) + relative_direction = torch.atan2(rotated_pos[..., 1], rotated_pos[..., 0]) + ret = [relative_direction[..., None], distance[..., None], pairwise_heading[..., None]] + # if query_step is not None: + # assert key_step is not None + + if non_agent_relation: + pass + else: + if step_diff is not None: + ret.append(step_diff[..., None]) + + # if query_vel is not None: + # ret.append(rotated_vel) + # if include_contour: + # ret.append(contour_diff.flatten(-2, -1)) + if non_agent_relation: + pass + else: + ret.append(contour_info) + + ret = torch.cat(ret, dim=-1) + ret[~valid_mask] = 0 + return ret, valid_mask, indices + +def compute_relation_for_scenestreamer( + *, + query_pos, + query_heading, + query_valid_mask, + key_pos, + key_heading, + key_valid_mask, + causal_valid_mask, + require_relation, + require_relation_for_key=None, + knn, + max_distance, + gather=True, + query_step=None, + key_step=None, + query_width=None, + query_length=None, + key_width=None, + key_length=None, + non_agent_relation=False, + # per_contour_point_relation=None, + force_attention_mask=None, +): + """ + Compute the relation encoding for the transformer encoder. + """ + # assert per_contour_point_relation is not None, "Not implemented" + assert query_pos.ndim == key_pos.ndim == 3 + assert query_heading.ndim == key_heading.ndim == 2 + assert query_valid_mask.ndim == key_valid_mask.ndim == 2 + + pairwise_heading = pairwise_relative_diff(query_heading, key_heading) + + heading_fill_0_mask = pairwise_mask( + query_heading == constants.HEADING_PLACEHOLDER, key_heading == constants.HEADING_PLACEHOLDER + ) + pairwise_heading[heading_fill_0_mask] = 0 + + rel_pos = pairwise_relative_diff(query_pos[..., :2], key_pos[..., :2]) + + B, Q = query_heading.shape + K = key_heading.shape[1] + + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global = query_heading.reshape(B, Q, 1).expand(B, Q, K) + i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 + + rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) + + # if rel_vel is not None: + # rotated_vel = rotate(rel_vel[..., 0], rel_vel[..., 1], angle=-i_local_x_wrt_global) + + valid_mask = pairwise_mask(query_valid_mask, key_valid_mask) + + raw_valid_mask = valid_mask.clone() + + if force_attention_mask is not None: + assert force_attention_mask.shape == causal_valid_mask.shape + assert force_attention_mask.shape == valid_mask.shape + # First remove impossible relations + force_attention_mask = force_attention_mask & valid_mask + # force_attention_mask = force_attention_mask & causal_valid_mask # This line is wrong. therefore we comment it out. + + if require_relation is not None: + if require_relation_for_key is not None: + require_relation_pairwise = pairwise_mask(require_relation, require_relation_for_key) + # require_relation_pairwise_qtrue_kfalse = pairwise_mask(require_relation, ~require_relation_for_key) + # require_relation_pairwise_qtrue_kfalse_neg = ~require_relation_pairwise_qtrue_kfalse + else: + require_relation_pairwise = pairwise_mask(require_relation, require_relation) + # require_relation_pairwise_qtrue_kfalse = pairwise_mask(require_relation, ~require_relation) + # require_relation_pairwise_qtrue_kfalse_neg = ~require_relation_pairwise_qtrue_kfalse + else: + require_relation_pairwise = None + + if not non_agent_relation: + contour_q_center = utils.cal_polygon_contour_torch( + x=query_pos[..., 0], + y=query_pos[..., 1], + theta=query_heading, + width=query_width if query_width is not None else torch.zeros_like(query_pos[..., 0]), + length=query_length if query_length is not None else torch.zeros_like(query_pos[..., 0]) + ) + contour_k = utils.cal_polygon_contour_torch( + x=key_pos[..., 0], + y=key_pos[..., 1], + theta=key_heading, + width=key_width if key_width is not None else torch.zeros_like(key_pos[..., 0]), + length=key_length if key_length is not None else torch.zeros_like(key_pos[..., 0]) + ) + contour_diff_in_q = pairwise_relative_diff(contour_q_center, contour_k) + contour_diff_in_q = rotate( + contour_diff_in_q[..., 0], + contour_diff_in_q[..., 1], + angle=-i_local_x_wrt_global.unsqueeze(-1).expand(-1, -1, -1, 4) + ) + contour_info = contour_diff_in_q.reshape(B, Q, K, 8) + + # THRESHOLD = 100 + dist = rel_pos.norm(dim=-1) + if causal_valid_mask is not None: + if causal_valid_mask.ndim == 2: + # the causal mask is not batched + causal_valid_mask = causal_valid_mask.unsqueeze(0).expand(B, -1, -1) + dist = dist.masked_fill_(~causal_valid_mask, float("+inf")) + valid_mask = valid_mask & causal_valid_mask + + elif causal_valid_mask.ndim == 3: + assert valid_mask.shape == causal_valid_mask.shape, (valid_mask.shape, causal_valid_mask.shape) + dist = dist.masked_fill_(~causal_valid_mask, float("+inf")) + valid_mask = valid_mask & causal_valid_mask + + else: + raise ValueError() # TODO + + # dist_mask = dist < THRESHOLD + # rel_mask = torch.logical_and(mask, dist_mask) + dist_argsort = dist.argsort(dim=-1) + if max_distance is not None: + within_dist = dist < max_distance # Shape (B, Q, K) + # Allow at least 8 neighbors... + closest = dist_argsort[..., :8] + # fill in True for these 8 neighbors + within_dist[torch.arange(B).view(B, 1, 1), torch.arange(Q).view(1, Q, 1), closest] = True + + if require_relation is not None: + # We want to make sure that in "within_dist", only + # those tokens with following conditions will participate in attention: + # 1) both Q and K require relation and they are closed, or + # 2) any of Q or K do not require relation. + within_dist = torch.logical_or( + within_dist & require_relation_pairwise, + ~require_relation_pairwise + ) + # within_dist = within_dist & require_relation_pairwise_qtrue_kfalse_neg + + valid_mask = valid_mask & within_dist + + if query_step is not None: + step_diff = pairwise_relative_diff(query_step, key_step) + else: + step_diff = None + + assert knn is not None + dist = dist.masked_fill_(~valid_mask, float("+inf")) + + if isinstance(knn, int): + + indices = dist_argsort[..., :knn] + + assert gather is False + + # Create a new mask with the same shape as rel_mask, initially set to False + new_valid_mask = torch.zeros_like(valid_mask, dtype=torch.bool) + + # Use advanced indexing to set True for indices selected by KNN + batch_indices = torch.arange(B).view(B, 1, 1) + query_indices = torch.arange(Q).view(1, Q, 1) + assert new_valid_mask.shape[0] == B, ( + new_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, B + ) + assert new_valid_mask.shape[1] == Q, ( + new_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, Q + ) + assert new_valid_mask.shape[2] == K, ( + new_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, K + ) + new_valid_mask[batch_indices, query_indices, indices] = True + + + else: + batch_indices = torch.arange(B).view(B, 1, 1) + new_valid_mask = torch.zeros_like(valid_mask, dtype=torch.bool) + for knn_group in knn.unique(): + # Use advanced indexing to set True for indices selected by KNN + query_indices = (knn == knn_group).nonzero(as_tuple=True)[1].view(1, -1, 1) + indices_subgroup = dist_argsort[batch_indices, query_indices, torch.arange(knn_group).view(1, 1, knn_group)] + new_valid_mask[batch_indices, query_indices, indices_subgroup] = True + + # Update rel_mask to only include indices selected by KNN + if require_relation is not None: + # only those tokens with following conditions will participate in attention: + # 1) both Q and K require relation and they are closed, or + # 2) any of Q or K do not require relation. + new_valid_mask = torch.logical_or( + new_valid_mask & require_relation_pairwise, + ~require_relation_pairwise + ) + # new_valid_mask = new_valid_mask & require_relation_pairwise_qtrue_kfalse_neg + valid_mask = valid_mask & new_valid_mask + + # Now, put back those force_attention_mask + if force_attention_mask is not None: + valid_mask = valid_mask | force_attention_mask + + distance = torch.norm(rotated_pos, p=2, dim=-1) + relative_direction = torch.atan2(rotated_pos[..., 1], rotated_pos[..., 0]) + ret = [relative_direction[..., None], distance[..., None], pairwise_heading[..., None]] + + if step_diff is not None: + ret.append(step_diff[..., None]) + + if non_agent_relation: + pass + else: + ret.append(contour_info) + + ret = torch.cat(ret, dim=-1) + ret[~valid_mask] = 0 + + # should assert raw valid mask include the new valid mask: + assert (raw_valid_mask.float() >= valid_mask.float()).all(), (raw_valid_mask.shape, valid_mask.shape) + + # if force_attention_mask is not None and force_attention_mask.sum()>0: + # if torch.isinf(ret[force_attention_mask]).any(): + # raise ValueError("The force_attention_mask is not working, please check it.") + # if torch.isnan(ret[force_attention_mask]).any(): + # raise ValueError("The force_attention_mask is not working, please check it.") + + return ret, valid_mask, require_relation_pairwise + + +def compute_relation_for_prev_step_key( + *, + query_pos, + query_heading, + query_valid_mask, + key_pos, + key_heading, + key_valid_mask, + causal_valid_mask, + knn=128, + max_distance=None, + gather=True, + query_step=None, + key_step=None, + query_width=None, + query_length=None, + key_width=None, + key_length=None, + non_agent_relation=False, + per_contour_point_relation=None, +): + """ + Compute the relation encoding for the transformer encoder. + """ + assert per_contour_point_relation is not None, "Not implemented" + assert query_pos.ndim == key_pos.ndim == 4 + assert query_heading.ndim == key_heading.ndim == 3 + assert query_valid_mask.ndim == key_valid_mask.ndim == 3 + + + B, T, raw_Q = query_heading.shape + raw_K = key_heading.shape[2] + assert key_heading.shape == (B, T, raw_K) + + # Flatten T and N(Q/K) dimensions + # We will apply a "block mask" to remove counterfactual relations + query_pos = query_pos.flatten(1, 2) + query_heading = query_heading.flatten(1, 2) + query_valid_mask = query_valid_mask.flatten(1, 2) + key_pos = key_pos.flatten(1, 2) + key_heading = key_heading.flatten(1, 2) + key_valid_mask = key_valid_mask.flatten(1, 2) + Q = query_heading.shape[1] + K = key_heading.shape[1] + # valid_mask will eventually in shape (B, T*raw_Q, T*raw_K) + # make a mask that all query at T can only attend to the key at T-1. + batch_causal_valid_mask = torch.zeros(B, T * raw_Q, T * raw_K, dtype=torch.bool) + for t in range(1, T): + batch_causal_valid_mask[:, t * raw_Q : (t + 1) * raw_Q, (t - 1) * raw_K : t * raw_K] = True + + pairwise_heading = pairwise_relative_diff(query_heading, key_heading) + + heading_fill_0_mask = pairwise_mask( + query_heading == constants.HEADING_PLACEHOLDER, key_heading == constants.HEADING_PLACEHOLDER + ) + pairwise_heading[heading_fill_0_mask] = 0 + + rel_pos = pairwise_relative_diff(query_pos[..., :2], key_pos[..., :2]) + + + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global = query_heading.reshape(B, Q, 1).expand(B, Q, K) + i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 + + rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) + + # if rel_vel is not None: + # rotated_vel = rotate(rel_vel[..., 0], rel_vel[..., 1], angle=-i_local_x_wrt_global) + + valid_mask = pairwise_mask(query_valid_mask, key_valid_mask) + + if not non_agent_relation: + contour_q_center = utils.cal_polygon_contour_torch( + x=query_pos[..., 0], + y=query_pos[..., 1], + theta=query_heading, + + # Note that set width and length to zeros so that the contour is a point. + # There is no need to compute per-contour-point relation. + width=torch.zeros_like(query_pos[..., 0]), + length=torch.zeros_like(query_pos[..., 0]) + ) + contour_k = utils.cal_polygon_contour_torch( + x=key_pos[..., 0], + y=key_pos[..., 1], + theta=key_heading, + width=key_width if key_width is not None else torch.zeros_like(key_pos[..., 0]), + length=key_length if key_length is not None else torch.zeros_like(key_pos[..., 0]) + ) + contour_diff_in_q = pairwise_relative_diff(contour_q_center, contour_k) + contour_diff_in_q = rotate( + contour_diff_in_q[..., 0], + contour_diff_in_q[..., 1], + angle=-i_local_x_wrt_global.unsqueeze(-1).expand(-1, -1, -1, 4) + ) + contour_info = contour_diff_in_q.reshape(B, Q, K, 8) + + # THRESHOLD = 100 + dist = rel_pos.norm(dim=-1) + if causal_valid_mask is not None: + if causal_valid_mask.ndim == 2: + # the causal mask is not batched + causal_valid_mask = causal_valid_mask.unsqueeze(0).expand(B, -1, -1) + dist = dist.masked_fill_(~causal_valid_mask, float("+inf")) + valid_mask = valid_mask & causal_valid_mask + + else: + raise ValueError() # TODO + + # PZH: Apply the block mask + assert valid_mask.shape == batch_causal_valid_mask.shape + valid_mask = valid_mask & batch_causal_valid_mask + + # dist_mask = dist < THRESHOLD + # rel_mask = torch.logical_and(mask, dist_mask) + dist_argsort = dist.argsort(dim=-1) + if max_distance is not None: + within_dist = dist < max_distance # Shape (B, Q, K) + # Allow at least 8 neighbors... + closest = dist_argsort[..., :8] + # fill in True for these 8 neighbors + within_dist[torch.arange(B).view(B, 1, 1), torch.arange(Q).view(1, Q, 1), closest] = True + + valid_mask = valid_mask & within_dist + + if query_step is not None: + step_diff = pairwise_relative_diff(query_step, key_step) + else: + step_diff = None + + indices = None + assert knn + if knn: + dist = dist.masked_fill_(~valid_mask, float("+inf")) + indices = dist_argsort[..., :knn] + + assert gather is False + + # Create a new mask with the same shape as rel_mask, initially set to False + original_valid_mask = torch.zeros_like(valid_mask, dtype=torch.bool) + + # Use advanced indexing to set True for indices selected by KNN + batch_indices = torch.arange(B).view(B, 1, 1) + query_indices = torch.arange(Q).view(1, Q, 1) + assert original_valid_mask.shape[0] == B, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, B + ) + assert original_valid_mask.shape[1] == Q, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, Q + ) + assert original_valid_mask.shape[2] == K, ( + original_valid_mask.shape, indices.shape, valid_mask.shape, dist.shape, K + ) + original_valid_mask[batch_indices, query_indices, indices] = True + + # Update rel_mask to only include indices selected by KNN + valid_mask = valid_mask & original_valid_mask + + distance = torch.norm(rotated_pos, p=2, dim=-1) + relative_direction = torch.atan2(rotated_pos[..., 1], rotated_pos[..., 0]) + ret = [relative_direction[..., None], distance[..., None], pairwise_heading[..., None]] + + if step_diff is not None: + ret.append(step_diff[..., None]) + + if non_agent_relation: + pass + else: + ret.append(contour_info) + + ret = torch.cat(ret, dim=-1) + ret[~valid_mask] = 0 + return ret, valid_mask, indices diff --git a/scenestreamer/models/scene_encoder.py b/scenestreamer/models/scene_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..74937f2777bfa8ae584adc34c3f7a83161c3280f --- /dev/null +++ b/scenestreamer/models/scene_encoder.py @@ -0,0 +1,367 @@ +import numpy as np +import torch +import torch.nn as nn + +from scenestreamer.dataset import constants +from scenestreamer.models.layers import polyline_encoder, common_layers, position_encoding_utils +# from torch.nn.modules.transformer import TransformerEncoderLayer as NativeTransformerEncoderLayer +from scenestreamer.models.layers.encoder_layer import TransformerEncoderLayer # as NativeTransformerEncoderLayer +from scenestreamer.models.ops.collapse_time import collapse_time +from scenestreamer.utils import rotate, unwrap + + +def mode_agent_id(agent_id, max_agents, fill_negative_1=False): + # As most of the "modeled agents" are in the first few agents, we want to remap those useless agents to latter + # positions. + agent_id = agent_id.clone() + if fill_negative_1: + agent_id[torch.logical_or(agent_id >= max_agents, agent_id < 0)] = -1 + else: + agent_id[torch.logical_or(agent_id >= max_agents, agent_id < 0)] = max_agents - 1 + return agent_id + + +def find_last_valid(array, mask): + assert mask.ndim + 1 == array.ndim + assert mask.shape == array.shape[:-1] + assert array.ndim == 4 + B, T, N, D = array.shape + indices = mask * torch.arange(T, device=mask.device).reshape(1, T, 1).expand(*mask.shape) + indices = indices.argmax(1, keepdims=True).unsqueeze(-1).expand(B, 1, N, D) + ret = torch.gather(array, index=indices, dim=1) # [B, 1, N, D] + ret[~mask.any(1, keepdims=True)] = 0 + return ret + + +def pairwise_mask(mask): + """ + input mask is in shape (B, N), we need to prepare a pairwise mask in shape (B, N, N). + It's not correct to naively expand the mask. We need to maintain the symmetry of the mask. + """ + B, N = mask.shape + mask = mask.unsqueeze(1).expand(B, N, N) + mask = mask & mask.transpose(1, 2) + return mask + + +def pairwise_relative_diff(positions): + """ + Compute pairwise relative diffs for a batch of objects. + For the ouput [b, i, j, :], it means the relative differences of [b, j] - [b, i], + which is the pos of j in i's coordinate system. + + Parameters: + - positions: A PyTorch tensor of shape (B, N, 2) + + Returns: + - A PyTorch tensor of shape (B, N, N, 2) containing pairwise relative positions. + """ + + # Expand dimensions to get tensors of shapes (B, N, 1, ...) and (B, 1, N, ...) + positions_expanded_a = positions.unsqueeze(2) # Shape: (B, N, 1, ...) + positions_expanded_b = positions.unsqueeze(1) # Shape: (B, 1, N, ...) + + # Compute the pairwise relative positions by subtraction + relative_positions = positions_expanded_b - positions_expanded_a # Shape: (B, N, N, ...) + + return relative_positions + + +def compute_relation(pos, heading, mask, hidden_dim, knn=128): + """ + Compute the relation encoding for the transformer encoder. + """ + assert heading.ndim == 2 + assert pos.ndim == 3 + pairwise_heading = pairwise_relative_diff(heading) + heading_fill_0_mask = pairwise_mask(heading == constants.HEADING_PLACEHOLDER) + pairwise_heading[heading_fill_0_mask] = 0 + + rel_pos = pairwise_relative_diff(pos[..., :2]) + + B, N = heading.shape + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global = heading.reshape(B, N, 1).expand(B, N, N) + i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 + rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) + + mask = pairwise_mask(mask) + + THRESHOLD = 100 + dist = rel_pos.norm(dim=-1) + dist_mask = dist < THRESHOLD + rel_mask = torch.logical_and(mask, dist_mask) + + indices = None + if knn: + dist = dist.masked_fill(~mask, float("+inf")) + indices = dist.argsort(dim=-1)[..., :knn] + + rotated_pos = torch.gather(rotated_pos, dim=-2, index=indices.unsqueeze(-1).expand(-1, -1, -1, 2)) + pairwise_heading = torch.gather(pairwise_heading, dim=-1, index=indices) + rel_mask = torch.gather(rel_mask, dim=-1, index=indices) + + pos_pe = position_encoding_utils.gen_sineembed_for_relation( + rotated_pos[rel_mask], pairwise_heading[rel_mask], hidden_dim=hidden_dim + ) + pos_pe = unwrap(pos_pe, rel_mask) + return pos_pe, rel_mask, indices + + +class SceneEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + # TODO: Pass this from config or datasource + SCENE_INPUT_TIME_STEPS = 11 + self.history_steps = SCENE_INPUT_TIME_STEPS + self.config = config + self.d_model = self.config.MODEL.D_MODEL + self.num_layers = self.config.MODEL.NUM_ATTN_LAYERS + + self.map_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( + in_channels=constants.MAP_FEATURE_STATE_DIM, + hidden_dim=64, + num_layers=2, + num_pre_layers=1, + out_channels=self.d_model + ) + self.agent_mlps = common_layers.build_mlps( + c_in=constants.AGENT_STATE_DIM * SCENE_INPUT_TIME_STEPS, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + self.light_mlps = common_layers.build_mlps( + c_in=constants.TRAFFIC_LIGHT_STATE_DIM * SCENE_INPUT_TIME_STEPS, + mlp_channels=[self.d_model] * 3, + ret_before_act=True, + ) + + # self.separate_pe = self.config.MODEL.get('SEPARATE_PE', False) + + dropout = self.config.MODEL.DROPOUT_OF_ATTN + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + self_attn_layers = [] + # transformer_d_model = self.d_model * 2 if self.separate_pe else self.d_model + for _ in range(self.num_layers): + self_attn_layers.append( + TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.num_heads, + dim_feedforward=self.d_model * 4, + dropout=dropout, + batch_first=True, + pre_projection=self.config.MODEL.get('PRE_PROJECTION', False), + relative_pe=self.config.MODEL.get('RELATIVE_PE', False), + ) + ) + + self.self_attn_layers = nn.ModuleList(self_attn_layers) + self.agent_pe = nn.Embedding(self.config.PREPROCESSING.MAX_AGENTS, self.d_model) + + self.out = common_layers.build_mlps( + c_in=self.d_model, + mlp_channels=[self.d_model], + ret_before_act=True, + ) + + self.relative_pe = self.config.MODEL.get('RELATIVE_PE', False) + + self.add_pe_for_static_features = self.config.MODEL.get('ADD_PE_FOR_STATIC_FEATURE', False) + if self.add_pe_for_static_features: + self.type_pe = common_layers.Tokenizer(num_actions=constants.NUM_TYPES, d_model=self.d_model) + + def forward(self, input_dict): + + # ===== Get shape ===== + B, T, N, D_agent = input_dict["encoder/agent_feature"].shape + _, M, num_vector, D_vector = input_dict["encoder/map_feature"].shape + _, _, L, D_light = input_dict["encoder/traffic_light_feature"].shape + in_evaluation = input_dict["in_evaluation"][0].item() + + # ===== Embed agent feature ===== + agent_feature = input_dict["encoder/agent_feature"] + agent_valid_mask = input_dict["encoder/agent_valid_mask"] + agent_position = input_dict["encoder/agent_position"] + agent_heading = input_dict["encoder/agent_heading"] + agent_id = input_dict["encoder/agent_id"] + assert agent_feature.shape[:3] == agent_position.shape[:3] == agent_valid_mask.shape[:3] + agent_feature = (agent_feature[:, :self.history_steps] * agent_valid_mask[:, :self.history_steps, ..., None]) + agent_feature = collapse_time(agent_feature) + agent_token = self.agent_mlps(agent_feature) # (B, N, D) + + if in_evaluation: + # Exempt filtering for maximum number of agents, so agent_id might be out of bound. + agent_id = mode_agent_id(agent_id, self.config.PREPROCESSING.MAX_AGENTS) + # Exempt filtering for maximum number of agents, so agent_id might be out of bound. + modeled_agent_id = mode_agent_id( + input_dict["encoder/modeled_agent_id"], self.config.PREPROCESSING.MAX_AGENTS + ) + else: + modeled_agent_id = input_dict["encoder/modeled_agent_id"] + + if self.config.MODEL.RANDOMIZE_AGENT_ID: + weights = torch.ones(self.config.PREPROCESSING.MAX_AGENTS).expand(B, -1) + if N > self.config.PREPROCESSING.MAX_AGENTS: + new_encoder_agent_id = torch.full_like(agent_id, -1) + num_samples = self.config.PREPROCESSING.MAX_AGENTS + new_encoder_agent_id[:, :num_samples] = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(agent_id) + assert (agent_id[:, num_samples:] == self.config.PREPROCESSING.MAX_AGENTS - 1).all() + else: + num_samples = N + new_encoder_agent_id = torch.multinomial( + weights, num_samples=num_samples, replacement=False + ).to(agent_id) + new_encoder_agent_id[agent_id == -1] = N + input_dict["encoder/randomized_agent_id"] = new_encoder_agent_id + agent_id = new_encoder_agent_id + + modeled_agent_mask = torch.logical_or(modeled_agent_id == -1, modeled_agent_id >= N) + modeled_agent_id[modeled_agent_mask] = N - 1 # Quick workaround + new_modeled_agent_id = torch.gather(new_encoder_agent_id, dim=1, index=modeled_agent_id) + # new_modeled_agent_id[modeled_agent_mask] = N - 1 + input_dict["encoder/randomized_modeled_agent_id"] = new_modeled_agent_id + modeled_agent_id = new_modeled_agent_id + + else: + raise ValueError("Please turn on MODEL.RANDOMIZE_AGENT_ID=True") + + agent_id = mode_agent_id(agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=False) + modeled_agent_id = mode_agent_id(modeled_agent_id, self.config.PREPROCESSING.MAX_AGENTS, fill_negative_1=False) + + agent_pe = self.agent_pe(agent_id) # (B, N, D) + agent_token += agent_pe + assert agent_token.shape == (B, N, self.d_model) + + if self.add_pe_for_static_features: + type_pe = self.type_pe(input_dict["encoder/agent_type"]) + agent_token += type_pe # [:, None] + + # ===== Embed map feature ===== + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_position = input_dict["encoder/map_position"] + map_heading = input_dict["encoder/map_heading"] + map_token_valid_mask = input_dict["encoder/map_valid_mask"] + map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + assert map_token.shape == (B, M, self.d_model) + + # ===== Embed traffic light ===== + traffic_light_feature = input_dict["encoder/traffic_light_feature"] + traffic_light_position = input_dict["encoder/traffic_light_position"] + traffic_light_heading = input_dict["encoder/traffic_light_heading"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + if L != 0: + traffic_light_feature = ( + traffic_light_feature[:, :self.history_steps] * + traffic_light_valid_mask[:, :self.history_steps, ..., None] + ) + traffic_light_feature = collapse_time(traffic_light_feature) + traffic_light_token = self.light_mlps(traffic_light_feature) + else: + traffic_light_token = traffic_light_feature.new_zeros([B, L, self.d_model]) + assert traffic_light_token.shape == (B, L, self.d_model) + + # ===== Call transformer layers ===== + x = torch.concatenate([map_token, agent_token, traffic_light_token], dim=1) + + # ======== changes for including language embedding into scenario encoding features + if self.config.LANGUAGE_CONDITION: + if 'decoder/prompt_embedding' not in input_dict: + print("PROMPT EMBED NOT FOUND") + raise () + else: + prompt_embedding = input_dict['decoder/prompt_embedding'] + print("x.shape", x.shape, "embeding.shape", prompt_embedding.shape) + expanded_embedding = prompt_embedding.unsqueeze(-1).repeat( + 1, 1, 256 + ) # Repeating to shape (6, 512, 256) + expanded_embedding_mask = torch.ones((expanded_embedding.shape[0], expanded_embedding.shape[1])) + x = torch.cat([x, expanded_embedding], dim=1) + # ======== + + x_pos = torch.concatenate( + [ + map_position, + find_last_valid(agent_position[:, :self.history_steps], agent_valid_mask[:, :self.history_steps])[:, 0], + traffic_light_position + ], + dim=1 + ) + + x_mask = torch.concatenate( + [ + map_token_valid_mask, agent_valid_mask[:, :self.history_steps].any(dim=1), + traffic_light_valid_mask[:, :self.history_steps].any(dim=1) + ], + dim=1 + ) + assert torch.all(x_mask.sum(dim=-1) > 0) + + if self.relative_pe: + x_heading = torch.concatenate( + [ + map_heading, + find_last_valid( + agent_heading[:, :self.history_steps, ..., None], agent_valid_mask[:, :self.history_steps] + )[:, 0, :, 0], traffic_light_heading + ], + dim=1 + ) + relation, rel_mask, indices = compute_relation( + pos=x_pos, + heading=x_heading, + mask=x_mask, + hidden_dim=self.d_model, + knn=self.config.MODEL.get('KNN', 128) + ) + pos_embedding = None + + # To speed up: + # assert rel_mask.ndim == 3 + # rel_mask = rel_mask.view(B, 1, rel_mask.shape[1], rel_mask.shape[2])\ + # .expand(-1, self.num_heads, -1, -1)\ + # .reshape(B * self.num_heads, rel_mask.shape[1], rel_mask.shape[2]) + else: + relation = None + pos_embedding = position_encoding_utils.gen_sineembed_for_position(x_pos[..., 0:2], hidden_dim=self.d_model) + + for k in range(len(self.self_attn_layers)): + # inp = self._add_pe(x, pos_embedding) + x = self.self_attn_layers[k]( + tgt=x, + pos=pos_embedding, + tgt_key_padding_mask=~x_mask, + relation=relation, + relation_mask=rel_mask, + relation_indices=indices, + ) + + # x = torch.cat([x, pos_embedding], dim=-1) + x = self.out(x.reshape(-1, x.shape[-1])).reshape(list(x.shape[:-1]) + [self.d_model]) + + if pos_embedding is not None: + x = x + pos_embedding + + input_dict["encoder/scenario_token"] = x + if self.relative_pe: + input_dict["encoder/scenario_position"] = x_pos + input_dict["encoder/scenario_heading"] = x_heading + input_dict["encoder/scenario_valid_mask"] = x_mask + + input_dict["encoder/modeled_agent_pe"] = self.agent_pe(modeled_agent_id) + if self.add_pe_for_static_features: + input_dict["encoder/modeled_agent_type_pe"] = self.type_pe(input_dict["encoder/modeled_agent_type"]) + return input_dict + + +if __name__ == '__main__': + from scenestreamer.utils import debug_tools + + config = debug_tools.get_debug_config() + model = SceneEncoder(config) + input_dict = debug_tools.get_debug_data() + out = model(input_dict) + print(out) diff --git a/scenestreamer/models/scenestreamer_model.py b/scenestreamer/models/scenestreamer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd414b1a274b032929be9f51a1a664b6c8987cf2 --- /dev/null +++ b/scenestreamer/models/scenestreamer_model.py @@ -0,0 +1,1675 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import dense_to_sparse +from torch_geometric.utils import softmax + +from scenestreamer.dataset import constants +from scenestreamer.dataset.preprocessor import TG_SKIP_STEP, NUM_TG_MULTI +from scenestreamer.models import relation +from scenestreamer.models.layers import common_layers, fourier_embedding +from scenestreamer.models.layers import polyline_encoder +from scenestreamer.models.layers.decoder_layer import _get_clones +from scenestreamer.models.layers.gpt_encoder_layer import SelfAttTransformerEncoder, SelfAttTransformerEncoderLayer +from scenestreamer.models.motion_decoder import create_causal_mask +from scenestreamer.tokenization import get_action_dim, get_tokenizer, START_ACTION +from scenestreamer.tokenization.trafficgen_tokenizers import TrafficGenTokenizerAutoregressive +from scenestreamer.utils import utils + + +def get_num_tg(N): + return N * NUM_TG_MULTI + 2 + + +def mode_agent_id(agent_id, max_agents, fill_negative_1=False): + # As most of the "modeled agents" are in the first few agents, we want to remap those useless agents to latter + # positions. + agent_id = agent_id.clone() + if fill_negative_1: + agent_id[torch.logical_or(agent_id >= max_agents, agent_id < 0)] = -1 + else: + agent_id[torch.logical_or(agent_id >= max_agents, agent_id < 0)] = max_agents - 1 + return agent_id + + +def get_edge_info_for_scenestreamer(*, q_k_valid_mask, q_k_relation, relation_model, relation_model_1d=None, require_relation_pairwise=None): + B, Lq, Lk = q_k_valid_mask.shape + edge_index, _ = dense_to_sparse(q_k_valid_mask.swapaxes(1, 2).contiguous()) + assert edge_index.numel() > 0, (edge_index.shape, q_k_valid_mask.sum()) + assert edge_index[0].max() < B * Lk, f"{edge_index[0].max()} >= {B * Lk}" + assert edge_index[1].max() < B * Lq, f"{edge_index[1].max()} >= {B * Lq}" + + batch_ind = edge_index[1] // Lq + q_ind = edge_index[1] % Lq + batch_ind_k = edge_index[0] // Lk + k_ind = edge_index[0] % Lk + assert torch.all(batch_ind == batch_ind_k) + edge_relation = q_k_relation[batch_ind, q_ind, k_ind] + + edge_features_v = None + + assert relation_model is not None + if require_relation_pairwise is not None: + require_relation = require_relation_pairwise[batch_ind, q_ind, k_ind] + + edge_feat = relation_model(edge_relation[require_relation]) + edge_features = utils.unwrap(edge_feat, require_relation) + + if relation_model_1d is not None: + assert edge_relation.shape[-1] == 4 + edge_feat_1d = relation_model_1d(edge_relation[~require_relation][:, -1:]) + edge_features = utils.unwrap(edge_feat_1d, ~require_relation, existing=edge_features) + + # (edge_features[require_relation] == edge_feat_4d).all() + # (edge_features[~require_relation] == edge_feat_1d).all() + + else: + edge_features = relation_model(edge_relation) + + return { + "edge_index": edge_index, + "edge_features": edge_features, + "edge_features_v": edge_features_v, + } + + +class MultiheadAttentionLayer(MessagePassing): + def __init__( + self, + d_model, + n_heads, + dropout=0.0, + simple_relation=False, + simple_relation_factor=2, + is_v7=False, + update_relation=False, + add_relation_to_v=None + ): + super(MultiheadAttentionLayer, self).__init__(aggr='add', node_dim=0) # Aggregation method 'add' + self.n_heads = n_heads + self.head_dim = d_model // n_heads + assert dropout == 0.0, "dropout is not supported" + self.dropout = nn.Dropout(dropout) + self.relation_head_dim = self.head_dim // simple_relation_factor + self.to_q_relation = nn.Linear(d_model, d_model) + self.to_k_r = nn.Linear(d_model // simple_relation_factor, d_model) + self.to_v_r = nn.Linear(d_model // simple_relation_factor, d_model) + self.to_k = nn.Linear(d_model, d_model) + self.to_q = nn.Linear(d_model, d_model) + self.to_v = nn.Linear(d_model, d_model) + self.out = nn.Linear(d_model, d_model) + + def forward( + self, + q, + k, + edge_index, + edge_features, + edge_features_v=None, + use_cache=False, + cache=None, #Relation=None + ): + B, Lq, D = q.shape + _, Lk, _ = k.shape + + # Compute linear projections + x_dst = q + x_src = k + Q = self.to_q(x_dst).reshape(-1, self.n_heads * self.head_dim) + K = self.to_k(x_src).reshape(-1, self.n_heads * self.head_dim) + V = self.to_v(x_src).reshape(-1, self.n_heads * self.head_dim) + + if cache is not None: + past_key = cache[0] + past_value = cache[1] + key_B, key_T = cache[2] + + K = K.reshape(key_B, -1, self.n_heads * self.head_dim) + past_key = past_key.reshape(key_B, key_T, self.n_heads * self.head_dim) + K = torch.cat((past_key, K), dim=1) + K = K.reshape(-1, self.n_heads * self.head_dim) + + V = V.reshape(key_B, -1, self.n_heads * self.head_dim) + past_value = past_value.reshape(key_B, key_T, self.n_heads * self.head_dim) + V = torch.cat((past_value, V), dim=1) + V = V.reshape(-1, self.n_heads * self.head_dim) + + assert edge_index[0].max() < K.shape[0], f"{edge_index[0].max()} >= {K.shape[0]}" + assert edge_index[1].max() < Q.shape[0], f"{edge_index[1].max()} >= {Q.shape[0]}" + + if use_cache: + new_cache = [K, V] + else: + new_cache = None + + Q_relation = self.to_q_relation(x_dst).reshape(-1, self.n_heads * self.head_dim) + Q = torch.cat([Q, Q_relation], dim=-1) + + assert edge_features_v is None + edge_features_v = edge_features + edge_features = self.to_k_r(edge_features) + edge_features_v = self.to_v_r(edge_features_v) + + # Propagate messages using edge_index + out, new_edge_features = self.propagate( + edge_index=edge_index, + # x_dst=x_dst.reshape(-1, self.n_heads * self.head_dim), + q=Q, + k=K, + v=V, + edge_features=edge_features, + edge_features_v=edge_features_v, + ) + + # Project the output back to original dimension + out = out.reshape(B, Lq, D) + if new_edge_features is not None: + new_edge_features = new_edge_features.reshape(-1, D) + out = self.out(out) + return out, new_cache, new_edge_features #, edge_features, edge_features_v + + def message( + self, q_i, k_j, v_j, edge_features, edge_features_v, index, ptr, edge_index, edge_index_i, edge_index_j, + relation + ): + k_j = k_j.reshape(-1, self.n_heads, self.head_dim) + v_j = v_j.reshape(-1, self.n_heads, self.head_dim) + q_i, q_relation = q_i[:, :self.n_heads * self.head_dim], q_i[:, self.n_heads * self.head_dim:] + # Compute attention scores + q_i = q_i.reshape(-1, self.n_heads, self.head_dim) + q_relation = q_relation.reshape(-1, self.n_heads, self.head_dim) + edge_features = edge_features.reshape(-1, self.n_heads, self.head_dim) + attn_scores = (q_i * k_j).sum(dim=-1) / self.head_dim**0.5 # Scaled dot-product + attn_scores_relation = (q_relation * edge_features).sum(dim=-1) / self.head_dim**0.5 + attn_scores = attn_scores + attn_scores_relation + attn_weights = softmax(attn_scores, index=index, ptr=ptr) + attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights + if edge_features_v is not None: + edge_features_v = edge_features_v.reshape(-1, self.n_heads, self.head_dim) + v_j = v_j + edge_features_v + attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights + return v_j * attn_weights.unsqueeze(-1), None + + def aggregate( + self, + inputs: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + ) -> Tensor: + raw_inputs, new_edge_features = inputs + inputs = super().aggregate(raw_inputs, index, ptr, dim_size) + if new_edge_features is not None: + new_edge_features = new_edge_features + raw_inputs + return inputs, new_edge_features + + +class SceneEncoderGPT(nn.Module): + def __init__(self, config, relation_embed): + super().__init__() + self.config = config + self.d_model = self.config.MODEL.D_MODEL + self.num_layers = self.config.MODEL.NUM_ATTN_LAYERS + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + dropout = self.config.MODEL.DROPOUT + self.map_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( + in_channels=constants.MAP_FEATURE_STATE_DIM, + hidden_dim=64, + num_layers=2, + num_pre_layers=1, + out_channels=self.d_model, + ) + simple_relation_factor = self.config.SIMPLE_RELATION_FACTOR + simple_relation = self.config.SIMPLE_RELATION + self.relation_embed = relation_embed + self.encoder = SelfAttTransformerEncoder( + decoder_layer=SelfAttTransformerEncoderLayer( + d_model=self.d_model, + nhead=self.num_heads, + simple_relation=simple_relation, + simple_relation_factor=simple_relation_factor, + dropout=dropout, + update_relation=self.config.UPDATE_RELATION, + add_relation_to_v=self.config.MODEL.ADD_RELATION_TO_V, + remove_rel_norm=self.config.REMOVE_REL_NORM, + ), + num_layers=self.num_layers, + ) + self.out = common_layers.build_mlps( + c_in=self.d_model, mlp_channels=[self.d_model], ret_before_act=True, + ) + self.out_prenorm = nn.LayerNorm(self.d_model) + + def forward(self, input_dict): + # ===== Get shape ===== + B, M, num_vector, D_vector = input_dict["encoder/map_feature"].shape + # ===== Embed map feature ===== + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_position = input_dict["encoder/map_position"] + map_heading = input_dict["encoder/map_heading"] + map_token_valid_mask = input_dict["encoder/map_valid_mask"] + map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + assert map_token.shape == (B, M, self.d_model) + x = map_token # [map_token, traffic_light_token] + x_pos = map_position # [map_position, traffic_light_position] + x_heading = map_heading # [map_heading, traffic_light_heading] + x_mask = map_token_valid_mask # [map_token_valid_mask, tlmask] + assert torch.all(x_mask.sum(dim=-1) > 0) + rel_feat, rel_mask, require_relation_pairwise = relation.compute_relation_for_scenestreamer( + query_pos=x_pos, + query_heading=x_heading, + query_valid_mask=x_mask, + key_pos=x_pos, + key_heading=x_heading, + key_valid_mask=x_mask, + # hidden_dim=self.d_model, + causal_valid_mask=None, + knn=self.config.SCENESTREAMER_ATTENTION_KNN, + max_distance=self.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE, + gather=False, + # return_pe=False, + non_agent_relation=True, + require_relation=None, + # per_contour_point_relation=self.config.MODEL.PER_CONTOUR_POINT_RELATION, + ) + scene_info = get_edge_info_for_scenestreamer( + q_k_valid_mask=rel_mask, + q_k_relation=rel_feat, + relation_model=self.relation_embed, + ) + x = self.encoder( + scene_tokens=x, + scene_info=scene_info, + edge_features=scene_info["edge_features"], + edge_features_v=scene_info["edge_features_v"] + ) + x = self.out_prenorm(x[x_mask]) + x = self.out(x) # .reshape(list(x.shape[:-1]) + [self.d_model]) + x = utils.unwrap(x, x_mask) + input_dict["model/map_token"] = x + return input_dict + + +class TransformerBlock(nn.Module): + """ + A single transformer block that uses adaptive layer normalization. + It includes a self-attention layer and a feed-forward network. + """ + + def __init__(self, hidden_size, num_heads, conditioning_dim, dropout=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=False) + self.adaln1 = common_layers.AdaLayerNorm(hidden_size, conditioning_dim, batch_first=False) + self.adaln2 = common_layers.AdaLayerNorm(hidden_size, conditioning_dim, batch_first=False) + + # Simple feed-forward network. + self.ff = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), nn.ReLU(), nn.Linear(hidden_size * 4, hidden_size), + nn.Dropout(dropout) + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, z, attn_mask=None, key_padding_mask=None): + """ + x: Tensor of shape [seq_len, B, hidden_size] + z: Conditioning tensor of shape [B, conditioning_dim] + attn_mask: Optional attention mask for self-attention. + key_padding_mask: Optional mask for padded positions. + """ + # Self-attention with pre-normalization using AdaLN. + # We apply AdaLN to x before attention. + x_norm = self.adaln1(x, z) + if key_padding_mask is not None: + assert attn_mask.dtype == key_padding_mask.dtype + attn_output, _ = self.self_attn( + x_norm, x_norm, x_norm, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=True + ) + x = x + self.dropout(attn_output) + + # Feed-forward network with pre-normalization. + x_norm = self.adaln2(x, z) + ff_output = self.ff(x_norm) + x = x + self.dropout(ff_output) + return x + + +class TrafficgenPredictionHead(nn.Module): + def __init__( + self, + vocab_size, + type_size, + hidden_size, + map_id_size, + num_heads, + num_layers, + conditioning_dim, + max_seq_len=512, + dropout=0.1, + ): + super().__init__() + self.max_seq_len = max_seq_len + # self.token_embedding = token_embedding + # self.map_id_embedding = common_layers.Tokenizer(vocab_size, hidden_size, add_one_more_action=False) + + self.offset_token_embedding = common_layers.Tokenizer(vocab_size, hidden_size, add_one_more_action=True) + self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len, hidden_size)) + + # ===== agent state prediction head ===== + # Stack of transformer blocks. + self.layers = nn.ModuleList( + [TransformerBlock(hidden_size, num_heads, conditioning_dim, dropout) for _ in range(num_layers)] + ) + # Final normalization (can be standard LN). + self.ln_final = nn.LayerNorm(hidden_size) + # Project back to vocabulary logits. + self.output_layer = nn.Linear(hidden_size, vocab_size) + self.type_output_layer = nn.Linear(hidden_size, type_size) + + # ===== map id prediction head ===== + d_model = hidden_size + self.map_id_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, map_id_size], ret_before_act=True, + ) + # self.dest_id_head = common_layers.build_mlps( + # c_in=d_model, mlp_channels=[d_model, map_id_size], ret_before_act=True, + # ) + + def forward(self, offset, trafficgen_token, key_padding_mask=None): + B, T, G, D = trafficgen_token.shape + assert offset.shape[0] == B + assert offset.shape[1] == T + assert offset.shape[3] == 9 + N = offset.shape[2] + + # Remove last agent dest id and sequence_eos: + trafficgen_token = trafficgen_token[:, :, 1:-1] + trafficgen_token = trafficgen_token.reshape(B, T, N, NUM_TG_MULTI, D) + + # The input tokens sequence: action_sos, map_id, agent_state, dest_map_id, action_eos + + # ===== sequence_sos/last_agent_dest_id -> agent_type ===== + agent_type_token = trafficgen_token[:, :, :, 0] + agent_type_logits = self.type_output_layer(agent_type_token) + assert agent_type_logits.shape == (B, T, N, 3) + + # ===== agent type -> map_id ===== + # To process the first token: + map_id_token = trafficgen_token[:, :, :, 1] + map_id_logit = self.map_id_head(map_id_token) + assert map_id_logit.shape[:3] == (B, T, N) + + # ===== map_id -> agent_state ===== + # To process the second token whose output is agent_state: + z = trafficgen_token[:, :, :, 2].flatten(0, 2) # B*T*N, D + offset_flattened = offset.flatten(0, 2) # B*T*N, 8 + + assert offset_flattened.dim() == 2, "Input tensor must have shape [B, seq_len]" + assert z.dim() == 2, "Conditioning tensor must have shape [B, conditioning_dim]" + + # Compute token embeddings. + emb = self.offset_token_embedding(offset_flattened) + # emb = torch.cat([torch.zeros_like(type_emb), type_emb, emb], dim=1) + seq_len = emb.size(1) + assert seq_len == 9 + emb = emb + self.pos_embedding[:, :seq_len, :] + h = emb.transpose(0, 1) + attn_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(h.device) + attn_mask = attn_mask < 0 + for layer in self.layers: + h = layer(h, z, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + # Final normalization. + h = self.ln_final(h) + # Transpose back to [B, seq_len, hidden_size] + h = h.transpose(0, 1) + agent_state_logits = self.output_layer(h) + agent_state_logits = agent_state_logits.reshape(B, T, N, 9, -1) + + # ===== agent_state -> dest_id ===== + # To process the third token whose output is dest map id: + # dest_id_token = trafficgen_token[:, :, :, 3] + # dest_id_logit = self.dest_id_head(dest_id_token) + # assert dest_id_logit.shape[:3] == (B, T, N) + + return map_id_logit, agent_type_logits, agent_state_logits, None + + @torch.no_grad() + def generate(self, z, greedy=False): + """ + Autoregressively generate a sequence conditioned on latent vector z. + + z: FloatTensor of shape [B, conditioning_dim] + max_length: Maximum length to generate (including and tokens). + greedy: If True, use argmax sampling; otherwise, sample from the distribution. + + Returns: + generated: LongTensor of shape [B, generated_seq_len] (including the starting ). + """ + assert z.ndim == 2 + + B = z.size(0) + max_length = self.max_seq_len + device = z.device + # Start each sequence with the token. + input_seq = torch.full((B, 1), -1, dtype=torch.long, device=device) + + # key_padding_mask = torch.zeros(B, 1, dtype=torch.float32, device=device) + # key_padding_valid_mask_bool = torch.ones(B, 1, dtype=torch.bool, device=device) + # key_padding_mask = (~key_padding_valid_mask_bool).clone() + + for step in range(max_length - 1): # already have one token + + assert input_seq.dim() == 2, "Input tensor must have shape [B, seq_len]" + + # Compute token embeddings. + emb = self.offset_token_embedding(input_seq) + # emb = torch.cat([torch.zeros_like(type_emb), type_emb, emb], dim=1) + seq_len = emb.size(1) + # assert seq_len == 9 + emb = emb + self.pos_embedding[:, :seq_len, :] + h = emb.transpose(0, 1) + attn_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(h.device) + attn_mask = attn_mask < 0 + for layer in self.layers: + h = layer(h, z, attn_mask=attn_mask, key_padding_mask=None) + # Final normalization. + h = self.ln_final(h) + # Transpose back to [B, seq_len, hidden_size] + h = h.transpose(0, 1) + agent_state_logits = self.output_layer(h) + last_logits = agent_state_logits[:, -1:] + + from scenestreamer.infer.scenestreamer_motion import sample_action + next_token, _ = sample_action(last_logits, sampling_method="softmax") + + # Append the predicted token. + input_seq = torch.cat([input_seq, next_token], dim=1) + return input_seq + + +class SceneStreamer(nn.Module): + def __init__(self, config): + super().__init__() + + # ===== A bunch of hyper-parameters and assertions ===== + self.config = config + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + self.num_actions = get_action_dim(self.config) + dropout = self.config.MODEL.DROPOUT + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + self.add_pe_for_token = self.config.MODEL.get('ADD_PE_FOR_TOKEN', False) + assert self.config.MODEL.NAME == "scenestreamer" + assert self.add_pe_for_token + self.use_destination = self.config.USE_DESTINATION + simple_relation = self.config.SIMPLE_RELATION + simple_relation_factor = self.config.SIMPLE_RELATION_FACTOR + is_v7 = self.config.MODEL.IS_V7 + self.is_v7 = is_v7 + assert is_v7 is True + assert simple_relation is True + assert self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE is False + self.start_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + self.end_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + self.no_tg = config.get("SCENESTREAMER_NO_TG", False) + + # ===== Build tokenizer ===== + tokenizer = get_tokenizer(self.config) + motion_features = tokenizer.get_motion_feature() + if tokenizer.use_type_specific_bins: + motion_features = torch.cat([motion_features, torch.zeros(1, 3, 4)], dim=0) + else: + motion_features = torch.cat([motion_features, torch.zeros(1, 4)], dim=0) + self.motion_tokenizer = tokenizer + + # ===== Build the relative continuous embedding ===== + relation_d_model = d_model // simple_relation_factor + self.relation_embed_4d = fourier_embedding.FourierEmbedding( + input_dim=4, hidden_dim=relation_d_model, num_freq_bands=64, + ) + self.relation_embed_3d = fourier_embedding.FourierEmbedding( + input_dim=3, hidden_dim=relation_d_model, num_freq_bands=64, + ) + self.relation_embed_1d = fourier_embedding.FourierEmbedding( + input_dim=1, hidden_dim=relation_d_model, num_freq_bands=64, + ) + + # ===== Build map features embedding ===== + self.map_encoder = SceneEncoderGPT(config=self.config, relation_embed=self.relation_embed_3d) + + # ===== Build the egocentric discrete embedding ===== + # Adding 2 tokens for trafficgen + num_total_map_actions = self.config.PREPROCESSING.MAX_MAP_FEATURES + 7 + self.map_id_embed = common_layers.Tokenizer( + num_actions=num_total_map_actions, d_model=d_model, add_one_more_action=True + ) + + trafficgen_sequence_sos_id = config.PREPROCESSING.MAX_MAP_FEATURES + trafficgen_sequence_eos_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + trafficgen_sequence_pad_id = config.PREPROCESSING.MAX_MAP_FEATURES + 2 + veh_id = config.PREPROCESSING.MAX_MAP_FEATURES + 3 + ped_id = config.PREPROCESSING.MAX_MAP_FEATURES + 4 + cyc_id = config.PREPROCESSING.MAX_MAP_FEATURES + 5 + trafficgen_agent_sos_id = config.PREPROCESSING.MAX_MAP_FEATURES + 6 + self.trafficgen_sequence_sos_id = trafficgen_sequence_sos_id + self.trafficgen_sequence_eos_id = trafficgen_sequence_eos_id + self.trafficgen_sequence_pad_id = trafficgen_sequence_pad_id + self.trafficgen_agent_sos_id = trafficgen_agent_sos_id + self.veh_id = veh_id + self.ped_id = ped_id + self.cyc_id = cyc_id + + N = 128 + G = get_num_tg(N) + + self.traffic_light_id_embed = common_layers.Tokenizer( + num_actions=self.config.PREPROCESSING.MAX_TRAFFIC_LIGHTS, d_model=self.d_model, add_one_more_action=True + ) + self.agent_id_embed = common_layers.Tokenizer( + num_actions=N, d_model=self.d_model, add_one_more_action=True + ) + self.action_embed = common_layers.Tokenizer( + num_actions=self.num_actions, d_model=d_model, add_one_more_action=True + ) + self.traffic_light_state_embed = common_layers.Tokenizer( + num_actions=4, d_model=self.d_model, add_one_more_action=True + ) + + # ===== Build the egocentric continuous embedding ===== + self.shape_embed = common_layers.build_mlps( + c_in=3, mlp_channels=[d_model, d_model], ret_before_act=True, + ) + self.motion_embed = fourier_embedding.FourierEmbedding( + input_dim=6, hidden_dim=d_model, num_freq_bands=64, + ) + self.register_buffer("motion_features", motion_features) + + # ===== Build the backbone transformer ===== + self.decoder = SceneStreamerDecoder( + decoder_layer=SceneStreamerDecoderLayer( + d_model=d_model, + nhead=self.num_heads, + dropout=dropout, + ), + num_layers=num_decoder_layers, + d_model=d_model, + ) + + # ===== Build the output head for different modalities ===== + num_traffic_light_states = 4 + self.traffic_light_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, num_traffic_light_states], ret_before_act=True, is_v7=is_v7, + zero_init=is_v7 + ) + self.traffic_light_prenorm = nn.LayerNorm(d_model) + + self.motion_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, self.num_actions], ret_before_act=True, is_v7=is_v7, zero_init=is_v7 + ) + if self.config.MODEL.USE_MOTION_HEAD_PRENORM: + self.motion_prenorm = nn.LayerNorm(d_model) + else: + self.motion_prenorm = None + + if self.no_tg: + pass + + else: + self.trafficgen_intra_step = common_layers.Tokenizer( + num_actions=G, d_model=self.d_model, add_one_more_action=False + ) + self.trafficgen_feat_embed = common_layers.build_mlps( + c_in=8, mlp_channels=[d_model, d_model], ret_before_act=True, + ) + self.trafficgen_head = TrafficgenPredictionHead( + vocab_size=TrafficGenTokenizerAutoregressive.num_bins["position_x"], + type_size=3, + hidden_size=self.d_model, + map_id_size=num_total_map_actions, + num_heads=4, + num_layers=3, + conditioning_dim=self.d_model, + max_seq_len=8 + 1, + dropout=0.0 + ) + self.trafficgen_prenorm = nn.LayerNorm(d_model) + self.trafficgen_tokenizer = TrafficGenTokenizerAutoregressive(self.config) + + def forward(self, input_dict): + + # ===== Build up some variables ===== + # in_evaluation = input_dict["in_evaluation"][0].item() + B, T, N = input_dict["decoder/input_action"].shape + _, _, L = input_dict["encoder/traffic_light_state"].shape + _, _, G = input_dict["decoder/input_action_for_trafficgen"].shape + + # ===== Prepare map tokens ===== + input_dict = self.prepare_map_tokens(input_dict) + + # ===== Prepare traffic light tokens ===== + input_dict = self.prepare_traffic_light_tokens(input_dict) + + + if self.no_tg: + # ===== Prepare trafficgen tokens ===== + # input_dict = self.prepare_trafficgen_tokens(input_dict) + + # ===== Prepare motion tokens ===== + input_dict = self.prepare_motion_tokens(input_dict) + + # ===== Prepare traffic light relation ===== + input_dict = self.prepare_dynamic_relation_notg(input_dict) + + else: + + # ===== Prepare trafficgen tokens ===== + input_dict = self.prepare_trafficgen_tokens(input_dict) + + # ===== Prepare motion tokens ===== + input_dict = self.prepare_motion_tokens(input_dict) + + # ===== Prepare traffic light relation ===== + input_dict = self.prepare_dynamic_relation(input_dict) + + # ===== Call the decoder ===== + # TODO: dest is not conditioning on anyone. + output_dict = self.decoder(input_dict=input_dict) + + # ===== Deal with the output ===== + all_token = output_dict["model/all_token"] + + # ===== Traffic light head ===== + traffic_light_token = [] + + if self.no_tg: + total = N + L + pointer = 0 + for t in range(T): + traffic_light_token.append(all_token[:, pointer: pointer+L]) + pointer += total + traffic_light_token = torch.stack(traffic_light_token, dim=1) + debug_traffic_light_token = traffic_light_token + + else: + total = N + G + L + pointer = 0 + for t in range(T): + traffic_light_token.append(all_token[:, pointer: pointer+L]) + if t % TG_SKIP_STEP == 0: + pointer += total + else: + pointer += N + L + assert pointer == all_token.shape[1] + traffic_light_token = torch.stack(traffic_light_token, dim=1) + traffic_light_token = self.traffic_light_prenorm(traffic_light_token) + traffic_light_token = self.traffic_light_head(traffic_light_token) + output_dict["model/traffic_light_logit"] = traffic_light_token + + # ===== Trafficgen head ===== + if self.no_tg: + pass + + else: + trafficgen_output_token = [] + pointer = 0 + for t in range(T): + if t % TG_SKIP_STEP == 0: + trafficgen_output_token.append(all_token[:, pointer + L: pointer + L + G]) + pointer += total + else: + pointer += N + L + assert pointer == all_token.shape[1] + trafficgen_output_token = torch.stack(trafficgen_output_token, dim=1) + + trafficgen_output_token = self.trafficgen_prenorm(trafficgen_output_token) + + from scenestreamer.dataset.preprocessor import slice_trafficgen_data + map_id_logit, agent_type_logits, agent_state_logits, dest_id_logit = self.trafficgen_head( + offset=slice_trafficgen_data(input_dict["decoder/input_offset_for_trafficgen"], dim=1), + trafficgen_token=trafficgen_output_token, + ) + output_dict["model/trafficgen_map_id_logit"] = map_id_logit + output_dict["model/trafficgen_agent_type_logit"] = agent_type_logits + output_dict["model/trafficgen_agent_state_logit"] = agent_state_logits + output_dict["model/trafficgen_dest_id_logit"] = dest_id_logit + + # output_dict["model/trafficgen_output_token"] = trafficgen_output_token + + # ===== Motion head ===== + if self.no_tg: + motion_token = [] + pointer = 0 + for t in range(T): + motion_token.append(all_token[:, pointer+L: pointer+L+N]) + pointer += N + L + assert pointer == all_token.shape[1] + motion_token = torch.stack(motion_token, dim=1) + else: + motion_token = [] + pointer = 0 + for t in range(T): + motion_token.append(all_token[:, pointer+L+G: pointer+L+G+N]) + if t % TG_SKIP_STEP == 0: + pointer += total + else: + pointer += N + L + assert pointer == all_token.shape[1] + motion_token = torch.stack(motion_token, dim=1) + + # assert (debug_traffic_light_token == all_token.reshape(B, T, N + L, -1)[:, :, :L]).all() + # assert (motion_token == all_token.reshape(B, T, N + L, -1)[:, :, L:]).all() + + # TODO: dest is not conditioning on anyone. + if self.motion_prenorm is not None: + motion_token = self.motion_prenorm(motion_token) + motion_token = self.motion_head(motion_token) + output_dict["model/motion_logit"] = motion_token + + return input_dict + + def prepare_map_tokens(self, input_dict): + + # ===== Get shape ===== + B, M, num_vector, D_vector = input_dict["encoder/map_feature"].shape + + # ===== Embed map feature ===== + map_feature = input_dict["encoder/map_feature"] + map_valid_mask = input_dict["encoder/map_feature_valid_mask"] + map_position = input_dict["encoder/map_position"] + map_heading = input_dict["encoder/map_heading"] + map_token_valid_mask = input_dict["encoder/map_valid_mask"] + # map_token = self.map_polyline_encoder(map_feature, map_valid_mask) + input_dict = self.map_encoder(input_dict=input_dict) + map_token = input_dict["model/map_token"] + assert map_token.shape == (B, M, self.d_model) + map_id = torch.arange(M, device=map_feature.device).unsqueeze(0).expand(B, M).clone() + map_id[~map_token_valid_mask] = -1 + map_id_pe = self.map_id_embed(map_id) + + # egocentric discrete embedding + assert map_token.shape == (B, M, self.d_model), (map_token.shape, B, M, self.d_model) + assert map_id_pe.shape == (B, M, self.d_model), (map_id_pe.shape, B, M, self.d_model) + map_token = map_token + map_id_pe + + input_dict["model/map_token"] = map_token + input_dict["model/map_token_position"] = map_position + input_dict["model/map_token_heading"] = map_heading + input_dict["model/map_token_valid_mask"] = map_token_valid_mask + + return input_dict + + def prepare_traffic_light_tokens(self, input_dict): + + traffic_light_state = input_dict["encoder/traffic_light_state"] + traffic_light_map_id = input_dict["encoder/traffic_light_map_id"] + traffic_light_position = input_dict["encoder/traffic_light_position"] + traffic_light_heading = input_dict["encoder/traffic_light_heading"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + + B, T, L = traffic_light_state.shape + + tl_id = torch.arange(L, device=traffic_light_state.device).reshape(1, 1, L).expand(B, T, L).clone() + tl_id[~traffic_light_valid_mask] = -1 + tl_id_pe = self.traffic_light_id_embed(tl_id) + + tl_map_id_pe = self.map_id_embed(traffic_light_map_id) + tl_map_id_pe = tl_map_id_pe.unsqueeze(1).expand(B, T, L, self.d_model) + + light_tokens = self.traffic_light_state_embed(traffic_light_state) + light_tokens = light_tokens + tl_map_id_pe + light_tokens = light_tokens + tl_id_pe + + input_dict["model/traffic_light_token"] = light_tokens + input_dict["model/traffic_light_token_position"] = traffic_light_position.unsqueeze(1).expand(B, T, L, traffic_light_position.shape[-1]) + input_dict["model/traffic_light_token_heading"] = traffic_light_heading.unsqueeze(1).expand(B, T, L) + input_dict["model/traffic_light_token_valid_mask"] = traffic_light_valid_mask + + require_relation = traffic_light_valid_mask.clone() + input_dict["model/traffic_light_require_relation"] = require_relation + + return input_dict + + def _pad_sos_eos(self, tensor, dim, value): + assert dim == 2 + return torch.cat([torch.full_like(tensor[:, :, :1], value), tensor, torch.full_like(tensor[:, :, :1], value)], + dim=dim) + + def prepare_trafficgen_tokens(self, input_dict): + + # ===== Agent Tokens ===== + B, T, N, _ = input_dict["decoder/modeled_agent_position"].shape + + # tg_start, agent_start, agent_map_id, agent_state, agent_dest, agent_end, ..., tg_end + G = get_num_tg(N) + + tg_input_action = input_dict["decoder/input_action_for_trafficgen"] + assert tg_input_action.shape == (B, T, G), (B, T, G, tg_input_action.shape) + + assert tg_input_action.max() <= self.trafficgen_agent_sos_id, (tg_input_action.max(), self.cyc_id) + + tg_intra_step = torch.arange(G, device=tg_input_action.device).reshape(1, 1, G).expand(B, T, G).clone() + tg_intra_step[tg_intra_step > self.trafficgen_intra_step.num_actions] = -1 + tg_intra_step_emb = self.trafficgen_intra_step(tg_intra_step) + + if input_dict["decoder/agent_type_for_trafficgen"].max().item() not in [self.veh_id, self.cyc_id, self.ped_id]: + print("WARNING: agent type is not veh, cyc, ped, it is: ", input_dict["decoder/agent_type_for_trafficgen"].max().item(), input_dict["scenario_id"], self.veh_id, self.cyc_id, self.ped_id) + type_emb = self.map_id_embed(input_dict["decoder/agent_type_for_trafficgen"]) + assert type_emb.shape == (B, T, G, self.d_model) + + # shape_emb = self.shape_embed(input_dict["decoder/current_agent_shape_for_trafficgen"]) + # assert shape_emb.shape == (B, G, self.d_model), (B, G, self.d_model, shape_emb.shape) + # shape_emb = shape_emb.unsqueeze(1).expand(B, T, G, self.d_model) + + modeled_agent_id = mode_agent_id(input_dict["decoder/agent_id_for_trafficgen"], 128, fill_negative_1=True) + agent_id_emb = self.agent_id_embed(modeled_agent_id) + assert agent_id_emb.shape == (B, T, G, self.d_model), (B, G, self.d_model, agent_id_emb.shape) + + tg_action_emb = self.map_id_embed(tg_input_action) + assert tg_action_emb.shape == (B, T, G, self.d_model), (B, T, G, self.d_model, tg_action_emb.shape) + + tg_feat_emb = self.trafficgen_feat_embed(input_dict["decoder/input_action_feature_for_trafficgen"]) + assert tg_action_emb.shape == (B, T, G, self.d_model), (B, T, G, self.d_model, tg_action_emb.shape) + + tg_tokens = tg_intra_step_emb + type_emb + agent_id_emb + tg_action_emb + tg_feat_emb + + input_dict["model/trafficgen_token"] = tg_tokens + # pad 0 before and after the sequence + + # TODO: hardcoded 5, 6 + tg_length = input_dict["decoder/input_action_feature_for_trafficgen"][..., 5] + tg_width = input_dict["decoder/input_action_feature_for_trafficgen"][..., 6] + + input_dict["model/trafficgen_position"] = input_dict["decoder/trafficgen_position"] + input_dict["model/trafficgen_heading"] = input_dict["decoder/trafficgen_heading"] + input_dict["model/trafficgen_valid_mask"] = input_dict["decoder/input_action_valid_mask_for_trafficgen"] + input_dict["model/trafficgen_width"] = tg_width + input_dict["model/trafficgen_length"] = tg_length + + require_relation = torch.ones(B, T, N, NUM_TG_MULTI, device=tg_input_action.device, dtype=torch.bool) + require_relation[:, :, :, 0] = 0 # sos + require_relation[:, :, :, 1] = 0 # agent type + require_relation[:, :, :, 2] = 1 # map_id (map pos) + require_relation[:, :, :, 3] = 1 # agent_state (agent pos) + # require_relation[:, :, :, 4] = 1 # dest_id (dest pos) + require_relation = torch.cat([ + torch.zeros(B, T, 1, device=tg_input_action.device, dtype=torch.bool), + require_relation.flatten(2, 3), + torch.zeros(B, T, 1, device=tg_input_action.device, dtype=torch.bool) + ], dim=2) + + input_dict["model/trafficgen_require_relation"] = require_relation & input_dict["model/trafficgen_valid_mask"] + return input_dict + + def prepare_trafficgen_single_token( + self, *, tg_intra_step, tg_type, tg_agent_id, tg_action, tg_feat + ): + assert tg_intra_step.ndim == 2 + assert tg_type.ndim == 2 + assert tg_agent_id.ndim == 2 + assert tg_action.ndim == 2 + assert tg_feat.ndim == 3 + tg_intra_step = mode_agent_id(tg_intra_step, max_agents=self.trafficgen_intra_step.num_actions, fill_negative_1=True) + tg_intra_step_emb = self.trafficgen_intra_step(tg_intra_step) + if (tg_type!=-1).any(): + assert tg_type[tg_type!=-1].min() >= self.veh_id + type_emb = self.map_id_embed(tg_type) + tg_agent_id = mode_agent_id(tg_agent_id, max_agents=128, fill_negative_1=True) + agent_id_emb = self.agent_id_embed(tg_agent_id) + tg_action_emb = self.map_id_embed(tg_action) + tg_feat_emb = self.trafficgen_feat_embed(tg_feat) + tg_tokens = tg_intra_step_emb + type_emb + agent_id_emb + tg_action_emb + tg_feat_emb + return tg_tokens + + def prepare_motion_tokens(self, input_dict): + + # === Process action embedding === + input_action = input_dict["decoder/input_action"] + modeled_agent_delta = input_dict["decoder/modeled_agent_delta"] + B, T_skipped, N = input_action.shape[:3] + + agent_id = input_dict["encoder/modeled_agent_id"].reshape(B, 1, N).expand(B, T_skipped, N) + agent_id = mode_agent_id(agent_id, 128, fill_negative_1=True) + agent_id_emb = self.agent_id_embed(agent_id) + + assert agent_id_emb.shape == (B, T_skipped, N, self.d_model), ( + B, T_skipped, N, self.d_model, agent_id_emb.shape) + + action_valid_mask = input_dict["decoder/input_action_valid_mask"] + assert action_valid_mask.shape == (B, T_skipped, N), (action_valid_mask.shape, (B, T_skipped, N)) + agent_pos = input_dict["decoder/modeled_agent_position"] + agent_heading = input_dict["decoder/modeled_agent_heading"] + # agent_vel = input_dict["decoder/modeled_agent_velocity"] + + # ===== Prepare input tokens ===== + # assert input_dict["decoder/agent_type"].min() == self.veh_id, (input_dict["decoder/agent_type"].min(), self.veh_id) + type_emb = self.map_id_embed(input_dict["decoder/agent_type"])[:, None].expand(B, T_skipped, N, self.d_model) + shape_emb = self.shape_embed(input_dict["decoder/current_agent_shape"])[:, None].expand(B, T_skipped, N, + self.d_model) + + valid_action = input_action[action_valid_mask] + valid_action[valid_action == START_ACTION] = -1 + valid_action_emb = self.action_embed(valid_action) + + motion_feat = self.motion_features.reshape(1, -1, 4).expand(valid_action_emb.shape[0], -1, 4) + + valid_action[valid_action < 0] = self.num_actions + valid_action = valid_action.reshape(-1, 1, 1).expand(-1, 1, 4) + assert motion_feat.shape[-2] > valid_action.max() + assert valid_action.min() >= 0 + motion_feat = torch.gather(motion_feat, dim=-2, index=valid_action).squeeze(-2) + + motion_feat = torch.cat([motion_feat, modeled_agent_delta[action_valid_mask]], dim=-1) + + action_token = self.motion_embed( + continuous_inputs=motion_feat, + categorical_embs=[ + agent_id_emb[action_valid_mask], type_emb[action_valid_mask], + shape_emb[action_valid_mask], valid_action_emb + ] + ) + action_token = utils.unwrap(action_token, action_valid_mask) + assert action_token.shape == (B, T_skipped, N, self.d_model) + assert action_valid_mask.shape == (B, T_skipped, N) + + input_dict["model/motion_token"] = action_token + input_dict["model/motion_token_valid_mask"] = action_valid_mask + input_dict["model/motion_token_position"] = agent_pos + input_dict["model/motion_token_heading"] = agent_heading + + T = input_dict["decoder/input_action"].shape[1] + shape = input_dict["decoder/current_agent_shape"].unsqueeze(1).expand(B, T, N, 3) + length = shape[:, :, :, 0] + width = shape[:, :, :, 1] + input_dict["model/motion_token_width"] = width + input_dict["model/motion_token_length"] = length + + require_relation = action_valid_mask.clone() + input_dict["model/motion_require_relation"] = require_relation + return input_dict + + def _build_all_tokens_mask_for_tl(self, B, T, num_tl, num_tg, num_motion): + """ + Recall that in our design, traffic light tokens attend to: + 1) map (not considered here), + 2) itself (self-attention), + 3) all previous traffic light tokens, + 4) last step motion tokens. + + The ultimate output should be a mask with shape: B, Q, K, where + Q = T*num_tl, K = T*(num_tl + num_tg + num_motion). + """ + total_tokens_per_step = num_tl + num_tg + num_motion + tl_mask = [] + for t in range(T): + mask = torch.zeros(B, num_tl, T, total_tokens_per_step, dtype=torch.bool) + mask[:, :, t, :num_tl] = True # self-attention + + # previous traffic light tokens + mask[:, :, :t, :num_tl] = torch.diag(torch.ones(num_tl)).bool().unsqueeze(1) + + if t > 0: + mask[:, :, t - 1, num_tl + num_tg:] = True # last step motion tokens + tl_mask.append(mask) + return tl_mask # .flatten(1, 2) + + def _build_all_tokens_mask_for_tg(self, B, T, num_tl, num_tg, num_motion): + """ + Recall that in our design, traffic light tokens attend to: + 1) map (not considered here), + 2) itself (self-attention and WITH CASUAL MASK!!), + 3) current step traffic light tokens, + 4) last step motion tokens. + 5) all previous step trafficgen tokens. + + The ultimate output should be a mask with shape: B, Q, K, where + Q = T*num_tg, K = T*(num_tl + num_tg + num_motion). + """ + total_tokens_per_step = num_tl + num_tg + num_motion + tg_mask = [] + intra_step_causal_mask = create_causal_mask(T=num_tg, N=1, is_valid_mask=True) + diag = torch.diag(torch.ones(num_motion)).bool() + diag_tg = torch.diag(torch.ones(num_tg)).bool() + diag_rep = diag[:, None, :, None].repeat(1, NUM_TG_MULTI, 1, NUM_TG_MULTI).flatten(-2, -1).flatten(0, 1) + for t in range(T): + mask = torch.zeros(B, num_tg, T, total_tokens_per_step, dtype=torch.bool) + mask[:, :, t, num_tl:num_tl + num_tg] = intra_step_causal_mask # self-attention + mask[:, :, t, :num_tl] = True # current step traffic light tokens + if t > 0: + mask[:, :, t - 1, num_tl + num_tg:] = True # last step motion tokens + mask[:, 1:-1, :t, num_tl+1:num_tl + num_tg-1] = diag_rep[:, None] # all previous step trafficgen tokens + mask[:, :, :t, num_tl:num_tl + num_tg] = diag_tg.unsqueeze(1) + mask[:, :, :t, num_tl + num_tg-1:num_tl + num_tg] = True # all token attend to previous eos token + mask[:, :, :t, num_tl:num_tl + 1] = True # all token attend to previous sos token + tg_mask.append(mask) + return tg_mask # .flatten(1, 2) + + def _build_all_tokens_mask_for_motion(self, B, T, num_tl, num_tg, num_motion): + """ + Recall that in our design, traffic light tokens attend to: + 1) map (not considered here), + 2) itself (self-attention), + 3) current step traffic light tokens, + 4) current step trafficgen tokens, + 5) all previous step motion tokens. + + The ultimate output should be a mask with shape: B, Q, K, where + Q = T*num_motion, K = T*(num_tl + num_tg + num_motion). + """ + total_tokens_per_step = num_tl + num_tg + num_motion + motion_mask = [] + diag = torch.diag(torch.ones(num_motion)).bool() + diag_rep = diag[..., None].repeat(1, 1, NUM_TG_MULTI).flatten(1, 2) + for t in range(T): + mask = torch.zeros(B, num_motion, T, total_tokens_per_step, dtype=torch.bool) + mask[:, :, t, num_tl + num_tg:] = True # self-attention + mask[:, :, t, :num_tl] = True # current step traffic light tokens + if num_tg > 0: + mask[:, :, :t+1, num_tl + 1:num_tl + num_tg - 1] = diag_rep.unsqueeze(1) # current step trafficgen tokens + mask[:, :, :t+1, num_tl:num_tl + 1] = True # current step trafficgen tokens + mask[:, :, :t+1, num_tl+num_tg-1:num_tl + num_tg] = True # current step trafficgen tokens + # all previous step motion tokens FOR EACH AGENT + mask[:, :, :t, num_tl + num_tg:] = diag.unsqueeze(1) + motion_mask.append(mask) + return motion_mask # .flatten(1, 2) + + def _build_force_mask_for_tl(self, B, T, num_tl, num_tg, num_motion): + """ + You must attend to your own history + """ + total_tokens_per_step = num_tl + num_tg + num_motion + tl_mask = [] + for t in range(T): + mask = torch.zeros(B, num_tl, T, total_tokens_per_step, dtype=torch.bool) + mask[:, :, :t, :num_tl] = torch.diag(torch.ones(num_tl)).bool().unsqueeze(1) + tl_mask.append(mask) + return tl_mask + + def _build_force_mask_for_tg(self, B, T, num_tl, num_tg, num_motion): + assert num_tg > 0 + total_tokens_per_step = num_tl + num_tg + num_motion + tg_mask = [] + diag = torch.diag(torch.ones(num_motion)).bool() + diag_rep = diag[:, None, :, None].repeat(1, NUM_TG_MULTI, 1, NUM_TG_MULTI).flatten(-2, -1).flatten(0, 1) + diag = diag[:, None, ].repeat(1, NUM_TG_MULTI, 1).flatten(0, 1) + for t in range(T): + mask = torch.zeros(B, num_tg, T, total_tokens_per_step, dtype=torch.bool) + if t > 0: + mask[:, 1:-1, :t, num_tl + num_tg:] = diag[:, None] # history motion tokens + mask[:, 1:-1, :t, num_tl + 1:num_tl + num_tg - 1] = diag_rep[:, None] # history tg tokens (only the same agent) + mask[:, :1, t - 1, num_tl + num_tg:] = True # sos attends history motion token + tg_mask.append(mask) + return tg_mask + + def _build_force_mask_for_motion(self, B, T, num_tl, num_tg, num_motion): + """ + You must attend to your own history. + """ + total_tokens_per_step = num_tl + num_tg + num_motion + motion_mask = [] + diag = torch.diag(torch.ones(num_motion)).bool() + diag_rep = diag[..., None].repeat(1, 1, NUM_TG_MULTI).flatten(1, 2) + for t in range(T): + mask = torch.zeros(B, num_motion, T, total_tokens_per_step, dtype=torch.bool) + if num_tg > 0: + mask[:, :, :t+1, num_tl + 1:num_tl + num_tg - 1] = diag_rep[:, None] # attend to all prev ego TG tokens. + mask[:, :, :t+1, num_tl:num_tl + 1] = True # current step trafficgen tokens sos + mask[:, :, :t+1, num_tl+num_tg-1:num_tl + num_tg] = True # current step trafficgen tokens eos + mask[:, :, :t, num_tl + num_tg:] = diag.unsqueeze(1) + motion_mask.append(mask) + return motion_mask # .flatten(1, 2) + + def _build_all_tokens_mask(self, B, T, num_tl, num_tg, num_motion): + + if self.no_tg: + tl_mask = self._build_all_tokens_mask_for_tl(B, T, num_tl, 0, num_motion) + tl_mask = torch.stack(tl_mask, dim=1).flatten(3, 4) + motion_mask = self._build_all_tokens_mask_for_motion(B, T, num_tl, 0, num_motion) + motion_mask = torch.stack(motion_mask, dim=1).flatten(3, 4) + all_mask = torch.cat([tl_mask, motion_mask], dim=2) + notg_all_mask = all_mask.flatten(1, 2) + assert notg_all_mask.shape[1] == (num_tl + num_motion) * T, (notg_all_mask.shape[1], num_tl, num_motion, T) + return notg_all_mask + + assert self.no_tg is False, "This function is only for no_tg = False" + with_tg = num_tl + num_tg + num_motion + without_tg = num_tl + num_motion + + total_tokens_so_far = with_tg + + tl_mask = self._build_all_tokens_mask_for_tl(B, T, num_tl, num_tg, num_motion) + tg_mask = self._build_all_tokens_mask_for_tg(B, T, num_tl, num_tg, num_motion) + motion_mask = self._build_all_tokens_mask_for_motion(B, T, num_tl, num_tg, num_motion) + + tl_mask = torch.stack(tl_mask, dim=1).flatten(3, 4) + tg_mask = torch.stack(tg_mask, dim=1).flatten(3, 4) + motion_mask = torch.stack(motion_mask, dim=1).flatten(3, 4) + + all_mask = torch.cat([tl_mask, tg_mask, motion_mask], dim=2) + + # all_mask shape is [B, T, L+G+N, T*(L+G+N)] + + new_mask = [] + + # [B, T, L+G+N, T*(L+G+N)] -> [B, correct_num_of_tokens, T*(L+G+N)] + for t in range(T): + if t % TG_SKIP_STEP == 0: + new_mask.append(all_mask[:, t, :, :]) + else: + cat = torch.cat([all_mask[:, t, :num_tl, :], all_mask[:, t, num_tl + num_tg:, :]], dim=1) + new_mask.append(cat) + new_mask = torch.cat(new_mask, dim=1) + + + full_mask = [] + for t in range(T): + if t % TG_SKIP_STEP == 0: + full_mask.append(new_mask[:, :, t*with_tg: (t+1)*with_tg]) + else: + full_mask.append(new_mask[:, :, t*with_tg:t*with_tg+num_tl]) + full_mask.append(new_mask[:, :, t*with_tg+num_tl+num_tg:(t+1)*with_tg]) + full_mask = torch.cat(full_mask, dim=2) + assert full_mask.shape[1] == full_mask.shape[2], (full_mask.shape[1], full_mask.shape[2]) + + # import matplotlib.pyplot as plt + # vis = full_mask[0].cpu().numpy() + # plt.imshow(vis) + + return full_mask + + def _build_all_force_mask(self, B, T, num_tl, num_tg, num_motion): + if self.no_tg: + tl_mask = self._build_force_mask_for_tl(B, T, num_tl, 0, num_motion) + tl_mask = torch.stack(tl_mask, dim=1).flatten(3, 4) + motion_mask = self._build_force_mask_for_motion(B, T, num_tl, 0, num_motion) + motion_mask = torch.stack(motion_mask, dim=1).flatten(3, 4) + all_mask = torch.cat([tl_mask, motion_mask], dim=2) + notg_all_mask = all_mask.flatten(1, 2) + assert notg_all_mask.shape[1] == (num_tl + num_motion) * T, (notg_all_mask.shape[1], num_tl, num_motion, T) + return notg_all_mask + + assert self.no_tg is False, "This function is only for no_tg = False" + with_tg = num_tl + num_tg + num_motion + tl_mask = self._build_force_mask_for_tl(B, T, num_tl, num_tg, num_motion) + tg_mask = self._build_force_mask_for_tg(B, T, num_tl, num_tg, num_motion) + motion_mask = self._build_force_mask_for_motion(B, T, num_tl, num_tg, num_motion) + tl_mask = torch.stack(tl_mask, dim=1).flatten(3, 4) + tg_mask = torch.stack(tg_mask, dim=1).flatten(3, 4) + motion_mask = torch.stack(motion_mask, dim=1).flatten(3, 4) + all_mask = torch.cat([tl_mask, tg_mask, motion_mask], dim=2) + new_mask = [] + for t in range(T): + if t % TG_SKIP_STEP == 0: + new_mask.append(all_mask[:, t, :, :]) + else: + cat = torch.cat([all_mask[:, t, :num_tl, :], all_mask[:, t, num_tl + num_tg:, :]], dim=1) + new_mask.append(cat) + new_mask = torch.cat(new_mask, dim=1) + full_mask = [] + for t in range(T): + if t % TG_SKIP_STEP == 0: + full_mask.append(new_mask[:, :, t*with_tg: (t+1)*with_tg]) + else: + full_mask.append(new_mask[:, :, t*with_tg:t*with_tg+num_tl]) + full_mask.append(new_mask[:, :, t*with_tg+num_tl+num_tg:(t+1)*with_tg]) + full_mask = torch.cat(full_mask, dim=2) + assert full_mask.shape[1] == full_mask.shape[2], (full_mask.shape[1], full_mask.shape[2]) + return full_mask + + def prepare_dynamic_relation_notg(self, input_dict): + + map_position = input_dict["model/map_token_position"] + map_heading = input_dict["model/map_token_heading"] + map_token_valid_mask = input_dict["model/map_token_valid_mask"] + + # traffic light tokens + traffic_light_position = input_dict["model/traffic_light_token_position"] + traffic_light_heading = input_dict["model/traffic_light_token_heading"] + traffic_light_token = input_dict["model/traffic_light_token"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + traffic_light_width = torch.zeros_like(traffic_light_position[..., 0]) + traffic_light_length = torch.zeros_like(traffic_light_position[..., 0]) + traffic_light_require_relation = input_dict["model/traffic_light_require_relation"] + B, T, L, _ = traffic_light_token.shape + + # motion tokens + motion_token = input_dict["model/motion_token"] + motion_token_valid_mask = input_dict["model/motion_token_valid_mask"] + motion_token_position = input_dict["model/motion_token_position"] + motion_token_heading = input_dict["model/motion_token_heading"] + motion_token_width = input_dict["model/motion_token_width"] + motion_token_length = input_dict["model/motion_token_length"] + motion_token_require_relation = input_dict["model/motion_require_relation"] + N = motion_token.shape[2] + + # build giant tensors for all "dynamic" tokens to serve as the key/value + all_tokens = torch.cat([ + traffic_light_token, + motion_token, + ], dim=2) + all_positions = torch.cat([ + traffic_light_position[..., :2], + motion_token_position[..., :2], + ], dim=2) + all_headings = torch.cat([ + traffic_light_heading, + motion_token_heading, + ], dim=2) + all_valid_masks = torch.cat([ + traffic_light_valid_mask, + motion_token_valid_mask, + ], dim=2) + all_widths = torch.cat([ + traffic_light_width, + motion_token_width, + ], dim=2) + all_lengths = torch.cat([ + traffic_light_length, + motion_token_length, + ], dim=2) + all_require_relation = torch.cat([ + traffic_light_require_relation, + motion_token_require_relation, + ], dim=2) + # all_require_relation = all_require_relation & all_valid_masks + + giant_N = all_tokens.shape[2] + all_steps = torch.arange(T).to(traffic_light_position.device).reshape(1, T, 1) # .expand(B, T, giant_N) + + # ===== Build causal mask for traffic light tokens ===== + tl_causal_mask = torch.stack( + self._build_all_tokens_mask_for_tl(B=B, T=T, num_tl=L, num_tg=0, num_motion=N), dim=1 + ).flatten(-2, -1).to(traffic_light_position.device) + + # ===== Build causal mask for motion tokens ===== + motion_causal_mask = torch.stack( + self._build_all_tokens_mask_for_motion(B=B, T=T, num_tl=L, num_tg=0, num_motion=N), dim=1 + ).flatten(-2, -1).to(traffic_light_position.device) + + all_causal_mask = torch.cat([tl_causal_mask, motion_causal_mask], dim=2) + + # ===== Build causal mask for traffic light tokens ===== + tl_force_mask = torch.stack( + self._build_force_mask_for_tl(B=B, T=T, num_tl=L, num_tg=0, num_motion=N), dim=1 + ).flatten(-2, -1).to(traffic_light_position.device) + motion_force_mask = torch.stack( + self._build_force_mask_for_motion(B=B, T=T, num_tl=L, num_tg=0, num_motion=N), dim=1 + ).flatten(-2, -1).to(traffic_light_position.device) + all_force_mask = torch.cat([tl_force_mask, motion_force_mask], dim=2) + + # import matplotlib.pyplot as plt + # vis = all_causal_mask[0].flatten(0, 1).cpu().numpy() + # plt.imshow(vis) + + relation_all_to_all, relation_valid_mask, require_relation_pairwise = relation.compute_relation_for_scenestreamer( + query_pos=all_positions.flatten(1, 2), + query_heading=all_headings.flatten(1, 2), + query_valid_mask=all_valid_masks.flatten(1, 2), + query_step=all_steps.expand(B, T, L + N).flatten(1, 2), + key_pos=all_positions.flatten(1, 2), + key_heading=all_headings.flatten(1, 2), + key_valid_mask=all_valid_masks.flatten(1, 2), + key_step=all_steps.expand(B, T, L + N).flatten(1, 2), + causal_valid_mask=all_causal_mask.flatten(1, 2), + + force_attention_mask=all_force_mask.flatten(1, 2), + + require_relation=all_require_relation.flatten(1, 2), + + knn=self.config.SCENESTREAMER_ATTENTION_KNN, + max_distance=self.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE, + + gather=False, + query_width=None, # set query's w/l to 0 so that we get the rel of contour of key w.r.t. center of query + query_length=None, + key_width=all_widths.flatten(1, 2), + key_length=all_lengths.flatten(1, 2), + non_agent_relation=True, + ) + relation_all_to_all = get_edge_info_for_scenestreamer( + q_k_relation=relation_all_to_all, + q_k_valid_mask=relation_valid_mask, + relation_model=self.relation_embed_4d, + relation_model_1d=self.relation_embed_1d, + require_relation_pairwise=require_relation_pairwise, + ) + relation_all_to_map = self._get_relation_for_4d_token_vs_map_token( + token_4d_pos=all_positions, + token_4d_heading=all_headings, + token_4d_valid_mask=all_valid_masks, + token_4d_step=all_steps.expand(B, T, L + N), + # token4d_length=all_lengths, + # token4d_width=all_widths, + map_pos=map_position, + map_heading=map_heading, + map_valid_mask=map_token_valid_mask, + knn=self.config.SCENESTREAMER_ATTENTION_KNN, + max_distance=self.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE, + require_relation=all_require_relation, + ) + + # import matplotlib.pyplot as plt + # vis = all_causal_mask.flatten(1, 2)[0].cpu().numpy() + # plt.imshow(vis) + + input_dict["model/all_token"] = all_tokens.flatten(1, 2) + input_dict["model/all_to_all_info"] = relation_all_to_all + input_dict["model/all_to_map_info"] = relation_all_to_map + return input_dict + + + + def prepare_dynamic_relation(self, input_dict): + + map_position = input_dict["model/map_token_position"] + map_heading = input_dict["model/map_token_heading"] + map_token_valid_mask = input_dict["model/map_token_valid_mask"] + + # traffic light tokens + traffic_light_position = input_dict["model/traffic_light_token_position"] + traffic_light_heading = input_dict["model/traffic_light_token_heading"] + traffic_light_token = input_dict["model/traffic_light_token"] + traffic_light_valid_mask = input_dict["encoder/traffic_light_valid_mask"] + traffic_light_width = torch.zeros_like(traffic_light_position[..., 0]) + traffic_light_length = torch.zeros_like(traffic_light_position[..., 0]) + traffic_light_require_relation = input_dict["model/traffic_light_require_relation"] + B, T, L, _ = traffic_light_token.shape + + # trafficgen tokens + tg_tokens = input_dict["model/trafficgen_token"] + tg_position = input_dict["model/trafficgen_position"] + tg_heading = input_dict["model/trafficgen_heading"] + tg_valid_mask = input_dict["model/trafficgen_valid_mask"] + tg_width = input_dict["model/trafficgen_width"] + tg_length = input_dict["model/trafficgen_length"] + tg_require_relation = input_dict["model/trafficgen_require_relation"] + _, _, G, _ = tg_tokens.shape + + # motion tokens + motion_token = input_dict["model/motion_token"] + motion_token_valid_mask = input_dict["model/motion_token_valid_mask"] + motion_token_position = input_dict["model/motion_token_position"] + motion_token_heading = input_dict["model/motion_token_heading"] + motion_token_width = input_dict["model/motion_token_width"] + motion_token_length = input_dict["model/motion_token_length"] + motion_token_require_relation = input_dict["model/motion_require_relation"] + N = motion_token.shape[2] + + # build giant tensors for all "dynamic" tokens to serve as the key/value + def _concat(tl_tokens, tg_tokens, motion_tokens): + assert tl_tokens.shape[:3] == (B, T, L), (tl_tokens.shape, (B, T, L)) + assert tg_tokens.shape[:3] == (B, T, G), (tg_tokens.shape, (B, T, G)) + assert motion_tokens.shape[:3] == (B, T, N), (motion_tokens.shape, (B, T, N)) + ret = [] + for t in range(T): + ret.append(tl_tokens[:, t]) + if t % TG_SKIP_STEP == 0: + ret.append(tg_tokens[:, t]) + ret.append(motion_tokens[:, t]) + ret = torch.cat(ret, dim=1) + return ret + + all_knn = self.config.SCENESTREAMER_ATTENTION_KNN + # tl_knn = torch.full((B, T, L), knn // 2, device=traffic_light_position.device) + # tg_knn = torch.full((B, T, G), knn // 2, device=traffic_light_position.device) + # motion_knn = torch.full((B, T, N), knn, device=traffic_light_position.device) + # all_knn = _concat(tl_knn, tg_knn, motion_knn) + + all_tokens = _concat(traffic_light_token, tg_tokens, motion_token) + all_positions = _concat(traffic_light_position[..., :2], tg_position[..., :2], motion_token_position[..., :2]) + all_headings = _concat(traffic_light_heading, tg_heading, motion_token_heading) + all_valid_masks = _concat(traffic_light_valid_mask, tg_valid_mask, motion_token_valid_mask) + all_widths = _concat(traffic_light_width, tg_width, motion_token_width) + all_lengths = _concat(traffic_light_length, tg_length, motion_token_length) + all_require_relation = _concat(traffic_light_require_relation, tg_require_relation, motion_token_require_relation) + # all_require_relation = all_require_relation & all_valid_masks + + all_steps = torch.arange(T).to(traffic_light_position.device).reshape(1, T, 1) # .expand(B, T, giant_N) + all_steps = _concat(all_steps.expand(B, T, L), all_steps.expand(B, T, G), all_steps.expand(B, T, N)) + + all_causal_mask = self._build_all_tokens_mask(B=B, T=T, num_tl=L, num_tg=G, num_motion=N).to( + all_require_relation.device) + + all_force_mask = self._build_all_force_mask(B=B, T=T, num_tl=L, num_tg=G, num_motion=N).to( + all_require_relation.device) + + # import matplotlib.pyplot as plt + # vis = all_causal_mask[0].cpu().numpy() + # plt.imshow(vis) + relation_all_to_all, relation_valid_mask, require_relation_pairwise = relation.compute_relation_for_scenestreamer( + query_pos=all_positions, + query_heading=all_headings, + query_valid_mask=all_valid_masks, + query_step=all_steps, + key_pos=all_positions, + key_heading=all_headings, + key_valid_mask=all_valid_masks, + key_step=all_steps, + causal_valid_mask=all_causal_mask, + force_attention_mask=all_force_mask, + require_relation=all_require_relation, + + knn=all_knn, + max_distance=self.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE, + + gather=False, + query_width=None, # set query's w/l to 0 so that we get the rel of contour of key w.r.t. center of query + query_length=None, + key_width=all_widths, + key_length=all_lengths, + non_agent_relation=True, + ) + relation_all_to_all = get_edge_info_for_scenestreamer( + q_k_relation=relation_all_to_all, + q_k_valid_mask=relation_valid_mask, + relation_model=self.relation_embed_4d, + relation_model_1d=self.relation_embed_1d, + require_relation_pairwise=require_relation_pairwise, + ) + relation_all_to_map = self._get_relation_for_3d_token_vs_map_token( + token_3d_pos=all_positions, + token_3d_heading=all_headings, + token_3d_valid_mask=all_valid_masks, + token_3d_step=all_steps, + # token4d_length=all_lengths, + # token4d_width=all_widths, + map_pos=map_position, + map_heading=map_heading, + map_valid_mask=map_token_valid_mask, + knn=self.config.SCENESTREAMER_ATTENTION_KNN, + max_distance=self.config.SCENESTREAMER_ATTENTION_MAX_DISTANCE, + require_relation=all_require_relation, + ) + + # import matplotlib.pyplot as plt + # vis = all_causal_mask.flatten(1, 2)[0].cpu().numpy() + # plt.imshow(vis) + + input_dict["model/all_token"] = all_tokens + input_dict["model/all_to_all_info"] = relation_all_to_all + input_dict["model/all_to_map_info"] = relation_all_to_map + return input_dict + + def _get_relation_for_4d_token_vs_map_token( + self, *, token_4d_pos, token_4d_heading, token_4d_valid_mask, token_4d_step, + map_pos, map_heading, map_valid_mask, + knn, max_distance, require_relation + ): + + return self._get_relation_for_3d_token_vs_map_token( + token_3d_pos=token_4d_pos.flatten(1, 2), + token_3d_heading=token_4d_heading.flatten(1, 2), + token_3d_valid_mask=token_4d_valid_mask.flatten(1, 2), + token_3d_step=token_4d_step.flatten(1, 2), + map_pos=map_pos, + map_heading=map_heading, + map_valid_mask=map_valid_mask, + knn=knn, + max_distance=max_distance, + require_relation=require_relation.flatten(1, 2), + ) + + def _get_relation_for_3d_token_vs_map_token( + self, *, token_3d_pos, token_3d_heading, token_3d_valid_mask, token_3d_step, + map_pos, map_heading, map_valid_mask, + knn, max_distance, require_relation, + token3d_width=None, token3d_length=None + ): + if token3d_width is not None: + token3d_width = token3d_width.flatten(1, 2) + token3d_length = token3d_length.flatten(1, 2) + non_agent_relation = False + raise ValueError + else: + non_agent_relation = True + a2m_3d = self.config.MODEL.ALL_TO_MAP_3D + q_k_relation, q_k_valid_mask, require_relation_output = relation.compute_relation_for_scenestreamer( + query_pos=token_3d_pos, # B, TN, D + query_heading=token_3d_heading, + query_valid_mask=token_3d_valid_mask, + query_step=None if a2m_3d else token_3d_step, + query_width=token3d_width, + query_length=token3d_length, + key_pos=map_pos, + key_heading=map_heading, + key_valid_mask=map_valid_mask, + key_step=None if a2m_3d else torch.zeros_like(map_heading, dtype=torch.int64), + key_width=None, + key_length=None, + causal_valid_mask=None, + knn=knn, + max_distance=max_distance, + gather=False, + non_agent_relation=non_agent_relation, + require_relation=require_relation, + require_relation_for_key=map_valid_mask, + ) + relation_info = get_edge_info_for_scenestreamer( + q_k_valid_mask=q_k_valid_mask, + q_k_relation=q_k_relation, + relation_model=self.relation_embed_3d if a2m_3d else self.relation_embed_4d, + relation_model_1d=self.relation_embed_1d, + require_relation_pairwise=require_relation_output, + ) + return relation_info + + +class SceneStreamerDecoder(Module): + def __init__(self, decoder_layer, num_layers, d_model, ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.d_model = d_model + + def forward(self, *, input_dict, use_cache=None, cache=None): + new_past_key_value_list = [] + output_dict = input_dict + for layer_idx, mod in enumerate(self.layers): + layer_input_cache = cache[layer_idx] if cache is not None else None + output_dict, layer_cache = mod(input_dict=output_dict, use_cache=use_cache, cache=layer_input_cache) + if use_cache: + new_past_key_value_list.append(layer_cache) + if use_cache: + return output_dict, new_past_key_value_list + return output_dict + + +class SceneStreamerDecoderLayer(Module): + __constants__ = ['norm_first'] + + def __init__(self, d_model: int, nhead: int, dropout: float = 0.1, ): + super().__init__() + + # ===== map tokens ===== + # self.map_to_map_attention = MultiheadAttentionLayer( + # d_model=d_model, + # n_heads=nhead, + # dropout=dropout, + # simple_relation=True, + # simple_relation_factor=1, + # is_v7=True, + # update_relation=False, + # add_relation_to_v=False, + # ) + # self.map_norm = nn.LayerNorm(d_model) + # self.map_to_map_rel_norm = nn.LayerNorm(d_model) + + # ===== all tokens ===== + self.all_to_map_attention = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=True, + simple_relation_factor=1, + is_v7=True, + update_relation=False, + add_relation_to_v=False, + ) + self.all_to_map_norm = nn.LayerNorm(d_model) + self.all_to_map_rel_norm = nn.LayerNorm(d_model) + + self.all_to_all_attention = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=True, + simple_relation_factor=1, + is_v7=True, + update_relation=False, + add_relation_to_v=False, + ) + self.all_to_all_norm = nn.LayerNorm(d_model) + self.all_to_all_rel_norm = nn.LayerNorm(d_model) + + # ===== feed forward ===== + self.mlp_prenorm = nn.LayerNorm(d_model) + self.mlp = common_layers.build_mlps( + c_in=d_model, mlp_channels=[4 * d_model, d_model], ret_before_act=True, without_norm=True + ) + + def forward(self, *, input_dict, use_cache=None, cache=None): + map_token = input_dict["model/map_token"] + + # ===== all token to map cross-attention ===== + input_all_token = input_dict["model/all_token"] + output_all_token = self.all_to_map_norm(input_all_token) + all_rel = input_dict["model/all_to_map_info"]["edge_features"] + all_rel = self.all_to_map_rel_norm(all_rel) + output_all_token, _, _ = self.all_to_map_attention( + q=output_all_token, + k=map_token, + edge_features=all_rel, + edge_features_v=None, + edge_index=input_dict["model/all_to_map_info"]["edge_index"], + use_cache=False, + cache=None, + ) + output_all_token = input_all_token + output_all_token + + # ===== all token self-attention ===== + output_all_token = self.all_to_all_norm(output_all_token) + all_to_all_rel = input_dict["model/all_to_all_info"]["edge_features"] + all_to_all_rel = self.all_to_all_rel_norm(all_to_all_rel) + output_all_token, new_cache, _ = self.all_to_all_attention( + q=output_all_token, + k=output_all_token, + edge_features=all_to_all_rel, + edge_features_v=None, + edge_index=input_dict["model/all_to_all_info"]["edge_index"], + use_cache=use_cache, + cache=cache, + ) + output_all_token = input_all_token + output_all_token + + # === Feed-forward layer === + output_all_token = self.mlp_prenorm(output_all_token) + output_all_token = self.mlp(output_all_token) + all_token = input_all_token + output_all_token + input_dict["model/all_token"] = all_token + + return input_dict, new_cache diff --git a/scenestreamer/models/test_autoregressive.py b/scenestreamer/models/test_autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..2dbb0ee4223de7bb927e6f475f7b192a6de58142 --- /dev/null +++ b/scenestreamer/models/test_autoregressive.py @@ -0,0 +1,102 @@ +import collections +import copy +import time + +import numpy as np +import torch +from tqdm import tqdm + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.models.motionlm import MotionLM +from scenestreamer.utils import debug_tools + + +def toy_test(): + cfg_file = "cfgs/motion_debug_2_local_train.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + + config.MODEL.update(dict( + D_MODEL=512, + NUM_ATTN_LAYERS=6, + NUM_ATTN_HEAD=8, + NUM_DECODER_LAYERS=6, + )) + datamodule = SceneStreamerDataModule( + config, + train_batch_size=1, + train_num_workers=0, + val_batch_size=2, + val_num_workers=8, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") + dataloader = datamodule.train_dataloader() + + model = MotionLM(config) + model.eval() + model.cuda() + + time_no_cache = 0.0 + time_with_cache = 0.0 + stat_dict = collections.defaultdict(list) + for input_dict in tqdm(dataloader): + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + input_dict[k] = v.cuda() + + num_modes_for_eval = 6 + + def _repeat_for_modes(v): + d = v.ndim + if d > 1: + v = v[:, None] + v = v.repeat(1, num_modes_for_eval, *((1, ) * (d - 1))) + v = v.flatten(0, 1) + else: + v = v.repeat(num_modes_for_eval) + return v + + input_dict = { + k: _repeat_for_modes(input_dict[k]) + for k in input_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("eval/") + ) + } + + input_dict2 = copy.deepcopy(input_dict) + input_dict1 = input_dict + + s = time.time() + with torch.no_grad(): + input_dict1 = model.autoregressive_rollout( + input_dict1, num_decode_steps=16, use_cache=False, sampling_method="argmax" + ) + time_no_cache += time.time() - s + + s = time.time() + with torch.no_grad(): + input_dict2 = model.autoregressive_rollout( + input_dict2, num_decode_steps=16, use_cache=True, sampling_method="argmax" + ) + time_with_cache += time.time() - s + + diff_dict = { + k: (input_dict1[k].float() - input_dict2[k].float()).abs().mean().item() + for k in input_dict1 if isinstance(input_dict1[k], torch.Tensor) + } + action_mismatch = (input_dict1["decoder/output_action"] != + input_dict2["decoder/output_action"]).float().mean().item() + diff_dict = {k: v for k, v in diff_dict.items() if v != 0.0} + diff_dict["action_mismatch"] = action_mismatch + + for k, v in diff_dict.items(): + stat_dict[k].append(v) + + stat_dict = {k: np.mean(v) for k, v in stat_dict.items()} + print(f"FINISHED. TIME without CACHE: {time_no_cache}, TIME with CACHE: {time_with_cache}.\nDIFF:{stat_dict}") + + +if __name__ == '__main__': + toy_test() diff --git a/scenestreamer/models/test_gpu_memory.py b/scenestreamer/models/test_gpu_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..51613a08cd75b8a5d0c5b9c69535edbcbffbea01 --- /dev/null +++ b/scenestreamer/models/test_gpu_memory.py @@ -0,0 +1,93 @@ +import lightning.pytorch as pl +import torch +import tqdm + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.models.motionlm_lightning import MotionLMLightning +from scenestreamer.utils import debug_tools + + +def toy_test(bs, in_eval): + cfg_file = "cfgs/motion_default.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + config.DATA.TRAINING_DATA_DIR = 'data/metadrive_processed_waymo/validation' + config.DATA.TEST_DATA_DIR = 'data/metadrive_processed_waymo/validation' + + # config.MODEL.update(dict( + # D_MODEL=512, + # NUM_ATTN_LAYERS=6, + # NUM_ATTN_HEAD=8, + # NUM_DECODER_LAYERS=6, + # )) + datamodule = SceneStreamerDataModule( + config, + train_batch_size=bs, + train_num_workers=0, + val_batch_size=bs * 6, + val_num_workers=8, + train_prefetch_factor=2, + val_prefetch_factor=1 + ) + datamodule.setup("fit") + if in_eval: + dataloader = datamodule.val_dataloader() + else: + dataloader = datamodule.train_dataloader() + + model: pl.LightningModule = MotionLMLightning(config) + model.train() + model.cuda() + + for input_dict in dataloader: + break + + # ===== Fill in some fake data ===== + N = config.PREPROCESSING.MAX_AGENTS + M = config.PREPROCESSING.MAX_MAP_FEATURES + V = config.PREPROCESSING.MAX_VECTORS + + def _extend_3rd_dim(key): + tensor = input_dict[key] + new_shape = list(tensor.shape) + if len(new_shape) < 3: + new_shape[1] = N + input_dict[key] = tensor.new_ones(*new_shape) #+ torch.randint(0, 1, size=new_shape) + else: + new_shape[2] = N + input_dict[key] = tensor.new_ones(*new_shape) #+ torch.randint(0, 1, size=new_shape) + + def _extend_map(key): + tensor = input_dict[key] + new_shape = list(tensor.shape) + new_shape[1] = M + if len(new_shape) > 2 and new_shape[2] > 3: + new_shape[2] = V + input_dict[key] = tensor.new_ones(*new_shape) #+ torch.randint(0, 1, size=new_shape) + + for k in input_dict.keys(): + if k.startswith("encoder/agent") or k.startswith("decoder"): + _extend_3rd_dim(k) + if k.startswith("encoder/map"): + _extend_map(k) + # ===== Fill in some fake data END ===== + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + input_dict[k] = v.cuda() + + optimizer = model.configure_optimizers()['optimizer'] + for _ in tqdm.trange(10000): + if in_eval: + with torch.no_grad(): + out = model.forward(input_dict) + loss = out["decoder/output_logit"].mean() + else: + out = model.forward(input_dict) + + optimizer.zero_grad() + loss = out["decoder/output_logit"].mean() + loss.backward() + optimizer.step() + + +if __name__ == '__main__': + toy_test(bs=8, in_eval=False) diff --git a/scenestreamer/models/trafficgen_decoder.py b/scenestreamer/models/trafficgen_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..984b45db662b137a78b861767435619c7565040b --- /dev/null +++ b/scenestreamer/models/trafficgen_decoder.py @@ -0,0 +1,677 @@ +import copy + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import Module + +from scenestreamer.dataset import constants +from scenestreamer.models import relation +from scenestreamer.models.layers import common_layers +from scenestreamer.models.layers import fourier_embedding +from scenestreamer.models.layers.gpt_decoder_layer import MultiCrossAttTransformerDecoder +from scenestreamer.models.layers.gpt_decoder_layer import MultiheadAttentionLayer +from scenestreamer.models.motion_decoder import create_causal_mask +from scenestreamer.models.motion_decoder_gpt import get_edge_info_new +from scenestreamer.tokenization.trafficgen_tokenizers import TrafficGenTokenizer +from scenestreamer.utils import utils + + +class MultiCrossAttTransformerDecoderLayerForTrafficGen(Module): + def __init__(self, d_model: int, nhead: int, dropout: float = 0.0, use_adaln=False) -> None: + super().__init__() + assert dropout == 0.0 + self.cross_a2a = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=True, + simple_relation_factor=1, + is_v7=True, + update_relation=False, + add_relation_to_v=False, + ) + self.cross_a2s = MultiheadAttentionLayer( + d_model=d_model, + n_heads=nhead, + dropout=dropout, + simple_relation=True, + simple_relation_factor=1, + is_v7=True, + update_relation=False, + add_relation_to_v=False, + ) + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = common_layers.Mlp(in_features=d_model, hidden_features=4 * d_model, act_layer=approx_gelu, drop=0) + + # self.cross_a2s = MultiheadAttentionLayer(d_model=d_model, n_heads=nhead, dropout=dropout, simple_relation=False) + # approx_gelu = lambda: nn.GELU(approximate="tanh") + # self.mlp = common_layers.Mlp(in_features=d_model, hidden_features=4 * d_model, act_layer=approx_gelu, drop=0) + + self.a2s_norm = nn.LayerNorm(d_model) + self.a2a_norm = nn.LayerNorm(d_model) + self.mlp_prenorm = nn.LayerNorm(d_model) + + self.a2a_norm_rel = nn.LayerNorm(d_model) + self.a2s_norm_rel = nn.LayerNorm(d_model) + + def forward(self, *, agent_token, scene_token, a2a_info, a2s_info, use_cache=False, past_key_value=None, **kwargs): + B, N, D = agent_token.shape + x = agent_token + + # === agent-agent attention === + out = x + out = self.a2a_norm(out) + out, past_key_value_a2t, _ = self.cross_a2a( + q=out, + k=out, + edge_features=self.a2a_norm_rel(a2a_info['edge_features']), + edge_index=a2a_info['edge_index'], + use_cache=use_cache, + cache=past_key_value, + ) + x = x + out + + # === agent-scene attention === + out = x + out = self.a2s_norm(out) + out, _, _ = self.cross_a2s( + q=out, + k=scene_token, + edge_features=self.a2s_norm_rel(a2s_info['edge_features']), + edge_index=a2s_info['edge_index'], + ) + x = x + out + + # === Feed-forward layer === + out = x + out = self.mlp_prenorm(out) + out = self.mlp(out) + x = x + out + return x, past_key_value_a2t + + +class OffsetHead(nn.Module): + def __init__(self, input_dim, d_model): + super().__init__() + self.prenorm = nn.LayerNorm(input_dim) + self.mlp = common_layers.build_mlps(c_in=input_dim, mlp_channels=[4 * d_model, d_model], ret_before_act=True) + self.norm = nn.LayerNorm(d_model) + self.position_x = nn.Linear(d_model, TrafficGenTokenizer.num_bins["position_x"]) + self.position_y = nn.Linear(d_model, TrafficGenTokenizer.num_bins["position_y"]) + self.velocity_x = nn.Linear(d_model, TrafficGenTokenizer.num_bins["velocity_x"]) + self.velocity_y = nn.Linear(d_model, TrafficGenTokenizer.num_bins["velocity_y"]) + self.heading = nn.Linear(d_model, TrafficGenTokenizer.num_bins["heading"]) + self.length = nn.Linear(d_model, TrafficGenTokenizer.num_bins["length"]) + self.width = nn.Linear(d_model, TrafficGenTokenizer.num_bins["width"]) + self.height = nn.Linear(d_model, TrafficGenTokenizer.num_bins["height"]) + + def forward(self, x): + x = self.prenorm(x) + x = self.mlp(x) + x = self.norm(x) + return { + "position_x": self.position_x(x), + "position_y": self.position_y(x), + "velocity_x": self.velocity_x(x), + "velocity_y": self.velocity_y(x), + "heading": self.heading(x), + "length": self.length(x), + "width": self.width(x), + "height": self.height(x), + } + + +class TrafficGenDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.d_model = d_model = self.config.MODEL.D_MODEL + num_decoder_layers = self.config.MODEL.NUM_DECODER_LAYERS + self.num_heads = self.config.MODEL.NUM_ATTN_HEAD + self.max_agents = self.config.PREPROCESSING.MAX_AGENTS + assert self.config.MODEL.NAME in ['gpt'] + num_tg_actions = self.config.PREPROCESSING.MAX_MAP_FEATURES + 2 + self.start_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + self.end_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + + # === Input embedding === + self.action_embed = fourier_embedding.FourierEmbedding(input_dim=5, hidden_dim=d_model, num_freq_bands=64) + self.relation_embed_a2a = fourier_embedding.FourierEmbedding( + input_dim=11, hidden_dim=d_model, num_freq_bands=64 + ) + self.relation_embed_a2s = fourier_embedding.FourierEmbedding(input_dim=3, hidden_dim=d_model, num_freq_bands=64) + self.type_embed = common_layers.Tokenizer( + num_actions=constants.NUM_TYPES, d_model=d_model, add_one_more_action=False + ) + self.shape_embed = common_layers.build_mlps(c_in=3, mlp_channels=[d_model, d_model], ret_before_act=True) + self.map_embed = common_layers.Tokenizer(num_actions=num_tg_actions, d_model=d_model, add_one_more_action=False) + self.step_embed = common_layers.Tokenizer( + num_actions=self.max_agents + 2, d_model=d_model, add_one_more_action=False + ) + + # === Transformer === + self.decoder = MultiCrossAttTransformerDecoder( + decoder_layer=MultiCrossAttTransformerDecoderLayerForTrafficGen( + d_model=d_model, nhead=self.num_heads, dropout=0.0, use_adaln=False + ), + num_layers=num_decoder_layers, + d_model=d_model, + ) + + # === Action head === + self.action_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, num_tg_actions], ret_before_act=True + ) + self.action_prenorm = nn.LayerNorm(d_model) + + # === Offset heads === + offset_head_input = 2 * d_model + self.agent_type_head = common_layers.build_mlps( + c_in=offset_head_input, + mlp_channels=[d_model, TrafficGenTokenizer.num_bins['agent_type']], + ret_before_act=True + ) + self.offset_head = OffsetHead(input_dim=offset_head_input, d_model=d_model) + self.trafficgen_tokenizer = TrafficGenTokenizer(self.config) + + def autoregressive_rollout_trafficgen(self, data_dict): + + START_ACTION = self.config.PREPROCESSING.MAX_MAP_FEATURES + END_ACTION = self.config.PREPROCESSING.MAX_MAP_FEATURES + 1 + + # Prepare running variables + current_input_action = data_dict["decoder/input_action_for_trafficgen"].clone()[:, :1] + current_input_action_valid_mask = data_dict["decoder/input_action_valid_mask_for_trafficgen"].clone()[:, :1] + num_valid_gt = data_dict["decoder/input_action_valid_mask_for_trafficgen"].sum().item() - 1 + + accumulative_action_valid_mask = current_input_action_valid_mask.clone() + + current_agent_position = data_dict["decoder/modeled_agent_position_for_trafficgen"].clone()[:, :1] + current_agent_velocity = data_dict["decoder/modeled_agent_velocity_for_trafficgen"].clone()[:, :1] + current_agent_heading = data_dict["decoder/modeled_agent_heading_for_trafficgen"].clone()[:, :1] + current_agent_type = data_dict["decoder/agent_type_for_trafficgen"].clone()[:, :1] + current_agent_shape = data_dict["decoder/current_agent_shape_for_trafficgen"].clone()[:, :1] + current_agent_feature = data_dict["decoder/input_action_feature_for_trafficgen"].clone()[:, :1] + + B = current_input_action.shape[0] + + use_gt = False + + # data_dict = model.model.encode_scene(data_dict) + + assert "encoder/scenario_token" in data_dict, "You must call encode_scene first." + + num_decode_steps = 128 + num_collisions = 0 + num_violations = 0 + assert B == 1 + decode_step = 0 + # for decode_step in range(num_decode_steps): + for _ in range(500): + input_dict = { + # Static features + "encoder/scenario_token": data_dict["encoder/scenario_token"], + "encoder/scenario_heading": data_dict["encoder/scenario_heading"], + "encoder/scenario_position": data_dict["encoder/scenario_position"], + "encoder/scenario_valid_mask": data_dict["encoder/scenario_valid_mask"], + "encoder/map_position": data_dict["encoder/map_position"], + "encoder/map_feature": data_dict["encoder/map_feature"], + "encoder/map_valid_mask": data_dict["encoder/map_valid_mask"], + "in_evaluation": torch.ones([ + B, + ], dtype=torch.bool), + + # Actions + "decoder/input_action_for_trafficgen": current_input_action, + "decoder/input_action_valid_mask_for_trafficgen": current_input_action_valid_mask, + + # Agent features + "decoder/modeled_agent_position_for_trafficgen": current_agent_position, + "decoder/modeled_agent_heading_for_trafficgen": current_agent_heading, + "decoder/modeled_agent_velocity_for_trafficgen": current_agent_velocity, + "decoder/agent_type_for_trafficgen": current_agent_type, + "decoder/current_agent_shape_for_trafficgen": current_agent_shape, + "decoder/input_action_feature_for_trafficgen": current_agent_feature, + } + + temperature = 1.0 + + input_dict = copy.deepcopy(input_dict) + + output_dict = self.forward(input_dict) + + # Force model to predict at least the same amount of agents in GT. + if decode_step < num_decode_steps: + force_no_end = True + else: + force_no_end = False + + sampled_action = self.sample_action(output_dict, force_no_end=force_no_end, temperature=temperature) + sampled_action = sampled_action[:, -1:] + + if decode_step == 0 and self.config.FORCE_SDC_FOR_TRAFFICGEN: + # In LCTGen, the map feature is cropped around SDC's position. + # The agent will always pick the map feature that is closest to the center of the map. + # That is the map feature whose (x, y) is closest to the (0, 0). + # To ensure fair comparison, we need to do the same here and force the first selected + # map feature to be the one closest to (0, 0). + + # TODO: Add a flag here. + + # TODO hardcode + if data_dict["decoder/agent_position"].shape[1] > 150: + current_t = 0 + else: + current_t = 10 + assert B == 1 + sdc_index = data_dict["decoder/sdc_index"][0].item() + sdc_center = data_dict["decoder/agent_position"][:, current_t, sdc_index] + + map_to_sdc_dist = (data_dict["encoder/map_position"][0][..., :2] - sdc_center[0, :2]).norm(dim=-1) + + map_to_sdc_dist_valid_mask = data_dict["encoder/map_valid_mask"].clone() + map_to_sdc_dist_valid_mask = ( + map_to_sdc_dist_valid_mask & (data_dict["encoder/map_feature"][:, :, 0, 13] == 1) + ) + + map_to_sdc_dist[~map_to_sdc_dist_valid_mask[0]] = 1e6 + + map_argmin = map_to_sdc_dist.argmin() + map_min = map_to_sdc_dist.min() + + # print("Original select action: {}, new action: {}, min dist: {}".format( + # sampled_action.item(), map_argmin.item(), map_min.item() + # )) + sampled_action = map_argmin.unsqueeze(0).unsqueeze(-1) + + if use_gt: + sampled_action = data_dict["decoder/input_action_for_trafficgen"][:, decode_step + 1:decode_step + 2] + new_current_input_action = torch.cat([current_input_action, sampled_action], dim=1) + is_end = sampled_action == END_ACTION + sampled_action[is_end] = 0 + accumulative_action_valid_mask[is_end] = False + + # Use last action to predict next position + # The first action is START_ACTION so we need to skip it. + agent_type_output = self.forward_agent_type(output_dict, action=new_current_input_action[:, 1:]) + agent_type = self.sample_agent_type(agent_type_output, temperature=temperature) + offset_output = self.forward_offset( + output_dict, action=new_current_input_action[:, 1:], agent_type=agent_type + ) + offset_action = self.sample_offset(offset_output=offset_output, temperature=temperature) + offset_action = {k: v[:, -1:] for k, v in offset_action.items()} + agent_type = agent_type[:, -1:] + + if use_gt: + offset_action = data_dict["decoder/target_offset_for_trafficgen"][:, decode_step:decode_step + 1] + # gt_position_x, gt_position_y, gt_heading, gt_vel_x, gt_vel_y, gt_shape_l, gt_shape_w, gt_shape_h, gt_agent_type + offset_action = { + "position_x": offset_action[..., 0], + "position_y": offset_action[..., 1], + "heading": offset_action[..., 2], + "velocity_x": offset_action[..., 3], + "velocity_y": offset_action[..., 4], + "length": offset_action[..., 5], + "width": offset_action[..., 6], + "height": offset_action[..., 7], + "agent_type": offset_action[..., 8], + } + + predicted_values = self.trafficgen_tokenizer.detokenize( + data_dict, new_current_input_action[:, -1:], agent_type=agent_type, offset_action=offset_action + ) + pos = predicted_values["position"] + head = predicted_values["heading"] + vel = predicted_values["velocity"] + agent_type = predicted_values["agent_type"] # in 0,1,2 + agent_shape = predicted_values["shape"] + agent_feature = predicted_values["feature"] + pos = pos * accumulative_action_valid_mask.unsqueeze(-1) + head = head * accumulative_action_valid_mask + vel = vel * accumulative_action_valid_mask.unsqueeze(-1) + agent_type = agent_type * accumulative_action_valid_mask + agent_shape = agent_shape * accumulative_action_valid_mask.unsqueeze(-1) + agent_feature = agent_feature * accumulative_action_valid_mask.unsqueeze(-1) + + # BID = 0 + # print("=== Step: {} ===".format(decode_step)) + # print( + # "New agent type: {}, length: {:.2f}, width: {:.2f}, height: {:.2f}".format( + # agent_type[BID, 0].item(), agent_shape[BID, 0, 0].item(), agent_shape[BID, 0, 1].item(), + # agent_shape[BID, 0, 2].item() + # ) + # ) + + # Check if collision happens: + from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour, detect_collision + assert current_agent_position.shape[0] == 1 + existing_contours = cal_polygon_contour( + x=current_agent_position[0, :, 0].cpu().numpy(), + y=current_agent_position[0, :, 1].cpu().numpy(), + theta=current_agent_heading[0, :].cpu().numpy(), + width=current_agent_shape[0, :, 1].cpu().numpy(), + length=current_agent_shape[0, :, 0].cpu().numpy() + ) # (N, 4, 2) + new_contour = cal_polygon_contour( + x=pos[0, :, 0].cpu().numpy(), + y=pos[0, :, 1].cpu().numpy(), + theta=head[0, :].cpu().numpy(), + width=agent_shape[0, :, 1].cpu().numpy(), + length=agent_shape[0, :, 0].cpu().numpy() + ) + if existing_contours.shape[0] == 1: + no_coll = True # Skip first one (it's the START_ACTION) + else: + no_coll = True + for existing_id in range(1, existing_contours.shape[0]): + collision_detected = detect_collision( + [existing_contours[existing_id]], # (N, 4, 2) + [current_input_action_valid_mask[0][existing_id]], # (N,) + new_contour, + accumulative_action_valid_mask[0] + ) + if collision_detected[0]: + # print("Collision detected!") + num_collisions += 1 + no_coll = False + break + + # ===== Additional postprocessing to comply with LCTGen ===== + offset_values = predicted_values["offset_values"] + vel_valid_mask = abs(torch.atan2(offset_values["velocity_y"], offset_values["velocity_x"])) < np.pi / 6 + dir_valid_mask = abs(offset_values["heading"]) < np.pi / 4 + sdc_index = data_dict["decoder/sdc_index"][0].item() + + # TODO hardcode + if data_dict["decoder/agent_position"].shape[1] > 150: + current_t = 0 + else: + current_t = 10 + + sdc_center = data_dict["decoder/agent_position"][0, current_t, sdc_index] + distance_mask = ((abs(pos[..., 0] - sdc_center[0]) < 50) & (abs(pos[..., 1] - sdc_center[1]) < 50)) + if existing_contours.shape[0] == 1: + no_violation = True # Skip first one (it's the START_ACTION) + else: + assert vel_valid_mask.numel() == 1 + assert dir_valid_mask.numel() == 1 + assert distance_mask.numel() == 1 + no_violation = (vel_valid_mask & dir_valid_mask & distance_mask).item() + if not no_violation: + num_violations += 1 + + if no_coll: + # Overwrite + current_agent_position = torch.cat([current_agent_position, pos], dim=1) + current_agent_velocity = torch.cat([current_agent_velocity, vel], dim=1) + current_agent_heading = torch.cat([current_agent_heading, head], dim=1) + current_agent_type = torch.cat([current_agent_type, agent_type], dim=1) + current_agent_shape = torch.cat([current_agent_shape, agent_shape], dim=1) + current_agent_feature = torch.cat([current_agent_feature, agent_feature], dim=1) + current_input_action_valid_mask = torch.cat( + [current_input_action_valid_mask, accumulative_action_valid_mask], dim=1 + ) + current_input_action = new_current_input_action.clone() + + decode_step += 1 + + if not accumulative_action_valid_mask.any(): + break + + if decode_step > num_decode_steps: + break + + # Remove batch dim + + data_dict.update(input_dict) + return data_dict, {"num_collisions": num_collisions, "num_violations": num_violations} + + def forward(self, input_dict, use_cache=False): + assert self.config.REMOVE_AGENT_FROM_SCENE_ENCODER + assert self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE + assert use_cache is False + + # TrafficGen decoder takes two inputs: + # 1. The map features, which is the scene tokens below (from the SceneEncoder) + # 2. The agent features of a frame, which is the agent tokens below (from the AgentEncoder) + # Should note that the agent features will start with token and end with token, + # just like in language task. + # ===== Scene Tokens ===== + scene_token = input_dict["encoder/scenario_token"] + scenario_valid_mask = input_dict["encoder/scenario_valid_mask"] + B, M, _ = input_dict["encoder/map_position"].shape + S = scene_token.shape[1] + map_id = torch.zeros([B, S], dtype=torch.long, device=scene_token.device) + map_id[:, :M] = torch.arange(M, device=map_id.device).unsqueeze(0) + map_id[~scenario_valid_mask] = -1 + map_id_pe = self.map_embed(map_id) + # We don't add map feat ID pe in SceneEncoder, so we add it here. + scene_token = scene_token + map_id_pe + + # ===== Agent Tokens ===== + input_action = input_dict["decoder/input_action_for_trafficgen"] + B, seq_len = input_action.shape + input_action_valid_mask = input_dict["decoder/input_action_valid_mask_for_trafficgen"] + agent_pos = input_dict["decoder/modeled_agent_position_for_trafficgen"] + agent_heading = input_dict["decoder/modeled_agent_heading_for_trafficgen"] + # agent_vel = input_dict["decoder/modeled_agent_velocity_for_trafficgen"] + + # Shape embedding and type embedding + type_emb = self.type_embed(input_dict["decoder/agent_type_for_trafficgen"]) + shape_emb = self.shape_embed(input_dict["decoder/current_agent_shape_for_trafficgen"]) + + if "decoder/input_step_for_trafficgen" not in input_dict: + input_step = torch.arange(seq_len).to(input_action.device).unsqueeze(0).repeat(B, 1) + input_dict["decoder/input_step_for_trafficgen"] = input_step + input_step = input_dict["decoder/input_step_for_trafficgen"] + assert input_step.shape == (B, seq_len), (B, seq_len, input_step.shape) + input_step[input_step >= self.max_agents] = self.max_agents + step_emb = self.step_embed(input_step) + + # Here we reuse the map_embedding to embed the action! + action_emb = self.map_embed(input_action) + action_feat = input_dict["decoder/input_action_feature_for_trafficgen"] + action_token = self.action_embed( + continuous_inputs=action_feat[input_action_valid_mask], + categorical_embs=[ + type_emb[input_action_valid_mask], shape_emb[input_action_valid_mask], + action_emb[input_action_valid_mask], step_emb[input_action_valid_mask] + ] + ) + action_token = utils.unwrap(action_token, valid_mask=input_action_valid_mask) + + # The T here is the number of agents, not the real temporal length. + causal_valid_mask = create_causal_mask(T=seq_len, N=1, is_valid_mask=True).to(action_token.device) + + # ===== Get agent-agent relation ===== + a2a_rel_feat, a2a_mask, _ = relation.compute_relation_simple_relation( + query_pos=agent_pos, + query_heading=agent_heading, + query_valid_mask=input_action_valid_mask, + query_step=None, + key_pos=agent_pos, + key_heading=agent_heading, + key_valid_mask=input_action_valid_mask, + key_step=None, + hidden_dim=self.d_model, + causal_valid_mask=causal_valid_mask, + knn=self.config.MODEL.A2A_KNN, + return_pe=False, + query_width=input_dict["decoder/current_agent_shape_for_trafficgen"][:, :, 1], + query_length=input_dict["decoder/current_agent_shape_for_trafficgen"][:, :, 0], + key_width=input_dict["decoder/current_agent_shape_for_trafficgen"][:, :, 1], + key_length=input_dict["decoder/current_agent_shape_for_trafficgen"][:, :, 0], + non_agent_relation=False, + per_contour_point_relation=False, + ) + # a2a_rel_pe = utils.unwrap(self.relation_embed_a2a(a2a_rel_feat[a2a_mask]), a2a_mask) + # a2a_info = get_edge_info(attn_valid_mask=a2a_mask, rel_pe_cross=a2a_rel_pe) + a2a_info = get_edge_info_new( + q_k_valid_mask=a2a_mask, + q_k_relation=a2a_rel_feat, + relation_model=self.relation_embed_a2a, + relation_model_v=None + ) + + # ===== Get agent-scene relation ===== + a2s_rel_feat, a2s_mask, a2s_indices = relation.compute_relation_simple_relation( + query_pos=agent_pos, + query_heading=agent_heading, + query_valid_mask=input_action_valid_mask, + query_step=None, + key_pos=input_dict["encoder/scenario_position"], # [..., :2], + key_heading=input_dict["encoder/scenario_heading"], + key_valid_mask=scenario_valid_mask, + key_step=None, + hidden_dim=self.d_model, + causal_valid_mask=None, + knn=self.config.MODEL.A2S_KNN, + gather=False, + return_pe=False, + query_width=input_dict["decoder/current_agent_shape_for_trafficgen"][:, :, 1], + query_length=input_dict["decoder/current_agent_shape_for_trafficgen"][:, :, 0], + key_width=torch.zeros([B, S], device=agent_pos.device), + key_length=torch.zeros([B, S], device=agent_pos.device), + non_agent_relation=True, + per_contour_point_relation=False, + ) + # a2s_rel_pe = utils.unwrap(self.relation_embed_a2s(a2s_rel_feat[a2s_mask]), a2s_mask) + # a2s_info = get_edge_info(attn_valid_mask=a2s_mask, rel_pe_cross=a2s_rel_pe) + a2s_info = get_edge_info_new( + q_k_valid_mask=a2s_mask, + q_k_relation=a2s_rel_feat, + relation_model=self.relation_embed_a2s, + relation_model_v=None + ) + + # === Call models === + past_key_value_list = None + if use_cache: + # Cache from last rollout + if "decoder/cache" in input_dict: + past_key_value_list = input_dict["decoder/cache"] + + decoded_tokens = self.decoder( + agent_token=action_token, + scene_token=scene_token, + a2a_info=a2a_info, + # a2t_info=None, + a2s_info=a2s_info, + # condition_token=None, #condition_token if self.use_adaln else None, + use_cache=use_cache, # We don't need decoder to take care cache. + past_key_value_list=past_key_value_list + ) + + # if use_cache: + # decoded_tokens, past_key_value_list = decoded_tokens + # for l in past_key_value_list: + # if l: + # l.append((B * N, real_T)) + # input_dict["decoder/cache"] = past_key_value_list + + output_token = self.action_head(self.action_prenorm(decoded_tokens[input_action_valid_mask])) + output_token = utils.unwrap(output_token, valid_mask=input_action_valid_mask) + + input_dict["decoder/output_logit_for_trafficgen"] = output_token + input_dict["decoder/output_token_for_trafficgen"] = decoded_tokens + return input_dict + + def sample_action(self, data_dict, force_no_end=False, temperature=1.0): + raw_output_logit = data_dict['decoder/output_logit_for_trafficgen'] # [:, -1, :] + output_logit = raw_output_logit.new_full(raw_output_logit.shape, float('-inf')) + + # mask out invalid actions + # scenario_valid_mask = data_dict["encoder/scenario_valid_mask"] + B, M, _ = data_dict["encoder/map_position"].shape + # map_mask = scenario_valid_mask[:, :M] + # assert (map_mask == data_dict["encoder/map_valid_mask"]).all() + map_mask = data_dict["encoder/map_valid_mask"] + only_lane = self.config.ONLY_LANE_FOR_TRAFFICGEN + if only_lane: + map_mask = map_mask & (data_dict["encoder/map_feature"][:, :, 0, 13] == 1) + + T = raw_output_logit.shape[1] + + output_logit[:, :, :M] = torch.where( + map_mask.unsqueeze(1).expand(-1, T, -1), raw_output_logit[:, :, :M], output_logit[:, :, :M] + ) + if force_no_end: + output_logit[:, :, -1] = float('-inf') + else: + output_logit[:, :, -1] = raw_output_logit[:, :, -1] # The prob for "End Action" + + # Just do the softmax sampling + sampled_action = torch.distributions.Categorical(logits=output_logit / temperature).sample() + return sampled_action + + def forward_agent_type(self, input_dict, action): + # Get scene token: + in_evaluation = input_dict["in_evaluation"][0].item() + scene_token = input_dict["encoder/scenario_token"] + B, M, _ = input_dict["encoder/map_position"].shape + action = action.clone() + + is_valid_action = (action < self.start_action_id) & (action >= 0) + action[~is_valid_action] = 0 + + selected_scene_token = torch.gather( + scene_token, dim=1, index=action.unsqueeze(-1).expand(-1, -1, scene_token.shape[-1]) + ) + + # Get the input + output_token = input_dict["decoder/output_token_for_trafficgen"].clone() + if in_evaluation: + assert selected_scene_token.shape == output_token.shape + else: + # output token contains value for the END_ACTION, which is not in the selected_scene_token. + B, T_minus_1, D = selected_scene_token.shape + assert output_token.shape == (B, T_minus_1 + 1, D) + output_token = output_token[:, :-1, :] + + agent_type_offset = self.agent_type_head(torch.cat([output_token, selected_scene_token], dim=-1)) + return agent_type_offset + + def sample_agent_type(self, agent_type_output, temperature=1.0): + return torch.distributions.Categorical(logits=agent_type_output / temperature).sample() + + def forward_offset(self, input_dict, action, agent_type): + + # Get scene token: + in_evaluation = input_dict["in_evaluation"][0].item() + scene_token = input_dict["encoder/scenario_token"] + B, M, _ = input_dict["encoder/map_position"].shape + action = action.clone() + + is_valid_action = (action < self.start_action_id) & (action >= 0) + action[~is_valid_action] = 0 + + selected_scene_token = torch.gather( + scene_token, dim=1, index=action.unsqueeze(-1).expand(-1, -1, scene_token.shape[-1]) + ) + + # Get the input + output_token = input_dict["decoder/output_token_for_trafficgen"].clone() + if in_evaluation: + assert selected_scene_token.shape == output_token.shape + else: + # output token contains value for the END_ACTION, which is not in the selected_scene_token. + B, T_minus_1, D = selected_scene_token.shape + assert output_token.shape == (B, T_minus_1 + 1, D) + output_token = output_token[:, :-1, :] + + agent_type_emb = self.type_embed(agent_type) + output_token = agent_type_emb + output_token + + offset_output = self.offset_head(torch.cat([output_token, selected_scene_token], dim=-1)) + + return offset_output + + def sample_offset(self, offset_output, temperature=1.0): + def _sample(v): + return torch.distributions.Categorical(logits=v / temperature).sample() + + offset_action = {k: _sample(v) for k, v in offset_output.items()} + + return offset_action diff --git a/scenestreamer/paper/__init__.py b/scenestreamer/paper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a47eb36b7f6caab1ee32ae680122060d5942656 --- /dev/null +++ b/scenestreamer/paper/__init__.py @@ -0,0 +1,2 @@ +"""Paper reproduction entrypoints (Table 1/2 + demos).""" + diff --git a/scenestreamer/paper/densify_demo.py b/scenestreamer/paper/densify_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..062cba434f66fa744805fc3150cf4344761fd76b --- /dev/null +++ b/scenestreamer/paper/densify_demo.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import json +import pathlib +import random +from typing import Any + +import numpy as np +import torch + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.infer.infinite import generate_scenestreamer_motion +from scenestreamer.utils import utils + + +def _seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def run_densify_demo( + *, + pl_model, + dataset_dir: str, + split: str, + scenario_index: int, + max_agents: int, + force_no_end: bool, + artifacts_dir: str, + run_id: str | None, + seed: int, +) -> pathlib.Path: + _seed_everything(seed) + + config = pl_model.config + config.DATA.TRAINING_DATA_DIR = dataset_dir + config.DATA.TEST_DATA_DIR = dataset_dir + config.DATA.USE_CACHE = True + config.PREPROCESSING.keep_all_data = True + + ds = SceneStreamerDataset(config, split) + raw = ds[scenario_index] + + device = pl_model.device + batched = utils.batch_data(utils.numpy_to_torch(raw, device=device)) + + # Densify + motion rollout. `force_add=True` disables the "end of agent states" token. + out = generate_scenestreamer_motion( + data_dict=batched, + model=pl_model.model, + force_add=force_no_end, + num_decode_steps=19, + ) + + # Save a lightweight artifact; visualization can be added later. + pred_pos = out.get("decoder/reconstructed_position") + pred_valid = out.get("decoder/reconstructed_valid_mask") + if pred_pos is None or pred_valid is None: + raise ValueError("Model output missing decoder/reconstructed_position or decoder/reconstructed_valid_mask") + + pred_pos_np = pred_pos.detach().cpu().numpy() + pred_valid_np = pred_valid.detach().cpu().numpy().astype(bool) + + base = pathlib.Path(artifacts_dir) + base.mkdir(parents=True, exist_ok=True) + if run_id is None: + run_id = f"densify-{seed}-{int(torch.randint(0, 1_000_000, (1,)).item())}" + out_dir = base / run_id + out_dir.mkdir(parents=True, exist_ok=False) + + # Save a small summary to avoid huge artifacts by default. + summary: dict[str, Any] = { + "scenario_id": raw.get("scenario_id", None), + "pred_shape": list(pred_pos_np.shape), + "num_agents_final": int(pred_valid_np.any(axis=1).any(axis=0).sum()) if pred_valid_np.ndim == 3 else None, + "max_agents_target": max_agents, + "force_no_end": force_no_end, + "seed": seed, + } + + with open(out_dir / "metrics.json", "w") as f: + json.dump(summary, f, indent=2) + + return out_dir diff --git a/scenestreamer/paper/table1_mmd.py b/scenestreamer/paper/table1_mmd.py new file mode 100644 index 0000000000000000000000000000000000000000..56a328185c05eee80f2dcee3f75cbf8ed8c68d70 --- /dev/null +++ b/scenestreamer/paper/table1_mmd.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import json +import pathlib +import random +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.eval.test_trafficgen_eval import TrafficGenEvaluator +from scenestreamer.infer.initial_state import generate_initial_state +from scenestreamer.utils import utils + + +def _seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _to_jsonable(v: Any): + if isinstance(v, (np.floating, np.integer)): + return v.item() + if isinstance(v, torch.Tensor): + if v.numel() == 1: + return v.detach().cpu().item() + return v.detach().cpu().tolist() + return v + + +@dataclass +class _MetricAgg: + sums: dict[str, float] + counts: dict[str, int] + + @classmethod + def create(cls) -> "_MetricAgg": + return cls(sums={}, counts={}) + + def add(self, k: str, v: Any) -> None: + vv = _to_jsonable(v) + if isinstance(vv, list): + return + if vv is None: + return + self.sums[k] = self.sums.get(k, 0.0) + float(vv) + self.counts[k] = self.counts.get(k, 0) + 1 + + def mean(self) -> dict[str, float]: + out = {} + for k, s in self.sums.items(): + c = self.counts.get(k, 0) + if c: + out[k] = s / c + return out + + +def run_table1_mmd( + *, + pl_model, + dataset_dir: str, + split: str, + limit: int | None, + artifacts_dir: str, + run_id: str | None, + seed: int, +) -> pathlib.Path: + _seed_everything(seed) + + config = pl_model.config + config.DATA.TRAINING_DATA_DIR = dataset_dir + config.DATA.TEST_DATA_DIR = dataset_dir + config.DATA.SD_PASSTHROUGH = True + config.DATA.USE_CACHE = True + config.PREPROCESSING.keep_all_data = True + + # Required by TrafficGenEvaluator (kept for backward-compat with existing evaluator code). + config.EVALUATION.USE_TG_AS_GT = 1111 + + ds = SceneStreamerDataset(config, split) + + evaluator = TrafficGenEvaluator(config) + agg = _MetricAgg.create() + + device = pl_model.device + + for idx in range(len(ds)): + raw = ds[idx] + batched = utils.batch_data(utils.numpy_to_torch(raw, device=device)) + + # Generate initial agent states (TrafficGen-style). + densified, _ = generate_initial_state( + data_dict=batched, + model=pl_model.model, + force_add=False, + ) + + def log_func(name: str, value: Any) -> None: + agg.add(name, value) + + evaluator.validation_step(densified, stat=None, log_func=log_func) + + if limit is not None and (idx + 1) >= limit: + break + + metrics = agg.mean() + + base = pathlib.Path(artifacts_dir) + base.mkdir(parents=True, exist_ok=True) + if run_id is None: + run_id = f"table1-{seed}-{int(torch.randint(0, 1_000_000, (1,)).item())}" + out_dir = base / run_id + out_dir.mkdir(parents=True, exist_ok=False) + + with open(out_dir / "metrics.json", "w") as f: + json.dump( + { + "table": "table1", + "dataset_dir": dataset_dir, + "split": split, + "limit": limit, + "seed": seed, + "metrics": metrics, + }, + f, + indent=2, + ) + + return out_dir + diff --git a/scenestreamer/paper/table2_motion.py b/scenestreamer/paper/table2_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..3242104e7da841ef5951e546e50ad92eeb1b0822 --- /dev/null +++ b/scenestreamer/paper/table2_motion.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import json +import math +import pathlib +import random +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator +from scenestreamer.utils import utils + + +def _seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _nanmean(x: np.ndarray) -> float: + if np.isnan(x).all(): + return float("nan") + return float(np.nanmean(x)) + + +def _pick_ooi_index(sample: dict[str, Any]) -> int | None: + # Common keys across variants; keep this conservative. + for key in ("decoder/object_of_interest_id", "decoder/labeled_agent_id"): + if key in sample: + v = sample[key] + if isinstance(v, np.ndarray): + flat = v.reshape(-1) + flat = flat[flat != -1] + if flat.size: + return int(flat[0]) + elif np.isscalar(v): + return int(v) + return None + + +@dataclass +class MotionMetrics: + ade_avg: float + ade_min: float + fde_avg: float + fde_min: float + add: float + fdd: float + + +def _compute_metrics( + *, + gt_pos: np.ndarray, # (T, N, 2) + gt_valid: np.ndarray, # (T, N) + pred_pos: np.ndarray, # (K, T, N, 2) + pred_valid: np.ndarray, # (K, T, N) + agent_mask: np.ndarray, # (N,) +) -> MotionMetrics: + # Evaluate on the standard Waymo horizon: future steps [11:91] => 80 steps + t0 = 11 if gt_pos.shape[0] >= 91 else 0 + t1 = min(gt_pos.shape[0], pred_pos.shape[1]) + + gt_pos = gt_pos[t0:t1] + gt_valid = gt_valid[t0:t1] + pred_pos = pred_pos[:, t0:t1] + pred_valid = pred_valid[:, t0:t1] + + T = gt_pos.shape[0] + N = gt_pos.shape[1] + K = pred_pos.shape[0] + + # Mask invalids. + valid = (gt_valid[None] & pred_valid) # (K, T, N) + valid &= agent_mask[None, None, :] + + # L2 errors (K, T, N) + err = np.linalg.norm(pred_pos - gt_pos[None], axis=-1) + err[~valid] = np.nan + + # ADE per (K, N) + ade_kn = np.nanmean(err, axis=1) + # FDE per (K, N): last valid timestep per (K, N) + fde_kn = np.full((K, N), np.nan, dtype=np.float64) + for k in range(K): + for n in range(N): + vv = valid[k, :, n] + if not vv.any(): + continue + last = int(np.where(vv)[0][-1]) + fde_kn[k, n] = err[k, last, n] + + # Average over agents first, then modes. + ade_k = np.nanmean(ade_kn, axis=1) + fde_k = np.nanmean(fde_kn, axis=1) + + ade_avg = _nanmean(ade_k) + fde_avg = _nanmean(fde_k) + + ade_min = _nanmean(np.nanmin(ade_kn, axis=0)) + fde_min = _nanmean(np.nanmin(fde_kn, axis=0)) + + # FDD: max pairwise distance between mode final positions (per agent). + final_pos = np.full((K, N, 2), np.nan, dtype=np.float64) + for k in range(K): + for n in range(N): + vv = valid[k, :, n] + if not vv.any(): + continue + last = int(np.where(vv)[0][-1]) + final_pos[k, n] = pred_pos[k, last, n] + + fdd_n = np.full((N,), np.nan, dtype=np.float64) + for n in range(N): + if not agent_mask[n]: + continue + pts = final_pos[:, n, :] + if np.isnan(pts).any(): + # require all modes to be valid for this agent for diversity metrics + continue + # max pairwise L2 + dmax = 0.0 + for i in range(K): + for j in range(K): + d = float(np.linalg.norm(pts[i] - pts[j])) + dmax = max(dmax, d) + fdd_n[n] = dmax + + # ADD: for each (t,n), max pairwise distance across modes at that timestep; then mean over t per agent. + add_n = np.full((N,), np.nan, dtype=np.float64) + for n in range(N): + if not agent_mask[n]: + continue + per_t = [] + for t in range(T): + vv = valid[:, t, n] + if not vv.all(): + continue + pts = pred_pos[:, t, n, :] + dmax = 0.0 + for i in range(K): + for j in range(K): + d = float(np.linalg.norm(pts[i] - pts[j])) + dmax = max(dmax, d) + per_t.append(dmax) + if per_t: + add_n[n] = float(np.mean(per_t)) + + return MotionMetrics( + ade_avg=float(ade_avg), + ade_min=float(ade_min), + fde_avg=float(fde_avg), + fde_min=float(fde_min), + add=_nanmean(add_n), + fdd=_nanmean(fdd_n), + ) + + +def run_table2_motion( + *, + pl_model, + dataset_dir: str, + split: str, + mode: str, + num_modes: int, + limit: int | None, + artifacts_dir: str, + run_id: str | None, + seed: int, +) -> pathlib.Path: + _seed_everything(seed) + + config = pl_model.config + config.DATA.TRAINING_DATA_DIR = dataset_dir + config.DATA.TEST_DATA_DIR = dataset_dir + config.DATA.USE_CACHE = True + config.PREPROCESSING.keep_all_data = True + + ds = SceneStreamerDataset(config, split) + + device = pl_model.device + generator = SceneStreamerGenerator(model=pl_model.model, device=device) + + all_rows: list[dict[str, Any]] = [] + + for idx in range(len(ds)): + raw = ds[idx] + batched = utils.batch_data(utils.numpy_to_torch(raw, device=device)) + + expanded = {k: utils.repeat_for_modes(v, num_modes=num_modes) for k, v in batched.items()} + generator.reset(new_data_dict=expanded) + + if mode == "motion": + out = generator.generate_scenestreamer_motion(progress_bar=False, teacher_forcing_sdc=True) + elif mode == "full": + out = generator.generate_scenestreamer_initial_state_and_motion(progress_bar=False, teacher_forcing_sdc=True) + else: + raise ValueError(f"Unknown mode: {mode}") + + # Prefer reconstructed outputs; fall back to agent_position if needed. + pred_pos = out.get("decoder/reconstructed_position") + pred_valid = out.get("decoder/reconstructed_valid_mask") + if pred_pos is None or pred_valid is None: + raise ValueError("Model output missing decoder/reconstructed_position or decoder/reconstructed_valid_mask") + + gt_pos = expanded.get("decoder/agent_position") + gt_valid = expanded.get("decoder/agent_valid_mask") + if gt_pos is None or gt_valid is None: + raise ValueError("Input missing decoder/agent_position or decoder/agent_valid_mask") + + # Convert to numpy + pred_pos_np = pred_pos.detach().cpu().numpy() + pred_valid_np = pred_valid.detach().cpu().numpy().astype(bool) + gt_pos_np = gt_pos[0].detach().cpu().numpy()[..., :2] if gt_pos.ndim == 4 else gt_pos.detach().cpu().numpy() + gt_valid_np = gt_valid[0].detach().cpu().numpy().astype(bool) if gt_valid.ndim == 3 else gt_valid.detach().cpu().numpy().astype(bool) + + # Align shapes to (K,T,N,2) and (K,T,N) + if pred_pos_np.ndim != 4: + raise ValueError(f"Unexpected pred_pos shape: {pred_pos_np.shape}") + K, T, N, _ = pred_pos_np.shape + if pred_valid_np.shape != (K, T, N): + raise ValueError(f"Unexpected pred_valid shape: {pred_valid_np.shape} vs {(K, T, N)}") + + if gt_pos_np.ndim != 3: + raise ValueError(f"Unexpected gt_pos shape: {gt_pos_np.shape}") + if gt_valid_np.shape != gt_pos_np.shape[:2]: + raise ValueError(f"Unexpected gt_valid shape: {gt_valid_np.shape} vs {gt_pos_np.shape[:2]}") + + # Per-agent masks + all_agent_mask = np.ones((N,), dtype=bool) + ooi_idx = _pick_ooi_index(raw) + ooi_mask = np.zeros((N,), dtype=bool) + if ooi_idx is not None and 0 <= ooi_idx < N: + ooi_mask[ooi_idx] = True + + metrics_all = _compute_metrics( + gt_pos=gt_pos_np, + gt_valid=gt_valid_np, + pred_pos=pred_pos_np, + pred_valid=pred_valid_np, + agent_mask=all_agent_mask, + ) + metrics_ooi = _compute_metrics( + gt_pos=gt_pos_np, + gt_valid=gt_valid_np, + pred_pos=pred_pos_np, + pred_valid=pred_valid_np, + agent_mask=ooi_mask if ooi_mask.any() else all_agent_mask, + ) + + row = { + "index": idx, + "scenario_id": raw.get("scenario_id", None), + "all": metrics_all.__dict__, + "ooi": metrics_ooi.__dict__, + } + all_rows.append(row) + + if limit is not None and (idx + 1) >= limit: + break + + def _avg(key: str, group: str) -> float: + vals = [r[group][key] for r in all_rows if not math.isnan(r[group][key])] + return float(np.mean(vals)) if vals else float("nan") + + summary = { + "all": {k: _avg(k, "all") for k in MotionMetrics.__annotations__.keys()}, + "ooi": {k: _avg(k, "ooi") for k in MotionMetrics.__annotations__.keys()}, + } + + base = pathlib.Path(artifacts_dir) + base.mkdir(parents=True, exist_ok=True) + if run_id is None: + run_id = f"table2-{mode}-{seed}-{int(torch.randint(0, 1_000_000, (1,)).item())}" + out_dir = base / run_id + out_dir.mkdir(parents=True, exist_ok=False) + + with open(out_dir / "metrics.json", "w") as f: + json.dump( + { + "table": "table2", + "dataset_dir": dataset_dir, + "split": split, + "mode": mode, + "num_modes": num_modes, + "limit": limit, + "seed": seed, + "summary": summary, + "rows": all_rows, + }, + f, + indent=2, + ) + + return out_dir + diff --git a/scenestreamer/rl_finetuning.py b/scenestreamer/rl_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..b944e8eeefa1cf51e31f25e37d288e19e196079b --- /dev/null +++ b/scenestreamer/rl_finetuning.py @@ -0,0 +1,309 @@ +import torch + +from scenestreamer import utils + + +def _all_reduce(valid_returns, device, all_gather_func): + with torch.no_grad(): + + if valid_returns is not None: + local_count = torch.tensor([valid_returns.numel()], device=device, dtype=torch.float32) + local_sum = torch.tensor([valid_returns.sum()], device=device) + local_sq_sum = torch.tensor([(valid_returns ** 2).sum()], device=device) + else: + local_count = torch.tensor([0], device=device, dtype=torch.float32) + local_sum = torch.tensor([0], device=device) + local_sq_sum = torch.tensor([0], device=device) + + # Reduce across all ranks + global_count = local_count.clone() + global_sum = local_sum.clone() + global_sq_sum = local_sq_sum.clone() + + global_count = all_gather_func(global_count).sum() + global_sum = all_gather_func(global_sum).sum() + global_sq_sum = all_gather_func(global_sq_sum).sum() + + # Compute global mean and std + global_mean = global_sum / global_count + global_var = (global_sq_sum / global_count) - (global_mean ** 2) + global_std = torch.sqrt(global_var.clamp(min=1e-6)) # avoid nan + return global_mean, global_std + + +class RLFinetuner: + + def __init__(self, model, all_gather): + self.model = model + self.replay_buffer = None + self.replay_count = 0 + self.all_gather = all_gather + + def rollout(self, data_dict): + + original_B = data_dict["encoder/map_feature"].shape[0] + + # TODO + num_modes_for_eval = 1 + + # Autoregressive rollout + from scenestreamer.eval.waymo_motion_prediction_evaluator import _repeat_for_modes + expanded_data_dict = { + k: _repeat_for_modes(data_dict[k], num_modes=num_modes_for_eval) + for k in data_dict.keys() if ( + k.startswith("encoder/") or k.startswith("decoder/") or k.startswith("metadata/") + or k.startswith("eval/") or k.startswith("decoder/") or k == "batch_idx" or k == "in_evaluation" + or k == "in_backward_prediction" + ) + } + + from scenestreamer.infer.scenestreamer_motion import motion_prediction_task + expanded_data_dict = motion_prediction_task( + model=self.model, + data_dict=expanded_data_dict, + progress_bar=False, + use_cache=True, + keep_output_token=True, + sampling_method="softmax", + temperature=1.05, + teacher_forcing_dest=False, # TODO + ) + + output_action = expanded_data_dict["model/output_action"][:, :-1] + + # Compute reward + pred_pos = expanded_data_dict["decoder/reconstructed_position"][:, ::5] + B, T_pred, N, _ = pred_pos.shape + pred_head = expanded_data_dict["decoder/reconstructed_heading"][:, ::5] + pred_valid_mask = expanded_data_dict["decoder/reconstructed_valid_mask"][:, ::5] + agent_shape = expanded_data_dict["decoder/current_agent_shape"][:, None] + pred_contour = utils.cal_polygon_contour_torch( + x=pred_pos[..., 0], + y=pred_pos[..., 1], + theta=pred_head, + width=agent_shape[..., 1].expand(B, T_pred, N), + length=agent_shape[..., 0].expand(B, T_pred, N) + ) + + gt_pos = _repeat_for_modes(data_dict["decoder/modeled_agent_position"], num_modes=num_modes_for_eval) + T_gt = gt_pos.shape[1] + gt_head = _repeat_for_modes(data_dict["decoder/modeled_agent_heading"], num_modes=num_modes_for_eval) + gt_valid_mask = _repeat_for_modes(data_dict["decoder/input_action_valid_mask"], num_modes=num_modes_for_eval) + gt_contour = utils.cal_polygon_contour_torch( + x=gt_pos[..., 0], + y=gt_pos[..., 1], + theta=gt_head, + width=agent_shape[..., 1].expand(B, T_gt, N), + length=agent_shape[..., 0].expand(B, T_gt, N) + ) + assert T_pred == T_gt + 1 + + pred_contour = pred_contour[:, :-1] + pred_valid_mask = pred_valid_mask[:, :-1] + assert pred_contour.shape == gt_contour.shape, (pred_contour.shape, gt_contour.shape) + + error_pos = torch.norm(pred_contour - gt_contour, dim=-1).mean(-1) + error_pos = error_pos[:, 1:] + + reward_valid_mask = pred_valid_mask[:, 1:] & gt_valid_mask[:, 1:] + error_pos[~reward_valid_mask] = 0.0 + + # Now we have B, T=18, N rewards. + reward = -error_pos.detach() + + # Get it back in the original shape + reward = reward.reshape(original_B, num_modes_for_eval, -1, N).detach() + returns = torch.flip(torch.cumsum(torch.flip(reward, dims=[2]), dim=2), dims=[2]).detach() + + scenestreamer_tokens = expanded_data_dict["scenestreamer_tokens"] + all_token = scenestreamer_tokens.output_token + L = scenestreamer_tokens.L + assert self.model.no_tg + all_token = all_token.reshape(B, -1, N + L, self.model.d_model) + motion_token = all_token[:, :, L:] + if self.model.motion_prenorm is not None: + motion_token = self.model.motion_prenorm(motion_token) + motion_logit = self.model.motion_head(motion_token) + + # Get the log probs + motion_logit = motion_logit[:, :-1] + motion_logit = motion_logit.reshape(original_B, num_modes_for_eval, -1, N, self.model.num_actions) + + reward_valid_mask = reward_valid_mask.reshape(original_B, num_modes_for_eval, -1, N) + output_action = output_action.reshape(original_B, num_modes_for_eval, -1, N) + + from scenestreamer.tokenization.motion_tokenizers import START_ACTION as MOTION_START_ACTION + reward_valid_mask = reward_valid_mask & (output_action != MOTION_START_ACTION) + + if reward_valid_mask.any(): + log_probs = torch.distributions.Categorical(logits=motion_logit[reward_valid_mask]).log_prob( + output_action[reward_valid_mask] + ) + else: + log_probs = torch.distributions.Categorical(logits=motion_logit.flatten()[:1]).log_prob( + torch.zeros_like(output_action.flatten()[:1]) + ) * 0.0 + + adv_mean, adv_std = _all_reduce(returns[reward_valid_mask], device=returns.device, + all_gather_func=self.all_gather) + + advantages = (returns[reward_valid_mask] - adv_mean) / (adv_std + 1e-5) + + # Also do reward for traffic light: + tl_token = all_token[:, :, :L] + tl_token = self.model.traffic_light_prenorm(tl_token) + traffic_light_logit = self.model.traffic_light_head(tl_token) + traffic_light_gt = _repeat_for_modes(data_dict["encoder/traffic_light_state"], num_modes_for_eval) + traffic_light_mask = _repeat_for_modes(data_dict["encoder/traffic_light_valid_mask"], num_modes_for_eval) + pred_tl_state = expanded_data_dict["model/traffic_light_state"] + + pred_tl_state = pred_tl_state[:, :-1] + traffic_light_logit = traffic_light_logit[:, :-1] + gt_tl_state = traffic_light_gt[:, 1:] + traffic_light_mask = traffic_light_mask[:, 1:] + + traffic_light_logit = traffic_light_logit.reshape(original_B, num_modes_for_eval, -1, L, + traffic_light_logit.shape[-1]) + pred_tl_state = pred_tl_state.reshape(original_B, num_modes_for_eval, -1, L) + traffic_light_mask = traffic_light_mask.reshape(original_B, num_modes_for_eval, -1, L) + gt_tl_state = gt_tl_state.reshape(original_B, num_modes_for_eval, -1, L) + + if traffic_light_mask.any(): + tl_log_probs = torch.distributions.Categorical(logits=traffic_light_logit[traffic_light_mask]).log_prob( + pred_tl_state[traffic_light_mask]) + + # For TL, we normalize advantage across the whole batch. + # This is because it's easy to have model predict all 4 same actions, then adv=0. + tl_reward = (pred_tl_state == gt_tl_state).float() + tl_reward = tl_reward.reshape(original_B, num_modes_for_eval, -1, L).detach() + tl_return = torch.flip(torch.cumsum(torch.flip(tl_reward, dims=[2]), dim=2), dims=[2]).detach() + + # print("RANK {}, before all reduce. tl return {}, tl return valid shape {}".format( + # self.global_rank, + # tl_return.shape, + # tl_return[traffic_light_mask].shape + # )) + + tl_adv_mean, tl_adv_std = _all_reduce(tl_return[traffic_light_mask], device=tl_return.device, + all_gather_func=self.all_gather) + # all_gather_tl_returns = self.all_gather(tl_return) + # all_gather_tl_reward_valid_mask = self.all_gather(traffic_light_mask) + # tl_returns_valid = all_gather_tl_returns[all_gather_tl_reward_valid_mask] + # print( + # "RANK {}, ALL GATHER REWARD VALID MASK SHAPE: {}, {}. RETURN VALID {}. GLOBAL MEAN {}, STD {}. Local mean {}".format( + # self.global_rank, + # None, + # None, + # None, tl_adv_mean, tl_adv_std, tl_return[traffic_light_mask].mean())) + + tl_return = tl_return[traffic_light_mask].detach() + tl_advantages = (tl_return - tl_adv_mean) / (tl_adv_std + 1e-5) + tl_reward = tl_reward[traffic_light_mask].detach() + tl_entropy = utils.safe_entropy(traffic_light_logit[traffic_light_mask]) + tl_accuracy = ( + traffic_light_logit[traffic_light_mask].argmax(-1) == pred_tl_state[traffic_light_mask] + ).float().mean() + else: + + # print("NO TL PASS") + # You have to keep this line for the case when there is no traffic light: + tl_adv_mean, tl_adv_std = _all_reduce(None, device=traffic_light_logit.device, all_gather_func=self.all_gather) + tl_log_probs = torch.distributions.Categorical(logits=traffic_light_logit.flatten(0, 2)[:1]).log_prob( + pred_tl_state.flatten()[:1]) * 0.0 + tl_advantages = torch.zeros_like(tl_log_probs) + tl_reward = torch.zeros_like(tl_log_probs) + tl_return = torch.zeros_like(tl_log_probs) + tl_entropy = torch.zeros_like(tl_log_probs) + tl_accuracy = torch.zeros_like(tl_log_probs) + + return dict( + # motion_logit=motion_logit, + log_probs=log_probs, + advantages=advantages, + data_dict=expanded_data_dict, + + tl_entropy=tl_entropy, + tl_log_probs=tl_log_probs, + tl_advantages=tl_advantages, + tl_reward=tl_reward, + tl_return=tl_return, + tl_accuracy=tl_accuracy, + + motion_logit=motion_logit, + pred_valid_mask=pred_valid_mask, + output_action=output_action, + reward_valid_mask=reward_valid_mask, + reward=reward, + returns=returns, + pred_pos=pred_pos, + pred_head=pred_head, + + ) + + def get_loss(self, data_dict): + + # if self.replay_buffer is None or self.replay_count >= 5: + # with torch.no_grad(): + self.replay_buffer = self.rollout(data_dict) + self.replay_count = 0 + + # log_probs = self.get_log_probs(self.replay_buffer["data_dict"]) + + log_probs = self.replay_buffer["log_probs"] + advantages = self.replay_buffer["advantages"] + + # GRPO Loss: + def grpo(new_log_prob, old_log_prob, advantages): + clip_eps = 0.2 + ratio = (new_log_prob - old_log_prob).exp() + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - clip_eps, 1 + clip_eps) * advantages + loss = -torch.min(surr1, surr2) # + self.kl_weight * kl + return loss.mean() + + # REINFORCE loss: + # motion_loss = -log_probs * advantages.detach() + # motion_loss = motion_loss.mean() + motion_loss = grpo(new_log_prob=log_probs, old_log_prob=log_probs.detach(), advantages=advantages.detach()) + + tl_log_probs = self.replay_buffer["tl_log_probs"] + tl_advantages = self.replay_buffer["tl_advantages"] + + tl_loss = grpo(new_log_prob=tl_log_probs, old_log_prob=tl_log_probs.detach(), advantages=tl_advantages.detach()) + # tl_loss = -tl_log_probs * tl_advantages.detach() + # tl_loss = tl_loss.mean() + + loss = motion_loss + tl_loss + + motion_logit = self.replay_buffer["motion_logit"] + pred_valid_mask = self.replay_buffer["pred_valid_mask"] + output_action = self.replay_buffer["output_action"] + reward_valid_mask = self.replay_buffer["reward_valid_mask"] + reward = self.replay_buffer["reward"] + returns = self.replay_buffer["returns"] + pred_pos = self.replay_buffer["pred_pos"] + pred_head = self.replay_buffer["pred_head"] + tl_reward = self.replay_buffer["tl_reward"] + tl_return = self.replay_buffer["tl_return"] + + tl_entropy = self.replay_buffer["tl_entropy"] + tl_accuracy = self.replay_buffer["tl_accuracy"] + + loss_stat = { + "motion_loss": motion_loss, + "motion_accuracy": ( + motion_logit[reward_valid_mask].argmax(-1) == output_action[reward_valid_mask]).float().mean(), + "motion_entropy": utils.safe_entropy(motion_logit[reward_valid_mask]).mean(), + "motion_reward": reward[reward_valid_mask].mean(), + "motion_return": returns.mean(), + "motion_advantages": advantages.mean(), + "total_loss": loss, + + "traffic_light_loss": tl_loss, + "traffic_light_advantages": tl_advantages.mean(), + "traffic_light_reward": tl_reward.mean(), + "traffic_light_return": tl_return.mean(), + "traffic_light_entropy": tl_entropy.mean(), + "traffic_light_accuracy": tl_accuracy, + } + return loss, loss_stat diff --git a/scenestreamer/rl_train/scripts/0304_CAT.sh b/scenestreamer/rl_train/scripts/0304_CAT.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb03de52e00e92cd211c158bf0912934e0570abc --- /dev/null +++ b/scenestreamer/rl_train/scripts/0304_CAT.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/mixed_selected_CAT_training" +save_path="${EXP_NAME}" +training_step=1_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 + + +mkdir -p ${save_path} + +# Loop over each GPU +for i in {0..7} +do +CUDA_VISIBLE_DEVICES=$i \ +nohup python rl_train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb \ +--wandb_project='scenestreamer' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +> ${EXP_NAME}_seed${seeds[$i]}.log 2>&1 & +done diff --git a/scenestreamer/rl_train/scripts/0304_scgen_TF.sh b/scenestreamer/rl_train/scripts/0304_scgen_TF.sh new file mode 100644 index 0000000000000000000000000000000000000000..cfcc5df77e80a3191d21e71afbeaa4bee8d1e2ad --- /dev/null +++ b/scenestreamer/rl_train/scripts/0304_scgen_TF.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/mixed_training_500_TF" +save_path="${EXP_NAME}" +training_step=1_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 + + +mkdir -p ${save_path} + +# Loop over each GPU +for i in {0..7} +do +CUDA_VISIBLE_DEVICES=$i \ +nohup python rl_train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb \ +--wandb_project='scenestreamer' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +> ${EXP_NAME}_seed${seeds[$i]}.log 2>&1 & +done diff --git a/scenestreamer/rl_train/scripts/0421_closed_loop_SCGEN_newADV.sh b/scenestreamer/rl_train/scripts/0421_closed_loop_SCGEN_newADV.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd3fea27de063d67e7e0cf29c7db293472ed08b6 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0421_closed_loop_SCGEN_newADV.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" #"/home/yuxin/scenestreamer/debug_scgen" # "/bigdata/yuxin/scenarionet_waymo_training_500" +save_path="${EXP_NAME}" +training_step=1_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +closed_loop_source_data="/bigdata/yuxin/scenarionet_waymo_training_500" #"/home/yuxin/scenestreamer/debug_scgen" #"/bigdata/yuxin/scenarionet_waymo_training_500" +closed_loop_generator="SCGEN" + +mkdir -p ${save_path} + +for i in {0..2} +do +CUDA_VISIBLE_DEVICES=$i \ +python rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--closed_loop \ +--source_data=${closed_loop_source_data} \ +--closed_loop_generator=${closed_loop_generator} \ +--wandb \ +> ${EXP_NAME}_${closed_loop_generator}_seed${seeds[$i]}.log 2>&1 & +done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0428_closed_loop_CAT_min_probs_0.5.sh b/scenestreamer/rl_train/scripts/0428_closed_loop_CAT_min_probs_0.5.sh new file mode 100644 index 0000000000000000000000000000000000000000..e33409de49406b528e97028c2757eb183e428649 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0428_closed_loop_CAT_min_probs_0.5.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" #"/home/yuxin/scenestreamer/debug_scgen" # "/bigdata/yuxin/scenarionet_waymo_training_500" +save_path="${EXP_NAME}" +training_step=5_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +closed_loop_source_data="/bigdata/yuxin/scenarionet_waymo_training_500" #"/home/yuxin/scenestreamer/debug_scgen" #"/bigdata/yuxin/scenarionet_waymo_training_500" +closed_loop_generator="CAT" + +mkdir -p ${save_path} + +# for i in {0..4} +# do +# CUDA_VISIBLE_DEVICES=$i \ +i=2 +python rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--closed_loop \ +--source_data=${closed_loop_source_data} \ +--closed_loop_generator=${closed_loop_generator} \ +--wandb \ +# > ${EXP_NAME}_${closed_loop_generator}_seed${seeds[$i]}.log 2>&1 & +# done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0501_SceneStreamer.sh b/scenestreamer/rl_train/scripts/0501_SceneStreamer.sh new file mode 100644 index 0000000000000000000000000000000000000000..8d62571fb6ddf6fa2c3064dca73f6b5cbe97eb5c --- /dev/null +++ b/scenestreamer/rl_train/scripts/0501_SceneStreamer.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +save_path="${EXP_NAME}" +training_step=5_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +closed_loop_source_data="/bigdata/yuxin/scenarionet_waymo_training_500" +closed_loop_generator="SceneStreamer" +generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +eval_horizon=100 + +mkdir -p ${save_path} + +# for i in {0..7} +# do +i=0 +CUDA_VISIBLE_DEVICES=$i \ +python scenestreamer/rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scenestreamer' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--source_data=${closed_loop_source_data} \ +--closed_loop_generator=${closed_loop_generator} \ +--generator_ckpt_path=${generator_ckpt_path} \ +--eval_horizon=${eval_horizon} \ +--wandb \ +# --closed_loop \ +# > ${EXP_NAME}_${closed_loop_generator}_seed${seeds[$i]}.log 2>&1 & +# done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0503_SceneStreamer_closed_loop.sh b/scenestreamer/rl_train/scripts/0503_SceneStreamer_closed_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0ebd393df1811c84f1ab7df9ae7d1d20a2f161c --- /dev/null +++ b/scenestreamer/rl_train/scripts/0503_SceneStreamer_closed_loop.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +model_name="scenestreamer-base-large" +eval_horizon=100 +closed_loop_source_data="/bigdata/yuxin/scenarionet_waymo_training_500" #"/home/yuxin/scenestreamer/debug_scgen" #"/bigdata/yuxin/scenarionet_waymo_training_500" +closed_loop_generator="SceneStreamer" +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" +ckpt_path="/bigdata/yuxin/0503_SceneStreamer_closed_loop/seed_0_750000_steps.zip" +resumed_step=750000 +num_eval_envs=4 +eval_ep=100 + +mkdir -p ${save_path} + +# for i in {0..3} +# do +# gpu_id=$((i+4)) \ +i=0 +CUDA_VISIBLE_DEVICES=${i} \ +python scenestreamer/rl_train/train/td3.py \ +--debug \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--model_name=${model_name} \ +--eval_horizon=${eval_horizon} \ +--closed_loop \ +--source_data=${closed_loop_source_data} \ +--closed_loop_generator=${closed_loop_generator} \ +--num_eval_envs=${num_eval_envs} \ +--eval_ep=${eval_ep} \ + +# --wandb \ +# --ckpt_path=${ckpt_path} \ +# --resumed_step=${resumed_step} \ +# >> ${EXP_NAME}_seed${seeds[$i]}.log 2>&1 & +# done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0503_SceneStreamer_open_loop.sh b/scenestreamer/rl_train/scripts/0503_SceneStreamer_open_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..4b97ebffb05f438beba71ecef2940978f9fd1b79 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0503_SceneStreamer_open_loop.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +data_dir="/bigdata/yuxin/SceneStreamer_scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +eval_horizon=100 +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" + +mkdir -p ${save_path} + +for i in {0..3} +do +CUDA_VISIBLE_DEVICES=$i \ +python scenestreamer/rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--source_data=None \ +--closed_loop_generator=None \ +--generator_ckpt_path=${generator_ckpt_path} \ +--eval_horizon=${eval_horizon} \ +--wandb \ +> ${EXP_NAME}_seed${seeds[$i]}.log 2>&1 & +done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0503_Waymo_open_loop.sh b/scenestreamer/rl_train/scripts/0503_Waymo_open_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6a011389cd6a0a77a6180c81fad9c9d99455c62 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0503_Waymo_open_loop.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=100000 +horizon=200 +generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +eval_horizon=100 +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" + +mkdir -p ${save_path} + +# for i in {0..3} +# do +i=1 +CUDA_VISIBLE_DEVICES=$i \ +python scenestreamer/rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--source_data=None \ +--closed_loop_generator=None \ +--generator_ckpt_path=${generator_ckpt_path} \ +--eval_horizon=${eval_horizon} \ +# --wandb \ +# > ${EXP_NAME}_seed${seeds[$i]}.log 2>&1 & +# done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0503_mixed_Waymo_SceneStreamer_open_loop.sh b/scenestreamer/rl_train/scripts/0503_mixed_Waymo_SceneStreamer_open_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..a8def60eeddd05d76d91855b46b48be6e963a371 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0503_mixed_Waymo_SceneStreamer_open_loop.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/mixed_waymo_SceneStreamer_scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +eval_horizon=100 +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" + + +mkdir -p ${save_path} + +for i in {0..3} +do +CUDA_VISIBLE_DEVICES=$i \ +python scenestreamer/rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--source_data=None \ +--closed_loop_generator=None \ +--generator_ckpt_path=${generator_ckpt_path} \ +--eval_horizon=${eval_horizon} \ +--wandb \ +> ${EXP_NAME}_horizon_${horizon}_seed${seeds[$i]}.log 2>&1 & +done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0508_SceneStreamer-base-large_closed_loop.sh b/scenestreamer/rl_train/scripts/0508_SceneStreamer-base-large_closed_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..d17ee5232182196be69cad1d09125476ece680e4 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0508_SceneStreamer-base-large_closed_loop.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + +filename=$(basename "$0") +extension="${filename##*.}" +data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=100000 +horizon=100 +model_name="scenestreamer-base-large" +eval_horizon=100 +closed_loop_source_data="/bigdata/yuxin/scenarionet_waymo_training_500" #"/home/yuxin/scenestreamer/debug_scgen" #"/bigdata/yuxin/scenarionet_waymo_training_500" +closed_loop_generator="SceneStreamer" +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" +num_eval_envs=5 +eval_ep=100 + +mkdir -p ${save_path} + +# for i in {0..3} +# do +i=0 +CUDA_VISIBLE_DEVICES=${i} \ +python scenestreamer/rl_train/train/td3.py \ +--debug \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--model_name=${model_name} \ +--eval_horizon=${eval_horizon} \ +--closed_loop \ +--source_data=${closed_loop_source_data} \ +--closed_loop_generator=${closed_loop_generator} \ +--num_eval_envs=${num_eval_envs} \ +--eval_ep=${eval_ep} \ +# --wandb \ +# > ${EXP_NAME}_seed${seeds[$i]}.log 2>&1 & +# done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0508_mixed_Waymo_SceneStreamer-base-large_open_loop.sh b/scenestreamer/rl_train/scripts/0508_mixed_Waymo_SceneStreamer-base-large_open_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..48fc97f55e71cd2d1e207b6db47512dc845c0b80 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0508_mixed_Waymo_SceneStreamer-base-large_open_loop.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/mixed_Waymo_SceneStreamer_scenestreamer-base-large_scenarionet_waymo_training_500" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=50000 +horizon=100 +# generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +eval_horizon=100 +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" +model_name="scenestreamer-base-large" +num_eval_envs=5 +eval_ep=100 + +mkdir -p ${save_path} + +for i in {0..3} +do +CUDA_VISIBLE_DEVICES=$i \ +python scenestreamer/rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--source_data=None \ +--closed_loop_generator=None \ +--eval_horizon=${eval_horizon} \ +--num_eval_envs=${num_eval_envs} \ +--eval_ep=${eval_ep} \ +--wandb \ +> ${EXP_NAME}_horizon_${horizon}_seed${seeds[$i]}.log 2>&1 & +done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0510_mixed_Waymo_SceneStreamer-base-large_4_mode_open_loop.sh b/scenestreamer/rl_train/scripts/0510_mixed_Waymo_SceneStreamer-base-large_4_mode_open_loop.sh new file mode 100644 index 0000000000000000000000000000000000000000..030480fa4b7a2b26107370eeacafa8d070859169 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0510_mixed_Waymo_SceneStreamer-base-large_4_mode_open_loop.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Define the seeds for each GPU +seeds=(0 100 200 300 400 500 600 700) + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" +data_dir="/bigdata/yuxin/mixed_Waymo_SceneStreamer_scenestreamer-base-large_scenarionet_waymo_training_500_4_mode" +eval_data_dir="/bigdata/yuxin/scenarionet_waymo_validation_100" +training_step=2_000_000 +lr=1e-4 +eval_freq=100000 +horizon=100 +# generator_ckpt_path="/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250428_scenestreamer_v17_notg_finetune_2025-04-28/checkpoints" +eval_horizon=100 +EXP_NAME="${filename%.*}_horizon_${horizon}" +save_path="${EXP_NAME}" +model_name="scenestreamer-base-large" +num_eval_envs=1 +eval_ep=100 + +mkdir -p ${save_path} + +for i in {0..3} +do +CUDA_VISIBLE_DEVICES=$i \ +python scenestreamer/rl_train/train/td3.py \ +--exp_name=${EXP_NAME} \ +--wandb_project='scgen' \ +--wandb_team='drivingforce' \ +--seed=${seeds[$i]} \ +--data_dir=${data_dir} \ +--eval_data_dir=${eval_data_dir} \ +--save_path=${save_path} \ +--training_step=${training_step} \ +--lr=${lr} \ +--eval_freq=${eval_freq} \ +--horizon=${horizon} \ +--source_data=None \ +--closed_loop_generator=None \ +--eval_horizon=${eval_horizon} \ +--num_eval_envs=${num_eval_envs} \ +--eval_ep=${eval_ep} \ +--wandb \ +> ${EXP_NAME}_horizon_${horizon}_seed${seeds[$i]}.log 2>&1 & +done \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_localscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_localscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..09d0348e1dbd7e458f2d00a587537d9b2ae79b02 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_localscript.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +EXP_NAME="0511_CLRL_scenestreamer-base-large_night" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(4 5 6 7) + +# ✅ Define seeds for experiments +SEEDS=(0 100 200 300) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name=scenestreamer-base-large \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=5 \ + --eval_ep=100 \ + --wandb \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_noadaptive_localscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_noadaptive_localscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..8649742dcd9e9892fbfae604426a34b3ac6a56a8 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_noadaptive_localscript.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +EXP_NAME="0511_CLRL_scenestreamer-base-large_noadaptive_night" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(2 3 4 5) + +# ✅ Define seeds for experiments +SEEDS=(0 100 200 300) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name=scenestreamer-base-large \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=5 \ + --eval_ep=100 \ + --wandb \ + --no_adaptive \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_rlscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_rlscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..981cc4f003f51719c994e87a280bb021ddad4905 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_rlscript.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +EXP_NAME="0511_CLRL_scenestreamer-base-large" +SAVE_PATH=${EXP_NAME} +mkdir -p ${SAVE_PATH} + +CUDA_VISIBLE_DEVICES=0 \ +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-base-large \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=5 \ + --eval_ep=500 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_slurmscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..daccd88c42ad2b693c9f98a2afbdbbb1f19fbc24 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-base-large_slurmscript.sh @@ -0,0 +1,21 @@ +#!/bin/bash +#SBATCH --job-name=0511_CLRL_scenestreamer-base-large +#SBATCH --output=0511_CLRL_scenestreamer-base-large_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu03 +#SBATCH --array=0-3 + +# Launch 4 jobs + + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED} on $(hostname)" +bash 0511_CLRL_scenestreamer-base-large_rlscript.sh ${SEED} + + +# USAGE: +# +# sbatch 0511_CLRL_scenestreamer-base-large_slurmscript.sh diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large-nors_localscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large-nors_localscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..01a67dca14b75a70d9b187ce2a7eeb1aee9078f6 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large-nors_localscript.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +EXP_NAME="0511_CLRL_scenestreamer-full-large-nors_night" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(2 3 4 5) + +# ✅ Define seeds for experiments +SEEDS=(0 100 200 300) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name=scenestreamer-full-large-nors \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=5 \ + --eval_ep=100 \ + --wandb \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large-nors_noadaptive_localscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large-nors_noadaptive_localscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..9253dc2fcd41385dc01bd5e3ec4ea330fbc5122b --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large-nors_noadaptive_localscript.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +EXP_NAME="0511_CLRL_scenestreamer-full-large-nors_noadaptive_night" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(2 3 4 5) + +# ✅ Define seeds for experiments +SEEDS=(0 100 200 300) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name=scenestreamer-full-large-nors \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=5 \ + --eval_ep=100 \ + --wandb \ + --no_adaptive \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large_localscript.sh b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large_localscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..2b052d9dba4c11ef3d5179976bea6cbd4d9a9e1c --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_CLRL_scenestreamer-full-large_localscript.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +EXP_NAME="0511_CLRL_scenestreamer-full-large_night" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(2 3 4 5) + +# ✅ Define seeds for experiments +SEEDS=(0 100 200 300) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name=scenestreamer-full-large \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=5 \ + --eval_ep=100 \ + --wandb \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-1000.sh b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-1000.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4d228b6a35727785a0ca05f29e1331a739c0eb1 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-1000.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" +EXP_NAME="0511_OLRL_scenestreamer-full-large_500-1000" +SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" +mkdir -p ${SAVE_PATH} + + +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir="/bigdata/yuxin/mixed_Waymo_SceneStreamer_scenestreamer-full-large_scenarionet_waymo_training_500_2_mode" \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name="invalid" \ + --eval_horizon=100 \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-1000_slurmscript.sh b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-1000_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..b2d240482adc6839479fa726cca57765e349fe29 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-1000_slurmscript.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --job-name=0511_OLRL_scenestreamer-full-large_500-1000 +#SBATCH --output=0511_OLRL_scenestreamer-full-large_500-1000_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu02 +#SBATCH --array=0-3 + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED}" +bash 0511_OLRL_scenestreamer-full-large_500-1000.sh ${SEED} diff --git a/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-2000.sh b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-2000.sh new file mode 100644 index 0000000000000000000000000000000000000000..eff8515af919245f119023e4dcc3aaa37da94813 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-2000.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" +EXP_NAME="0511_OLRL_scenestreamer-full-large_500-2000" +SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" +mkdir -p ${SAVE_PATH} + + +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir="/bigdata/yuxin/mixed_Waymo_SceneStreamer_scenestreamer-full-large_scenarionet_waymo_training_500_4_mode" \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name="invalid" \ + --eval_horizon=100 \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-2000_slurmscript.sh b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-2000_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e32e6dd6e31680c3d9fd4fb1b578d845255efee --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-2000_slurmscript.sh @@ -0,0 +1,21 @@ +#!/bin/bash +#SBATCH --job-name=0511_OLRL_scenestreamer-full-large_500-2000 +#SBATCH --output=0511_OLRL_scenestreamer-full-large_500-2000_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu02 +#SBATCH --array=0-3 + +# Launch 4 jobs + + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED}" +bash 0511_OLRL_scenestreamer-full-large_500-2000.sh ${SEED} + + +# USAGE: +# +# sbatch 0511_CLRL_scenestreamer-base-large_slurmscript.sh diff --git a/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-500.sh b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-500.sh new file mode 100644 index 0000000000000000000000000000000000000000..6dd6e8008260f3387802d9ac411295cae7482441 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-500.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" +EXP_NAME="0511_OLRL_scenestreamer-full-large_500-500" +SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" +mkdir -p ${SAVE_PATH} + + +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir="/bigdata/yuxin/mixed_Waymo_SceneStreamer_scenestreamer-full-large_scenarionet_waymo_training_500_1_mode" \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=50000 \ + --horizon=100 \ + --model_name="invalid" \ + --eval_horizon=100 \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-500_slurmscript.sh b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-500_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..794e2a60e0aec09f474d626086b818d77549c7ee --- /dev/null +++ b/scenestreamer/rl_train/scripts/0511_OLRL_scenestreamer-full-large_500-500_slurmscript.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --job-name=0511_OLRL_scenestreamer-full-large_500-500 +#SBATCH --output=0511_OLRL_scenestreamer-full-large_500-500_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu03 +#SBATCH --array=0-3 + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED}" +bash 0511_OLRL_scenestreamer-full-large_500-500.sh ${SEED} diff --git a/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-base-large_noadaptive_gpu03.sh b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-base-large_noadaptive_gpu03.sh new file mode 100644 index 0000000000000000000000000000000000000000..a5b9e6b76d75bc92e49334c366619daf62e9eed8 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-base-large_noadaptive_gpu03.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +EXP_NAME="0512_CLRL_scenestreamer-base-large_noadaptive" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(0 1 2 3) + +# ✅ Define seeds for experiments +SEEDS=(400 500 600 700) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-base-large \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + --no_adaptive \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large-nors_gpu05.sh b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large-nors_gpu05.sh new file mode 100644 index 0000000000000000000000000000000000000000..237dd9ca671848af5da93f8b70c39af41f4b7181 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large-nors_gpu05.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +EXP_NAME="0512_CLRL_scenestreamer-full-large-nors" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(4 5 6 7) + +# ✅ Define seeds for experiments +SEEDS=(400 500 600 700) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-full-large-nors \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large-nors_noadaptive_gpu05.sh b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large-nors_noadaptive_gpu05.sh new file mode 100644 index 0000000000000000000000000000000000000000..5ee04e5a8edbbbaf6488f0fb551a97098c671634 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large-nors_noadaptive_gpu05.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +EXP_NAME="0512_CLRL_scenestreamer-full-large-nors_noadaptive" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(4 5 6 7) + +# ✅ Define seeds for experiments +SEEDS=(400 500 600 700) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-full-large-nors \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + --no_adaptive \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large_gpu02.sh b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large_gpu02.sh new file mode 100644 index 0000000000000000000000000000000000000000..89561991072a49c94ea4e2cbfa54158159b31f61 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0512_CLRL_scenestreamer-full-large_gpu02.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +EXP_NAME="0512_CLRL_scenestreamer-full-large" +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(0 1 2 3) + +# ✅ Define seeds for experiments +SEEDS=(400 500 600 700) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-full-large \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0513_OLRL_scenestreamer-full-large_500-2000.sh b/scenestreamer/rl_train/scripts/0513_OLRL_scenestreamer-full-large_500-2000.sh new file mode 100644 index 0000000000000000000000000000000000000000..8a59a1f441fdc53a9e204cecb1165234678da628 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0513_OLRL_scenestreamer-full-large_500-2000.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" +EXP_NAME="0513_OLRL_scenestreamer-full-large_500-2000" +SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" +mkdir -p ${SAVE_PATH} + + +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir="/bigdata/yuxin/mixed_Waymo_SceneStreamer_scenestreamer-full-large_scenarionet_waymo_training_500_4_mode" \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name="invalid" \ + --eval_horizon=100 \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0513_OLRL_scenestreamer-full-large_500-2000_slurmscript.sh b/scenestreamer/rl_train/scripts/0513_OLRL_scenestreamer-full-large_500-2000_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..612e9b50178c071510bfdd4fab7eef9e2185c7b5 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0513_OLRL_scenestreamer-full-large_500-2000_slurmscript.sh @@ -0,0 +1,16 @@ +#!/bin/bash +#SBATCH --job-name=0513_OLRL_scenestreamer-full-large_500-2000 +#SBATCH --output=0513_OLRL_scenestreamer-full-large_500-2000_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu06 +#SBATCH --array=0-3 + +# Launch 4 jobs + + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED}" +bash 0513_OLRL_scenestreamer-full-large_500-2000.sh ${SEED} diff --git a/scenestreamer/rl_train/scripts/0514_OLRL_waymo500.sh b/scenestreamer/rl_train/scripts/0514_OLRL_waymo500.sh new file mode 100644 index 0000000000000000000000000000000000000000..45651149758d70017267a3f488cd5908f8193cb1 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0514_OLRL_waymo500.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" +EXP_NAME="0514_OLRL_waymo500" +SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" +mkdir -p ${SAVE_PATH} + + +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name="invalid" \ + --eval_horizon=100 \ + --source_data="/bigdata/yuxin/scenarionet_waymo_training_500" \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0514_OLRL_waymo500_slurmscript.sh b/scenestreamer/rl_train/scripts/0514_OLRL_waymo500_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..74f30e624fbb80200aa072dd5b7e7332907d8407 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0514_OLRL_waymo500_slurmscript.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH --job-name=0514_OLRL_waymo500 +#SBATCH --output=0514_OLRL_waymo500_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu02 +#SBATCH --array=0-7 + + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED}" +bash 0514_OLRL_waymo500.sh ${SEED} diff --git a/scenestreamer/rl_train/scripts/0521_CLRL_scenestreamer-full-xl_gpu05.sh b/scenestreamer/rl_train/scripts/0521_CLRL_scenestreamer-full-xl_gpu05.sh new file mode 100644 index 0000000000000000000000000000000000000000..34375897204d8c58f066724c636fcceb4d9f6e36 --- /dev/null +++ b/scenestreamer/rl_train/scripts/0521_CLRL_scenestreamer-full-xl_gpu05.sh @@ -0,0 +1,54 @@ +#!/bin/bash + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(0 1 2 3) + +# ✅ Define seeds for experiments +SEEDS=(400 500 600 700) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-full-xl \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0521_CLRL_scenestreamer-full-xl_noada_gpu05.sh b/scenestreamer/rl_train/scripts/0521_CLRL_scenestreamer-full-xl_noada_gpu05.sh new file mode 100644 index 0000000000000000000000000000000000000000..63444774ccc1b66385e779b8493c09e35927c84e --- /dev/null +++ b/scenestreamer/rl_train/scripts/0521_CLRL_scenestreamer-full-xl_noada_gpu05.sh @@ -0,0 +1,55 @@ +#!/bin/bash + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" + +# ✅ Manually specify available GPU IDs +GPUS=(4 5 6 7) + +# ✅ Define seeds for experiments +SEEDS=(400 500 600 700) + +# Total number of experiments +echo "Total experiments to run: ${#SEEDS[@]}" +echo "Using GPUs: ${GPUS[@]}" + +for ((i=0; i<${#SEEDS[@]}; i++)); do + GPU_INDEX=$((i % ${#GPUS[@]})) + GPU_ID=${GPUS[$GPU_INDEX]} + SEED=${SEEDS[i]} + SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" + mkdir -p "${SAVE_PATH}" + + echo "Launching seed ${SEED} on GPU ${GPU_ID}. Saving to ${SAVE_PATH}." + + CUDA_VISIBLE_DEVICES=${GPU_ID} \ + python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name=scenestreamer-full-xl \ + --eval_horizon=100 \ + --closed_loop \ + --source_data=/bigdata/yuxin/scenarionet_waymo_training_500 \ + --closed_loop_generator=SceneStreamer \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + --no_adaptive \ + > "${EXP_NAME}_seed${SEED}.log" 2>&1 & + +done + +echo "All ${#SEEDS[@]} experiments launched." diff --git a/scenestreamer/rl_train/scripts/0521_OLRL_waymo500.sh b/scenestreamer/rl_train/scripts/0521_OLRL_waymo500.sh new file mode 100644 index 0000000000000000000000000000000000000000..02805b4c5cb140e1733d00470b53b678818c3aef --- /dev/null +++ b/scenestreamer/rl_train/scripts/0521_OLRL_waymo500.sh @@ -0,0 +1,41 @@ +#!/bin/bash + + +filename=$(basename "$0") +extension="${filename##*.}" +EXP_NAME="${filename%.*}" + + + +# Usage: bash run_td3.sh +SEED=$1 + +if [ -z "$SEED" ]; then + echo "Error: Must provide seed as first argument." + exit 1 +fi + +SAVE_PREFIX="/bigdata/zhenghao/scenestreamer/rl_exps" +SAVE_PATH="${SAVE_PREFIX}/${EXP_NAME}_seed${SEED}" +mkdir -p ${SAVE_PATH} + + +python ../train/train_td3.py \ + --exp_name=${EXP_NAME} \ + --wandb_project='scgen' \ + --wandb_team='drivingforce' \ + --seed=${SEED} \ + --data_dir="/bigdata/yuxin/scenarionet_waymo_training_500" \ + --eval_data_dir=/bigdata/yuxin/scenarionet_waymo_validation_100 \ + --save_path=${SAVE_PATH} \ + --training_step=2000000 \ + --lr=1e-4 \ + --eval_freq=100000 \ + --horizon=100 \ + --model_name="invalid" \ + --eval_horizon=100 \ + --source_data="/bigdata/yuxin/scenarionet_waymo_training_500" \ + --num_eval_envs=1 \ + --eval_ep=100 \ + --wandb \ + > ${EXP_NAME}_seed${SEED}.log 2>&1 \ No newline at end of file diff --git a/scenestreamer/rl_train/scripts/0521_OLRL_waymo500_slurmscript.sh b/scenestreamer/rl_train/scripts/0521_OLRL_waymo500_slurmscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a5bbfffd2d320fa14d3648432998828e32d92cc --- /dev/null +++ b/scenestreamer/rl_train/scripts/0521_OLRL_waymo500_slurmscript.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH --job-name=0521_OLRL_waymo500 +#SBATCH --output=0521_OLRL_waymo500_%a.out +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu05 +#SBATCH --array=0-7 + + +# Map array ID to seed +SEED=$((SLURM_ARRAY_TASK_ID * 100)) + +echo "Launching TD3 training with seed ${SEED}" +bash 0521_OLRL_waymo500.sh ${SEED} diff --git a/scenestreamer/rl_train/train/ScenarioOnlineEnvWrapper.py b/scenestreamer/rl_train/train/ScenarioOnlineEnvWrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..793e43d50612d1ba9d90758c3de609e015056f61 --- /dev/null +++ b/scenestreamer/rl_train/train/ScenarioOnlineEnvWrapper.py @@ -0,0 +1,108 @@ +from metadrive.envs.scenario_env import ScenarioOnlineEnv +from typing import Union, Dict, AnyStr +from metadrive.scenario.utils import get_number_of_scenarios +import os +import pickle +from metadrive.scenario.utils import read_dataset_summary, read_scenario_data +import numpy as np +import random +import copy + + +def get_filenames(folder_path, prefix="sd_"): + all_files = [] + for file in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, file)) and file.startswith(prefix): + all_files.append(os.path.join(folder_path, file)) + + return all_files + +class ScenarioOnlineEnvWrapper(ScenarioOnlineEnv): + def default_config(cls): + config = super().default_config() + config.update( + { + "total_timesteps": 2_000_000, + "min_prob": 0.5, + "store_map": False + } + ) + return config + + def __init__(self, generator=None, config=None, no_adaptive=False): + self.scenario_dataset = config["data_directory"] + super(ScenarioOnlineEnvWrapper, self).__init__(config) + if no_adaptive: + generator.set_no_adaptive(True) + + self.generator = generator + self.scenario_index = 0 + # print("num of scenarios:", self.num_scenarios) + self.all_scenario_files = get_filenames(self.scenario_dataset) + random.shuffle(self.all_scenario_files) + + self._total_timesteps = self.config["total_timesteps"] + self.num_timesteps = 0 + self.min_prob = self.config["min_prob"] + self.in_raw_scenario = False + + assert self.config["store_map"] is False, "store_map should be False in ScenarioOnlineEnvWrapper" + + + def set_total_time_steps(self, total_steps): + raise ValueError("please don't use this function, set config please...") + self._total_timesteps = total_steps + + + def set_timestep(self, step): + self.num_timesteps = step + + def reset(self, seed: Union[None, int] = None): + + if self.generator.ego_traj: # first time reset() does not need after_episode + self.generator.after_episode() # update current ego traj for current scenario + + original_SD_path = self.all_scenario_files[self.scenario_index % self.num_scenarios] + scenario_description = read_scenario_data(original_SD_path) + + self.set_scenario(scenario_description) # by default use the origianl SD + + self.generator.before_episode(self) # parse GT info if not done before + + assert self.engine.data_manager.current_scenario['id'] == scenario_description['id'], "Scenario ID mismatch" + + progress = max(min(self.num_timesteps / self._total_timesteps, 1), 0) + prob = self.min_prob * progress # (0 -> min_prob) + + assert len( + self.all_scenario_files) == self.num_scenarios, "The number of scenarios is not equal to the number of scenario files" + + if np.random.random() < prob: + print( + "Current step: {}, Total steps: {}, Progress: {:.2f}, Probability to generate: {:.2f}/{}. Current scenario index {}/{}. Generating...".format( + self.num_timesteps, self._total_timesteps, progress, prob, self.min_prob, self.scenario_index, + self.num_scenarios + )) + new_SD = self.generator.generate() + if new_SD is None: + pass + else: + self.set_scenario(new_SD) + self.in_raw_scenario = False + else: + print( + "Current step: {}, Total steps: {}, Progress: {:.2f}, Probability to generate: {:.2f}/{}. Current scenario index {}/{}. Use original scenario.".format( + self.num_timesteps, self._total_timesteps, progress, prob, self.min_prob, self.scenario_index, + self.num_scenarios + )) + self.in_raw_scenario = True + + self.scenario_index += 1 + o, i = super().reset() + i["in_raw_scenario"] = self.in_raw_scenario + return o, i + + def step(self, *args, **kwargs): + ret = super().step(*args, **kwargs) + ret[-1]["in_raw_scenario"] = self.in_raw_scenario + return ret diff --git a/scenestreamer/rl_train/train/customized_td3.py b/scenestreamer/rl_train/train/customized_td3.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a5d05068e497cc597bc8549c1b993534a0c8d4 --- /dev/null +++ b/scenestreamer/rl_train/train/customized_td3.py @@ -0,0 +1,288 @@ +import time +from typing import Optional +from typing import Any, ClassVar, Optional, TypeVar, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from torch.nn import functional as F + +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_parameters_by_name, polyak_update +from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy + +from stable_baselines3 import TD3 +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.off_policy_algorithm import SelfOffPolicyAlgorithm +from stable_baselines3.common.type_aliases import MaybeCallback, TrainFreq, RolloutReturn, TrainFrequencyUnit +from stable_baselines3.common.utils import safe_mean, should_collect_more_steps +from stable_baselines3.common.vec_env import VecEnv +import sys + + +class CustomizedTD3(TD3): + def __init__( + self, + policy: Union[str, type[TD3Policy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 1e-3, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: Union[int, tuple[int, str]] = 1, + gradient_steps: int = 1, + action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, + optimize_memory_usage: bool = False, + policy_delay: int = 2, + target_policy_noise: float = 0.2, + target_noise_clip: float = 0.5, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + OffPolicyAlgorithm.__init__( + self, + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + seed=seed, + sde_support=False, + optimize_memory_usage=optimize_memory_usage, + supported_action_spaces=(spaces.Box,), + support_multi_env=True, + monitor_wrapper=False, # PZH: We do not need to use the original monitor wrapper + ) + + self.policy_delay = policy_delay + self.target_noise_clip = target_noise_clip + self.target_policy_noise = target_policy_noise + + if _init_setup_model: + self._setup_model() + + def _dump_logs(self) -> None: + """ + Write log. + """ + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + + # PZH: We modify here to record environment information + first_ep_info = self.ep_info_buffer[-1] + for k, v in first_ep_info.items(): + if k not in ["r", "l"] and type(v) is not str and v is not None: + self.logger.record( + "rollout/{}_mean".format(k), safe_mean([ep_info[k] for ep_info in self.ep_info_buffer]) + ) + for k, v in first_ep_info.items(): + if k.startswith("total"): + self.logger.record("rollout/{}_sum".format(k), self.ep_info_buffer[-1][k]) + + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + if self.use_sde: + self.logger.record("train/std", (self.actor.get_std()).mean().item()) + if len(self.ep_success_buffer) > 0: + self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer)) + # Pass the number of timesteps for tensorboard + self.logger.dump(step=self.num_timesteps) + + +class Closed_Loop_TD3(CustomizedTD3): + def __init__(self, *args, training_dataset=None, **kwargs): + super(Closed_Loop_TD3, self).__init__(*args, **kwargs) + self.current_step_info = None + self.current_step_done = None + + def learn( + self: SelfOffPolicyAlgorithm, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "run", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfOffPolicyAlgorithm: + + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + assert self.env is not None, "You must set the environment before calling learn()" + assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn() + + # self.env.envs[0].set_total_time_steps(total_timesteps) # changes for closed loop + + while self.num_timesteps < total_timesteps: + rollout = self.collect_rollouts( + self.env, + train_freq=self.train_freq, + action_noise=self.action_noise, + callback=callback, + learning_starts=self.learning_starts, + replay_buffer=self.replay_buffer, + log_interval=log_interval, + ) + + if not rollout.continue_training: + break + + if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: + # If no `gradient_steps` is specified, + # do as many gradients steps as steps performed during the rollout + gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps + # Special case when the user passes `gradient_steps=0` + if gradient_steps > 0: + self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) + + callback.on_training_end() + return self + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + train_freq: TrainFreq, + replay_buffer: ReplayBuffer, + action_noise: Optional[ActionNoise] = None, + learning_starts: int = 0, + log_interval: Optional[int] = None, + ) -> RolloutReturn: + """ + Collect experiences and store them into a ``ReplayBuffer``. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param train_freq: How much experience to collect + by doing rollouts of current policy. + Either ``TrainFreq(, TrainFrequencyUnit.STEP)`` + or ``TrainFreq(, TrainFrequencyUnit.EPISODE)`` + with ```` being an integer greater than 0. + :param action_noise: Action noise that will be used for exploration + Required for deterministic policy (e.g. TD3). This can also be used + in addition to the stochastic policy for SAC. + :param learning_starts: Number of steps before learning for the warm-up phase. + :param replay_buffer: + :param log_interval: Log data every ``log_interval`` episodes + :return: + """ + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + num_collected_steps, num_collected_episodes = 0, 0 + + assert isinstance(env, VecEnv), "You must pass a VecEnv" + assert train_freq.frequency > 0, "Should at least collect one step or episode." # only one env in TD3 + + if env.num_envs > 1: + assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training." + + if self.use_sde: + self.actor.reset_noise(env.num_envs) + + callback.on_rollout_start() + continue_training = True + + while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): + if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.actor.reset_noise(env.num_envs) + + # Select action randomly or according to policy + actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs) + + # Rescale and perform action + new_obs, rewards, dones, infos = env.step(actions) + + # changes for closed loop + env.envs[0].unwrapped.set_timestep(self.num_timesteps) + assert len(env.envs) == 1 + + self.env.envs[0].unwrapped.generator.log_ego_history() # changes for closed loop + self.current_step_info = infos + self.current_step_done = dones + + self.num_timesteps += env.num_envs + num_collected_steps += 1 + + # Give access to local variables + callback.update_locals(locals()) + # Only stop training if return value is False, not when it is None. + if not callback.on_step(): + return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, + continue_training=False) + + # Retrieve reward and episode length if using Monitor wrapper + self._update_info_buffer(infos, dones) + + # Store data in replay buffer (normalized action and unnormalized observation) + self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, + infos) # type: ignore[arg-type] + + self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) + + # For DQN, check if the target network should be updated + # and update the exploration schedule + # For SAC/TD3, the update is dones as the same time as the gradient update + # see https://github.com/hill-a/stable-baselines/issues/900 + self._on_step() + + for idx, done in enumerate(dones): + if done: + # Update stats + num_collected_episodes += 1 + self._episode_num += 1 + + if action_noise is not None: + kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} + action_noise.reset(**kwargs) + + # Log training infos + if log_interval is not None and self._episode_num % log_interval == 0: + self._dump_logs() + + callback.on_rollout_end() + + # print("self.num_timesteps", self.num_timesteps) + return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/scenestreamer/rl_train/train/eval_policy.py b/scenestreamer/rl_train/train/eval_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..f9146f02dcb04b634c7bc95510a7cea2e85a7f98 --- /dev/null +++ b/scenestreamer/rl_train/train/eval_policy.py @@ -0,0 +1,341 @@ + +import numpy as np +import matplotlib.pyplot as plt +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.utils import set_random_seed +from functools import partial +from IPython.display import clear_output +import os +from stable_baselines3 import TD3 +from stable_baselines3.common.noise import NormalActionNoise +from metadrive.policy.env_input_policy import EnvInputPolicy +from metadrive.envs.scenario_env import ScenarioEnv +from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback +import logging +from collections import deque +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +import torch +import random +import argparse + +def set_seed(seed): + set_random_seed(seed) + random.seed(seed) # Python random + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch random (CPU and GPU)​:contentReference[oaicite:2]{index=2} + torch.cuda.manual_seed_all(seed) # PyTorch random (all GPU devices) + torch.backends.cudnn.deterministic = True # Use deterministic CuDNN operations​:contentReference[oaicite:3]{index=3} + torch.backends.cudnn.benchmark = False # Disable CuDNN benchmark for determinism​:contentReference[oaicite:4]{index=4} + + +def evaluate_info(all_info): + all_scenarios = list(all_info.keys()) + num_scenario = len(all_scenarios) + + _rewards = [] + _costs = [] + _completion = [] + _crash = [] + _episode_length = [] + + for i in all_scenarios: + _rewards.append(all_info[i]['episode_reward']) + _costs.append(all_info[i]['cost']) + _completion.append(all_info[i]['route_completion']) + _crash.append(1 if all_info[i]['crash'] else 0) + _episode_length.append(all_info[i]['episode_length']) + + + avg_rewards = sum(_rewards) / num_scenario + avg_costs = sum(_costs) / num_scenario + avg_completion = sum(_completion) / num_scenario + avg_collisions = sum(_crash) / num_scenario + avg_episode_length = sum(_episode_length) / num_scenario + + + result = {"num_scenario": num_scenario, "avg_rewards": avg_rewards, "avg_costs": avg_costs, "avg_completion": avg_completion, "avg_collisions": avg_collisions, "avg_length": avg_episode_length} + print(result) + + + return result + + +class CustomMonitor(Monitor): + def __init__(self, env, buffer_size=256, info_keywords=None): + super().__init__(env, info_keywords=info_keywords) + self.ep_info_buffer = deque(maxlen=buffer_size) # Initialize buffer + + def step(self, action): + obs, reward, tm, tc, info = super().step(action) + done = tm or tc + + if done: + info["episode_reward"] = info["episode_reward"] + info["episode_length"] = info["episode_length"] + info["route_completion"] = info.get("route_completion", 0) + info["cost"] = info.get("cost", 0) + info["crash"] = info.get("crash", 0) + + return obs, reward, tm, tc, info + + + +def create_env(config, need_monitor=False): + env = ScenarioEnv(config=config) + if need_monitor: + info_keywords = ["episode_reward", "episode_length", "route_completion", "cost", "crash"] + env = CustomMonitor(env, info_keywords=info_keywords) # Pass the custom metrics + + return env + +def eval_policy(config_test, checkpoint_path=None, eval_episodes=100): + # Now, save all last step's info instead of manually calculate + env=create_env(config_test) + if checkpoint_path: + model = TD3.load(checkpoint_path) + else: + model = TD3("MlpPolicy", + env, + action_noise=None, + learning_rate=1e-4, + learning_starts=200, + batch_size=1024, + tau=0.005, + gamma=0.99, + train_freq=1, + gradient_steps=1, + device="cuda", + seed=0, + verbose=2, + tensorboard_log="td3_rl", + ) + + all_info = {} + + for ep_num in range(eval_episodes): + + while True: + try: + obs, info = env.reset() + break + + except: + continue + + done = False + episode_timesteps = 0 + + collision = False + while not done: + action, _states = model.predict(obs, deterministic=True) + obs, reward, tm, tc, info = env.step(action) + + if (env.vehicle.crash_vehicle): + # print("collision") + collision = True + + done = tm or tc or info['arrive_dest'] or info['max_step'] + # done = info['arrive_dest'] or info['max_step'] + + episode_timesteps += 1 + + # if done and episode_timesteps < 10: + # continue # invalid scenario + + if collision: + info['crash'] = True + + completion = info['route_completion'] + if completion <= 0: + info['route_completion'] = 0 + + if completion >= 1: + info['route_completion'] = 1 + + if info['arrive_dest']: + info['route_completion'] = 1 + + all_info[ep_num] = info + # print("info", info) + + + results = evaluate_info(all_info) + env.close() + + return results + + +def eval_policy_formal(config_test, checkpoint_path=None, eval_episodes=100, episodes_per_env=5): + env = create_env(config_test) + + if checkpoint_path: + model = TD3.load(checkpoint_path) + else: + model = TD3("MlpPolicy", + env, + action_noise=None, + learning_rate=1e-4, + learning_starts=200, + batch_size=1024, + tau=0.005, + gamma=0.99, + train_freq=1, + gradient_steps=1, + device="cuda", + seed=0, + verbose=2, + tensorboard_log="td3_rl", + ) + + all_info = {} + ep_count = 0 # Keep track of total episodes + + for env_num in range(eval_episodes): # Each environment runs multiple episodes + for ep in range(episodes_per_env): + while True: + try: + obs, info = env.reset() + break + except: + continue + + done = False + episode_timesteps = 0 + collision = False + + while not done: + action, _states = model.predict(obs, deterministic=True) + obs, reward, tm, tc, info = env.step(action) + + if env.vehicle.crash_vehicle: + collision = True + + done = tm or tc or info['arrive_dest'] or info['max_step'] + episode_timesteps += 1 + + if collision: + info['crash'] = True + + completion = info['route_completion'] + info['route_completion'] = max(0, min(1, completion)) # Ensure range [0,1] + + if info['arrive_dest']: + info['route_completion'] = 1 + + all_info[ep_count] = info # Store per-episode results + ep_count += 1 # Increment episode count + + results = evaluate_info(all_info) + env.close() + + return results + + +def eval_ckpt_for_seeds(config_test, ckpt_root_dir, step_num): + seeds = [0, 100, 200, 300, 400, 500, 600, 700] + res_all_seeds = {} + + for seed in seeds: + ckpt_path = os.path.join(ckpt_root_dir, f"seed_{seed}_{int(step_num)}_steps.zip") + if not os.path.exists(ckpt_path): + print(f"Checkpoint not found: {ckpt_path}") + continue + + print(f"Evaluating checkpoint: {ckpt_path}") + results = eval_policy_formal(config_test=config_test, checkpoint_path=ckpt_path) + print(f"Results: {results}") + + res_all_seeds[seed] = results + + avg_results = {} + std_results = {} + valid_seeds = list(res_all_seeds.keys()) + + # avg results over all seeds and writ to avg_results + for key in res_all_seeds[valid_seeds[0]]: + values = np.array([res_all_seeds[seed][key] for seed in valid_seeds]) + avg_results[key] = sum([res_all_seeds[seed][key] for seed in valid_seeds]) / len(valid_seeds) + std_results[key] = np.std(values, ddof=1) # Compute standard deviation + + print(f"Average results over all seeds: {avg_results}") + print(f"Standard deviation over all seeds: {std_results}") + + return {"average": avg_results, "std": std_results} + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, help="The dir for held-out envs.") + parser.add_argument("--ckpt_dir", type=str, help="The dir checkpoints.") + parser.add_argument("--eval_horizon", type=int, help="Eval horizon in MD.") + parser.add_argument("--ckpt_steps", type=int, help="number of training steps of the ckpt to eval.") + args = parser.parse_args() + + TEST_HORIZON = args.eval_horizon + TEST_DIR= args.data_dir + ckpt_root_dir = args.ckpt_dir + ckpt_steps = args.ckpt_steps + + config_test = dict( + use_render=False, + manual_control= False, + show_interface= False, + data_directory=TEST_DIR, # scenarionet_waymo_training_500 + start_scenario_index=0, + num_scenarios=100, + agent_policy=EnvInputPolicy, + force_render_fps=10, + reactive_traffic=False, + sequential_seed = True, + # force_reuse_object_name = True, + horizon = TEST_HORIZON, + out_of_route_done=False, + crash_vehicle_done=False, + crash_object_done=False, + crash_human_done=False, + relax_out_of_road_done=False, + ) + + config_test_CAT = dict( + data_directory=TEST_DIR, + start_scenario_index = 0, + num_scenarios=100, + sequential_seed = False, + agent_policy=EnvInputPolicy, + force_reuse_object_name = True, + horizon = 50, + no_light = True, + no_static_vehicles = True, + reactive_traffic = False, + vehicle_config=dict( + lidar = dict(num_lasers=30,distance=50, num_others=3), + side_detector = dict(num_lasers=30), + lane_line_detector = dict(num_lasers=12)), + + # ===== Reward Scheme ===== + success_reward=10.0, + out_of_road_penalty=10.0, + crash_vehicle_penalty=1, + crash_object_penalty=1.0, + driving_reward=1.0, + # speed_reward=0.1, + # use_lateral_reward=False, + + # ===== Cost Scheme ===== + crash_vehicle_cost=1.0, + crash_object_cost=1.0, + out_of_road_cost=1.0, + + # ===== Termination Scheme ===== + out_of_route_done=False, + crash_vehicle_done=False, + relax_out_of_road_done=True, + ) + + + eval_ckpt_for_seeds(config_test, ckpt_root_dir, ckpt_steps) + + diff --git a/scenestreamer/rl_train/train/scenestreamer_rl_generator.py b/scenestreamer/rl_train/train/scenestreamer_rl_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c3410528323aaf5e8016915b9b29f69cf90501e3 --- /dev/null +++ b/scenestreamer/rl_train/train/scenestreamer_rl_generator.py @@ -0,0 +1,419 @@ +import copy + +import numpy as np +import torch + +from scenestreamer import utils +from scenestreamer.tokenization import get_tokenizer + + +def _convert_to_SD_types(data_dict_agent_type): + type_map = { + 0: "UNSET", + 1: "VEHICLE", + 2: "PEDESTRIAN", + 3: "CYCLIST", + 4: "OTHER" + } + + SD_type_name = type_map[data_dict_agent_type] + + return SD_type_name + + +def transform_to_global_coordinate(data_dict): + map_center = data_dict["metadata/map_center"].reshape(-1, 1, 3) # (1,1,3) + map_heading = data_dict["metadata/map_heading"].reshape(-1, 1, 1) + assert (map_heading == 0).all() + + expanded_mask = data_dict["decoder/reconstructed_valid_mask"][:, :, None] + data_dict["decoder/reconstructed_position"] += map_center[:, :, :2] * expanded_mask + return data_dict + + +def overwrite_to_scenario_description_new_agent(output_dict_mode, original_SD, ooi=None, type_convert_map=None, + track_length=91): + """ + Write all tracks in OOI to original_SD (discard all old tracks) + """ + new_SD = copy.deepcopy(original_SD) + if ooi is None: + ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents + + new_SD['metadata']['objects_of_interest'] = [] + new_SD['metadata']['tracks_to_predict'] = {} + new_SD['tracks'] = {} + + sdc_track_name = str(output_dict_mode["decoder/track_name"][output_dict_mode["decoder/sdc_index"]]) + + ego_traj = original_SD['tracks'][sdc_track_name]['state']['position'].copy() + ego_traj[..., -1] = 0 # Reset Z axis to 0 + + ego_avg_pos = ego_traj[original_SD['tracks'][sdc_track_name]['state']['valid']][..., :2].mean(0) + all_avg_pos = output_dict_mode["decoder/reconstructed_position"][10, :].mean(axis=0) + dist = np.linalg.norm(all_avg_pos - ego_avg_pos) + assert dist < 500, f"Original SDC average position {ego_avg_pos} and new SDC average position {all_avg_pos} are not the same. Please check your code." + print(f"Original SDC average position {ego_avg_pos} and all agents average position {all_avg_pos}") + + assert (output_dict_mode["decoder/sdc_index"] == 0) + for id in ooi: + + if id == 0: + new_SD['tracks'][sdc_track_name] = original_SD['tracks'][sdc_track_name] + new_SD['tracks'][sdc_track_name]['state']['position'] = ego_traj + new_SD['metadata']['objects_of_interest'].append(sdc_track_name) + + if sdc_track_name in original_SD['metadata']['tracks_to_predict']: + sdc_tracks_to_predict = original_SD['metadata']['tracks_to_predict'][sdc_track_name] + new_sdc_tracks_to_predict = { + 'difficulty': sdc_tracks_to_predict['difficulty'], + 'object_type': sdc_tracks_to_predict['object_type'], + 'track_id': sdc_tracks_to_predict['track_id'], + 'track_index': 0, + } + new_SD['metadata']['tracks_to_predict'][sdc_track_name] = new_sdc_tracks_to_predict + + else: + new_agent_track_name = str(output_dict_mode["decoder/track_name"][id]) + + agent_type = output_dict_mode['decoder/agent_type'][id] + if type_convert_map is not None: + new_agent_type = type_convert_map[agent_type] + else: + new_agent_type = _convert_to_SD_types(agent_type) + + new_SD['tracks'][new_agent_track_name] = {'state': {}, 'type': new_agent_type, 'metadata': {}} + + agent_traj = output_dict_mode["decoder/reconstructed_position"][:track_length, id, :2] + agent_heading = output_dict_mode["decoder/reconstructed_heading"][:track_length, id] + agent_vel = output_dict_mode["decoder/reconstructed_velocity"][:track_length, id] + agent_traj_mask = output_dict_mode["decoder/reconstructed_valid_mask"][:track_length, id] + + if "decoder/reconstructed_shape" in output_dict_mode and output_dict_mode[ + "decoder/reconstructed_shape"] is not None: + agent_length = output_dict_mode["decoder/reconstructed_shape"][:track_length, id, 0] + agent_width = output_dict_mode["decoder/reconstructed_shape"][:track_length, id, 1] + agent_height = output_dict_mode["decoder/reconstructed_shape"][:track_length, id, 2] + + else: + length = float(output_dict_mode['decoder/agent_shape'][10, id, 0]) + width = float(output_dict_mode['decoder/agent_shape'][10, id, 1]) + height = float(output_dict_mode['decoder/agent_shape'][10, id, 2]) + + agent_length = np.full((track_length,), length, dtype=float) + agent_width = np.full((track_length,), width, dtype=float) + agent_height = np.full((track_length,), height, dtype=float) + + agent_state = { + 'position': agent_traj, + 'velocity': agent_vel, + 'heading': agent_heading, + 'valid': agent_traj_mask, + 'length': agent_length, + 'width': agent_width, + 'height': agent_height, + } + + new_track = { + 'state': agent_state, + 'metadata': { + 'dataset': 'SCENESTREAMER', + 'object_id': new_agent_track_name, + 'track_length': track_length, + 'type': new_agent_type + }, + 'type': new_agent_type + } + + new_SD['tracks'][new_agent_track_name] = new_track + new_SD['metadata']['objects_of_interest'].append(new_agent_track_name) + + new_SD['metadata']['tracks_to_predict'][new_agent_track_name] = { + 'difficulty': 0, + 'object_type': new_agent_type, + 'track_id': new_agent_track_name, + 'track_index': int(id), + } + + # set new SID + new_SD['metadata']['sdc_id'] = sdc_track_name + new_SD['id'] = original_SD['id'] + if "id" in new_SD['metadata']: + new_SD['metadata']['id'] = original_SD['metadata']['id'] + new_SD['metadata']['scenario_id'] = original_SD['metadata']['scenario_id'] + new_SD['metadata']['dataset'] = 'SceneStreamer' + + return new_SD + + +def _recursive_check_type(obj, allow_types=(int, float, str, np.ndarray, dict, list, tuple, type(None), set), depth=0): + # copy MD's sanity check here + if isinstance(obj, dict): + for k, v in obj.items(): + print(f"checking key {k}") + assert isinstance(k, str), "Must use string to be dict keys" + _recursive_check_type(v, allow_types, depth=depth + 1) + + if isinstance(obj, list): + for v in obj: + _recursive_check_type(v, allow_types, depth=depth + 1) + + assert isinstance(obj, allow_types), "TypeError in key {}: Object type {} not allowed! ({})".format(obj, type(obj), + allow_types) + + if depth > 1000: + raise ValueError() + + +def overwrite_new_sdc_traj_to_SD(new_SD, new_ego_traj, new_ego_heading, new_ego_vel, track_length): + new_ego_mask = np.ones((track_length,), dtype=bool) + + assert new_ego_traj.shape[0] == new_ego_heading.shape[0] and new_ego_heading.shape[0] == new_ego_vel.shape[0] + traj_len = new_ego_traj.shape[0] + + if new_ego_traj.shape[0] < track_length: + padding_length = track_length - new_ego_traj.shape[0] + padding_traj = np.zeros((padding_length, 2)) # For positions + padding_heading = np.zeros((padding_length,)) # For heading + padding_vel = np.zeros((padding_length, 2)) # For velocity + + new_ego_traj = np.concatenate((new_ego_traj, padding_traj), axis=0) + new_ego_heading = np.concatenate((new_ego_heading, padding_heading), axis=0) + new_ego_vel = np.concatenate((new_ego_vel, padding_vel), axis=0) + + new_ego_mask[traj_len:] = 0 + + else: + new_ego_traj = new_ego_traj[:track_length] + new_ego_heading = new_ego_heading[:track_length] + new_ego_vel = new_ego_vel[:track_length] + + sdc_track_name = new_SD['metadata']['sdc_id'] + + original_ego_init_pos = new_SD['tracks'][sdc_track_name]['state']['position'][0][..., :2] + new_ego_init_pos = new_ego_traj[0] + dist = np.linalg.norm(original_ego_init_pos - new_ego_init_pos) + if dist > 1: + print( + f"ERROR?? Original SDC initial position {original_ego_init_pos} and new SDC initial position {new_ego_init_pos} are not the same. Please check your code.") + + new_SD['tracks'][sdc_track_name]['state']['position'][..., :2] = new_ego_traj + new_SD['tracks'][sdc_track_name]['state']['velocity'] = new_ego_vel + new_SD['tracks'][sdc_track_name]['state']['heading'] = new_ego_heading + new_SD['tracks'][sdc_track_name]['state']['valid'] = new_ego_mask + for agent_name in new_SD['tracks']: + new_SD['tracks'][agent_name]['state']['position'][..., -1] = 0 # Reset Z axis to 0 + + return new_SD + + +class SceneStreamerRLScenarioGenerator: + def __init__(self, model_name): + + from hydra import initialize_config_dir, compose + from scenestreamer.utils import REPO_ROOT + + if not model_name.endswith(".yaml"): + model_name += ".yaml" + # Load config with Hydra + config_path = REPO_ROOT / "cfgs" + with initialize_config_dir(config_dir=str(config_path), version_base=None): + config = compose(config_name=model_name) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + assert torch.cuda.is_available(), "CUDA is not available, please check your environment." + pl_model = utils.get_model(config=config, device=device) + + config = pl_model.config + config.PREPROCESSING.keep_all_data = True + # Set the maximum number of agents, so we can avoid making prediction for those static agents, thus saving GPU. + config.PREPROCESSING.MAX_AGENTS = 64 + + model_name = model_name.replace(".yaml", "") + assert model_name in [ + "scenestreamer-full-large", + "scenestreamer-full-xl", + "scenestreamer-base-large", + "scenestreamer-full-large-nors"], "Model name not supported. Please use scenestreamer-full-large or scenestreamer-base-large." + if "scenestreamer-full" in model_name: + assert pl_model.config.EVALUATION.TG_REJECT_SAMPLING is True + assert pl_model.config.EVALUATION.TG_SDC_DISTANCE_MASKING is False + self.model_name = model_name + + tokenizer = get_tokenizer(config) + + self.config = config + self.tokenizer = tokenizer + self.pl_model = pl_model + + self.storage = {} + self.cur_adv_agent = None + + self.num_modes = 8 # for now one mode as CAT + self.adv_id = None + self.ego_traj = [] + self.ego_vel = [] + self.ego_heading = [] + + self.scenestreamer_generator = None + self.no_adaptive = False + + def set_no_adaptive(self, no_adaptive): + self.no_adaptive = no_adaptive + + + def GPT_AR(self, input_dict): + if self.scenestreamer_generator is None: + from scenestreamer.infer.scenestreamer_generator import SceneStreamerGenerator + self.scenestreamer_generator = SceneStreamerGenerator( + model=self.pl_model.model, + device=self.pl_model.device, + ) + with torch.no_grad(): + self.scenestreamer_generator.reset(new_data_dict=input_dict) + + if self.model_name in ["scenestreamer-full-large", "scenestreamer-full-large-nors", "scenestreamer-full-xl"]: + output_dict = self.scenestreamer_generator.generate_scenestreamer_initial_state_and_motion( + progress_bar=False, + teacher_forcing_sdc=True, + ) + elif self.model_name == "scenestreamer-base-large": + output_dict = self.scenestreamer_generator.generate_scenestreamer_motion( + progress_bar=False, + teacher_forcing_sdc=True, + ) + else: + raise ValueError("Model name not supported. Please use scenestreamer-full-large or scenestreamer-base-large.") + + return output_dict + + def before_episode(self, env=None, scenario_data=None): + + if env is not None: + self.env = env + sid = self.env.engine.data_manager.current_scenario["id"] + else: + assert scenario_data is not None + sid = scenario_data["id"] + + if sid not in self.storage: + # from scenario_net data to scenestreamer data_dict + + if scenario_data is None: + assert env is not None + scenario_data = self.env.engine.data_manager.current_scenario + + sdc_id = scenario_data['metadata']['sdc_id'] + ego_pos = scenario_data['tracks'][sdc_id]['state']['position'][:, :2] + ego_heading = scenario_data['tracks'][sdc_id]['state']['heading'] + ego_vel = scenario_data['tracks'][sdc_id]['state']['velocity'][:, :2] + + sdc_gt_info = {"ego_traj": ego_pos, "ego_heading": ego_heading, "ego_vel": ego_vel} + + sdc_traj = sdc_gt_info["ego_traj"] + sdc_heading = sdc_gt_info["ego_heading"] + sdc_vel = sdc_gt_info["ego_vel"] + + self.storage[sid] = dict( + SDC_traj=sdc_traj, + SDC_heading=sdc_heading, + SDC_vel=sdc_vel, + sdc_initial_pos=sdc_traj[0].copy(), # for later use + ) + + def log_ego_history(self): + obj = self.env.engine.current_track_agent + + self.ego_traj.append(obj.position) + # print("current pos:", obj.position) + self.ego_vel.append(obj.velocity) + self.ego_heading.append(obj.heading_theta) + + def generate(self, scenario_data=None, track_length=91): + if scenario_data is None: + assert self.env is not None + scenario_data = self.env.engine.data_manager.current_scenario + + sid = scenario_data["id"] + sdc_traj = self.storage[sid].get('SDC_traj') + + if sdc_traj.shape[0] <= 10: + print("SDC traj length is too short, please check the scenario data. Skipping editing this scenario. ") + return None + + sdc_heading = self.storage[sid].get('SDC_heading') + sdc_vel = self.storage[sid].get('SDC_vel') + if isinstance(sdc_traj, list): # first time scenario in training + sdc_traj = np.array(sdc_traj) + sdc_vel = np.array(sdc_vel) + sdc_heading = np.array(sdc_heading) + + if self.no_adaptive: # for ablation + overwritten_sd = copy.deepcopy(scenario_data) + else: + overwritten_sd = overwrite_new_sdc_traj_to_SD(copy.deepcopy(scenario_data), sdc_traj, sdc_heading, sdc_vel, + track_length=track_length) # need to overwrite mask as well + + from scenestreamer.dataset.preprocessor import preprocess_scenario_description_for_motionlm + data_dict = preprocess_scenario_description_for_motionlm( + scenario=overwritten_sd, + config=self.config, + in_evaluation=False, + keep_all_data=True, + tokenizer=self.pl_model.model.motion_tokenizer + ) + + batched_data_dict = utils.batch_data(utils.numpy_to_torch(data_dict, device=self.pl_model.device)) + output_data = self.GPT_AR(batched_data_dict) + batched_data_dict.update(output_data) + data_dict = utils.unbatch_data(utils.torch_to_numpy(batched_data_dict)) + + global_output_dict = transform_to_global_coordinate(data_dict=data_dict) + type_convert_map = { + self.pl_model.model.veh_id: "VEHICLE", + self.pl_model.model.ped_id: "PEDESTRIAN", + self.pl_model.model.cyc_id: "CYCLIST", + } + new_SD = overwrite_to_scenario_description_new_agent(output_dict_mode=global_output_dict, + original_SD=scenario_data, + type_convert_map=type_convert_map) + + return new_SD # return modified scenario description + + def after_episode(self): + latest_ego_traj = np.array(self.ego_traj) # now we have the whole new traj + latest_ego_heading = np.array(self.ego_heading) + latest_ego_vel = np.array(self.ego_vel) + + if len(latest_ego_traj) <= 10: + print('Ignore traj less than 1s') # abandon bad policy + return + + sid = self.env.engine.data_manager.current_scenario["id"] + # print("in after_episode, sid:", sid) + + self.storage[sid]['SDC_traj'] = latest_ego_traj + self.storage[sid]['SDC_heading'] = latest_ego_heading + self.storage[sid]['SDC_vel'] = latest_ego_vel + + +if __name__ == '__main__': + g = SceneStreamerRLScenarioGenerator(model_name="scenestreamer-base-large") + from scenestreamer.utils import REPO_ROOT + import pathlib + import pickle + + example_sd = pathlib.Path( + REPO_ROOT) / "data" / "20scenarios" / "sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl" + with open(example_sd, "rb") as f: + scenario_description = pickle.load(f) + + import tqdm + for _ in tqdm.trange(20): + g.before_episode(scenario_data=scenario_description) + new_sd = g.generate(scenario_data=scenario_description) + + import pickle + + with open("sd_1a7143a44e480ca6_scenestreamer_test.pkl", "wb") as f: + pickle.dump(new_sd, f) diff --git a/scenestreamer/rl_train/train/test_prob.py b/scenestreamer/rl_train/train/test_prob.py new file mode 100644 index 0000000000000000000000000000000000000000..be8336e9c4536284fb04c3f473c2cfe2f643734b --- /dev/null +++ b/scenestreamer/rl_train/train/test_prob.py @@ -0,0 +1,10 @@ + +import random + +_total_timesteps = 1_000_000 +min_prob = 0.5 + +for num_timesteps in range(0, 1000000, 10000): + prob = max(1 - (2 * num_timesteps / _total_timesteps) * (1 - min_prob), min_prob) + # prob = max(min(1 - (2 * num_timesteps / _total_timesteps) * (1 - min_prob), 0.5), min_prob) + print(f"prob: {prob:.4f}") \ No newline at end of file diff --git a/scenestreamer/rl_train/train/train_td3.py b/scenestreamer/rl_train/train/train_td3.py new file mode 100644 index 0000000000000000000000000000000000000000..1c67876a74273c93a4ac0ad6a79470c97694ca93 --- /dev/null +++ b/scenestreamer/rl_train/train/train_td3.py @@ -0,0 +1,791 @@ +import argparse +import os +import random +import time +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch +import wandb +from IPython.display import clear_output +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.env_input_policy import EnvInputPolicy +from metadrive.scenario.utils import get_number_of_scenarios +from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback +from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.monitor import ResultsWriter +from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import sync_envs_normalization +from wandb.integration.sb3 import WandbCallback + +from scenestreamer.rl_train.train.ScenarioOnlineEnvWrapper import ScenarioOnlineEnvWrapper +from scenestreamer.rl_train.train.customized_td3 import CustomizedTD3, Closed_Loop_TD3 +from scenestreamer.rl_train.train.scenestreamer_rl_generator import SceneStreamerRLScenarioGenerator + + +# from metadrive.engine.logger import set_log_level +# set_log_level(logging.ERROR) + +def set_seed(seed): + set_random_seed(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) # PyTorch random (CPU and GPU) + torch.cuda.manual_seed_all(seed) # PyTorch random (all GPU devices) + torch.backends.cudnn.deterministic = True # Use deterministic CuDNN operations + torch.backends.cudnn.benchmark = False # Disable CuDNN benchmark for determinism + + +class Monitor(gym.Wrapper): + """ + A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. + + :param env: The environment + :param filename: the location to save a log file, can be None for no log + :param allow_early_resets: allows the reset of the environment before it is done + :param reset_keywords: extra keywords for the reset call, + if extra parameters are needed at reset + :param info_keywords: extra information to log, from the information return of env.step() + """ + + EXT = "monitor.csv" + + def __init__( + self, + env, + filename: Optional[str] = None, + allow_early_resets: bool = True, + reset_keywords: Tuple[str, ...] = (), + # info_keywords: Tuple[str, ...] = (), + ): + super(Monitor, self).__init__(env=env) + + # PZH: Step the environment for once to understand the info keys. + self.env.reset() + o, r, tm, tc, i = self.env.step(self.env.action_space.sample()) + info_keywords = tuple(i.keys()) + reset_keywords = tuple(reset_keywords) + ep_info_keywords = tuple("ep_" + k for k in info_keywords) + record_keys = reset_keywords + info_keywords + ep_info_keywords + + self.t_start = time.time() + if filename is not None: + self.results_writer = ResultsWriter( + filename, + header={ + "t_start": self.t_start, + "env_id": env.spec and env.spec.id + }, + extra_keys=record_keys, + ) + else: + self.results_writer = None + self.reset_keywords = reset_keywords + self.info_keywords = info_keywords + self.metadata["info_keywords"] = self.info_keywords + self.allow_early_resets = allow_early_resets + self.rewards = None + self.needs_reset = True + self.episode_returns = [] + self.episode_lengths = [] + self.episode_times = [] + + # PZH: Ours + self.episode_infos = defaultdict(list) + + self.total_steps = 0 + self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() + + def reset(self, **kwargs) -> GymObs: + """ + Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True + + :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords + :return: the first observation of the environment + """ + if not self.allow_early_resets and not self.needs_reset: + raise RuntimeError( + "Tried to reset an environment before done. If you want to allow early resets, " + "wrap your env with Monitor(env, path, allow_early_resets=True)" + ) + self.rewards = [] + self.needs_reset = False + for key in self.reset_keywords: + value = kwargs.get(key) + if value is None: + raise ValueError(f"Expected you to pass keyword argument {key} into reset") + self.current_reset_info[key] = value + + # PZH: hardcoded here to discard seed to avoid bug in ScenarioEnv + if "seed" in kwargs: + kwargs.pop("seed") + + return self.env.reset(**kwargs) + + def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + """ + Step the environment with the given action + + :param action: the action + :return: observation, reward, done, information + """ + if self.needs_reset: + raise RuntimeError("Tried to step environment that needs reset") + observation, reward, tm, tc, info = self.env.step(action) + self.rewards.append(reward) + + for key in self.info_keywords: + self.episode_infos[key].append(info[key]) + + done = tm or tc + if done: + self.needs_reset = True + ep_rew = sum(self.rewards) + ep_len = len(self.rewards) + ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} + for key in self.info_keywords: + ep_info[key] = info[key] + ep_data = np.asarray(self.episode_infos[key]) + if ep_data.dtype == object: + pass + else: + # Temporary workaround solution for accessing mean for non float/int + try: + ep_info["epavg_{}".format(key)] = np.mean(ep_data) + ep_info["epsum_{}".format(key)] = np.sum(ep_data) + except: + pass + self.episode_returns.append(ep_rew) + self.episode_lengths.append(ep_len) + self.episode_times.append(time.time() - self.t_start) + ep_info.update(self.current_reset_info) + if self.results_writer: + self.results_writer.write_row(ep_info) + info["episode"] = ep_info + self.episode_infos.clear() + self.total_steps += 1 + return observation, reward, tm, tc, info + + def close(self) -> None: + """ + Closes the environment + """ + super(Monitor, self).close() + if self.results_writer is not None: + self.results_writer.close() + + def get_total_steps(self) -> int: + """ + Returns the total number of timesteps + + :return: + """ + return self.total_steps + + def get_episode_rewards(self) -> List[float]: + """ + Returns the rewards of all the episodes + + :return: + """ + return self.episode_returns + + def get_episode_lengths(self) -> List[int]: + """ + Returns the number of timesteps of all the episodes + + :return: + """ + return self.episode_lengths + + def get_episode_times(self) -> List[float]: + """ + Returns the runtime in seconds of all the episodes + + :return: + """ + return self.episode_times + + +def evaluate_info(all_info): + all_scenarios = list(all_info.keys()) + num_scenario = len(all_scenarios) + + _rewards = [] + _costs = [] + _completion = [] + _crash = [] + _episode_length = [] + + for i in all_scenarios: + _rewards.append(all_info[i]['episode_reward']) + _costs.append(all_info[i]['cost']) + _completion.append(all_info[i]['route_completion']) + _crash.append(1 if all_info[i]['crash'] else 0) + _episode_length.append(all_info[i]['episode_length']) + + # Convert lists to NumPy arrays + _rewards = np.array(_rewards) + _costs = np.array(_costs) + _completion = np.array(_completion) + _crash = np.array(_crash) + _episode_length = np.array(_episode_length) + + result = { + "num_scenario": num_scenario, + "avg_rewards": np.mean(_rewards), + "avg_costs": np.mean(_costs), + "avg_completion": np.mean(_completion), + "avg_collisions": np.mean(_crash), + "avg_length": np.mean(_episode_length), + } + + return result + + +class CustomizedFormalevalCallback(EvalCallback): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.evaluations_info_buffer = defaultdict(list) + self._all_episode_info = {} + self._episode_counter = 0 + + def _log_success_callback(self, locals_, globals_): + info = locals_["info"] + if locals_["done"]: + info = dict(info) # Shallow copy + + completion = info['route_completion'] + if completion <= 0: + info['route_completion'] = 0 + + if completion >= 1: + info['route_completion'] = 1 + + # if info['arrive_dest']: + # info['route_completion'] = 1 + + info['crash'] = bool(info.get("crash", False)) + + self._all_episode_info[self._episode_counter] = info + self._episode_counter += 1 + + # for k in [ + # "route_completion", "cost", "arrive_dest", "out_of_road", "crash", "episode_reward", + # "episode_energy", "route_completion", "total_cost", "max_step", + # "crash" + # ]: + # if k in info: + # self.evaluations_info_buffer[k].append(info[k]) + + for k in info: + self.evaluations_info_buffer[k].append(info[k]) + + def _on_step(self) -> bool: + continue_training = True + + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + # Sync normalization if needed + if self.model.get_vec_normalize_env() is not None: + try: + sync_envs_normalization(self.training_env, self.eval_env) + except AttributeError as e: + raise AssertionError( + "Training and eval envs must be wrapped similarly for normalization." + ) from e + + # Reset success rate buffer + self._is_success_buffer = [] + self.evaluations_info_buffer.clear() + + print("Start evaluating policy for {} episodes!".format(self.n_eval_episodes)) + + episode_rewards, episode_lengths = evaluate_policy( + self.model, + self.eval_env, + n_eval_episodes=self.n_eval_episodes, + render=self.render, + deterministic=self.deterministic, + return_episode_rewards=True, + warn=self.warn, + callback=self._log_success_callback, + ) + + print("Finish evaluating policy for {} episodes!".format(self.n_eval_episodes)) + + if self.log_path is not None: + self.evaluations_timesteps.append(self.num_timesteps) + self.evaluations_results.append(episode_rewards) + self.evaluations_length.append(episode_lengths) + + kwargs = {} + # Save success log if present + if len(self._is_success_buffer) > 0: + self.evaluations_successes.append(self._is_success_buffer) + kwargs["successes"] = self.evaluations_successes + + for k, v in self.evaluations_info_buffer.items(): + assert len(v) <= self.n_eval_episodes + kwargs[k] = v + + np.savez( + self.log_path, + timesteps=self.evaluations_timesteps, + results=self.evaluations_results, + ep_lengths=self.evaluations_length, + **kwargs, + ) + + mean_reward = np.mean(episode_rewards) + std_reward = np.std(episode_rewards) + mean_ep_length = np.mean(episode_lengths) + std_ep_length = np.std(episode_lengths) + + self.last_mean_reward = float(mean_reward) + + if self.verbose >= 1: + print(f"Eval num_timesteps={self.num_timesteps}, episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") + print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") + + # self.logger.record("eval/mean_reward", float(mean_reward)) + # self.logger.record("eval/mean_ep_length", mean_ep_length) + # self.logger.record("eval/num_episodes", len(episode_rewards)) + + # if len(self._is_success_buffer) > 0: + # success_rate = np.mean(self._is_success_buffer) + # if self.verbose >= 1: + # print(f"Success rate: {100 * success_rate:.2f}%") + # self.logger.record("eval/success_rate", success_rate) + + results_dict = evaluate_info(self._all_episode_info) + self._all_episode_info = {} # Reset + self._episode_counter = 0 + + for k, v in results_dict.items(): + self.logger.record(f"eval/{k}", v) + + for k, v in self.evaluations_info_buffer.items(): + # assert len(v) == self.n_eval_episodes + try: + self.logger.record(f"eval/{k}", np.mean(np.asarray(v))) + except: + pass + + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(self.num_timesteps) + + if mean_reward > self.best_mean_reward: + if self.verbose >= 1: + print("New best mean reward!") + if self.best_model_save_path is not None: + self.model.save(os.path.join(self.best_model_save_path, "best_model")) + self.best_mean_reward = float(mean_reward) + if self.callback_on_new_best is not None: + continue_training = self.callback_on_new_best.on_step() + + if self.callback is not None: + continue_training = continue_training and self._on_event() + + return continue_training + + +def create_env(config, need_monitor=False, closed_loop=False, closed_loop_generator="SceneStreamer", model_name=None, + no_adaptive=False): + if closed_loop: + if closed_loop_generator == "SceneStreamer": + assert model_name is not None + generator = SceneStreamerRLScenarioGenerator(model_name) + elif closed_loop_generator == "SCGEN": + raise ValueError("SCGEN is not supported yet.") + generator = SCGEN_Generator("0202_midgpt", mode='SCGEN') + elif closed_loop_generator == "CAT": + raise ValueError("CAT is not supported yet.") + generator = CAT_Generator() + else: + raise ValueError("Unknown closed_loop_generator") + + env = ScenarioOnlineEnvWrapper(config=config, generator=generator, no_adaptive=no_adaptive) + else: + env = ScenarioEnv(config=config) + + if need_monitor: + env = Monitor(env) # Pass the custom metrics + + return env + + +def create_eval_env(eval_config, ): + eval_env = ScenarioEnv(config=eval_config) + eval_env = Monitor(eval_env) # Pass the custom metrics + return eval_env + + +class WandbLoggingCallback(BaseCallback): + """Logs TD3 loss and other key training metrics to Weights & Biases (W&B).""" + + def __init__(self, verbose=1): + super(WandbLoggingCallback, self).__init__(verbose) + + def _on_step(self) -> bool: + """Logs training metrics every step.""" + if "loss/critic" in self.locals: + wandb.log({"loss/critic": self.locals["loss/critic"]}) + if "loss/policy" in self.locals: + wandb.log({"loss/policy": self.locals["loss/policy"]}) + + return True + + +def train( + config_train, config_eval, load_model_path=None, seed=None, save_path="./td3", training_steps=None, + lr=None, eval_freq=None, eval_ep=None, wandb_config=None, exp_name="td3", num_eval_envs=None): + assert seed is not None + assert num_eval_envs is not None + set_seed(seed) + train_env = DummyVecEnv([lambda: create_env(config_train, True)]) # use only one training environment + save_prefix = f"seed_{seed}" + callbacks = [] + checkpoint_callback = CheckpointCallback(save_freq=50000, save_path=save_path, name_prefix=save_prefix) + use_wandb = wandb_config.get("use_wandb", False) + if use_wandb: + import wandb + project_name = wandb_config.get("wandb_project", "scgen") + team_name = wandb_config.get("wandb_team", "drivingforce") + wandb.init( + project=project_name, + entity=team_name, + name=f"{exp_name}_seed_{seed}", + group=exp_name, + sync_tensorboard=True, + save_code=True + ) + wandb_callback = WandbCallback(model_save_path=f"./wandb_models/{exp_name}_seed_{seed}", verbose=1) + wandb_loss_callback = WandbLoggingCallback() + callbacks.append(wandb_callback) + callbacks.append(wandb_loss_callback) + + eval_env = SubprocVecEnv([lambda: create_eval_env(config_eval) for _ in range(num_eval_envs)]) + + eval_callback = CustomizedFormalevalCallback( + eval_env, + eval_freq=eval_freq, + n_eval_episodes=eval_ep, + best_model_save_path=save_path, + log_path=save_path, + # deterministic=True, + # render=False + ) + + callbacks.append(checkpoint_callback) + callbacks.append(eval_callback) + + if load_model_path: + model = CustomizedTD3.load(load_model_path, env=train_env) + print(f"Resuming training from model at {load_model_path}") + + trained_steps = int(model.num_timesteps) + + remaining_steps = training_steps - trained_steps + print(f"Resuming from {trained_steps} steps; training for {remaining_steps} more steps.") + + model.learn( + total_timesteps=remaining_steps, + reset_num_timesteps=False, # Important: continue counting from previous steps + callback=callbacks, + ) + + else: + model = CustomizedTD3("MlpPolicy", + train_env, + action_noise=None, + learning_rate=lr, + learning_starts=200, + batch_size=1024, + tau=0.005, + gamma=0.99, + train_freq=1, + gradient_steps=1, + device="cuda", + seed=seed, + verbose=2, + tensorboard_log="TD3", + ) + print("Starting new training...") + + model.learn( + total_timesteps=training_steps, + callback=callbacks, + ) + + clear_output() + + +def closed_loop_train( + config_train, config_eval, load_model_path=None, seed=None, save_path=None, + training_steps=None, lr=None, eval_freq=None, eval_ep=None, wandb_config=None, + exp_name="td3", source_data=None, closed_loop_generator="SceneStreamer", + model_name=None, resumed_step=0, num_eval_envs=None, no_adaptive=False): + assert seed is not None + assert eval_ep is not None + set_seed(seed) + + train_env = create_env(config_train, need_monitor=True, closed_loop=True, + closed_loop_generator=closed_loop_generator, model_name=model_name, no_adaptive=no_adaptive) + save_prefix = f"seed_{seed}" + + checkpoint_callback = CheckpointCallback(save_freq=50000, save_path=save_path, name_prefix=save_prefix) + + callbacks = [checkpoint_callback] + + use_wandb = wandb_config.get("use_wandb", False) + if use_wandb: + project_name = wandb_config.get("wandb_project", "scgen") + team_name = wandb_config.get("wandb_team", "drivingforce") + + if load_model_path is not None and resumed_step > 0: + resumed_step = resumed_step + wandb.init( + # id=resume_id, + # resume="must", + project="scgen", + entity="drivingforce", + name=f"{exp_name}_seed_{seed}_resumed_{resumed_step}", + group=exp_name, + sync_tensorboard=True, + save_code=True # Save script files in W&B + ) + else: + wandb.init( + project=project_name, + entity=team_name, + name=f"{exp_name}_seed_{seed}", + group=exp_name, + sync_tensorboard=True, + save_code=True # Save script files in W&B + ) + + wandb_callback = WandbCallback(model_save_path=f"./wandb_models/{exp_name}_seed_{seed}", verbose=1) + callbacks.append(wandb_callback) + + # eval_env = None + eval_env = SubprocVecEnv([(lambda: create_eval_env(config_eval)) for _ in range(num_eval_envs)]) + + if eval_env is not None: + eval_callback = CustomizedFormalevalCallback( + eval_env, + eval_freq=eval_freq, + n_eval_episodes=eval_ep, + best_model_save_path=save_path, + log_path=save_path, + # deterministic=True, + # render=False + ) + callbacks.append(eval_callback) + + callbacks.append(checkpoint_callback) + + if load_model_path: + model = Closed_Loop_TD3.load(load_model_path, env=train_env) + print(f"Resuming training from model at {load_model_path}") + + trained_steps = int(model.num_timesteps) + + remaining_steps = training_steps - trained_steps + print(f"Resuming from {trained_steps} steps; training for {remaining_steps} more steps.") + + model.learn( + total_timesteps=remaining_steps, + reset_num_timesteps=False, # Important: continue counting from previous steps + callback=callbacks, + ) + + else: + model = Closed_Loop_TD3("MlpPolicy", + train_env, + action_noise=None, + learning_rate=lr, + learning_starts=200, + batch_size=1024, + tau=0.005, + gamma=0.99, + train_freq=1, + gradient_steps=1, + device="cuda", + seed=seed, + verbose=2, + tensorboard_log=str(save_path), + training_dataset=source_data + ) + print("Starting new training...") + + model.learn( + total_timesteps=training_steps, + callback=callbacks, + ) + + clear_output() + + +def train_wrapper( + config_train, config_eval, exp_name, seed, save_path, ckpt_path=None, + training_steps=None, eval_freq=None, lr=None, wandb_config=None, closed_loop=False, + closed_loop_dir=None, closed_loop_generator="SceneStreamer", model_name=None, resumed_step=0, + num_eval_envs=None, eval_ep=None, no_adaptive=False): + print("current learning rate:", lr) + + if not closed_loop: + train( + config_train=config_train, + config_eval=config_eval, + load_model_path=ckpt_path, + seed=seed, + save_path=save_path, + training_steps=training_steps, + lr=lr, + eval_freq=eval_freq, + wandb_config=wandb_config, + exp_name=exp_name, + num_eval_envs=num_eval_envs, + eval_ep=eval_ep, + ) + + else: + assert closed_loop_dir is not None, "Please provide the closed loop source data directory." + closed_loop_train( + config_train=config_train, + config_eval=config_eval, + load_model_path=ckpt_path, + seed=seed, + save_path=save_path, + training_steps=training_steps, + lr=lr, + eval_freq=eval_freq, + wandb_config=wandb_config, + exp_name=exp_name, + source_data=closed_loop_dir, + closed_loop_generator=closed_loop_generator, + model_name=model_name, + resumed_step=resumed_step, + num_eval_envs=num_eval_envs, + eval_ep=eval_ep, + no_adaptive=no_adaptive + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, help="The dir for this batch of experiments.") + parser.add_argument("--eval_data_dir", type=str, help="The dir for this batch of experiments.") + parser.add_argument("--save_path", type=str, help="The dir for checkpoints to save.") + parser.add_argument("--exp_name", default="td3_metadrive", type=str, help="The name for this batch of experiments.") + parser.add_argument("--seed", default=0, type=int, help="The random seed.") + parser.add_argument("--training_step", default=1_000_000, type=int, help="The number of steps in training") + parser.add_argument("--eval_freq", default=50000, type=int, help="Eval frequency.") + parser.add_argument("--wandb", action="store_true", help="Set to True to upload stats to wandb.") + parser.add_argument("--wandb_project", type=str, default="", help="The project name for wandb.") + parser.add_argument("--wandb_team", type=str, default="", help="The team name for wandb.") + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate of TD3.") + parser.add_argument("--CAT_config", action="store_true", help="Set to CAT train config") + parser.add_argument("--horizon", type=int, help="training horizon") + parser.add_argument("--eval_horizon", type=int, default=100, help="training horizon") + parser.add_argument("--ckpt_path", type=str, help="pre-trained policy path") + parser.add_argument("--closed_loop", action="store_true", help="closd_loop") + parser.add_argument("--source_data", type=str, help="closd_loop source data directory") + parser.add_argument("--closed_loop_generator", type=str, help="closed_loop_generator") + parser.add_argument("--model_name", type=str, help="model name") + parser.add_argument("--resume_wandb_id", type=str, help="wandb run id to resume training") + parser.add_argument("--resumed_step", type=int, help="resuming step number.") + parser.add_argument("--num_eval_envs", type=int, help="resuming step number.") + parser.add_argument("--eval_ep", type=int, help="number of eval episodes.") + parser.add_argument("--no_adaptive", action="store_true", help="generator takes GT ego traj") + + args = parser.parse_args() + + num_scenario = get_number_of_scenarios(args.data_dir) + print(f"Number of scenarios: {num_scenario}") + print(f"Number of training horizon: {args.horizon}") + + # Assert gpu is there + assert torch.cuda.is_available(), "GPU is not available. Please check your CUDA installation." + + config_train = dict( + store_map=False, + + use_render=False, + manual_control=False, + show_interface=False, + data_directory=args.data_dir, + # "/home/yuxin/scenestreamer/mixed_selected_scgen_diverse_coll_training", # "/home/yuxin/scenestreamer/mixed_selected_CAT_training", # "/home/yuxin/scenestreamer/mixed_training_500_TF", # SCGEN_waymo_training_500 + agent_policy=EnvInputPolicy, + start_scenario_index=0, + num_scenarios=num_scenario, + sequential_seed=False, + horizon=args.horizon, + reactive_traffic=False, + no_static_vehicles=True, + no_light=True, + # crash_vehicle_done=True, + # out_of_route_done=True, + # crash_object_done=True, + # crash_human_done=False, + # relax_out_of_road_done=False, + ) + + config_eval = dict( + store_map=False, + + use_render=False, + manual_control=False, + show_interface=False, + data_directory=args.eval_data_dir, + # "/home/yuxin/scenestreamer/mixed_selected_scgen_diverse_coll_training", # "/home/yuxin/scenestreamer/mixed_selected_CAT_training", # "/home/yuxin/scenestreamer/mixed_training_500_TF", # SCGEN_waymo_training_500 + agent_policy=EnvInputPolicy, + start_scenario_index=0, + num_scenarios=get_number_of_scenarios(args.eval_data_dir), + sequential_seed=True, + horizon=args.eval_horizon, + reactive_traffic=False, + no_static_vehicles=True, + no_light=True, + crash_vehicle_done=False, + out_of_route_done=False, + crash_object_done=False, + crash_human_done=False, + relax_out_of_road_done=False, + ) + + if args.closed_loop: + config_train["total_timesteps"] = args.training_step + + wandb_config = { + "use_wandb": args.wandb, + "wandb_project": args.wandb_project, + "wandb_team": args.wandb_team, + } + + train_wrapper( + config_train=config_train, + config_eval=config_eval, + exp_name=args.exp_name, + seed=args.seed, + save_path=args.save_path, + ckpt_path=args.ckpt_path, + training_steps=args.training_step, + lr=args.lr, + eval_freq=args.eval_freq, + wandb_config=wandb_config, + closed_loop=args.closed_loop, + closed_loop_dir=args.source_data, + closed_loop_generator=args.closed_loop_generator, + model_name=args.model_name, + resumed_step=args.resumed_step, + num_eval_envs=args.num_eval_envs, # TODO: num_eval_envs just set it to 5? + eval_ep=args.eval_ep, # TODO: this is wrong. it should be eval_horizon, not eval_ep. + no_adaptive=args.no_adaptive + ) diff --git a/scenestreamer/tokenization/0305_fast_all/delta_normalization_quantiles.json b/scenestreamer/tokenization/0305_fast_all/delta_normalization_quantiles.json new file mode 100644 index 0000000000000000000000000000000000000000..be508c8ba8fed5cbc07420316b2b2f7a4a00c479 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/delta_normalization_quantiles.json @@ -0,0 +1 @@ +{"q_lower": [-0.056130945682525635, -0.03426671009510755, -0.21757962703704836, -0.07495682165026665, -0.033623198792338375, -0.21293401718139648, -0.09630700498819351, -0.033363208174705505, -0.2103518009185791, -0.11870235204696655, -0.034003064036369324, -0.2113771677017212, -0.14167103469371795, -0.03508753385394812, -0.21343374252319336], "q_upper": [0.05557198449969292, 2.095459222793579, 0.21746430397033656, 0.06942770779132834, 2.095139503479004, 0.2140974283218382, 0.08460234701633451, 2.095113754272461, 0.21288055181503163, 0.10136498883366526, 2.095697116851806, 0.2135203123092646, 0.119339220225811, 2.0970006465911855, 0.21647340059280396]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_all/error_mean.json b/scenestreamer/tokenization/0305_fast_all/error_mean.json new file mode 100644 index 0000000000000000000000000000000000000000..4b83489ec84a7befdf4642f1bf2eef376a4f0b5b --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/error_mean.json @@ -0,0 +1 @@ +[[0.00046348866418710003, 0.00512541053867593, 0.021890314762240518], [0.0004924957902618292, 0.01621369319469648, 0.02454723051712962], [0.0005961439731003831, 0.03377028682929846, 0.028674030058411584], [0.0008347955116343095, 0.058032792512467915, 0.029874285496576607], [0.0012092677231452587, 0.09025443860971505, 0.032739394065719025]] \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_all/norm_info_cyc.json b/scenestreamer/tokenization/0305_fast_all/norm_info_cyc.json new file mode 100644 index 0000000000000000000000000000000000000000..b71565c718909ed427919e52bf6f356c26b49e8f --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/norm_info_cyc.json @@ -0,0 +1 @@ +{"q_lower": [-0.05620414614677429, -0.04952176660299301, -0.12296247482299805, -0.059109173715114594, -0.049527047015726564, -0.1218917369842529, -0.06488734185695648, -0.04870385266840458, -0.12222493886947629, -0.0770510733127594, -0.04834958184510469, -0.11846284866333005, -0.08916960656642912, -0.04897959642112254, -0.12207889556884766], "q_upper": [0.05849912762641907, 0.9729149580001821, 0.12755169868469196, 0.060657307505607605, 0.9720799922943115, 0.12100558280944773, 0.06705052554607388, 0.9728403091430664, 0.12269333600997914, 0.07628470659255981, 0.9754065752029415, 0.12187862396240234, 0.08799059689044947, 0.9741636753082262, 0.12116074562072754]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_all/norm_info_ped.json b/scenestreamer/tokenization/0305_fast_all/norm_info_ped.json new file mode 100644 index 0000000000000000000000000000000000000000..03453104fdb8841f37fb90ecebf907f4b41aa447 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/norm_info_ped.json @@ -0,0 +1 @@ +{"q_lower": [-0.05694263204932213, -0.04731698837131262, -0.6616688060760498, -0.05341310054063797, -0.04721499979496002, -0.6466223359107973, -0.052964311093091965, -0.04694697543978691, -0.6535398435592653, -0.055489320307970054, -0.04842733725905419, -0.665165858268738, -0.060123561397194866, -0.05154890127480031, -0.6728901505470277], "q_upper": [0.05659624136984333, 0.2950726473331451, 0.6576809883117569, 0.053949579633771956, 0.2932160198688507, 0.6581473350524902, 0.052648843824863234, 0.29114753007888794, 0.6527402067184136, 0.054819024540483714, 0.2921927571296692, 0.6676728796958922, 0.06044423863291734, 0.2948721647262573, 0.6923989838361616]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_all/norm_info_veh.json b/scenestreamer/tokenization/0305_fast_all/norm_info_veh.json new file mode 100644 index 0000000000000000000000000000000000000000..3e751c6a5e452e25b3531d333baba780d3c136ef --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/norm_info_veh.json @@ -0,0 +1 @@ +{"q_lower": [-0.05588071268051863, -0.02626045051962137, -0.04006481170654297, -0.07810875803232192, -0.025505834873765707, -0.039873600006103516, -0.10224558517336846, -0.024929437786340714, -0.03981208801269531, -0.1268754607439041, -0.025142773985862732, -0.039824485778808594, -0.15168607711791993, -0.025687281861901286, -0.03987455368041992], "q_upper": [0.05508363403379943, 2.194649739265449, 0.03848123550415039, 0.07283832132816315, 2.194388873577119, 0.03825026750564575, 0.09218260884285012, 2.1935806417465233, 0.038222551345825195, 0.11195725649595334, 2.194534349441536, 0.03817129135131836, 0.13210431039333392, 2.1963739585876496, 0.03825235366821289]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_all/processor_config.json b/scenestreamer/tokenization/0305_fast_all/processor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..8d9a3f725377bee10763626aa8cf6ddb5ad8ac87 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/processor_config.json @@ -0,0 +1,8 @@ +{ + "action_dim": null, + "min_token": -22, + "processor_class": "UniversalActionProcessor", + "scale": 10, + "time_horizon": null, + "vocab_size": 1024 +} diff --git a/scenestreamer/tokenization/0305_fast_all/special_tokens_map.json b/scenestreamer/tokenization/0305_fast_all/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/special_tokens_map.json @@ -0,0 +1 @@ +{} diff --git a/scenestreamer/tokenization/0305_fast_all/tokenizer.json b/scenestreamer/tokenization/0305_fast_all/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..fd291f427f415552bc7600c5e4f2a001c3a76925 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/tokenizer.json @@ -0,0 +1,4847 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": false, + "vocab": { + "\u0000": 0, + "\u0001": 1, + "\u0002": 2, + "\u0003": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "Ā": 45, + "ā": 46, + "Ă": 47, + "ă": 48, + "Ą": 49, + "ą": 50, + "Ć": 51, + "ć": 52, + "Ĉ": 53, + "ĉ": 54, + "Ċ": 55, + "ċ": 56, + "Č": 57, + "č": 58, + "Ď": 59, + "ď": 60, + "Đ": 61, + "đ": 62, + "Ē": 63, + "ē": 64, + "Ĕ": 65, + "ĕ": 66, + "Ė": 67, + "ė": 68, + "Ę": 69, + "ę": 70, + "Ě": 71, + "ě": 72, + "Ĝ": 73, + "ĝ": 74, + "Ğ": 75, + "ğ": 76, + "Ġ": 77, + "ĖĖ": 78, + "ĖĖĖĖ": 79, + "Ėĕ": 80, + "ĖĖĖĖĖĖĖĖ": 81, + "ėĖĖ": 82, + "ĖĖĖ": 83, + "ĖĖĕ": 84, + "ĖĕĖĖĖĖĖĖĖĖ": 85, + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ": 86, + "ĖĖĖĖĖ": 87, + "ėĀ": 88, + "ėĀĖĕĖĖĖĖĖĖĖĖĖĖĖ": 89, + "Ėė": 90, + "ĖĔ": 91, + "ĖĖĕĖĖ": 92, + "ĘĖĖ": 93, + "ĖĕĖĖ": 94, + "ĕĖ": 95, + "ėĖ": 96, + "ĕĖĖ": 97, + "ĔĖĖ": 98, + "ėĖĖėĖĖ": 99, + "ĖĕĖĖĕ": 100, + "ėĖĖĖĖĖ": 101, + "ėă": 102, + "ēĖĖ": 103, + "ĖĕĖĖĖĖĖ": 104, + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ": 105, + "Ęă": 106, + "ĔĖ": 107, + "ėĖĖĖĖ": 108, + "ėĖĕ": 109, + "ęĖĖ": 110, + "ĕĕ": 111, + "ĖėĖĖ": 112, + "ėā": 113, + "ėĖĖĖĖĖĖĖĖ": 114, + "ėĕ": 115, + "Ėĕĕ": 116, + "ėĖė": 117, + "ĖĖĖĖĕ": 118, + "ĖĖĕĖĖĖĖĖ": 119, + "ĘĖ": 120, + "Ėă": 121, + "ĖĔĖĖ": 122, + "ĖĖĖĖĖĖ": 123, + "ėĂ": 124, + "ĖĕĖĖĖĖ": 125, + "ĖĘ": 126, + "ĘĀ": 127, + "Ėē": 128, + "ĖĖĖĖĖĖĖ": 129, + "ĖĕĖĖĕĖĖ": 130, + "ĖĖĕĖĖĕĖĖ": 131, + "ĖĖĖĖĖĖĖĖĖĖ": 132, + "ĔĖė": 133, + "ėĄ": 134, + "ĒĖĖ": 135, + "ėĖĖĕ": 136, + "ĖĖĖĖĖĕĖĖ": 137, + "ęă": 138, + "Ęā": 139, + "ĖĕĖĖĕĖĖĖĖĖĖĖĖ": 140, + "ėĖĖĕĖĖ": 141, + "ĘĂ": 142, + "ĚĖĖ": 143, + "ĕĖė": 144, + "ĖĖĖĖĖĖĖĖĖĖĖ": 145, + "ĖĀ": 146, + "ĔĖĖĕĖĖ": 147, + "ēĖ": 148, + "ĘĖĕ": 149, + "ĖĔĖĖĕ": 150, + "ĔĖĕ": 151, + "ĘĄ": 152, + "ĖĖĖĖĖĖĖĖĖĖĖĖ": 153, + "ĖĖĖĖĖėĖĖ": 154, + "ĘĖĖėĖĖ": 155, + "ĔĖĖĕ": 156, + "ĘĖė": 157, + "ęĖ": 158, + "ėė": 159, + "ĕĖĕ": 160, + "ĖĂ": 161, + "ĔĖĖĖĖĖ": 162, + "Ėā": 163, + "ĕă": 164, + "ėĖĔ": 165, + "ĖĖĖĖėĖĖ": 166, + "Ęĕ": 167, + "Ĕĕ": 168, + "ĖĄ": 169, + "đĖĖ": 170, + "Ėę": 171, + "ĖĕĖĖėĖĖ": 172, + "ĖĒ": 173, + "ēĖė": 174, + "ĖĕĖė": 175, + "ĘĖĖĖĖĖ": 176, + "ėĖĖĖĖĖĖĖĖĖĖ": 177, + "ĕĖėĖĖ": 178, + "ėĖėĖĖ": 179, + "ĖĕĖĖĖĖĖĖĖĖĕĖĖ": 180, + "ĖĕĖĖĖĖĖĖĖĖėĖĖ": 181, + "ĕĖĘ": 182, + "ęĀ": 183, + "ėĖĘ": 184, + "ĘĖĖĘĖĖ": 185, + "ĖĕĖĖĖĖĖĕĖĖĖĖĖ": 186, + "ėĖĖėĖĖėĖĖ": 187, + "ĖĕĖĖĕĖĖĕĖĖĖĖĖ": 188, + "ĔĖĖĔĖĖ": 189, + "ėāĖĕĖĖĖĖĖĖĖĖĖĖĖ": 190, + "ěĖĖ": 191, + "ėą": 192, + "ĖēĖĖ": 193, + "ĕĖĖĕ": 194, + "ėĔ": 195, + "ėĆ": 196, + "ĖĖĖĖĕĖĖ": 197, + "ęĂ": 198, + "ĖĖĖĖĖĖĖĖĖĖėĖĖ": 199, + "ĖėĖĖėĖĖ": 200, + "ĖĖĖĖė": 201, + "Ęăĕ": 202, + "ĖĘĖĖ": 203, + "ēĖĕ": 204, + "ĖĚ": 205, + "ęĖĕ": 206, + "ĖĕĖĖĖĖĖėĖĖĖĖĖ": 207, + "ėć": 208, + "ęā": 209, + "ėĘ": 210, + "ėĖĕĖĖ": 211, + "ĖĖĖĖĖĖĖėĖĖĖĖĖ": 212, + "ėĖĖĖĖĖėĖĖ": 213, + "ĖĖĖĖĖĖĖĖĖĖĕĖĖ": 214, + "ĖĖĖėĖĖ": 215, + "Ėđ": 216, + "ĕĖĔ": 217, + "ėĈ": 218, + "Ěă": 219, + "ėĕĖĖ": 220, + "ęĄ": 221, + "ĖĖĕĖė": 222, + "ėē": 223, + "ĖĐ": 224, + "ėĖĖĖĖĕ": 225, + "ėėĖĖ": 226, + "ĕĀ": 227, + "ėĒ": 228, + "ėĖĖĖĖĖĕĖĖ": 229, + "ęĖė": 230, + "ĕĖē": 231, + "ĕĔ": 232, + "ėĖĖĖĖĖĖ": 233, + "ĖĕĖĕ": 234, + "ĖĕĖĖĕĖĖĕĖĖ": 235, + "ĖĔĖĖĕĖĖ": 236, + "ĖĕĕĖĖĖĖĖĖĖĖĖĖ": 237, + "ĖĕĖĖĕĖĖĕĖĖĕĖĖ": 238, + "ĖĖĖĖĖĖĖĕĖĖĖĖĖ": 239, + "ĕĖĖĖĖ": 240, + "Ėě": 241, + "ėĕĖ": 242, + "ĖĖĖĖėĖĖĖĖĖĖĖĖ": 243, + "ėĐ": 244, + "ĖĕĖĖĕĖĖĖĖĖ": 245, + "ėĖē": 246, + "ėĎ": 247, + "ĔĖĘ": 248, + "ėď": 249, + "ėĖĖĘĖĖ": 250, + "ėđ": 251, + "ĕā": 252, + "ĕĂ": 253, + "ėĖę": 254, + "ĖĜ": 255, + "ĕĖę": 256, + "ĖĖĖĖĕĖĖĖĖĖĖĖĖ": 257, + "ĐĖĖ": 258, + "Ėď": 259, + "ĖĕĖĖĖĖĖĕĖĖĕĖĖ": 260, + "ėę": 261, + "ĖĕĖĖĕĖĖĖĖĖĕĖĖ": 262, + "ĖĔĖĖĖĖĖ": 263, + "ĖėĖĖĖĖĖ": 264, + "ĖĖĖĖĖĖĖĖėĖĖ": 265, + "ėėĖ": 266, + "ĕē": 267, + "ĖĕĖĖĖ": 268, + "ėĚ": 269, + "ĕĕĖĖ": 270, + "ĜĖĖ": 271, + "ĔĖĔ": 272, + "Ęą": 273, + "ĔĖĖėĖĖ": 274, + "ĘĖĖĕĖĖ": 275, + "ĖĖĖĖĖĖĖĖĕĖĖ": 276, + "ĖĕĖĖĔ": 277, + "ĖėĖĖĕ": 278, + "ĖĕĖĖĖĖĖėĖĖėĖĖ": 279, + "ĘĆ": 280, + "ĖĖĖĖĖĕĖĖĖĖĖ": 281, + "ĖĎ": 282, + "ĖĕĖĖĔĖĖ": 283, + "ĘĖĔ": 284, + "ĖĕėĖĖĖĖĖĖĖĖĖĖ": 285, + "ĖĖĖĖĖĖĖėĖĖėĖĖ": 286, + "ĕĖĖĖ": 287, + "Ėĝ": 288, + "ĖĕĖĖĕĖĖĖĖĖėĖĖ": 289, + "ĒĖė": 290, + "Ęć": 291, + "ėĂĖĕĖĖĖĖĖĖĖĖĖĖĖ": 292, + "ĕĄ": 293, + "ėĕĖė": 294, + "ĘĖĘ": 295, + "ĖĔĖĖĔ": 296, + "ĘĖĖĕ": 297, + "ĖĔĖė": 298, + "ėě": 299, + "ĘĈ": 300, + "Ėĕė": 301, + "Ĕă": 302, + "ĘĎ": 303, + "ĖĖĕĖĖĖĖĖĖĖĖ": 304, + "ēĖĖĕĖĖ": 305, + "ĖĕĕĖė": 306, + "Ęď": 307, + "ĖĔĕ": 308, + "ĕĖĒ": 309, + "Ęăĕĕ": 310, + "ĖĔĖĖĕĖĖĖĖĖĖĖĖ": 311, + "ĖĖĖĖėĖĖėĖĖĖĖĖ": 312, + "ĘĒ": 313, + "Ěā": 314, + "ĖĕĖ": 315, + "Ęē": 316, + "Ęđ": 317, + "ĕė": 318, + "ĘĐ": 319, + "ĖĖĕĖĕ": 320, + "ĚĂ": 321, + "ĘĔ": 322, + "ĕĖĚ": 323, + "ėĖĚ": 324, + "ĖĞ": 325, + "ėăĖĕĖĖĖĖĖĖĖĖĖĖĖ": 326, + "ĚĀ": 327, + "ėĜ": 328, + "ĔĖėĖĖ": 329, + "ėėĖĕ": 330, + "ėĖĖĕĖĖĖĖĖ": 331, + "ęăĕ": 332, + "ĘĘ": 333, + "ĖĖĖĖĖĖĖĕĖĖĕĖĖ": 334, + "ėăĕ": 335, + "ėĖĒ": 336, + "ēĖĖĔĖĖ": 337, + "ęĖĖėĖĖ": 338, + "ēĕ": 339, + "ĕĖĖėĖĖ": 340, + "ĖĒĖĖ": 341, + "Ėą": 342, + "ĖĖĖĖĖĖĖėĖĖ": 343, + "ĖĖĕĖĖĖĖĖĖĖĖĖĖ": 344, + "ĖĆ": 345, + "ĕĕĖė": 346, + "ęĖĖĘĖĖ": 347, + "Ėć": 348, + "ėĝ": 349, + "ĘĚ": 350, + "ĖĖĖĖėĖĖėĖĖ": 351, + "ĖĈ": 352, + "ĖĔĖĖĖĖĖĖĖĖĖĖĖ": 353, + "ėĀĖĖĖĖĖĖĖĖĖĖĖĖĖ": 354, + "ĘĕĖ": 355, + "ĔĖē": 356, + "ĖĕĖėĖĖ": 357, + "ĖĖĕĖĖĖĖĖėĖĖ": 358, + "ĔĖę": 359, + "ĖĖĖĖĖĕĖĖĕĖĖ": 360, + "Ęę": 361, + "ĖĕĖĖĖĖĖėĖĖ": 362, + "ĖĕėĖĖĖĖ": 363, + "Ĕā": 364, + "ďĖĖ": 365, + "ĖĕĖĖėĖĖĖĖĖĖĖĖ": 366, + "ėĖĖĖĖĖĖĖĖĖĖĖ": 367, + "ėāĖĖĖĖĖĖĖĖĖĖĖĖĖ": 368, + "ĕĖĖĖĖĖ": 369, + "Ėğ": 370, + "ĝĖĖ": 371, + "ĖĔĖĖĔĖĖ": 372, + "ėĖĘĖĖ": 373, + "ĖĔĖĖĕĖĖĕĖĖĖĖĖ": 374, + "ĔĕĖ": 375, + "ėĞ": 376, + "ĖĖĖĖĕĖĖĖĖĖ": 377, + "ĖĖĕĖĖėĖĖ": 378, + "ĖėĖė": 379, + "ĒĖĕ": 380, + "Ęě": 381, + "ĖĖĖĖĖĖĖĕĖĖ": 382, + "ĖĖĖĖėĖĖėĖĖėĖĖ": 383, + "ĖĖĖĖĘĖĖ": 384, + "ĚĖĕ": 385, + "ęĕ": 386, + "ĖĖĖĖĖĕ": 387, + "ėĖĔĖĖ": 388, + "ėėĖė": 389, + "ĔĀ": 390, + "ĖęĖĖ": 391, + "ĖĖĖĖĕĖĖĕĖĖĖĖĖ": 392, + "ĘėĖ": 393, + "ĖĕĕĖĕ": 394, + "ĖėĖĕ": 395, + "ĘĜ": 396, + "ĘĖē": 397, + "ĕĖđ": 398, + "ėĖĖĕĖĖĕĖĖ": 399, + "ĖĖĖĖėĖĖĖĖĖ": 400, + "ĕĖĘĖĖ": 401, + "ĖėĖĖĘĖĖ": 402, + "ĕĒ": 403, + "ĔĂ": 404, + "ĚĖė": 405, + "đĖĖĖĖĖĖĖĖĖĖĖĖ": 406, + "ėĖĕĖĖĖ": 407, + "ėğ": 408, + "ĖĖĖĕ": 409, + "ĚĄ": 410, + "ėĖě": 411, + "ė,": 412, + "ėĄĖĕĖĖĖĖĖĖĖĖĖĖĖ": 413, + "ēĖĖĖĖĖ": 414, + "Ęĝ": 415, + "ěă": 416, + "ĘĖę": 417, + "ēĖĘ": 418, + "ėĔĖĖ": 419, + "ĘĖĖĖĖĖĖĖĖ": 420, + "ĕĖě": 421, + "ėĖđ": 422, + "ĘĞ": 423, + "ėĕĖĕ": 424, + "Ė!": 425, + "ĕĖĕĖĖ": 426, + "Ė,": 427, + "ĖėĖĖĖ": 428, + "ėĘĖĖ": 429, + "ĔĖĒ": 430, + "Ė\"": 431, + "ĕĕĖĖĖĖĖĖĖĖĖĖĖ": 432, + "ĖĖėĖĖĖĖĖĖĖĖĖĖ": 433, + "ĘĖĖĖ": 434, + "ė!": 435, + "ĕĖĖĕĖĖ": 436, + "ĖĖĕĖĖĕĖĖĖĖĖ": 437, + "ěā": 438, + "ĕę": 439, + "ĖĕĖĖĖĖĖĕĖĖ": 440, + "ēĖĔ": 441, + "ĖĖĖĖėĖĖĖĖĖėĖĖ": 442, + "Ęğ": 443, + "Ę,": 444, + "ĚĖĖĖĖĖĖĖĖĖĖĖĖ": 445, + "ĖĔĖĖĕĖĖĕĖĖĕĖĖ": 446, + "ėėĖĖĖĖ": 447, + "ėĖĖėĖĖĕĖĖ": 448, + "ĖĖĖĖĔ": 449, + "ĖėĖĖĕĖĖ": 450, + "ĔĔ": 451, + "ĘĀĖĖĖĖĖĖĖĖĖĖĖĖĖ": 452, + "ĔĖĚ": 453, + "ęĖĖĖĖĖ": 454, + "ĔĖĖēĖĖ": 455, + "ĕĖĐ": 456, + "ĖĖĖĖĕĖĖĕĖĖ": 457, + "ĔėĖ": 458, + "ĕđ": 459, + "Ę!": 460, + "Ęė": 461, + "ėĖĕėĖĖ": 462, + "ĖĖĖĖĖĖĕ": 463, + "ĖėĖĖĖĖĖĖĖĖĖĖĖ": 464, + "ĕĘ": 465, + "ėĀĖĕĖĖĕĖĖĖĖĖĖĖĖ": 466, + "ĖĖĖĖĖĖĖĕ": 467, + "ė\"": 468, + "ēĖĖēĖĖ": 469, + "ĖĖĖĖėĖĖĖĖĖĕĖĖ": 470, + "ęĖĔ": 471, + "ĒĖ": 472, + "ėĖĜ": 473, + "ĎĖĖ": 474, + "ĘāĖĖĖĖĖĖĖĖĖĖĖĖĖ": 475, + "ėąĖĕĖĖĖĖĖĖĖĖĖĖĖ": 476, + "ēĖĖĕ": 477, + "ĘĖĚ": 478, + "ĖĖĖĖĕĖĖĕĖĖĕĖĖ": 479, + "ėĖĖĖĖĖĖĖĕĖĖ": 480, + "ĕĖĜ": 481, + "ĕĔĖĖ": 482, + "ĞĖĖ": 483, + "ĖĔĖĖėĖĖ": 484, + "ĖĕĖĖėĖĖėĖĖĖĖĖ": 485, + "đĖė": 486, + "ĖĖĖė": 487, + "ĖĕĖĖė": 488, + "ėĖĖĖĖė": 489, + "ėĖėėĖĖ": 490, + "ĖĖĖĖĕĖĖĖĖĖĕĖĖ": 491, + "ĕĖĖĖĖĖĖĖĖ": 492, + "Ėăėĕ": 493, + "ėĖĖĖĖĖĖĖĖĖĖĖĖ": 494, + "ĖĕĖĖĕĖĖėĖĖĖĖĖ": 495, + "ĔĖĖĕĖė": 496, + "ĖĖėĖĖĖĖ": 497, + "ĕėĖĖ": 498, + "ėĖĖėĖĖėĖĖĖĖĖ": 499, + "ĕĚ": 500, + "ěĂ": 501, + "ĖĖĖĖĖĖĖĖĖ": 502, + "ĔĖđ": 503, + "ėăėĖĖ": 504, + "Ė#": 505, + "ĘĖĖęĖĖ": 506, + "ēā": 507, + "ėăĖĕĖĖ": 508, + "ėĖĐ": 509, + "ėāĖĕĖĖĕĖĖĖĖĖĖĖĖ": 510, + "ęĖĘ": 511, + "ĔĖĖĕĖĖĖĖĖ": 512, + "ĘĖĒ": 513, + "ĕĕĖĖĕ": 514, + "ĖĖĖĖĕĖĖĖĖĖėĖĖ": 515, + "ĖēĖė": 516, + "ĔĖĖĕĖĖĕĖĖ": 517, + "ĕĕĖĖĖĖĖ": 518, + "ė#": 519, + "ėĔĖė": 520, + "ęą": 521, + "ĒĖĖĖĖĖĖĖĖĖĖĖĖ": 522, + "ĕĖēĖĖ": 523, + "ěĀ": 524, + "ĖĖĖĖĘ": 525, + "ęĆ": 526, + "ėĖĖĖĖĕĖĖĖĖĖ": 527, + "Ę\"": 528, + "ĕ,": 529, + "ėĖĖĕĖĖėĖĖ": 530, + "ĔĖĖĕĖĖĖĖĖĖĖĖ": 531, + "ĕĖĖĖĖĖĖĖĖĖĖĖĖ": 532, + "ĖĔĖĖĕĖĖĕĖĖ": 533, + "ēă": 534, + "ĔĖĖĖ": 535, + "ęĘ": 536, + "ęć": 537, + "ėĖĖĕĖė": 538, + "ĕě": 539, + "ĖĕĕĖĕĖĖĖĖĖĖĖĖ": 540, + "ęĈ": 541, + "ĖĕĖĖĕĖĖĕĖĖėĖĖ": 542, + "ęĕĖ": 543, + "ęĒ": 544, + "ĖĕĖĖĕĖė": 545, + "ė$": 546, + "ĖĕĖĖĖĖĖĕĖĖėĖĖ": 547, + "ĔĖě": 548, + "ĔĖĖĖĖĖĖĖĖ": 549, + "ĕĐ": 550, + "ĔĄ": 551, + "ĚĖ": 552, + "ĖĔĖĖĕĖĖĖĖĖĕĖĖ": 553, + "ėĖĕĖĖĖĖĖĖĖĖ": 554, + "ĕĕĖĕ": 555, + "ęĖĖęĖĖ": 556, + "ĖĕĖĖĕĖĖĕ": 557, + "ĔĖĖĕĖĖėĖĖ": 558, + "ĖĔĖĕ": 559, + "ĘėĖĖ": 560, + "ĖēĖĖĕĖĖ": 561, + "ĖĖĖĖĖĕĖĖėĖĖ": 562, + "ęėĖ": 563, + "Ę#": 564, + "ĕĖď": 565, + "ėĆĖĕĖĖĖĖĖĖĖĖĖĖĖ": 566, + "ĖăėĖĖ": 567, + "ėĖėĖĖĖ": 568, + "ĕĖĖĔĖĖ": 569, + "ĖĕĖĖĘĖĖ": 570, + "ęĔ": 571, + "ęĚ": 572, + "Ēĕ": 573, + "ęě": 574, + "Ė$": 575, + "ĖėĖĖėĖĖĖĖĖĖĖĖ": 576, + "ĘĖě": 577, + "ęĐ": 578, + "ĕĜ": 579, + "ėĕĖĖĕ": 580, + "ĔĖĐ": 581, + "ę,": 582, + "ęđ": 583, + "ėĕĖĖĖĖĖĖĖĖĖĖĖ": 584, + "ė%": 585, + "ėĖĝ": 586, + "ēĕĖ": 587, + "ēĖē": 588, + "ęď": 589, + "ēĀ": 590, + "ĖĖĖĖĔĖĖ": 591, + "ĖĕĖĖĖĖĖėĖĖĕĖĖ": 592, + "ėĖĖĖĖĖĖĖ": 593, + "Ę$": 594, + "ęĎ": 595, + "Ėăė": 596, + "ėĖēĖĖ": 597, + "đę": 598, + "ėĕĖĖĖĖĖ": 599, + "ĔĖĖĕĖĕ": 600, + "ĖĕėĖĕĖĖĖĖĖĖĖĖ": 601, + "ęĜ": 602, + "ĖēĖĖĕ": 603, + "ęăĔ": 604, + "Ēę": 605, + "ĖĕĖĖėĖĖėĖĖėĖĖ": 606, + "ėĕĕ": 607, + "ĖđĖĖ": 608, + "ė&": 609, + "Đę": 610, + "ĕĖĝ": 611, + "ĖĕĖĖėĖĖĖĖĖ": 612, + "Ė%": 613, + "ĕď": 614, + "ĘĕĖė": 615, + "ĘăĔ": 616, + "đĖĕ": 617, + "ėĖĖėĖĖĘĖĖ": 618, + "ēĂ": 619, + "ěĖĕ": 620, + "ĕĖĔĖĖ": 621, + "Ę&": 622, + "ĖĖĖĖĖĖėĖĖ": 623, + "ĖĔĖĖĕĖĖĖĖĖėĖĖ": 624, + "ğĖĖ": 625, + "ĖĖĖĖĖĖĖėĖĖĕĖĖ": 626, + "ēĖę": 627, + "ĖĔĖĖĖĖĖĕĖĖĖĖĖ": 628, + "ėĖĖĖė": 629, + "ĖĕĖĖĖĖĕ": 630, + "ĕĔĖė": 631, + "ėćĖĕĖĖĖĖĖĖĖĖĖĖĖ": 632, + "ėĖď": 633, + "ĖĖĖĖĖĖĖĕĖĖėĖĖ": 634, + "ėăė": 635, + "ēĖĖėĖĖ": 636, + "ĖĕĖĖĕĖĖėĖĖėĖĖ": 637, + "Ę%": 638, + "Ĝā": 639, + "ĘĖđ": 640, + "ĔĖĜ": 641, + "ěĖė": 642, + "ęē": 643, + "ĘĖĜ": 644, + "Ė&": 645, + "ęĝ": 646, + "ĕĖęĖĖ": 647, + "ėĖĖėĖĖĖĖĖĖĖĖ": 648, + "ĕĝ": 649, + "ēĖĒ": 650, + "ĖĘĖĖĘĖĖ": 651, + "ĖĖĖėĖ": 652, + "ĖĚĖĖ": 653, + "ēĔ": 654, + "ĖĔĖĖĖĖĖĖĖĖĕĖĖ": 655, + "ėĂĖĖĖĖĖĖĖĖĖĖĖĖĖ": 656, + "ĘĕĖĖ": 657, + "ĕĎ": 658, + "ėĖĖĕĖĕ": 659, + "ĖĘĖĖėĖĖ": 660, + "ĖĘĖĕ": 661, + "ęĖĖĕ": 662, + "ĖĕĖĖĖĖĖĖĖĖĖĖ": 663, + "ė'": 664, + "ėĖęĖĖ": 665, + "ėĈĖĕĖĖĖĖĖĖĖĖĖĖĖ": 666, + "ĔĖėĕĖĖ": 667, + "ĕć": 668, + "ėăĖĖĖĖĖĖĖĖĖĖĖĖĖ": 669, + "ĕą": 670, + "ĖėĖĖėĖĖėĖĖĖĖĖ": 671, + "ĖĔĖĖĖĖĖĖĖĖėĖĖ": 672, + "ěĖĖĖĖĖĖĖĖĖĖĖĖ": 673, + "ĔĖĖĘĖĖ": 674, + "ėĕĖėĖĖ": 675, + "Ėėĕ": 676, + "ĘėĖĕ": 677, + "ĕĆ": 678, + "ēę": 679, + "ĖĕĖĖėĖĖĖĖĖėĖĖ": 680, + "ėăĖĖĖĖ": 681, + "ĕĕĖĖĕĖĖ": 682, + "ĔĖď": 683, + "ĒĖĘ": 684, + "ĖėĖĖėĖĖĖĖĖ": 685, + "Ĕē": 686, + "Ēā": 687, + "Ė'": 688, + "ęĞ": 689, + "ĕēĖĖ": 690, + "ĕĈ": 691, + "ęĖĖĕĖĖ": 692, + "ēĖĚ": 693, + "ĘăĖĖĖĖ": 694, + "ĕĞ": 695, + "Ĝă": 696, + "ĕĕĕ": 697, + "ĖĕĖĘ": 698, + "ĖĕĖĖėĖĖĖĖĖĕĖĖ": 699, + "ĖĕėĖĕ": 700, + "Ĕę": 701, + "ĠĖĖ": 702, + "ďę": 703, + "ĒĖĔ": 704, + "ėėĕ": 705, + "ėăĖĕĖė": 706, + "ĖĕĕĖĔ": 707, + "ėĖėĖĖĖĖĖĖĖĖ": 708, + "ęğ": 709, + "ėĖĕĕĖĖ": 710, + "Ę'": 711, + "Ę+": 712, + "ĔĕĖė": 713, + "ĖėĖĖėĖĖėĖĖ": 714, + "ėĖĕĖĖĕ": 715, + "ĖĕĖĖĘ": 716, + "Ĕ,": 717, + "ĖĕĖĖĕĖĖėĖĖ": 718, + "ęę": 719, + "ĕĖĎ": 720, + "ėĖĖĖĖėĖĖ": 721, + "ėĖĞ": 722, + "ĒĖĖĔĖĖ": 723, + "ĘĖĖĔĖĖ": 724, + "ė(": 725, + "ĖĘĖė": 726, + "ĖĖĕĖĖĕ": 727, + "ĖēĖĖĔĖĖ": 728, + "ĘĖĐ": 729, + "ĔĖĖĕĖĖĔĖĖ": 730, + "ĖĕĕĖĖĖĖ": 731, + "ę!": 732, + "ěĄ": 733, + "ėĖĖĖĖĖĖĖĖĖ": 734, + "ĔĖĖĕĖĖĕĖĖĖĖĖ": 735, + "ĘăĖĖĖĖĖĖĖĖĖĖĖĖĖ": 736, + "ĔĒ": 737, + "ĖėĖ": 738, + "ĖĔĖĖēĖĖ": 739, + "ėĖĖĖĖĕĖĖĕĖĖ": 740, + "ĘāĖĕĖĖĖĖĖĖĖĖĖĖĖ": 741, + "ēėĖ": 742, + "ĖėĖĖĘ": 743, + "ĖēĖĖĖĖĖ": 744, + "ĘăĖė": 745, + "ĘăĕĖĖ": 746, + "ėăĖĔ": 747, + "ė)": 748, + "ĘĀĖĕĖĖĖĖĖĖĖĖĖĖĖ": 749, + "ėĕĖĖĖĖ": 750, + "Ěăĕ": 751, + "ēĖđ": 752, + "ĖĘĖĖĖĖĖ": 753, + "ĕğ": 754, + "ĖĖĖĖĕĖĖėĖĖĖĖĖ": 755, + "ĔĘ": 756, + "ĖĔĖĖĖĖĖėĖĖĖĖĖ": 757, + "ĔĕĖĖ": 758, + "ĜĂ": 759, + "Ę(": 760, + "ĘĖĕėĖĖ": 761, + "Ė)": 762, + "ęĖĚ": 763, + "ĘĖĝ": 764, + "ĔĖĝ": 765, + "ĖĕĖĖėĖĖėĖĖ": 766, + "ĖėĖĖėĖĖėĖĖėĖĖ": 767, + "ėĖėĕĖĖ": 768, + "ĕĖėĖĖĕ": 769, + "Ė(": 770, + "ĖĖĖĖĕĖĖĖ": 771, + "ęĖę": 772, + "ėĖĕėĖĕ": 773, + "ĔĖĖĔĖĖĔĖĖ": 774, + "ĘĂĕ": 775, + "ĕĖĞ": 776, + "Ę*": 777, + "ėĔĖĕ": 778, + "ė*": 779, + "ĖĕĖĖĕĖĖĔĖĖĕĖĖ": 780, + "ėĕĖĖĕĖĖ": 781, + "ĖĕĖĖĕĖĕ": 782, + "ĕĖĒĖĖ": 783, + "Ěĕ": 784, + "ėēĖĖ": 785, + "ėăĕĕ": 786, + ",Ď": 787, + "ėĖĎ": 788, + "ęĖĒ": 789, + "ĕĖėĖĖĖ": 790, + "ĜĀ": 791, + "ę\"": 792, + "ĐĖė": 793, + "ĖĕĕĖĖĖĖĖĖĖėĖĖ": 794, + "ĚĖĔ": 795, + "ĚĖĘ": 796, + "Ę)": 797, + "ęĖē": 798, + "ĕ!": 799, + "ėĀĖĕĖĖĖĖĖĖĖĖĕĖĖ": 800, + "ĖĖĖĖėĖĖėĖĖĕĖĖ": 801, + "ėĘĖĕ": 802, + "ėĖėĖĖĕ": 803, + "ėĖĖĕĖĖĔĖĖ": 804, + "ĖĕĕĖĖĖĖĖĖĖĕĖĖ": 805, + "ĘăĕĖĖĖ": 806, + "ĔĖĖĖĖĕ": 807, + "ĘăĖėĖĖ": 808, + "ĖĕĖĔ": 809, + "ēē": 810, + "ĔĖĖĖĖ": 811, + "ĖĕĖĕĖĖ": 812, + "ĚĖĖĘĖĖ": 813, + "Ďę": 814, + "ĒĖĖĕĖĖ": 815, + "ĖĖĖĖĖĖĖė": 816, + "ėĀĖĕĖĖĖĖĖĕĖĖĖĖĖ": 817, + "ĖĖĖĖĖĖėĖĖĖĖĖ": 818, + "ĘėĖė": 819, + "ĖăĖĔ": 820, + "ėĖĕĖĖĕĖĖĖĖĖ": 821, + "ĔĖĎ": 822, + "ĔĖĕĕĖĖ": 823, + "ĘĂĖĖĖĖĖĖĖĖĖĖĖĖĖ": 824, + "Ěą": 825, + "ĖĕĕĖĖĖĖĕĖĖĖĖĖ": 826, + "ė+": 827, + "Ėăĕ": 828, + "ēĖĐ": 829, + "ēĖě": 830, + "ėĖĖėĖĖĖĖĖ": 831, + "ĘĖĖĖĖ": 832, + "ĒĕĖĖĕ": 833, + "ĖĕĕĖĕĖĖĕĖĖĖĖĖ": 834, + "ĖĔĖĖĔĖĖĕĖĖĖĖĖ": 835, + "ĖĖĕĖĖĕĖĖėĖĖ": 836, + "ėĀĖĕĖĖĕĖĖĕĖĖĖĖĖ": 837, + "ėėĖĖĖĖĖ": 838, + "ėĀĖĕĖĖĖĖĖĖĖĖėĖĖ": 839, + "ĖĕĖĖĔĖĖĕĖĖĖĖĖ": 840, + "ĕ\"": 841, + "ĖĖĖĖĕĖĖĕ": 842, + "Ě,": 843, + "đĖ": 844, + "ĒĖĖēĖĖ": 845, + "ĚĆ": 846, + "ĖėĖĖĖĖĖėĖĖĖĖĖ": 847, + "ĚĖĖėĖĖ": 848, + "ĚĕĖ": 849, + "ėĀĕ": 850, + "ĖĖĖĖĖėĖĖĘĖĖ": 851, + "ĔĚ": 852, + "Ė*": 853, + "ĖĕĕĖĕĖĖ": 854, + "ĖĕĕĖĖĖĖĖ": 855, + "ĘăĕĕĖĖ": 856, + "ĖĔĖĖē": 857, + "ĘĄĕ": 858, + "Ěć": 859, + "ĘĕĖĕ": 860, + "ę#": 861, + "ĘĕĖĔ": 862, + "ĖĖĖĖėĖĖĕĖĖĖĖĖ": 863, + "ĚĈ": 864, + "ĒĕĖ": 865, + "ĔĖĖĖĖĖĕĖĖ": 866, + "ėĎĖĕĖĖĖĖĖĖĖĖĖĖĖ": 867, + "Ĕĕĕ": 868, + "ėĖğ": 869, + "ėĕĖĔ": 870, + "Ēă": 871, + "ėăĖĕĖĖĕĖĖĖĖĖĖĖĖ": 872, + "ėāĖĕĖĖĖĖĖĖĖĖĕĖĖ": 873, + "ĖĖĖĖėĖĖĖ": 874, + "ėĂĕ": 875, + "ĘĖĕĖĖĖ": 876, + "ėėĖĖėĖĖ": 877, + "ĖĕĕĖĖ": 878, + "ĖĔĖĖĖĖĖĕĖĖĕĖĖ": 879, + "ĐĖĖĖĖĖĖĖĖĖĖĖĖ": 880, + "ėĂĖĕĖĖĕĖĖĖĖĖĖĖĖ": 881, + "ĔĖĘĖĖ": 882, + "ĖĖĖĖĕĖĖėĖĖ": 883, + "ĘĖėĖĖ": 884, + "ėāĖĕĖĖĖĖĖĕĖĖĖĖĖ": 885, + "ĘĖėėĖĖ": 886, + "ĖĖĖĖĖĕĖĖĔĖĖ": 887, + "ęĖě": 888, + "ĖėĖĖĖĖĖĖĖĖ": 889, + "ĘĖď": 890, + "ĒĀ": 891, + "ĖĔĖĖĖĖĖėĖĖėĖĖ": 892, + "ėāĖĕĖĖĕĖĖĕĖĖĖĖĖ": 893, + "ĒĂ": 894, + "ĔĕĖĕ": 895, + "ĖĔĖĖĖĖĖĖĖĖ": 896, + "ĖĖĕĖĖĖĖ": 897, + "ęăĕĕ": 898, + "ĖĕĖĖēĖĖ": 899, + "Ĕđ": 900, + "ĔĖĕĖĖĖ": 901, + "ĖĖĕĖĖĔĖĖ": 902, + "ĕĖĖĕĖė": 903, + "ėėĖĔ": 904, + "ĕ#": 905, + "ėĖĖĖĖĖĕ": 906, + "ėĖėėĖė": 907, + "ę$": 908, + "ĘĖĖėĖĖĘĖĖ": 909, + "ĝā": 910, + "ĖĕĖĖĕĖĖĕĖĖĔĖĖ": 911, + "ėĖĖĖĖĖĖĖėĖĖ": 912, + "ĘĖĖĖĖĖĖĖĖĖĖĖĖ": 913, + "ėāĖĕĖĖĖĖĖĖĖĖėĖĖ": 914, + "ėĖėĖĖėĖĖĖĖĖ": 915, + "ĖĕėĖĖĖĖĖĖĖėĖĖ": 916, + "ĔĖĖĕĖĖĕĖĖĕĖĖ": 917, + "ėĄĖĖĖĖĖĖĖĖĖĖĖĖĖ": 918, + "ėĖĖėĖĕ": 919, + "ĘĖĖėĖĖĖĖĖĖĖĖ": 920, + "ĘĖĞ": 921, + "ĖĖĖĖĕĖĖĕĖĖėĖĖ": 922, + "ĚėĖ": 923, + "ĖĖĖĖĖĕĖė": 924, + "ĔĖĞ": 925, + "ĔĖĖĖĖė": 926, + "ĜĖĕ": 927, + "ĚĖĖęĖĖ": 928, + "ĖėĖĖęĖĖ": 929, + "ĖĕėĖĕĖĖĕĖĖĖĖĖ": 930, + "ĖėĖĖĖĖĖĖĖĖėĖĖ": 931, + "ėăėĕ": 932, + "ĔĖĖĖĖĖėĖĖ": 933, + "ĖĔĖĖĔĖĖĕĖĖĕĖĖ": 934, + "ėĖĕĖĖĖĖĖ": 935, + "ĕĖğ": 936, + "ėďĖĕĖĖĖĖĖĖĖĖĖĖĖ": 937, + "ĖĐĖĖ": 938, + "ėăĖĔĖĖ": 939, + "ęĄĕ": 940, + "ĚĒ": 941, + "ēĘ": 942, + "ēĒ": 943, + "ėęĖĖ": 944, + "ėĖėĖĖĖĖĖ": 945, + "ĘĖĖĘĖĖĘĖĖ": 946, + "ĐĖĕ": 947, + "ĔĖėĖĖĖ": 948, + "ĔėĖĕ": 949, + "ĕăė": 950, + "ēĖĜ": 951, + "ĘĀĕ": 952, + "ēĖď": 953, + "ėăĖĖĖ": 954, + "ę%": 955, + "ėĖĒĖĖ": 956, + "ĕĖĚĖĖ": 957, + "ĕĕĖĖĕĖĖĖĖĖĖĖĖ": 958, + "ĕ$": 959, + "ĘăĖĕĖĖĖĖĖĖĖĖĖĖĖ": 960, + "ęĂĕ": 961, + "ĜĖė": 962, + "ĔėĖĖ": 963, + "ĖĕĖĖē": 964, + "ĖĕĖĖĕĖĖĔĖĖĖĖĖ": 965, + "ėėĖėĖĖ": 966, + ",ď": 967, + "ę&": 968, + "Ěđ": 969, + "ėăĕĕĖĖ": 970, + "ėĖėĘĖĖ": 971, + "ęĖđ": 972, + "ĖĕĖĖĖĖĖĖĖĖĖ": 973, + "ĖĕėĖĖĖĖĖĖĖĕĖĖ": 974, + "ęăĕĖĖ": 975, + "ĖĀĕ": 976, + "ĖėĖĖėĖĖĖĖĖėĖĖ": 977, + "ęĖĜ": 978, + "ĖĖĕĖėĖĖ": 979, + "ĖĔĖĖĕĖĖĖĖĖ": 980, + "ęĖĖĖ": 981, + "ĖăėĕĖĖ": 982, + "ĔĖėĕĖė": 983, + "ĖēĖĖēĖĖ": 984, + "ĖĕĖĖĔĖĖĕĖĖĕĖĖ": 985, + "ĖĖĖĖėĖĖĘĖĖėĖĖ": 986, + "ėĖĖĖĖĕĖĖ": 987, + "ĖėĖĖĖĖĖĖĖĖĕĖĖ": 988, + "ėĖĖĖĖĖĖĖĖėĖĖ": 989, + "ėĘĖė": 990, + "ėĖĚĖĖ": 991, + "ĖĕĖĕĕ": 992, + "ĖĖĕĖĖĕĖĖĕĖĖ": 993, + "ĖĖĖĖĕĖĖėĖĖėĖĖ": 994, + "ĔĖėĖĖĕ": 995, + "ĖĒĖė": 996, + "ėăĖĕĖĖĖĖ": 997, + "ĖĕĖĖĔĖĖĖĖĖĖĖĖ": 998, + "ĕ%": 999, + "ĚĎ": 1000, + "Ěě": 1001, + "ėĐĖĕĖĖĖĖĖĖĖĖĖĖĖ": 1002, + "ĘĄĖĖĖĖĖĖĖĖĖĖĖĖĖ": 1003, + "ėĖĕĘĖĖ": 1004, + "ĖĔĖĖĕĖĖėĖĖĖĖĖ": 1005, + "ĘăĖĖ": 1006, + "ĔĖėĔĖĖ": 1007, + "ĖĕėĖĖĖĖĕĖĖĖĖĖ": 1008, + "ėĖĖĔĖĖ": 1009, + "Ěď": 1010, + "ĖĖĖĖĕĖė": 1011, + "ęăĕĖĖĖ": 1012, + "ĖĕĕĖĖĖĖėĖĖĖĖĖ": 1013, + "ĖĖĖĖĖĖĕĖĖ": 1014, + "ĘĖĖĘ": 1015, + "ēĄ": 1016, + "ĖĔĖĖĕĖĖĕĖĖėĖĖ": 1017, + "ėĕĖĘ": 1018, + "ĖăĘ": 1019, + "ĔĐ": 1020, + "Ĕě": 1021, + "ĖĕĖĖėĖĖĕĖĖĖĖĖ": 1022, + "ĖĕĖĖĖĖĖĖ": 1023 + }, + "merges": [ + [ + "Ė", + "Ė" + ], + [ + "ĖĖ", + "ĖĖ" + ], + [ + "Ė", + "ĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ė", + "ĖĖ" + ], + [ + "ĖĖ", + "Ė" + ], + [ + "ĖĖ", + "ĕ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ĖĖĖ" + ], + [ + "ĖĖĖĖ", + "Ė" + ], + [ + "ė", + "Ā" + ], + [ + "ėĀ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ė", + "ė" + ], + [ + "Ė", + "Ĕ" + ], + [ + "ĖĖĕ", + "ĖĖ" + ], + [ + "Ę", + "ĖĖ" + ], + [ + "Ėĕ", + "ĖĖ" + ], + [ + "ĕ", + "Ė" + ], + [ + "ė", + "Ė" + ], + [ + "ĕ", + "ĖĖ" + ], + [ + "Ĕ", + "ĖĖ" + ], + [ + "ėĖĖ", + "ėĖĖ" + ], + [ + "Ėĕ", + "ĖĖĕ" + ], + [ + "ė", + "ĖĖĖĖĖ" + ], + [ + "ė", + "ă" + ], + [ + "ē", + "ĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "ă" + ], + [ + "Ĕ", + "Ė" + ], + [ + "ė", + "ĖĖĖĖ" + ], + [ + "ė", + "Ėĕ" + ], + [ + "ę", + "ĖĖ" + ], + [ + "ĕ", + "ĕ" + ], + [ + "Ė", + "ėĖĖ" + ], + [ + "ė", + "ā" + ], + [ + "ė", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ė", + "ĕ" + ], + [ + "Ėĕ", + "ĕ" + ], + [ + "ė", + "Ėė" + ], + [ + "ĖĖĖĖ", + "ĕ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "Ė" + ], + [ + "Ė", + "ă" + ], + [ + "ĖĔ", + "ĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖ" + ], + [ + "ė", + "Ă" + ], + [ + "Ėĕ", + "ĖĖĖĖ" + ], + [ + "Ė", + "Ę" + ], + [ + "Ę", + "Ā" + ], + [ + "Ė", + "ē" + ], + [ + "ĖĖĖĖ", + "ĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖ" + ], + [ + "Ĕ", + "Ėė" + ], + [ + "ė", + "Ą" + ], + [ + "Ē", + "ĖĖ" + ], + [ + "ėĖĖ", + "ĕ" + ], + [ + "ĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ę", + "ă" + ], + [ + "Ę", + "ā" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖ", + "ĕĖĖ" + ], + [ + "Ę", + "Ă" + ], + [ + "Ě", + "ĖĖ" + ], + [ + "ĕ", + "Ėė" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖ" + ], + [ + "Ė", + "Ā" + ], + [ + "Ĕ", + "ĖĖĕĖĖ" + ], + [ + "ē", + "Ė" + ], + [ + "Ę", + "Ėĕ" + ], + [ + "ĖĔ", + "ĖĖĕ" + ], + [ + "Ĕ", + "Ėĕ" + ], + [ + "Ę", + "Ą" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĘĖĖ", + "ėĖĖ" + ], + [ + "Ĕ", + "ĖĖĕ" + ], + [ + "Ę", + "Ėė" + ], + [ + "ę", + "Ė" + ], + [ + "ė", + "ė" + ], + [ + "ĕ", + "Ėĕ" + ], + [ + "Ė", + "Ă" + ], + [ + "Ĕ", + "ĖĖĖĖĖ" + ], + [ + "Ė", + "ā" + ], + [ + "ĕ", + "ă" + ], + [ + "ė", + "ĖĔ" + ], + [ + "ĖĖĖĖ", + "ėĖĖ" + ], + [ + "Ę", + "ĕ" + ], + [ + "Ĕ", + "ĕ" + ], + [ + "Ė", + "Ą" + ], + [ + "đ", + "ĖĖ" + ], + [ + "Ė", + "ę" + ], + [ + "ĖĕĖĖ", + "ėĖĖ" + ], + [ + "Ė", + "Ē" + ], + [ + "ē", + "Ėė" + ], + [ + "Ėĕ", + "Ėė" + ], + [ + "Ę", + "ĖĖĖĖĖ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖ" + ], + [ + "ĕĖ", + "ėĖĖ" + ], + [ + "ėĖ", + "ėĖĖ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĕĖ", + "Ę" + ], + [ + "ę", + "Ā" + ], + [ + "ėĖ", + "Ę" + ], + [ + "ĘĖĖ", + "ĘĖĖ" + ], + [ + "ĖĕĖĖĖĖ", + "ĖĕĖĖĖĖĖ" + ], + [ + "ėĖĖėĖĖ", + "ėĖĖ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "ĔĖĖ", + "ĔĖĖ" + ], + [ + "ėā", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ě", + "ĖĖ" + ], + [ + "ė", + "ą" + ], + [ + "Ė", + "ēĖĖ" + ], + [ + "ĕ", + "ĖĖĕ" + ], + [ + "ė", + "Ĕ" + ], + [ + "ė", + "Ć" + ], + [ + "ĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ę", + "Ă" + ], + [ + "ĖĖĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "Ė", + "ėĖĖėĖĖ" + ], + [ + "ĖĖĖĖ", + "ė" + ], + [ + "Ęă", + "ĕ" + ], + [ + "Ė", + "ĘĖĖ" + ], + [ + "ē", + "Ėĕ" + ], + [ + "Ė", + "Ě" + ], + [ + "ę", + "Ėĕ" + ], + [ + "ĖĕĖĖĖĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ė", + "ć" + ], + [ + "ę", + "ā" + ], + [ + "ė", + "Ę" + ], + [ + "ė", + "ĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ėĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖ", + "ėĖĖ" + ], + [ + "Ė", + "đ" + ], + [ + "ĕ", + "ĖĔ" + ], + [ + "ė", + "Ĉ" + ], + [ + "Ě", + "ă" + ], + [ + "ė", + "ĕĖĖ" + ], + [ + "ę", + "Ą" + ], + [ + "ĖĖĕ", + "Ėė" + ], + [ + "ė", + "ē" + ], + [ + "Ė", + "Đ" + ], + [ + "ėĖĖĖĖ", + "ĕ" + ], + [ + "ė", + "ėĖĖ" + ], + [ + "ĕ", + "Ā" + ], + [ + "ė", + "Ē" + ], + [ + "ėĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ę", + "Ėė" + ], + [ + "ĕĖ", + "ē" + ], + [ + "ĕ", + "Ĕ" + ], + [ + "ėĖĖĖĖ", + "ĖĖ" + ], + [ + "Ėĕ", + "Ėĕ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĔ", + "ĖĖĕĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĖ", + "ĖĕĖĖĖĖĖ" + ], + [ + "ĕ", + "ĖĖĖĖ" + ], + [ + "Ė", + "ě" + ], + [ + "ė", + "ĕĖ" + ], + [ + "ĖĖĖĖ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "ė", + "Đ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ėĖ", + "ē" + ], + [ + "ė", + "Ď" + ], + [ + "ĔĖ", + "Ę" + ], + [ + "ė", + "ď" + ], + [ + "ėĖĖ", + "ĘĖĖ" + ], + [ + "ė", + "đ" + ], + [ + "ĕ", + "ā" + ], + [ + "ĕ", + "Ă" + ], + [ + "ėĖ", + "ę" + ], + [ + "Ė", + "Ĝ" + ], + [ + "ĕĖ", + "ę" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "Đ", + "ĖĖ" + ], + [ + "Ė", + "ď" + ], + [ + "ĖĕĖĖĖĖ", + "ĖĕĖĖĕĖĖ" + ], + [ + "ė", + "ę" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖĕĖĖ" + ], + [ + "ĖĔ", + "ĖĖĖĖĖ" + ], + [ + "Ėė", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ė", + "ėĖ" + ], + [ + "ĕ", + "ē" + ], + [ + "Ėĕ", + "ĖĖĖ" + ], + [ + "ė", + "Ě" + ], + [ + "ĕ", + "ĕĖĖ" + ], + [ + "Ĝ", + "ĖĖ" + ], + [ + "Ĕ", + "ĖĔ" + ], + [ + "Ę", + "ą" + ], + [ + "ĔĖĖ", + "ėĖĖ" + ], + [ + "Ę", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ĖĕĖĖ", + "Ĕ" + ], + [ + "ĖėĖĖ", + "ĕ" + ], + [ + "ĖĕĖĖĖĖĖ", + "ėĖĖėĖĖ" + ], + [ + "Ę", + "Ć" + ], + [ + "ĖĖĖĖ", + "ĖĕĖĖĖĖĖ" + ], + [ + "Ė", + "Ď" + ], + [ + "ĖĕĖĖ", + "ĔĖĖ" + ], + [ + "Ę", + "ĖĔ" + ], + [ + "Ėĕ", + "ėĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖ", + "ėĖĖėĖĖ" + ], + [ + "ĕ", + "ĖĖĖ" + ], + [ + "Ė", + "ĝ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖėĖĖ" + ], + [ + "Ē", + "Ėė" + ], + [ + "Ę", + "ć" + ], + [ + "ėĂ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "Ą" + ], + [ + "ėĕ", + "Ėė" + ], + [ + "ĘĖ", + "Ę" + ], + [ + "ĖĔĖĖ", + "Ĕ" + ], + [ + "Ę", + "ĖĖĕ" + ], + [ + "ĖĔ", + "Ėė" + ], + [ + "ė", + "ě" + ], + [ + "Ę", + "Ĉ" + ], + [ + "Ėĕ", + "ė" + ], + [ + "Ĕ", + "ă" + ], + [ + "Ę", + "Ď" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ē", + "ĖĖĕĖĖ" + ], + [ + "Ėĕĕ", + "Ėė" + ], + [ + "Ę", + "ď" + ], + [ + "ĖĔ", + "ĕ" + ], + [ + "ĕĖ", + "Ē" + ], + [ + "Ęă", + "ĕĕ" + ], + [ + "ĖĔĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ę", + "Ē" + ], + [ + "Ě", + "ā" + ], + [ + "Ėĕ", + "Ė" + ], + [ + "Ę", + "ē" + ], + [ + "Ę", + "đ" + ], + [ + "ĕ", + "ė" + ], + [ + "Ę", + "Đ" + ], + [ + "ĖĖĕ", + "Ėĕ" + ], + [ + "Ě", + "Ă" + ], + [ + "Ę", + "Ĕ" + ], + [ + "ĕĖ", + "Ě" + ], + [ + "ėĖ", + "Ě" + ], + [ + "Ė", + "Ğ" + ], + [ + "ėă", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ě", + "Ā" + ], + [ + "ė", + "Ĝ" + ], + [ + "ĔĖ", + "ėĖĖ" + ], + [ + "ė", + "ėĖĕ" + ], + [ + "ėĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ęă", + "ĕ" + ], + [ + "Ę", + "Ę" + ], + [ + "ĖĖĖĖĖĖ", + "ĖĕĖĖĕĖĖ" + ], + [ + "ėă", + "ĕ" + ], + [ + "ėĖ", + "Ē" + ], + [ + "ēĖĖ", + "ĔĖĖ" + ], + [ + "ęĖĖ", + "ėĖĖ" + ], + [ + "ē", + "ĕ" + ], + [ + "ĕĖĖ", + "ėĖĖ" + ], + [ + "Ė", + "ĒĖĖ" + ], + [ + "Ė", + "ą" + ], + [ + "ĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ė", + "Ć" + ], + [ + "ĕĕ", + "Ėė" + ], + [ + "ęĖĖ", + "ĘĖĖ" + ], + [ + "Ė", + "ć" + ], + [ + "ė", + "ĝ" + ], + [ + "Ę", + "Ě" + ], + [ + "ĖĖĖĖ", + "ėĖĖėĖĖ" + ], + [ + "Ė", + "Ĉ" + ], + [ + "ĖĔ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĀ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ę", + "ĕĖ" + ], + [ + "ĔĖ", + "ē" + ], + [ + "Ėĕ", + "ĖėĖĖ" + ], + [ + "ĖĖĕĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĔĖ", + "ę" + ], + [ + "ĖĖĖĖ", + "ĖĕĖĖĕĖĖ" + ], + [ + "Ę", + "ę" + ], + [ + "ĖĕĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "Ėĕ", + "ėĖĖĖĖ" + ], + [ + "Ĕ", + "ā" + ], + [ + "ď", + "ĖĖ" + ], + [ + "ĖĕĖĖ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖĖ" + ], + [ + "ėā", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "ĖĖĖĖĖ" + ], + [ + "Ė", + "ğ" + ], + [ + "ĝ", + "ĖĖ" + ], + [ + "ĖĔĖĖ", + "ĔĖĖ" + ], + [ + "ėĖ", + "ĘĖĖ" + ], + [ + "ĖĔĖĖĕ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "Ĕ", + "ĕĖ" + ], + [ + "ė", + "Ğ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "Ėė", + "Ėė" + ], + [ + "Ē", + "Ėĕ" + ], + [ + "Ę", + "ě" + ], + [ + "ĖĖĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ĖĖĖĖ", + "ėĖĖėĖĖėĖĖ" + ], + [ + "ĖĖĖĖ", + "ĘĖĖ" + ], + [ + "Ě", + "Ėĕ" + ], + [ + "ę", + "ĕ" + ], + [ + "ĖĖĖĖ", + "Ėĕ" + ], + [ + "ė", + "ĖĔĖĖ" + ], + [ + "ė", + "ėĖė" + ], + [ + "Ĕ", + "Ā" + ], + [ + "Ė", + "ęĖĖ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "Ę", + "ėĖ" + ], + [ + "Ėĕĕ", + "Ėĕ" + ], + [ + "Ėė", + "Ėĕ" + ], + [ + "Ę", + "Ĝ" + ], + [ + "ĘĖ", + "ē" + ], + [ + "ĕĖ", + "đ" + ], + [ + "ėĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ĕĖ", + "ĘĖĖ" + ], + [ + "ĖėĖĖ", + "ĘĖĖ" + ], + [ + "ĕ", + "Ē" + ], + [ + "Ĕ", + "Ă" + ], + [ + "Ě", + "Ėė" + ], + [ + "đ", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĕ", + "ĖĖĖ" + ], + [ + "ė", + "ğ" + ], + [ + "ĖĖ", + "Ėĕ" + ], + [ + "Ě", + "Ą" + ], + [ + "ėĖ", + "ě" + ], + [ + "ė", + "," + ], + [ + "ėĄ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ē", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "ĝ" + ], + [ + "ě", + "ă" + ], + [ + "ĘĖ", + "ę" + ], + [ + "ē", + "ĖĘ" + ], + [ + "ė", + "ĔĖĖ" + ], + [ + "Ę", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĕĖ", + "ě" + ], + [ + "ėĖ", + "đ" + ], + [ + "Ę", + "Ğ" + ], + [ + "ėĕ", + "Ėĕ" + ], + [ + "Ė", + "!" + ], + [ + "ĕ", + "ĖĕĖĖ" + ], + [ + "Ė", + "," + ], + [ + "ĖėĖĖ", + "Ė" + ], + [ + "ė", + "ĘĖĖ" + ], + [ + "ĔĖ", + "Ē" + ], + [ + "Ė", + "\"" + ], + [ + "ĕĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖ", + "ėĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ę", + "ĖĖĖ" + ], + [ + "ė", + "!" + ], + [ + "ĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĕ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "ě", + "ā" + ], + [ + "ĕ", + "ę" + ], + [ + "ĖĕĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ē", + "ĖĔ" + ], + [ + "ĖĖĖĖ", + "ėĖĖĖĖĖėĖĖ" + ], + [ + "Ę", + "ğ" + ], + [ + "Ę", + "," + ], + [ + "Ě", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĔĖĖĕ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ė", + "ėĖĖĖĖ" + ], + [ + "ėĖĖėĖĖ", + "ĕĖĖ" + ], + [ + "ĖĖĖĖ", + "Ĕ" + ], + [ + "ĖėĖĖ", + "ĕĖĖ" + ], + [ + "Ĕ", + "Ĕ" + ], + [ + "ĘĀ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĔĖ", + "Ě" + ], + [ + "ę", + "ĖĖĖĖĖ" + ], + [ + "ĔĖĖ", + "ēĖĖ" + ], + [ + "ĕĖ", + "Đ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "Ĕ", + "ėĖ" + ], + [ + "ĕ", + "đ" + ], + [ + "Ę", + "!" + ], + [ + "Ę", + "ė" + ], + [ + "ėĖĕ", + "ėĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖĕ" + ], + [ + "Ėė", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "Ę" + ], + [ + "ėĀ", + "ĖĕĖĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖ", + "Ėĕ" + ], + [ + "ė", + "\"" + ], + [ + "ēĖĖ", + "ēĖĖ" + ], + [ + "ĖĖĖĖ", + "ėĖĖĖĖĖĕĖĖ" + ], + [ + "ę", + "ĖĔ" + ], + [ + "Ē", + "Ė" + ], + [ + "ėĖ", + "Ĝ" + ], + [ + "Ď", + "ĖĖ" + ], + [ + "Ęā", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėą", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ē", + "ĖĖĕ" + ], + [ + "ĘĖ", + "Ě" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ėĖĖĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ĕĖ", + "Ĝ" + ], + [ + "ĕ", + "ĔĖĖ" + ], + [ + "Ğ", + "ĖĖ" + ], + [ + "ĖĔĖĖ", + "ėĖĖ" + ], + [ + "ĖĕĖĖėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "đ", + "Ėė" + ], + [ + "ĖĖĖ", + "ė" + ], + [ + "ĖĕĖĖ", + "ė" + ], + [ + "ėĖĖĖĖ", + "ė" + ], + [ + "ėĖė", + "ėĖĖ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖĕĖĖ" + ], + [ + "ĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "Ėă", + "ėĕ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ĖĕĖĖĕĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ĔĖĖĕ", + "Ėė" + ], + [ + "ĖĖ", + "ėĖĖĖĖ" + ], + [ + "ĕ", + "ėĖĖ" + ], + [ + "ėĖĖėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ĕ", + "Ě" + ], + [ + "ě", + "Ă" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "Ė" + ], + [ + "ĔĖ", + "đ" + ], + [ + "ėă", + "ėĖĖ" + ], + [ + "Ė", + "#" + ], + [ + "ĘĖĖ", + "ęĖĖ" + ], + [ + "ē", + "ā" + ], + [ + "ėă", + "ĖĕĖĖ" + ], + [ + "ėĖ", + "Đ" + ], + [ + "ėā", + "ĖĕĖĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ĖĘ" + ], + [ + "Ĕ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "ĘĖ", + "Ē" + ], + [ + "ĕĕ", + "ĖĖĕ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖėĖĖ" + ], + [ + "Ėē", + "Ėė" + ], + [ + "Ĕ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ĕĕ", + "ĖĖĖĖĖ" + ], + [ + "ė", + "#" + ], + [ + "ė", + "ĔĖė" + ], + [ + "ę", + "ą" + ], + [ + "Ē", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕĖ", + "ēĖĖ" + ], + [ + "ě", + "Ā" + ], + [ + "ĖĖĖĖ", + "Ę" + ], + [ + "ę", + "Ć" + ], + [ + "ėĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "\"" + ], + [ + "ĕ", + "," + ], + [ + "ėĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ĔĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĔ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ē", + "ă" + ], + [ + "Ĕ", + "ĖĖĖ" + ], + [ + "ę", + "Ę" + ], + [ + "ę", + "ć" + ], + [ + "ėĖĖĕ", + "Ėė" + ], + [ + "ĕ", + "ě" + ], + [ + "Ėĕĕ", + "ĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "Ĉ" + ], + [ + "ĖĕĖĖĕĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ę", + "ĕĖ" + ], + [ + "ę", + "Ē" + ], + [ + "ĖĕĖĖĕ", + "Ėė" + ], + [ + "ė", + "$" + ], + [ + "ĖĕĖĖĖĖ", + "ĖĕĖĖėĖĖ" + ], + [ + "ĔĖ", + "ě" + ], + [ + "Ĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "Đ" + ], + [ + "Ĕ", + "Ą" + ], + [ + "Ě", + "Ė" + ], + [ + "ĖĔĖĖĕ", + "ĖĖĖĖĖĕĖĖ" + ], + [ + "ė", + "ĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ĕĕ", + "Ėĕ" + ], + [ + "ęĖĖ", + "ęĖĖ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĕ" + ], + [ + "ĔĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ĖĔ", + "Ėĕ" + ], + [ + "Ę", + "ėĖĖ" + ], + [ + "Ėē", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ę", + "ėĖ" + ], + [ + "Ę", + "#" + ], + [ + "ĕĖ", + "ď" + ], + [ + "ėĆ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ėă", + "ėĖĖ" + ], + [ + "ėĖėĖĖ", + "Ė" + ], + [ + "ĕĖĖ", + "ĔĖĖ" + ], + [ + "ĖĕĖĖ", + "ĘĖĖ" + ], + [ + "ę", + "Ĕ" + ], + [ + "ę", + "Ě" + ], + [ + "Ē", + "ĕ" + ], + [ + "ę", + "ě" + ], + [ + "Ė", + "$" + ], + [ + "ĖėĖĖ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "ĘĖ", + "ě" + ], + [ + "ę", + "Đ" + ], + [ + "ĕ", + "Ĝ" + ], + [ + "ėĕ", + "ĖĖĕ" + ], + [ + "ĔĖ", + "Đ" + ], + [ + "ę", + "," + ], + [ + "ę", + "đ" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ė", + "%" + ], + [ + "ėĖ", + "ĝ" + ], + [ + "ē", + "ĕĖ" + ], + [ + "ē", + "Ėē" + ], + [ + "ę", + "ď" + ], + [ + "ē", + "Ā" + ], + [ + "ĖĖĖĖ", + "ĔĖĖ" + ], + [ + "ĖĕĖĖĖĖĖ", + "ėĖĖĕĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĖĖĖ" + ], + [ + "Ę", + "$" + ], + [ + "ę", + "Ď" + ], + [ + "Ėă", + "ė" + ], + [ + "ėĖ", + "ēĖĖ" + ], + [ + "đ", + "ę" + ], + [ + "ėĕ", + "ĖĖĖĖĖ" + ], + [ + "ĔĖĖĕ", + "Ėĕ" + ], + [ + "Ėĕė", + "ĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "Ĝ" + ], + [ + "Ėē", + "ĖĖĕ" + ], + [ + "ęă", + "Ĕ" + ], + [ + "Ē", + "ę" + ], + [ + "ĖĕĖĖ", + "ėĖĖėĖĖėĖĖ" + ], + [ + "ė", + "ĕĕ" + ], + [ + "Ė", + "đĖĖ" + ], + [ + "ė", + "&" + ], + [ + "Đ", + "ę" + ], + [ + "ĕĖ", + "ĝ" + ], + [ + "ĖĕĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ė", + "%" + ], + [ + "ĕ", + "ď" + ], + [ + "Ę", + "ĕĖė" + ], + [ + "Ęă", + "Ĕ" + ], + [ + "đ", + "Ėĕ" + ], + [ + "ėĖĖėĖĖ", + "ĘĖĖ" + ], + [ + "ē", + "Ă" + ], + [ + "ě", + "Ėĕ" + ], + [ + "ĕ", + "ĖĔĖĖ" + ], + [ + "Ę", + "&" + ], + [ + "ĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĖĔ", + "ĖĖĕĖĖĖĖĖėĖĖ" + ], + [ + "ğ", + "ĖĖ" + ], + [ + "ĖĖĖĖĖĖĖ", + "ėĖĖĕĖĖ" + ], + [ + "ēĖ", + "ę" + ], + [ + "ĖĔ", + "ĖĖĖĖĖĕĖĖĖĖĖ" + ], + [ + "ėĖĖ", + "Ėė" + ], + [ + "Ėĕ", + "ĖĖĖĖĕ" + ], + [ + "ĕ", + "ĔĖė" + ], + [ + "ėć", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖ", + "ď" + ], + [ + "ĖĖĖĖĖĖ", + "ĖĕĖĖėĖĖ" + ], + [ + "ėă", + "ė" + ], + [ + "ēĖĖ", + "ėĖĖ" + ], + [ + "ĖĕĖĖĕĖĖ", + "ėĖĖėĖĖ" + ], + [ + "Ę", + "%" + ], + [ + "Ĝ", + "ā" + ], + [ + "ĘĖ", + "đ" + ], + [ + "ĔĖ", + "Ĝ" + ], + [ + "ě", + "Ėė" + ], + [ + "ę", + "ē" + ], + [ + "ĘĖ", + "Ĝ" + ], + [ + "Ė", + "&" + ], + [ + "ę", + "ĝ" + ], + [ + "ĕĖ", + "ęĖĖ" + ], + [ + "ėĖĖ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "ĝ" + ], + [ + "ēĖ", + "Ē" + ], + [ + "Ė", + "ĘĖĖĘĖĖ" + ], + [ + "ĖĖĖ", + "ėĖ" + ], + [ + "Ė", + "ĚĖĖ" + ], + [ + "ē", + "Ĕ" + ], + [ + "ĖĔ", + "ĖĖĖĖĖĖĖĖĕĖĖ" + ], + [ + "ėĂ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ę", + "ĕĖĖ" + ], + [ + "ĕ", + "Ď" + ], + [ + "ėĖĖĕ", + "Ėĕ" + ], + [ + "Ė", + "ĘĖĖėĖĖ" + ], + [ + "ĖĘ", + "Ėĕ" + ], + [ + "ę", + "ĖĖĕ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ĖĖ" + ], + [ + "ė", + "'" + ], + [ + "ėĖ", + "ęĖĖ" + ], + [ + "ėĈ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĔĖė", + "ĕĖĖ" + ], + [ + "ĕ", + "ć" + ], + [ + "ėă", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "ą" + ], + [ + "ĖėĖĖėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ĖĔ", + "ĖĖĖĖĖĖĖĖėĖĖ" + ], + [ + "ě", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĔĖĖ", + "ĘĖĖ" + ], + [ + "ė", + "ĕĖėĖĖ" + ], + [ + "Ėė", + "ĕ" + ], + [ + "Ę", + "ėĖĕ" + ], + [ + "ĕ", + "Ć" + ], + [ + "ē", + "ę" + ], + [ + "ĖĕĖĖ", + "ėĖĖĖĖĖėĖĖ" + ], + [ + "ėă", + "ĖĖĖĖ" + ], + [ + "ĕĕ", + "ĖĖĕĖĖ" + ], + [ + "ĔĖ", + "ď" + ], + [ + "Ē", + "ĖĘ" + ], + [ + "ĖėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ĕ", + "ē" + ], + [ + "Ē", + "ā" + ], + [ + "Ė", + "'" + ], + [ + "ę", + "Ğ" + ], + [ + "ĕ", + "ēĖĖ" + ], + [ + "ĕ", + "Ĉ" + ], + [ + "ę", + "ĖĖĕĖĖ" + ], + [ + "ēĖ", + "Ě" + ], + [ + "Ęă", + "ĖĖĖĖ" + ], + [ + "ĕ", + "Ğ" + ], + [ + "Ĝ", + "ă" + ], + [ + "ĕĕ", + "ĕ" + ], + [ + "Ėĕ", + "ĖĘ" + ], + [ + "ĖĕĖĖ", + "ėĖĖĖĖĖĕĖĖ" + ], + [ + "Ėĕ", + "ėĖĕ" + ], + [ + "Ĕ", + "ę" + ], + [ + "Ġ", + "ĖĖ" + ], + [ + "ď", + "ę" + ], + [ + "Ē", + "ĖĔ" + ], + [ + "ė", + "ėĕ" + ], + [ + "ėă", + "ĖĕĖė" + ], + [ + "Ėĕĕ", + "ĖĔ" + ], + [ + "ėĖė", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ğ" + ], + [ + "ėĖĕ", + "ĕĖĖ" + ], + [ + "Ę", + "'" + ], + [ + "Ę", + "+" + ], + [ + "Ĕ", + "ĕĖė" + ], + [ + "Ė", + "ėĖĖėĖĖėĖĖ" + ], + [ + "ė", + "ĖĕĖĖĕ" + ], + [ + "ĖĕĖĖ", + "Ę" + ], + [ + "Ĕ", + "," + ], + [ + "ĖĕĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ę", + "ę" + ], + [ + "ĕĖ", + "Ď" + ], + [ + "ėĖĖĖĖ", + "ėĖĖ" + ], + [ + "ėĖ", + "Ğ" + ], + [ + "ĒĖĖ", + "ĔĖĖ" + ], + [ + "ĘĖĖ", + "ĔĖĖ" + ], + [ + "ė", + "(" + ], + [ + "ĖĘ", + "Ėė" + ], + [ + "ĖĖĕ", + "ĖĖĕ" + ], + [ + "ĖēĖĖ", + "ĔĖĖ" + ], + [ + "ĘĖ", + "Đ" + ], + [ + "ĔĖĖĕĖĖ", + "ĔĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖ" + ], + [ + "ę", + "!" + ], + [ + "ě", + "Ą" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "Ė" + ], + [ + "ĔĖĖĕ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "Ęă", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "Ē" + ], + [ + "Ėė", + "Ė" + ], + [ + "ĖĔĖĖ", + "ēĖĖ" + ], + [ + "ėĖĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "Ęā", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ē", + "ėĖ" + ], + [ + "ĖėĖĖ", + "Ę" + ], + [ + "Ėē", + "ĖĖĖĖĖ" + ], + [ + "Ęă", + "Ėė" + ], + [ + "Ęă", + "ĕĖĖ" + ], + [ + "ėă", + "ĖĔ" + ], + [ + "ė", + ")" + ], + [ + "ĘĀ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĕ", + "ĖĖĖĖ" + ], + [ + "Ěă", + "ĕ" + ], + [ + "ēĖ", + "đ" + ], + [ + "ĖĘ", + "ĖĖĖĖĖ" + ], + [ + "ĕ", + "ğ" + ], + [ + "ĖĖĖĖĕĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ĕ", + "Ę" + ], + [ + "ĖĔĖĖĖĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ĕ", + "ĕĖĖ" + ], + [ + "Ĝ", + "Ă" + ], + [ + "Ę", + "(" + ], + [ + "ĘĖĕ", + "ėĖĖ" + ], + [ + "Ė", + ")" + ], + [ + "ęĖ", + "Ě" + ], + [ + "ĘĖ", + "ĝ" + ], + [ + "ĔĖ", + "ĝ" + ], + [ + "ĖĕĖĖ", + "ėĖĖėĖĖ" + ], + [ + "ĖėĖĖėĖĖ", + "ėĖĖėĖĖ" + ], + [ + "ėĖė", + "ĕĖĖ" + ], + [ + "ĕĖ", + "ėĖĖĕ" + ], + [ + "Ė", + "(" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖ" + ], + [ + "ęĖ", + "ę" + ], + [ + "ėĖĕ", + "ėĖĕ" + ], + [ + "ĔĖĖĔĖĖ", + "ĔĖĖ" + ], + [ + "ĘĂ", + "ĕ" + ], + [ + "ĕĖ", + "Ğ" + ], + [ + "Ę", + "*" + ], + [ + "ė", + "ĔĖĕ" + ], + [ + "ė", + "*" + ], + [ + "ĖĕĖĖĕĖĖ", + "ĔĖĖĕĖĖ" + ], + [ + "ėĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĕĖĖĕ", + "Ėĕ" + ], + [ + "ĕĖ", + "ĒĖĖ" + ], + [ + "Ě", + "ĕ" + ], + [ + "ė", + "ēĖĖ" + ], + [ + "ėă", + "ĕĕ" + ], + [ + ",", + "Ď" + ], + [ + "ėĖ", + "Ď" + ], + [ + "ęĖ", + "Ē" + ], + [ + "ĕĖėĖĖ", + "Ė" + ], + [ + "Ĝ", + "Ā" + ], + [ + "ę", + "\"" + ], + [ + "Đ", + "Ėė" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĖĖĖėĖĖ" + ], + [ + "Ě", + "ĖĔ" + ], + [ + "Ě", + "ĖĘ" + ], + [ + "Ę", + ")" + ], + [ + "ę", + "Ėē" + ], + [ + "ĕ", + "!" + ], + [ + "ėĀ", + "ĖĕĖĖĖĖĖĖĖĖĕĖĖ" + ], + [ + "ĖĖĖĖėĖĖėĖĖ", + "ĕĖĖ" + ], + [ + "ė", + "ĘĖĕ" + ], + [ + "ėĖ", + "ėĖĖĕ" + ], + [ + "ėĖĖĕĖĖ", + "ĔĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĖĖĖĕĖĖ" + ], + [ + "Ęăĕ", + "ĖĖĖ" + ], + [ + "Ĕ", + "ĖĖĖĖĕ" + ], + [ + "Ęă", + "ĖėĖĖ" + ], + [ + "Ėĕ", + "ĖĔ" + ], + [ + "ē", + "ē" + ], + [ + "Ĕ", + "ĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĕĖĖ" + ], + [ + "ĚĖĖ", + "ĘĖĖ" + ], + [ + "Ď", + "ę" + ], + [ + "Ē", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĖĖ", + "ė" + ], + [ + "ėĀ", + "ĖĕĖĖĖĖĖĕĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ę", + "ėĖė" + ], + [ + "Ėă", + "ĖĔ" + ], + [ + "ė", + "ĖĕĖĖĕĖĖĖĖĖ" + ], + [ + "ĔĖ", + "Ď" + ], + [ + "ĔĖĕ", + "ĕĖĖ" + ], + [ + "ĘĂ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ě", + "ą" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĕĖĖĖĖĖ" + ], + [ + "ė", + "+" + ], + [ + "Ėă", + "ĕ" + ], + [ + "ēĖ", + "Đ" + ], + [ + "ēĖ", + "ě" + ], + [ + "ėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ę", + "ĖĖĖĖ" + ], + [ + "Ē", + "ĕĖĖĕ" + ], + [ + "Ėĕĕ", + "ĖĕĖĖĕĖĖĖĖĖ" + ], + [ + "ĖĔĖĖĔ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "ĖĖĕĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ėĀ", + "ĖĕĖĖĕĖĖĕĖĖĖĖĖ" + ], + [ + "ė", + "ėĖĖĖĖĖ" + ], + [ + "ėĀ", + "ĖĕĖĖĖĖĖĖĖĖėĖĖ" + ], + [ + "ĖĕĖĖĔ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "ĕ", + "\"" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĕ" + ], + [ + "Ě", + "," + ], + [ + "đ", + "Ė" + ], + [ + "ĒĖĖ", + "ēĖĖ" + ], + [ + "Ě", + "Ć" + ], + [ + "ĖėĖĖĖĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ĚĖĖ", + "ėĖĖ" + ], + [ + "Ě", + "ĕĖ" + ], + [ + "ėĀ", + "ĕ" + ], + [ + "ĖĖĖĖĖėĖĖ", + "ĘĖĖ" + ], + [ + "Ĕ", + "Ě" + ], + [ + "Ė", + "*" + ], + [ + "Ėĕĕ", + "ĖĕĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĖ" + ], + [ + "Ęăĕ", + "ĕĖĖ" + ], + [ + "ĖĔĖĖ", + "ē" + ], + [ + "ĘĄ", + "ĕ" + ], + [ + "Ě", + "ć" + ], + [ + "Ę", + "ĕĖĕ" + ], + [ + "ę", + "#" + ], + [ + "Ęĕ", + "ĖĔ" + ], + [ + "ĖĖĖĖ", + "ėĖĖĕĖĖĖĖĖ" + ], + [ + "Ě", + "Ĉ" + ], + [ + "Ē", + "ĕĖ" + ], + [ + "Ĕ", + "ĖĖĖĖĖĕĖĖ" + ], + [ + "ėĎ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "ĕĕ" + ], + [ + "ėĖ", + "ğ" + ], + [ + "ėĕ", + "ĖĔ" + ], + [ + "Ē", + "ă" + ], + [ + "ėă", + "ĖĕĖĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ėā", + "ĖĕĖĖĖĖĖĖĖĖĕĖĖ" + ], + [ + "ĖĖĖĖėĖĖ", + "Ė" + ], + [ + "ėĂ", + "ĕ" + ], + [ + "ĘĖĕ", + "ĖĖĖ" + ], + [ + "ė", + "ėĖĖėĖĖ" + ], + [ + "Ėĕ", + "ĕĖĖ" + ], + [ + "ĖĔ", + "ĖĖĖĖĖĕĖĖĕĖĖ" + ], + [ + "Đ", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĂ", + "ĖĕĖĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ĔĖ", + "ĘĖĖ" + ], + [ + "ĖĖĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "Ę", + "ĖėĖĖ" + ], + [ + "ėā", + "ĖĕĖĖĖĖĖĕĖĖĖĖĖ" + ], + [ + "ĘĖė", + "ėĖĖ" + ], + [ + "ĖĖĖĖĖĕĖĖ", + "ĔĖĖ" + ], + [ + "ęĖ", + "ě" + ], + [ + "Ėė", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĘĖ", + "ď" + ], + [ + "Ē", + "Ā" + ], + [ + "ĖĔĖĖĖĖĖ", + "ėĖĖėĖĖ" + ], + [ + "ėā", + "ĖĕĖĖĕĖĖĕĖĖĖĖĖ" + ], + [ + "Ē", + "Ă" + ], + [ + "Ĕ", + "ĕĖĕ" + ], + [ + "ĖĔ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖ" + ], + [ + "ęă", + "ĕĕ" + ], + [ + "ĖĕĖĖ", + "ēĖĖ" + ], + [ + "Ĕ", + "đ" + ], + [ + "ĔĖĕ", + "ĖĖĖ" + ], + [ + "ĖĖĕĖĖ", + "ĔĖĖ" + ], + [ + "ĕĖĖĕ", + "Ėė" + ], + [ + "ėė", + "ĖĔ" + ], + [ + "ĕ", + "#" + ], + [ + "ėĖĖĖĖ", + "Ėĕ" + ], + [ + "ėĖė", + "ėĖė" + ], + [ + "ę", + "$" + ], + [ + "ĘĖĖėĖĖ", + "ĘĖĖ" + ], + [ + "ĝ", + "ā" + ], + [ + "ĖĕĖĖĕĖĖĕĖĖ", + "ĔĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĖĖĖėĖĖ" + ], + [ + "Ę", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėā", + "ĖĕĖĖĖĖĖĖĖĖėĖĖ" + ], + [ + "ėĖėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ĖĕėĖĖĖĖ", + "ĖĖĖėĖĖ" + ], + [ + "ĔĖĖĕ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ėĄ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖ", + "ėĖĕ" + ], + [ + "ĘĖĖ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "ĘĖ", + "Ğ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĕĖĖėĖĖ" + ], + [ + "Ě", + "ėĖ" + ], + [ + "ĖĖĖĖ", + "ĖĕĖė" + ], + [ + "ĔĖ", + "Ğ" + ], + [ + "Ĕ", + "ĖĖĖĖė" + ], + [ + "Ĝ", + "Ėĕ" + ], + [ + "ĚĖĖ", + "ęĖĖ" + ], + [ + "ĖėĖĖ", + "ęĖĖ" + ], + [ + "Ėĕė", + "ĖĕĖĖĕĖĖĖĖĖ" + ], + [ + "Ėė", + "ĖĖĖĖĖĖĖĖėĖĖ" + ], + [ + "ėă", + "ėĕ" + ], + [ + "Ĕ", + "ĖĖĖĖĖėĖĖ" + ], + [ + "ĖĔĖĖĔ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ė", + "ĖĕĖĖĖĖĖ" + ], + [ + "ĕĖ", + "ğ" + ], + [ + "ėď", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĐ", + "ĖĖ" + ], + [ + "ėă", + "ĖĔĖĖ" + ], + [ + "ęĄ", + "ĕ" + ], + [ + "Ě", + "Ē" + ], + [ + "ē", + "Ę" + ], + [ + "ē", + "Ē" + ], + [ + "ė", + "ęĖĖ" + ], + [ + "ėĖė", + "ĖĖĖĖĖ" + ], + [ + "ĘĖĖĘĖĖ", + "ĘĖĖ" + ], + [ + "Đ", + "Ėĕ" + ], + [ + "ĔĖėĖĖ", + "Ė" + ], + [ + "Ĕ", + "ėĖĕ" + ], + [ + "ĕă", + "ė" + ], + [ + "ēĖ", + "Ĝ" + ], + [ + "ĘĀ", + "ĕ" + ], + [ + "ēĖ", + "ď" + ], + [ + "ėă", + "ĖĖĖ" + ], + [ + "ę", + "%" + ], + [ + "ėĖ", + "ĒĖĖ" + ], + [ + "ĕĖ", + "ĚĖĖ" + ], + [ + "ĕĕ", + "ĖĖĕĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "$" + ], + [ + "Ęă", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ęĂ", + "ĕ" + ], + [ + "Ĝ", + "Ėė" + ], + [ + "Ĕ", + "ėĖĖ" + ], + [ + "ĖĕĖĖ", + "ē" + ], + [ + "ĖĕĖĖĕĖĖ", + "ĔĖĖĖĖĖ" + ], + [ + "ė", + "ėĖėĖĖ" + ], + [ + ",", + "ď" + ], + [ + "ę", + "&" + ], + [ + "Ě", + "đ" + ], + [ + "ėă", + "ĕĕĖĖ" + ], + [ + "ėĖė", + "ĘĖĖ" + ], + [ + "ęĖ", + "đ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "Ė" + ], + [ + "Ėĕ", + "ėĖĖĖĖĖĖĖĕĖĖ" + ], + [ + "ęă", + "ĕĖĖ" + ], + [ + "ĖĀ", + "ĕ" + ], + [ + "ĖėĖĖ", + "ėĖĖĖĖĖėĖĖ" + ], + [ + "ęĖ", + "Ĝ" + ], + [ + "ĖĖĕ", + "ĖėĖĖ" + ], + [ + "ĖĔ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "ę", + "ĖĖĖ" + ], + [ + "Ėă", + "ėĕĖĖ" + ], + [ + "ĔĖė", + "ĕĖė" + ], + [ + "ĖēĖĖ", + "ēĖĖ" + ], + [ + "ĖĕĖĖĔ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ĖĖĖĖėĖĖ", + "ĘĖĖėĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĕĖĖ" + ], + [ + "Ėė", + "ĖĖĖĖĖĖĖĖĕĖĖ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ė", + "ĘĖė" + ], + [ + "ėĖ", + "ĚĖĖ" + ], + [ + "Ėĕ", + "Ėĕĕ" + ], + [ + "ĖĖĕ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĕĖĖ", + "ėĖĖėĖĖ" + ], + [ + "ĔĖ", + "ėĖĖĕ" + ], + [ + "ĖĒ", + "Ėė" + ], + [ + "ėă", + "ĖĕĖĖĖĖ" + ], + [ + "ĖĕĖĖĔ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "%" + ], + [ + "Ě", + "Ď" + ], + [ + "Ě", + "ě" + ], + [ + "ėĐ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĘĄ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĕ", + "ĘĖĖ" + ], + [ + "ĖĔĖĖĕĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ęă", + "ĖĖ" + ], + [ + "ĔĖė", + "ĔĖĖ" + ], + [ + "Ėĕ", + "ėĖĖĖĖĕĖĖĖĖĖ" + ], + [ + "ėĖĖ", + "ĔĖĖ" + ], + [ + "Ě", + "ď" + ], + [ + "ĖĖĖĖĕ", + "Ėė" + ], + [ + "ęă", + "ĕĖĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖėĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖĕĖĖ" + ], + [ + "ĘĖĖ", + "Ę" + ], + [ + "ē", + "Ą" + ], + [ + "ĖĔĖĖĕĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ė", + "ĕĖĘ" + ], + [ + "Ėă", + "Ę" + ], + [ + "Ĕ", + "Đ" + ], + [ + "Ĕ", + "ě" + ], + [ + "ĖĕĖĖ", + "ėĖĖĕĖĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖ" + ] + ] + } +} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_all/tokenizer_config.json b/scenestreamer/tokenization/0305_fast_all/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..44b81ebde2224b4e3935b02872938beae622c37c --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_all/tokenizer_config.json @@ -0,0 +1,8 @@ +{ + "added_tokens_decoder": {}, + "clean_up_tokenization_spaces": false, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "processor_class": "UniversalActionProcessor", + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/scenestreamer/tokenization/0305_fast_cyc_440000/delta_normalization_quantiles.json b/scenestreamer/tokenization/0305_fast_cyc_440000/delta_normalization_quantiles.json new file mode 100644 index 0000000000000000000000000000000000000000..212204312c4d80639c0763981c96a77ca7604b2b --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_cyc_440000/delta_normalization_quantiles.json @@ -0,0 +1 @@ +{"q_lower": [-0.0562463104724884, -0.04942926615476609, -0.12296247482299805, -0.05916598066687584, -0.049494446702301506, -0.12199340581893922, -0.06487450540065766, -0.048804283142089844, -0.12228697001934052, -0.0770511893182993, -0.048270415514707565, -0.11826350688934326, -0.08911120891571045, -0.04901319369673729, -0.12249613285064698], "q_upper": [0.058538779616355896, 0.9725202322006226, 0.12751812219619768, 0.06069798842072492, 0.9720370769500732, 0.12094704627990738, 0.06706408478319645, 0.9723510098457337, 0.12257930755615254, 0.07632199041545394, 0.9750965356826785, 0.12193348586559298, 0.08802430331707001, 0.9738962650299072, 0.1204061508178711]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_cyc_440000/error_mean.json b/scenestreamer/tokenization/0305_fast_cyc_440000/error_mean.json new file mode 100644 index 0000000000000000000000000000000000000000..84fc29c9841ca3705ce02cc2dcf22cb3f6a7eae8 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_cyc_440000/error_mean.json @@ -0,0 +1 @@ +[[6.0896552864887044e-05, 0.0013265651167509809, 0.050996215752343005], [0.01950295918578097, 0.4473946097059047, 0.06308904025233754], [0.0004593604702810042, 0.01064230356511029, 0.057537467073347955], [0.0007838270269644638, 0.018373398431503596, 0.06566936588958061], [0.0011506386418468116, 0.02828774303892772, 0.07266366599924856]] \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_cyc_440000/processor_config.json b/scenestreamer/tokenization/0305_fast_cyc_440000/processor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..8d9a3f725377bee10763626aa8cf6ddb5ad8ac87 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_cyc_440000/processor_config.json @@ -0,0 +1,8 @@ +{ + "action_dim": null, + "min_token": -22, + "processor_class": "UniversalActionProcessor", + "scale": 10, + "time_horizon": null, + "vocab_size": 1024 +} diff --git a/scenestreamer/tokenization/0305_fast_cyc_440000/special_tokens_map.json b/scenestreamer/tokenization/0305_fast_cyc_440000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_cyc_440000/special_tokens_map.json @@ -0,0 +1 @@ +{} diff --git a/scenestreamer/tokenization/0305_fast_cyc_440000/tokenizer.json b/scenestreamer/tokenization/0305_fast_cyc_440000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..c5643ff8517ccc669548dfc2872998aaaa202709 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_cyc_440000/tokenizer.json @@ -0,0 +1,4847 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": false, + "vocab": { + "\u0000": 0, + "\u0001": 1, + "\u0002": 2, + "\u0003": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "Ā": 45, + "ā": 46, + "Ă": 47, + "ă": 48, + "Ą": 49, + "ą": 50, + "Ć": 51, + "ć": 52, + "Ĉ": 53, + "ĉ": 54, + "Ċ": 55, + "ċ": 56, + "Č": 57, + "č": 58, + "Ď": 59, + "ď": 60, + "Đ": 61, + "đ": 62, + "Ē": 63, + "ē": 64, + "Ĕ": 65, + "ĕ": 66, + "Ė": 67, + "ė": 68, + "Ę": 69, + "ę": 70, + "Ě": 71, + "ě": 72, + "Ĝ": 73, + "ĝ": 74, + "Ğ": 75, + "ğ": 76, + "Ġ": 77, + "ĖĖ": 78, + "Ėĕ": 79, + "Ėė": 80, + "ĕĕ": 81, + "ėė": 82, + "ĖĖĖĖ": 83, + "ĕĖ": 84, + "ėĖ": 85, + "ėĕ": 86, + "ėĖĖ": 87, + "ĕĖĖ": 88, + "ĖĔ": 89, + "ĘĖĖ": 90, + "ĕė": 91, + "ĔĖĖ": 92, + "ĖĘ": 93, + "ĖĕĖĖ": 94, + "ĖėĖĖ": 95, + "ĕĔ": 96, + "ėĖĕ": 97, + "ēĖĖ": 98, + "ęĖĖ": 99, + "ėĘ": 100, + "ėĔ": 101, + "ĕĖĕ": 102, + "ĕĖė": 103, + "ĕĘ": 104, + "Ėē": 105, + "Ėę": 106, + "ėĖė": 107, + "ĘĖĕ": 108, + "ĔĔ": 109, + "ĘĖ": 110, + "ĕē": 111, + "ĔĖ": 112, + "ĒĖĖ": 113, + "Ęĕ": 114, + "ĔĖė": 115, + "ėĕĕ": 116, + "ĚĖĖ": 117, + "Ĕĕ": 118, + "ėę": 119, + "ėĖĖĖĖ": 120, + "ĔĖĕ": 121, + "ĕĖĖĖĖ": 122, + "ėē": 123, + "ęĖĕ": 124, + "ĖĒ": 125, + "ĘĘ": 126, + "ĕĕĖĖ": 127, + "ĖĚ": 128, + "ĖĔĖĖ": 129, + "ĘĖė": 130, + "ĕėė": 131, + "ėĕĖ": 132, + "Ĕė": 133, + "ĕę": 134, + "Ęė": 135, + "Ėđ": 136, + "ēĖė": 137, + "ĔėĖ": 138, + "ėėĖĖ": 139, + "ĖĘĖĖ": 140, + "ĕėĖ": 141, + "ĕĒ": 142, + "ĘĕĖ": 143, + "Ėě": 144, + "ėĕĖĖ": 145, + "ĔĕĖ": 146, + "ĘėĖ": 147, + "ĖĖĖĖĖĖ": 148, + "ĖĂ": 149, + "ĚĖĕ": 150, + "ĖėĖĕ": 151, + "Ĕē": 152, + "Ęĕĕ": 153, + "đĖĖ": 154, + "ĔĘ": 155, + "ēĖĕ": 156, + "ĖĖĖĖĖ": 157, + "ěĖĖ": 158, + "ĕĂ": 159, + "ėĚ": 160, + "Ĕĕĕ": 161, + "ĕėĖĖ": 162, + "ėĒ": 163, + "ĖĐ": 164, + "ĖĕĖĕ": 165, + "ęĖė": 166, + "ĕđ": 167, + "ĕėĕ": 168, + "ĘĔ": 169, + "ėĂ": 170, + "Ĕėė": 171, + "ĖĜ": 172, + "Ęėĕ": 173, + "ĕĚ": 174, + "ęĕĖ": 175, + "ĒĖė": 176, + "ēėĖ": 177, + "ęėĖ": 178, + "ēĕĖ": 179, + "ĖĕĖė": 180, + "ęĕĕ": 181, + "ĕĔĖĖ": 182, + "ėĖĕĖĖ": 183, + "ėě": 184, + "Ęę": 185, + "ĕĐ": 186, + "ėđ": 187, + "Ėď": 188, + "Ęėė": 189, + "Ęē": 190, + "ĒĖĕ": 191, + "Ĕĕė": 192, + "Ĕėĕ": 193, + "ĕě": 194, + "ĔĂ": 195, + "ĕĖĕĖĖ": 196, + "Ĕę": 197, + "ĖėĖė": 198, + "ĕĖĖĕĖĖ": 199, + "ĕĖėĖĖ": 200, + "ĔĒ": 201, + "ĖėĖ": 202, + "ēē": 203, + "ėĖĖėĖĖ": 204, + "ėĐ": 205, + "ęėĕ": 206, + "ĜĖĖ": 207, + "ĖĕĖ": 208, + "ĕď": 209, + "ĘĚ": 210, + "ėėĖ": 211, + "ěĖĕ": 212, + "ĒĕĖ": 213, + "ėĖėĖĖ": 214, + "ĕĕĖ": 215, + "ĚėĖ": 216, + "ĒėĖ": 217, + "Ęĕė": 218, + "ĘĖĔ": 219, + "ēėė": 220, + "Ĕđ": 221, + "ėĜ": 222, + "ėĘĖĖ": 223, + "ĖĎ": 224, + "ĚĕĖ": 225, + "ĚĖė": 226, + "Ėĝ": 227, + "ėėĖĕ": 228, + "ēĕĕ": 229, + "ĔĖĘ": 230, + "ĘĂ": 231, + "ĐĖĖ": 232, + "ĔĖĔ": 233, + "ėď": 234, + "Ěĕĕ": 235, + "ĖĘĖĕ": 236, + "ęę": 237, + "ĘĖĖĖĖ": 238, + "ĕĜ": 239, + "ĖĕĖĖĖĖ": 240, + "ėĖĖĕĖĖ": 241, + "ĖēĖĖ": 242, + "ĕĕĖė": 243, + "Ēē": 244, + "ĕĕĖĕ": 245, + "ĕĎ": 246, + "Ěėĕ": 247, + "ĘĖĘ": 248, + "đĖė": 249, + "ēĕė": 250, + "ĘĒ": 251, + "ēĔ": 252, + "Ęě": 253, + "ėĝ": 254, + "ĔĕĖĖ": 255, + "ęėė": 256, + "ĔĐ": 257, + "ĕėĖĕ": 258, + "ēĒ": 259, + "ĖĞ": 260, + "ĖęĖĖ": 261, + "Ęđ": 262, + "ėĎ": 263, + "ēėĕ": 264, + "ĔĖĖĖĖ": 265, + "ĖėĖĖĖĖ": 266, + "Ēėė": 267, + "ęĘ": 268, + "đĖĕ": 269, + "ėĕĖĕ": 270, + "ĖĔĖĕ": 271, + "ĖĔĖė": 272, + "ĕĝ": 273, + "ĔĚ": 274, + "ėĕĖė": 275, + "ĖĕĖĖĖ": 276, + "ĕĘĖĖ": 277, + "Ĕď": 278, + "ěėĖ": 279, + "ěĕĖ": 280, + "Ēĕĕ": 281, + "ĕēĖĖ": 282, + "ēę": 283, + "ėėĖė": 284, + "ĔĖĖĕĖĖ": 285, + "ĘĐ": 286, + "Ěę": 287, + "ĜĖĕ": 288, + "đėĖ": 289, + "ĘėĖĖ": 290, + "ęē": 291, + "ěĖė": 292, + "đĕĖ": 293, + "ėĞ": 294, + "ēĂ": 295, + "ĕėĖė": 296, + "ĝĖĖ": 297, + "Ėğ": 298, + "Ēĕė": 299, + "ĖĕĕĖĖ": 300, + "ĕĖĘ": 301, + "ėĖĔ": 302, + "Ęď": 303, + "ĘĖĖėĖĖ": 304, + "ĕĞ": 305, + "ęĕė": 306, + "ēđ": 307, + "ėĔĖĖ": 308, + "ęĖĔ": 309, + "Ěėė": 310, + "ěĕĕ": 311, + "ėęĖĖ": 312, + "ĔĎ": 313, + "ėĕĕĖĖ": 314, + "ĒĒ": 315, + "ĖĕĖĖĕĖĖ": 316, + "Ĕě": 317, + "ĘĜ": 318, + "ĘĕĖĖ": 319, + "ęĚ": 320, + "ĕĔĖė": 321, + "ėĖĖĖĕ": 322, + "ęĒ": 323, + "ĕĖĖĖĖĖ": 324, + "ęĔ": 325, + "ĖĖĖĖĖĖĖĖ": 326, + "ęđ": 327, + "ėĘĖĕ": 328, + "Ēėĕ": 329, + "ēĐ": 330, + "ėğ": 331, + "ĘĕĔ": 332, + "ėĖĖĖĖĖ": 333, + "ĘĎ": 334, + "ēĖĘ": 335, + "ĔĕĖĕ": 336, + "ěėĕ": 337, + "ĖĔėĖ": 338, + "ĐĖė": 339, + "ęĖ": 340, + "ĖėėĖ": 341, + "ĖęĖĕ": 342, + "ďĖĖ": 343, + "ĔĕĘ": 344, + "ĕĖĖėĖĖ": 345, + "ēĖ": 346, + "đėė": 347, + "ĖĔĕĖ": 348, + "Ėā": 349, + "ĔĜ": 350, + "ĖėĕĖ": 351, + "ĘėĖĕ": 352, + "ĘėĔ": 353, + "ĘėĖė": 354, + "ĕĖĖĖĕ": 355, + "ęě": 356, + "ĕğ": 357, + "ĖėĖĖėĖĖ": 358, + "ęĂ": 359, + "đē": 360, + "ĕĔĖĕ": 361, + "Ęĝ": 362, + "ĚĚ": 363, + "ēď": 364, + "ĖĖĖĕ": 365, + "ĖĘĖė": 366, + "ĘĖē": 367, + "ĖĘėĖ": 368, + "ĕĕĖĖĖĖ": 369, + "Ěē": 370, + "ęĐ": 371, + "ĔĕĖė": 372, + "Ěĕė": 373, + "ėėĖĖĖĖ": 374, + "ĕā": 375, + "ĖĘĕĖ": 376, + "Ēę": 377, + "ĔėĖĖ": 378, + "ĔėĘ": 379, + "ĚĖĔ": 380, + "Ėėĕĕ": 381, + "ĔĖē": 382, + "ĔĖĖėĖĖ": 383, + "Ėă": 384, + "ĖĕĕĖ": 385, + "ĘĖĖĕĖĖ": 386, + "ĔĕĔ": 387, + "ęď": 388, + "ĘĕĖė": 389, + "ĔėĖĕ": 390, + "Ēđ": 391, + "ĖĕėĖ": 392, + "ĒĐ": 393, + "ĘĕĘ": 394, + "ĝĖĕ": 395, + "đĕĕ": 396, + "ĔĖę": 397, + "ĐėĖ": 398, + "ĜĕĖ": 399, + "ĔĖĖĔĖĖ": 400, + "ĖĖĖĖėĖĖ": 401, + "ĘĖę": 402, + "ĐĖĕ": 403, + "ĜėĖ": 404, + "ĖĖĖĖĕĖĖ": 405, + "ĘĞ": 406, + "ėĖĖĖĖĖĖ": 407, + "ĖĕĖĖėĖĖ": 408, + "ěę": 409, + "ėėĕ": 410, + "ĘĕĖĕ": 411, + "ĖėĖĖĕĖĖ": 412, + "ĒĖĘ": 413, + "ēĖĔ": 414, + "Ĕĝ": 415, + "Ėĕĕĕ": 416, + "ĜĖė": 417, + "ĘĖĖĘĖĖ": 418, + "ēĚ": 419, + "ėă": 420, + "đĕė": 421, + "ĐĕĖ": 422, + "ĖĒĖĖ": 423, + "ĔėĔ": 424, + "ĚĒ": 425, + "ĕă": 426, + "Ĝĕĕ": 427, + "ěėė": 428, + "ĔĔĖĖ": 429, + "ėĕė": 430, + "ĕĖĖĖė": 431, + "ĕĘĖĕ": 432, + "đĒ": 433, + "ēĎ": 434, + "ĘėĘ": 435, + "đėĕ": 436, + "ėėė": 437, + "Ė!": 438, + "ĖēĖė": 439, + "ĕĖĖĖĖĖĖ": 440, + "Ēď": 441, + "ĕĒĖĖ": 442, + "ęĖĘ": 443, + "ĘĖĕĖĖ": 444, + "ėėĕĕ": 445, + "ĔėĖė": 446, + "ęĎ": 447, + "ĖĕėĖĖ": 448, + "ĔĞ": 449, + "Đē": 450, + "ěĖĔ": 451, + "ĖĘĕĕ": 452, + "ėā": 453, + "ė!": 454, + "ĖĚĖĖ": 455, + "ĘĘĖĖ": 456, + "ęĕĔ": 457, + "đđ": 458, + "ĖėėĖĖ": 459, + "Ė\"": 460, + "ėĖĔĖĖ": 461, + "Ěě": 462, + "ęĜ": 463, + "Ęğ": 464, + "ĕėėĖĖ": 465, + "ĕĕĕĕ": 466, + "ęėĔ": 467, + "ĘĘĕ": 468, + "ĘĔĖ": 469, + "ĕ!": 470, + "ėĕĖĖĖĖ": 471, + "ĔĖėĖĖ": 472, + "ĞĖĖ": 473, + "ēě": 474, + "Ĝėĕ": 475, + "ėĖĖĖė": 476, + "ĒĚ": 477, + "ĕĕė": 478, + "ėĖē": 479, + "ĒĖĔ": 480, + "ěĕė": 481, + "ĘĘĖ": 482, + "ĕĖĘĖĖ": 483, + "ĔĖĖĖĖĖ": 484, + "ĕęĖĖ": 485, + "ėĖĘĖĖ": 486, + "Đėė": 487, + "ėĖĕĕĖĖ": 488, + "ĔĘĖ": 489, + "ĚĐ": 490, + "ėĖĕėĖĖ": 491, + "ĎĖĖ": 492, + "ĖēĖĕ": 493, + "ĕĖĔĖĖ": 494, + "Ěđ": 495, + "Ĕă": 496, + "ęĝ": 497, + "ĒĂ": 498, + "ęĕ": 499, + "ĚĕĔ": 500, + "ęė": 501, + "ĖēėĖ": 502, + "ĖėĕĖĖ": 503, + "ĕėĖĖĖĖ": 504, + "ĕėĕĕ": 505, + "ĕĖē": 506, + "ĘĘĖĕ": 507, + "ėęĖĕ": 508, + "ďĖĕ": 509, + "Đĕĕ": 510, + "Ėėėė": 511, + "Ęă": 512, + "Ęĕē": 513, + "ė\"": 514, + "ĘĔĕ": 515, + "ĖĔĕĕ": 516, + "ĖĖĖĖĖĕ": 517, + "ĒĎ": 518, + "Ėėėĕ": 519, + "Ěď": 520, + "ĕĖĖĔĖĖ": 521, + "ēĕĘ": 522, + "ęĖĖėĖĖ": 523, + "Ėėė": 524, + "ēė": 525, + "ěĚ": 526, + "ėĖėĖĕ": 527, + "ęĘĕ": 528, + "ĖēĕĖ": 529, + "ęĔĖ": 530, + "Ė#": 531, + "ĕĖėĖĕ": 532, + "ĕĖĔ": 533, + "ĖĘėĕ": 534, + "ĘĖĖĖĖĖ": 535, + "ďĖė": 536, + "ĕĕĕ": 537, + "đĐ": 538, + "Ėėĕ": 539, + "ĖęėĖ": 540, + "ĖĔėė": 541, + "ēėĘ": 542, + "ēĜ": 543, + "ēĖĖĕĖĖ": 544, + "Ēě": 545, + "ĚėĔ": 546, + "ĖęĕĖ": 547, + "ĖĈ": 548, + "ĖđĖĖ": 549, + "Ĕā": 550, + "ēĕ": 551, + "ęĘĖ": 552, + "ĕēĖė": 553, + "ĕĖĖĖ": 554, + "ėĚĖĖ": 555, + "ĕ\"": 556, + "ėĖĘ": 557, + "ėĖĕĖĕ": 558, + "ęĞ": 559, + "ęĖĖĖĖ": 560, + "Ĕğ": 561, + "đĖĘ": 562, + "ĖěĖĖ": 563, + "ĝėĖ": 564, + "ĚĖĘ": 565, + "ĖĔĖĖĖĖ": 566, + "ėĔĖĕ": 567, + "ėĖĖĘĖĖ": 568, + "ĕĕėė": 569, + "ĝĕĖ": 570, + "ėĕėė": 571, + "ĕĖĕĖĕ": 572, + "đĚ": 573, + "ĖęĖė": 574, + "ėēĖĖ": 575, + "ĚĜ": 576, + "ĒĕĘ": 577, + "ěē": 578, + "ĖėĖĕĖĖ": 579, + "ĝĖė": 580, + "Ėĕėė": 581, + "ĕĖėĕĖ": 582, + "Đėĕ": 583, + "Ęėĕĕ": 584, + "ēĘĖ": 585, + "ėĖĖĕĕ": 586, + "ĐĒ": 587, + "ĕĖĕĕĖĖ": 588, + "Ĝėė": 589, + "ĚĎ": 590, + "ďĕĖ": 591, + "ĕėĕĖĖ": 592, + "ĘęĖĖ": 593, + "ďėĖ": 594, + "ĚĘĖ": 595, + "ęĔĕ": 596, + "Đĕė": 597, + "ĚĔĖ": 598, + "ĕĖĕĕĖ": 599, + "ĖĔėĕ": 600, + "Ĕĕē": 601, + "ěĒ": 602, + "ĖĖĖė": 603, + "ėĘĖė": 604, + "ēĝ": 605, + "ēĖĖĖĖ": 606, + "ėĖėėĖ": 607, + "ĒĔĖ": 608, + "ė#": 609, + "ĔēĖĖ": 610, + "ĕĈ": 611, + "ĒĘĖ": 612, + "ĘĖĕėĖĖ": 613, + "ĘėĖĖĖĖ": 614, + "ėĖĕĖė": 615, + "ĕėĕĖ": 616, + "ĖėĖĖĘĖĖ": 617, + "ĕ#": 618, + "ĔĖĕĖĖ": 619, + "Ĝĕė": 620, + "đę": 621, + "ĕĖĕĖė": 622, + "ĕĔėĖ": 623, + "ĔĔĖĕ": 624, + "đď": 625, + "ĞĖĕ": 626, + "ĖĘĖĖĖĖ": 627, + "ĕĖĖĕĕ": 628, + "ėĘĕĕ": 629, + "ėĔĖė": 630, + "Ęā": 631, + "ĖĕĖĖĔĖĖ": 632, + "Ę!": 633, + "ĚĂ": 634, + "ĕĖĕėĖ": 635, + "ĕēĖĕ": 636, + "Ěĝ": 637, + "ĜĖĔ": 638, + "ęėĖĖ": 639, + "ĚĘĕ": 640, + "ĖĚĖĕ": 641, + "ėĈ": 642, + "ēĘ": 643, + "ėĖĕĕĖ": 644, + "ėĖĕėĖ": 645, + "ēĔĖ": 646, + "ęĖĖĘĖĖ": 647, + "ĐĐ": 648, + "ĒėĘ": 649, + "Ė$": 650, + "ėėėė": 651, + "ĕĔĕĖ": 652, + "ĒĜ": 653, + "ēă": 654, + "ěě": 655, + "ėĖĕĖĖĖ": 656, + "ĔĕĖĖĖĖ": 657, + "ĝĕĕ": 658, + "đĎ": 659, + "ęėĖė": 660, + "ĕĘĖė": 661, + "ěėĔ": 662, + "ėĕĖėĖĖ": 663, + "ėĖĖĖĖĖĖĖĖ": 664, + "ĖĖĖĖĖĖĖ": 665, + "ĘĔĔ": 666, + "ĠĖĖ": 667, + "ęğ": 668, + "ĕĕėĖ": 669, + "ęĕĘ": 670, + "đĖĔ": 671, + "ĕĖėėĖ": 672, + "ĕĖėĖė": 673, + "Ĕ!": 674, + "Ėć": 675, + "ěĕĔ": 676, + "ēĕĔ": 677, + "ĝėĕ": 678, + "ĔĔĖė": 679, + "ĔĘĕ": 680, + "ēėĔ": 681, + "Ėĕĕ": 682, + "ėėĖĕĖĖ": 683, + "Ĕėē": 684, + "Ęėē": 685, + "ĕĖĖėĖ": 686, + "ēĖĖĔĖĖ": 687, + "ęėĘ": 688, + "ĕĖĖĕĖ": 689, + "ēĞ": 690, + "ėĖĖĕĖ": 691, + "Ĕėę": 692, + "Ĝę": 693, + "ĖĖĖĖĖė": 694, + "ĖĕĖĕĖĖ": 695, + "ĕĕėĖĖ": 696, + "ĔĈ": 697, + "Ėęėĕ": 698, + "Ėĕė": 699, + "ĘĖėĖĖ": 700, + "ěđ": 701, + "ĖĒĖė": 702, + "Ęėę": 703, + "ĜĚ": 704, + "Ėęĕĕ": 705, + "ėĖėĕĖ": 706, + "ĎĖė": 707, + "ĔĘė": 708, + "ėć": 709, + "ěď": 710, + "Đď": 711, + "ĕęĖĕ": 712, + "ĕĔėė": 713, + "ĘĖĖĔĖĖ": 714, + "ĕđĖĖ": 715, + "ĕĕĖĕĖĖ": 716, + "ėėĖĔ": 717, + "ēğ": 718, + "đĕĘ": 719, + "ĖĕĖėĖĖ": 720, + "ĕĕĕĖ": 721, + "ĕ$": 722, + "ĘĖĕĕĖĖ": 723, + "đėĘ": 724, + "ĕĕĖėĖĖ": 725, + "ėėėĕ": 726, + "ěĎ": 727, + "ěĔĖ": 728, + "ĕėĖĕĖĖ": 729, + "ĐĖĘ": 730, + "ĚĞ": 731, + "ėėĕĖĖ": 732, + "ĘĈ": 733, + "ė$": 734, + "ĎėĖ": 735, + "Ę\"": 736, + "ĖĔĖĖĕĖĖ": 737, + "ĒĖ": 738, + "ėĕĖĕĖĖ": 739, + "ĕĔĖĖĖĖ": 740, + "Đđ": 741, + "ėĖĖĔĖĖ": 742, + "Ēĝ": 743, + "ĕć": 744, + "ĕĕĖĖĖ": 745, + "ĕĔĕĕ": 746, + "ďėė": 747, + "ĚĔĕ": 748, + "ěĜ": 749, + "ğĖĖ": 750, + "ėėĕĖ": 751, + "ĖĒĖĕ": 752, + "đĔĖ": 753, + "ėĖėĖė": 754, + "ėėĖėĖĖ": 755, + "ĕĖĖĖĖĖĖĖĖ": 756, + "ĖĔĕė": 757, + "ĒėĔ": 758, + "ĔĘĖĖ": 759, + "ĕĘĕĖ": 760, + "ĕēėĖ": 761, + "ĕĕĕĖĖ": 762, + "ĖĕĖĘ": 763, + "ęĔĔ": 764, + "ĕĕĖĖĕĖĖ": 765, + "ĕĖĕėĖĖ": 766, + "ďĕė": 767, + "Ĕėĕĕ": 768, + "ěĖĘ": 769, + "đĘĖ": 770, + "ęĖē": 771, + "ĝĕė": 772, + "ĖėĖėĖĖ": 773, + "ēĕĖĖ": 774, + "ėĚĖĕ": 775, + "ĕēĕĖ": 776, + "ĔĘĖĕ": 777, + "Ę#": 778, + "ęĘĖĖ": 779, + "ėĘėĖ": 780, + "ėĕėĖĖ": 781, + "ĘĕĖĖĖĖ": 782, + "ĚĔĔ": 783, + "ĘĔĖĖ": 784, + "ēĕĖĕ": 785, + "ĖėĖĔ": 786, + "ĘĔė": 787, + "ĞĕĖ": 788, + "ēā": 789, + "Ė%": 790, + "Ėēėė": 791, + "ēĔĕ": 792, + "ėėĖĘ": 793, + "ęă": 794, + "ĕĘėĖ": 795, + "ėĕėĖ": 796, + "ēĖę": 797, + "ĖĒėĖ": 798, + "ĐĎ": 799, + "ĔĒĖĖ": 800, + "ėĘĕĖ": 801, + "Ėēĕĕ": 802, + "ĘęĖĕ": 803, + "ěĐ": 804, + "ėėėĖ": 805, + "ĔĘĘ": 806, + "ėĖĖėĖ": 807, + "ėĒĖĖ": 808, + "ĖĘĖĖėĖĖ": 809, + "ěĘĕ": 810, + "ďėĕ": 811, + "ĖĆ": 812, + "ĝėė": 813, + "ėęĖė": 814, + "đě": 815, + "ĖĖĖĖĖĖĖĕ": 816, + "ęĕĖė": 817, + "ĘĔĖė": 818, + "Ĕ\"": 819, + "ėĖĖėĖĕ": 820, + "ėĕĖĖĖ": 821, + "ęĖĖĕĖĖ": 822, + "ęĔė": 823, + "ęėĖĕ": 824, + "ďĕĕ": 825, + "ĒĕĔ": 826, + "ĎĖĕ": 827, + "ēĘė": 828, + "ēĖĖėĖĖ": 829, + "ĒĘė": 830, + "Ęć": 831, + "ĒĞ": 832, + "Ėĕėĕ": 833, + "ēĖē": 834, + "ĕĖėĕĖĖ": 835, + "ěĘĖ": 836, + "ĖĚĕĖ": 837, + "ĖĄ": 838, + "ĖĐĖĖ": 839, + "ĕĚĖĖ": 840, + "ėĘĖĖĖĖ": 841, + "ėĘėĕ": 842, + "Ğĕĕ": 843, + "ĕĒĖė": 844, + "ďĐ": 845, + "ĔēĖė": 846, + "ĕĆ": 847, + "ęėĕĕ": 848, + "ēĔė": 849, + "ĖĕĖĔ": 850, + "ĖėĖĖĖĖĖ": 851, + "ĞĖė": 852, + "ėĔĕĕ": 853, + "ęĈ": 854, + "ĕėĖėĖĖ": 855, + "ďď": 856, + "ĒĔĕ": 857, + "ėĖĖėė": 858, + "ĔĔĖ": 859, + "ĘĚĖĖ": 860, + "ĜĕĔ": 861, + "ęĖę": 862, + "ĕĄ": 863, + "ĘĖĖėĖĕ": 864, + "đĂ": 865, + "ĔėĖĖĖĖ": 866, + "ĔĕĖĕĖĖ": 867, + "ėēĖė": 868, + "ĔėĖĕĖĖ": 869, + "đĜ": 870, + "ė%": 871, + "ĞėĖ": 872, + "ĕĕĖĘ": 873, + "ĎĕĖ": 874, + "ĕ%": 875, + "ėĆ": 876, + "ĕĖĕĖĖĖ": 877, + "ēĘĕ": 878, + "ęĖĕĖĖ": 879, + "ę!": 880, + "ĖėĖĖĔĖĖ": 881, + "ĔĖėĕĖĖ": 882, + "ĐĖĔ": 883, + "Ė,": 884, + "ĖĒĕĖ": 885, + "ēĔĘ": 886, + "ĕĕĖĔ": 887, + "ĐĕĘ": 888, + "ĒĘĕ": 889, + "Ěğ": 890, + "ěĔĔ": 891, + "ĚĕĘ": 892, + "ĚėĘ": 893, + "ĖĕĔĖĖ": 894, + "ĕĕėĕ": 895, + "ĕĖĕĔĖĖ": 896, + "ĖĚėĖ": 897, + "ĒĔė": 898, + "ĖĜĖĖ": 899, + "ēĕĖė": 900, + "ėą": 901, + "ĖĘėė": 902, + "ĚĖ": 903, + "ĘĕĖėĖĖ": 904, + "ĐĘĖ": 905, + "ėęĕĖ": 906, + "ďđ": 907, + "ĝĖĔ": 908, + "ĖĕĖĖĖĕ": 909, + "ĕĔėĕ": 910, + "ĕēėė": 911, + "ĜėĔ": 912, + "ĒĔĘ": 913, + "ėěĖĖ": 914, + "Ę$": 915, + "Ėēėĕ": 916, + "ĎĐ": 917, + "ĘėĖėĖĖ": 918, + "ĒĖę": 919, + "ėĖĕĕĕ": 920, + "Ė&": 921, + "ĔĖĖēĖĖ": 922, + "ėęĕĕ": 923, + "ĖĖĖĖĘĖĖ": 924, + "Ėą": 925, + "ĕĖĕĕĕ": 926, + "Ĕć": 927, + "ĘėĖĕĖĖ": 928, + "ĘĖĖęĖĖ": 929, + "ĔĕĖėĖĖ": 930, + "ĜĔĖ": 931, + "ĘĕĕĖĖ": 932, + "Ĕ#": 933, + "ĕą": 934, + "ĘėĕĖĖ": 935, + "ęĕĖĖ": 936, + "ĜĒ": 937, + "ėĕĕĖĕ": 938, + "Ďėė": 939, + "ēĈ": 940, + "ę\"": 941, + "ĖĚĖė": 942, + "ė,": 943, + "ėĄ": 944, + "ĕĘĖ": 945, + "ďĖĘ": 946, + "đėĔ": 947, + "ėĕĕĖė": 948, + "ĖĔĖĖĔĖĖ": 949, + "Ĝě": 950, + "ĠĖĕ": 951, + "ĔĖĕĕĖĖ": 952, + "ĚĖē": 953, + "ėĕėĕ": 954, + "ĘĖĖĘĖĕ": 955, + "ĘĘĖė": 956, + "ėėĖĖėĖĖ": 957, + "ĐėĘ": 958, + "ĕ,": 959, + "ėęėĖ": 960, + "ĖĖėĖ": 961, + "Ďď": 962, + "ĔĄ": 963, + "ĘĖĕĖĖĖ": 964, + "ėēĖĕ": 965, + "ĖėĖĕĕĖĖ": 966, + "ė&": 967, + "ĖěĖĕ": 968, + "Ę%": 969, + "ĚĔė": 970, + "ēĔĖĖ": 971, + "ĘĔĖĕ": 972, + "ĖĕĖĖĘĖĖ": 973, + "ēėĖĕ": 974, + "ĕĕĖĖėĖĖ": 975, + "Ēă": 976, + "ĝĚ": 977, + "ĖėĖĘ": 978, + "ĖĕĖĖė": 979, + "ėĘĖ": 980, + "Ĕĕėė": 981, + "ēĖėĖĖ": 982, + "ĖėĖĖĖĕ": 983, + "ĕĖėĕĕ": 984, + "ėĖĖĖĖĖĕĖĖ": 985, + "ĖĔĖĖėĖĖ": 986, + "ęĕē": 987, + "ėĔĖ": 988, + "ĕĖėėĖĖ": 989, + "ęĘė": 990, + "ĕĖĖĘĖĖ": 991, + "Ėēĕė": 992, + "ęĘĖĕ": 993, + "ďĒ": 994, + "ėĖĕĘĖĖ": 995, + "ęā": 996, + "ęĖĖĖĖĖ": 997, + "ėĖĕĔĖĖ": 998, + "ĕĘėė": 999, + "Ę,": 1000, + "ěĔĕ": 1001, + "ďĎ": 1002, + "Ė'": 1003, + "ĘĆ": 1004, + "ęć": 1005, + "ĖĖĖĖĔĖĖ": 1006, + "ėĔĕĖ": 1007, + "ĕĘĕĕ": 1008, + "ĐĔĖ": 1009, + "đĕĔ": 1010, + "ĜĖĘ": 1011, + "ĕĖĕėė": 1012, + "ĕĒĖĕ": 1013, + "ďē": 1014, + "ĕĖĖėė": 1015, + "ėĖĖĖĕĖĖ": 1016, + "ĕēĕĕ": 1017, + "Ĕ$": 1018, + "ęĖĕėĖĖ": 1019, + "ė'": 1020, + "ĕĔĕė": 1021, + "ĔĖĖĖĕ": 1022, + "ĔĘĖė": 1023 + }, + "merges": [ + [ + "Ė", + "Ė" + ], + [ + "Ė", + "ĕ" + ], + [ + "Ė", + "ė" + ], + [ + "ĕ", + "ĕ" + ], + [ + "ė", + "ė" + ], + [ + "ĖĖ", + "ĖĖ" + ], + [ + "ĕ", + "Ė" + ], + [ + "ė", + "Ė" + ], + [ + "ė", + "ĕ" + ], + [ + "ė", + "ĖĖ" + ], + [ + "ĕ", + "ĖĖ" + ], + [ + "Ė", + "Ĕ" + ], + [ + "Ę", + "ĖĖ" + ], + [ + "ĕ", + "ė" + ], + [ + "Ĕ", + "ĖĖ" + ], + [ + "Ė", + "Ę" + ], + [ + "Ėĕ", + "ĖĖ" + ], + [ + "Ėė", + "ĖĖ" + ], + [ + "ĕ", + "Ĕ" + ], + [ + "ė", + "Ėĕ" + ], + [ + "ē", + "ĖĖ" + ], + [ + "ę", + "ĖĖ" + ], + [ + "ė", + "Ę" + ], + [ + "ė", + "Ĕ" + ], + [ + "ĕ", + "Ėĕ" + ], + [ + "ĕ", + "Ėė" + ], + [ + "ĕ", + "Ę" + ], + [ + "Ė", + "ē" + ], + [ + "Ė", + "ę" + ], + [ + "ė", + "Ėė" + ], + [ + "Ę", + "Ėĕ" + ], + [ + "Ĕ", + "Ĕ" + ], + [ + "Ę", + "Ė" + ], + [ + "ĕ", + "ē" + ], + [ + "Ĕ", + "Ė" + ], + [ + "Ē", + "ĖĖ" + ], + [ + "Ę", + "ĕ" + ], + [ + "Ĕ", + "Ėė" + ], + [ + "ė", + "ĕĕ" + ], + [ + "Ě", + "ĖĖ" + ], + [ + "Ĕ", + "ĕ" + ], + [ + "ė", + "ę" + ], + [ + "ė", + "ĖĖĖĖ" + ], + [ + "Ĕ", + "Ėĕ" + ], + [ + "ĕ", + "ĖĖĖĖ" + ], + [ + "ė", + "ē" + ], + [ + "ę", + "Ėĕ" + ], + [ + "Ė", + "Ē" + ], + [ + "Ę", + "Ę" + ], + [ + "ĕĕ", + "ĖĖ" + ], + [ + "Ė", + "Ě" + ], + [ + "ĖĔ", + "ĖĖ" + ], + [ + "Ę", + "Ėė" + ], + [ + "ĕ", + "ėė" + ], + [ + "ė", + "ĕĖ" + ], + [ + "Ĕ", + "ė" + ], + [ + "ĕ", + "ę" + ], + [ + "Ę", + "ė" + ], + [ + "Ė", + "đ" + ], + [ + "ē", + "Ėė" + ], + [ + "Ĕ", + "ėĖ" + ], + [ + "ėė", + "ĖĖ" + ], + [ + "Ė", + "ĘĖĖ" + ], + [ + "ĕ", + "ėĖ" + ], + [ + "ĕ", + "Ē" + ], + [ + "Ę", + "ĕĖ" + ], + [ + "Ė", + "ě" + ], + [ + "ėĕ", + "ĖĖ" + ], + [ + "Ĕ", + "ĕĖ" + ], + [ + "Ę", + "ėĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖ" + ], + [ + "Ė", + "Ă" + ], + [ + "Ě", + "Ėĕ" + ], + [ + "Ėė", + "Ėĕ" + ], + [ + "Ĕ", + "ē" + ], + [ + "Ę", + "ĕĕ" + ], + [ + "đ", + "ĖĖ" + ], + [ + "Ĕ", + "Ę" + ], + [ + "ē", + "Ėĕ" + ], + [ + "ĖĖĖĖ", + "Ė" + ], + [ + "ě", + "ĖĖ" + ], + [ + "ĕ", + "Ă" + ], + [ + "ė", + "Ě" + ], + [ + "Ĕ", + "ĕĕ" + ], + [ + "ĕ", + "ėĖĖ" + ], + [ + "ė", + "Ē" + ], + [ + "Ė", + "Đ" + ], + [ + "Ėĕ", + "Ėĕ" + ], + [ + "ę", + "Ėė" + ], + [ + "ĕ", + "đ" + ], + [ + "ĕ", + "ėĕ" + ], + [ + "Ę", + "Ĕ" + ], + [ + "ė", + "Ă" + ], + [ + "Ĕ", + "ėė" + ], + [ + "Ė", + "Ĝ" + ], + [ + "Ę", + "ėĕ" + ], + [ + "ĕ", + "Ě" + ], + [ + "ę", + "ĕĖ" + ], + [ + "Ē", + "Ėė" + ], + [ + "ē", + "ėĖ" + ], + [ + "ę", + "ėĖ" + ], + [ + "ē", + "ĕĖ" + ], + [ + "Ėĕ", + "Ėė" + ], + [ + "ę", + "ĕĕ" + ], + [ + "ĕ", + "ĔĖĖ" + ], + [ + "ė", + "ĖĕĖĖ" + ], + [ + "ė", + "ě" + ], + [ + "Ę", + "ę" + ], + [ + "ĕ", + "Đ" + ], + [ + "ė", + "đ" + ], + [ + "Ė", + "ď" + ], + [ + "Ę", + "ėė" + ], + [ + "Ę", + "ē" + ], + [ + "Ē", + "Ėĕ" + ], + [ + "Ĕ", + "ĕė" + ], + [ + "Ĕ", + "ėĕ" + ], + [ + "ĕ", + "ě" + ], + [ + "Ĕ", + "Ă" + ], + [ + "ĕ", + "ĖĕĖĖ" + ], + [ + "Ĕ", + "ę" + ], + [ + "Ėė", + "Ėė" + ], + [ + "ĕĖĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "ĖėĖĖ" + ], + [ + "Ĕ", + "Ē" + ], + [ + "Ėė", + "Ė" + ], + [ + "ē", + "ē" + ], + [ + "ėĖĖ", + "ėĖĖ" + ], + [ + "ė", + "Đ" + ], + [ + "ę", + "ėĕ" + ], + [ + "Ĝ", + "ĖĖ" + ], + [ + "Ėĕ", + "Ė" + ], + [ + "ĕ", + "ď" + ], + [ + "Ę", + "Ě" + ], + [ + "ėė", + "Ė" + ], + [ + "ě", + "Ėĕ" + ], + [ + "Ē", + "ĕĖ" + ], + [ + "ė", + "ĖėĖĖ" + ], + [ + "ĕĕ", + "Ė" + ], + [ + "Ě", + "ėĖ" + ], + [ + "Ē", + "ėĖ" + ], + [ + "Ę", + "ĕė" + ], + [ + "Ę", + "ĖĔ" + ], + [ + "ē", + "ėė" + ], + [ + "Ĕ", + "đ" + ], + [ + "ė", + "Ĝ" + ], + [ + "ė", + "ĘĖĖ" + ], + [ + "Ė", + "Ď" + ], + [ + "Ě", + "ĕĖ" + ], + [ + "Ě", + "Ėė" + ], + [ + "Ė", + "ĝ" + ], + [ + "ėė", + "Ėĕ" + ], + [ + "ē", + "ĕĕ" + ], + [ + "Ĕ", + "ĖĘ" + ], + [ + "Ę", + "Ă" + ], + [ + "Đ", + "ĖĖ" + ], + [ + "Ĕ", + "ĖĔ" + ], + [ + "ė", + "ď" + ], + [ + "Ě", + "ĕĕ" + ], + [ + "ĖĘ", + "Ėĕ" + ], + [ + "ę", + "ę" + ], + [ + "Ę", + "ĖĖĖĖ" + ], + [ + "ĕ", + "Ĝ" + ], + [ + "Ėĕ", + "ĖĖĖĖ" + ], + [ + "ėĖĖ", + "ĕĖĖ" + ], + [ + "Ė", + "ēĖĖ" + ], + [ + "ĕĕ", + "Ėė" + ], + [ + "Ē", + "ē" + ], + [ + "ĕĕ", + "Ėĕ" + ], + [ + "ĕ", + "Ď" + ], + [ + "Ě", + "ėĕ" + ], + [ + "Ę", + "ĖĘ" + ], + [ + "đ", + "Ėė" + ], + [ + "ē", + "ĕė" + ], + [ + "Ę", + "Ē" + ], + [ + "ē", + "Ĕ" + ], + [ + "Ę", + "ě" + ], + [ + "ė", + "ĝ" + ], + [ + "Ĕ", + "ĕĖĖ" + ], + [ + "ę", + "ėė" + ], + [ + "Ĕ", + "Đ" + ], + [ + "ĕė", + "Ėĕ" + ], + [ + "ē", + "Ē" + ], + [ + "Ė", + "Ğ" + ], + [ + "Ė", + "ęĖĖ" + ], + [ + "Ę", + "đ" + ], + [ + "ė", + "Ď" + ], + [ + "ē", + "ėĕ" + ], + [ + "Ĕ", + "ĖĖĖĖ" + ], + [ + "Ėė", + "ĖĖĖĖ" + ], + [ + "Ē", + "ėė" + ], + [ + "ę", + "Ę" + ], + [ + "đ", + "Ėĕ" + ], + [ + "ėĕ", + "Ėĕ" + ], + [ + "ĖĔ", + "Ėĕ" + ], + [ + "ĖĔ", + "Ėė" + ], + [ + "ĕ", + "ĝ" + ], + [ + "Ĕ", + "Ě" + ], + [ + "ėĕ", + "Ėė" + ], + [ + "ĖĕĖĖ", + "Ė" + ], + [ + "ĕ", + "ĘĖĖ" + ], + [ + "Ĕ", + "ď" + ], + [ + "ě", + "ėĖ" + ], + [ + "ě", + "ĕĖ" + ], + [ + "Ē", + "ĕĕ" + ], + [ + "ĕ", + "ēĖĖ" + ], + [ + "ē", + "ę" + ], + [ + "ėė", + "Ėė" + ], + [ + "ĔĖĖ", + "ĕĖĖ" + ], + [ + "Ę", + "Đ" + ], + [ + "Ě", + "ę" + ], + [ + "Ĝ", + "Ėĕ" + ], + [ + "đ", + "ėĖ" + ], + [ + "Ę", + "ėĖĖ" + ], + [ + "ę", + "ē" + ], + [ + "ě", + "Ėė" + ], + [ + "đ", + "ĕĖ" + ], + [ + "ė", + "Ğ" + ], + [ + "ē", + "Ă" + ], + [ + "ĕė", + "Ėė" + ], + [ + "ĝ", + "ĖĖ" + ], + [ + "Ė", + "ğ" + ], + [ + "Ē", + "ĕė" + ], + [ + "Ėĕ", + "ĕĖĖ" + ], + [ + "ĕĖ", + "Ę" + ], + [ + "ėĖ", + "Ĕ" + ], + [ + "Ę", + "ď" + ], + [ + "ĘĖĖ", + "ėĖĖ" + ], + [ + "ĕ", + "Ğ" + ], + [ + "ę", + "ĕė" + ], + [ + "ē", + "đ" + ], + [ + "ė", + "ĔĖĖ" + ], + [ + "ę", + "ĖĔ" + ], + [ + "Ě", + "ėė" + ], + [ + "ě", + "ĕĕ" + ], + [ + "ė", + "ęĖĖ" + ], + [ + "Ĕ", + "Ď" + ], + [ + "ėĕĕ", + "ĖĖ" + ], + [ + "Ē", + "Ē" + ], + [ + "ĖĕĖĖ", + "ĕĖĖ" + ], + [ + "Ĕ", + "ě" + ], + [ + "Ę", + "Ĝ" + ], + [ + "Ę", + "ĕĖĖ" + ], + [ + "ę", + "Ě" + ], + [ + "ĕĔ", + "Ėė" + ], + [ + "ėĖĖ", + "Ėĕ" + ], + [ + "ę", + "Ē" + ], + [ + "ĕĖĖĖĖ", + "Ė" + ], + [ + "ę", + "Ĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ę", + "đ" + ], + [ + "ėĘ", + "Ėĕ" + ], + [ + "Ē", + "ėĕ" + ], + [ + "ē", + "Đ" + ], + [ + "ė", + "ğ" + ], + [ + "Ę", + "ĕĔ" + ], + [ + "ėĖĖĖĖ", + "Ė" + ], + [ + "Ę", + "Ď" + ], + [ + "ē", + "ĖĘ" + ], + [ + "Ĕ", + "ĕĖĕ" + ], + [ + "ě", + "ėĕ" + ], + [ + "ĖĔ", + "ėĖ" + ], + [ + "Đ", + "Ėė" + ], + [ + "ę", + "Ė" + ], + [ + "Ėė", + "ėĖ" + ], + [ + "Ėę", + "Ėĕ" + ], + [ + "ď", + "ĖĖ" + ], + [ + "Ĕ", + "ĕĘ" + ], + [ + "ĕĖĖ", + "ėĖĖ" + ], + [ + "ē", + "Ė" + ], + [ + "đ", + "ėė" + ], + [ + "ĖĔ", + "ĕĖ" + ], + [ + "Ė", + "ā" + ], + [ + "Ĕ", + "Ĝ" + ], + [ + "Ėė", + "ĕĖ" + ], + [ + "Ę", + "ėĖĕ" + ], + [ + "Ę", + "ėĔ" + ], + [ + "Ę", + "ėĖė" + ], + [ + "ĕĖĖ", + "Ėĕ" + ], + [ + "ę", + "ě" + ], + [ + "ĕ", + "ğ" + ], + [ + "ĖėĖĖ", + "ėĖĖ" + ], + [ + "ę", + "Ă" + ], + [ + "đ", + "ē" + ], + [ + "ĕĔ", + "Ėĕ" + ], + [ + "Ę", + "ĝ" + ], + [ + "Ě", + "Ě" + ], + [ + "ē", + "ď" + ], + [ + "ĖĖ", + "Ėĕ" + ], + [ + "ĖĘ", + "Ėė" + ], + [ + "Ę", + "Ėē" + ], + [ + "ĖĘ", + "ėĖ" + ], + [ + "ĕĕ", + "ĖĖĖĖ" + ], + [ + "Ě", + "ē" + ], + [ + "ę", + "Đ" + ], + [ + "Ĕ", + "ĕĖė" + ], + [ + "Ě", + "ĕė" + ], + [ + "ėė", + "ĖĖĖĖ" + ], + [ + "ĕ", + "ā" + ], + [ + "ĖĘ", + "ĕĖ" + ], + [ + "Ē", + "ę" + ], + [ + "Ĕ", + "ėĖĖ" + ], + [ + "Ĕ", + "ėĘ" + ], + [ + "Ě", + "ĖĔ" + ], + [ + "Ėė", + "ĕĕ" + ], + [ + "Ĕ", + "Ėē" + ], + [ + "ĔĖĖ", + "ėĖĖ" + ], + [ + "Ė", + "ă" + ], + [ + "Ėĕ", + "ĕĖ" + ], + [ + "ĘĖĖ", + "ĕĖĖ" + ], + [ + "Ĕ", + "ĕĔ" + ], + [ + "ę", + "ď" + ], + [ + "Ę", + "ĕĖė" + ], + [ + "Ĕ", + "ėĖĕ" + ], + [ + "Ē", + "đ" + ], + [ + "Ėĕ", + "ėĖ" + ], + [ + "Ē", + "Đ" + ], + [ + "Ę", + "ĕĘ" + ], + [ + "ĝ", + "Ėĕ" + ], + [ + "đ", + "ĕĕ" + ], + [ + "Ĕ", + "Ėę" + ], + [ + "Đ", + "ėĖ" + ], + [ + "Ĝ", + "ĕĖ" + ], + [ + "ĔĖĖ", + "ĔĖĖ" + ], + [ + "ĖĖĖĖ", + "ėĖĖ" + ], + [ + "Ę", + "Ėę" + ], + [ + "Đ", + "Ėĕ" + ], + [ + "Ĝ", + "ėĖ" + ], + [ + "ĖĖĖĖ", + "ĕĖĖ" + ], + [ + "Ę", + "Ğ" + ], + [ + "ėĖĖĖĖ", + "ĖĖ" + ], + [ + "ĖĕĖĖ", + "ėĖĖ" + ], + [ + "ě", + "ę" + ], + [ + "ėė", + "ĕ" + ], + [ + "Ę", + "ĕĖĕ" + ], + [ + "ĖėĖĖ", + "ĕĖĖ" + ], + [ + "Ē", + "ĖĘ" + ], + [ + "ē", + "ĖĔ" + ], + [ + "Ĕ", + "ĝ" + ], + [ + "Ėĕ", + "ĕĕ" + ], + [ + "Ĝ", + "Ėė" + ], + [ + "ĘĖĖ", + "ĘĖĖ" + ], + [ + "ē", + "Ě" + ], + [ + "ė", + "ă" + ], + [ + "đ", + "ĕė" + ], + [ + "Đ", + "ĕĖ" + ], + [ + "Ė", + "ĒĖĖ" + ], + [ + "Ĕ", + "ėĔ" + ], + [ + "Ě", + "Ē" + ], + [ + "ĕ", + "ă" + ], + [ + "Ĝ", + "ĕĕ" + ], + [ + "ě", + "ėė" + ], + [ + "Ĕ", + "ĔĖĖ" + ], + [ + "ėĕ", + "ė" + ], + [ + "ĕĖĖ", + "Ėė" + ], + [ + "ĕĘ", + "Ėĕ" + ], + [ + "đ", + "Ē" + ], + [ + "ē", + "Ď" + ], + [ + "Ę", + "ėĘ" + ], + [ + "đ", + "ėĕ" + ], + [ + "ėė", + "ė" + ], + [ + "Ė", + "!" + ], + [ + "Ėē", + "Ėė" + ], + [ + "ĕĖĖĖĖ", + "ĖĖ" + ], + [ + "Ē", + "ď" + ], + [ + "ĕ", + "ĒĖĖ" + ], + [ + "ę", + "ĖĘ" + ], + [ + "Ę", + "ĖĕĖĖ" + ], + [ + "ėė", + "ĕĕ" + ], + [ + "Ĕ", + "ėĖė" + ], + [ + "ę", + "Ď" + ], + [ + "Ėĕ", + "ėĖĖ" + ], + [ + "Ĕ", + "Ğ" + ], + [ + "Đ", + "ē" + ], + [ + "ě", + "ĖĔ" + ], + [ + "ĖĘ", + "ĕĕ" + ], + [ + "ė", + "ā" + ], + [ + "ė", + "!" + ], + [ + "Ė", + "ĚĖĖ" + ], + [ + "Ę", + "ĘĖĖ" + ], + [ + "ę", + "ĕĔ" + ], + [ + "đ", + "đ" + ], + [ + "Ėė", + "ėĖĖ" + ], + [ + "Ė", + "\"" + ], + [ + "ėĖ", + "ĔĖĖ" + ], + [ + "Ě", + "ě" + ], + [ + "ę", + "Ĝ" + ], + [ + "Ę", + "ğ" + ], + [ + "ĕėė", + "ĖĖ" + ], + [ + "ĕĕ", + "ĕĕ" + ], + [ + "ę", + "ėĔ" + ], + [ + "Ę", + "Ęĕ" + ], + [ + "Ę", + "ĔĖ" + ], + [ + "ĕ", + "!" + ], + [ + "ėĕ", + "ĖĖĖĖ" + ], + [ + "Ĕ", + "ĖėĖĖ" + ], + [ + "Ğ", + "ĖĖ" + ], + [ + "ē", + "ě" + ], + [ + "Ĝ", + "ėĕ" + ], + [ + "ėĖĖ", + "Ėė" + ], + [ + "Ē", + "Ě" + ], + [ + "ĕĕ", + "ė" + ], + [ + "ėĖ", + "ē" + ], + [ + "Ē", + "ĖĔ" + ], + [ + "ě", + "ĕė" + ], + [ + "Ę", + "ĘĖ" + ], + [ + "ĕĖ", + "ĘĖĖ" + ], + [ + "Ĕ", + "ĖĖĖĖĖ" + ], + [ + "ĕ", + "ęĖĖ" + ], + [ + "ėĖ", + "ĘĖĖ" + ], + [ + "Đ", + "ėė" + ], + [ + "ėĖĕ", + "ĕĖĖ" + ], + [ + "Ĕ", + "ĘĖ" + ], + [ + "Ě", + "Đ" + ], + [ + "ėĖĕ", + "ėĖĖ" + ], + [ + "Ď", + "ĖĖ" + ], + [ + "Ėē", + "Ėĕ" + ], + [ + "ĕĖ", + "ĔĖĖ" + ], + [ + "Ě", + "đ" + ], + [ + "Ĕ", + "ă" + ], + [ + "ę", + "ĝ" + ], + [ + "Ē", + "Ă" + ], + [ + "ę", + "ĕ" + ], + [ + "Ě", + "ĕĔ" + ], + [ + "ę", + "ė" + ], + [ + "Ėē", + "ėĖ" + ], + [ + "Ėė", + "ĕĖĖ" + ], + [ + "ĕė", + "ĖĖĖĖ" + ], + [ + "ĕė", + "ĕĕ" + ], + [ + "ĕĖ", + "ē" + ], + [ + "Ę", + "ĘĖĕ" + ], + [ + "ėę", + "Ėĕ" + ], + [ + "ď", + "Ėĕ" + ], + [ + "Đ", + "ĕĕ" + ], + [ + "Ėė", + "ėė" + ], + [ + "Ę", + "ă" + ], + [ + "Ę", + "ĕē" + ], + [ + "ė", + "\"" + ], + [ + "Ę", + "Ĕĕ" + ], + [ + "ĖĔ", + "ĕĕ" + ], + [ + "ĖĖĖĖ", + "Ėĕ" + ], + [ + "Ē", + "Ď" + ], + [ + "Ėė", + "ėĕ" + ], + [ + "Ě", + "ď" + ], + [ + "ĕĖĖ", + "ĔĖĖ" + ], + [ + "ē", + "ĕĘ" + ], + [ + "ęĖĖ", + "ėĖĖ" + ], + [ + "Ėė", + "ė" + ], + [ + "ē", + "ė" + ], + [ + "ě", + "Ě" + ], + [ + "ėĖė", + "Ėĕ" + ], + [ + "ę", + "Ęĕ" + ], + [ + "Ėē", + "ĕĖ" + ], + [ + "ę", + "ĔĖ" + ], + [ + "Ė", + "#" + ], + [ + "ĕĖė", + "Ėĕ" + ], + [ + "ĕĖ", + "Ĕ" + ], + [ + "ĖĘ", + "ėĕ" + ], + [ + "Ę", + "ĖĖĖĖĖ" + ], + [ + "ď", + "Ėė" + ], + [ + "ĕĕ", + "ĕ" + ], + [ + "đ", + "Đ" + ], + [ + "Ėė", + "ĕ" + ], + [ + "Ėę", + "ėĖ" + ], + [ + "ĖĔ", + "ėė" + ], + [ + "ē", + "ėĘ" + ], + [ + "ē", + "Ĝ" + ], + [ + "ēĖĖ", + "ĕĖĖ" + ], + [ + "Ē", + "ě" + ], + [ + "Ě", + "ėĔ" + ], + [ + "Ėę", + "ĕĖ" + ], + [ + "Ė", + "Ĉ" + ], + [ + "Ėđ", + "ĖĖ" + ], + [ + "Ĕ", + "ā" + ], + [ + "ē", + "ĕ" + ], + [ + "ę", + "ĘĖ" + ], + [ + "ĕē", + "Ėė" + ], + [ + "ĕĖĖ", + "Ė" + ], + [ + "ė", + "ĚĖĖ" + ], + [ + "ĕ", + "\"" + ], + [ + "ėĖ", + "Ę" + ], + [ + "ėĖĕ", + "Ėĕ" + ], + [ + "ę", + "Ğ" + ], + [ + "ę", + "ĖĖĖĖ" + ], + [ + "Ĕ", + "ğ" + ], + [ + "đ", + "ĖĘ" + ], + [ + "Ėě", + "ĖĖ" + ], + [ + "ĝ", + "ėĖ" + ], + [ + "Ě", + "ĖĘ" + ], + [ + "ĖĔ", + "ĖĖĖĖ" + ], + [ + "ėĔ", + "Ėĕ" + ], + [ + "ėĖĖ", + "ĘĖĖ" + ], + [ + "ĕĕ", + "ėė" + ], + [ + "ĝ", + "ĕĖ" + ], + [ + "ėĕ", + "ėė" + ], + [ + "ĕĖĕ", + "Ėĕ" + ], + [ + "đ", + "Ě" + ], + [ + "Ėę", + "Ėė" + ], + [ + "ė", + "ēĖĖ" + ], + [ + "Ě", + "Ĝ" + ], + [ + "Ē", + "ĕĘ" + ], + [ + "ě", + "ē" + ], + [ + "Ėė", + "ĖĕĖĖ" + ], + [ + "ĝ", + "Ėė" + ], + [ + "Ėĕ", + "ėė" + ], + [ + "ĕĖė", + "ĕĖ" + ], + [ + "Đ", + "ėĕ" + ], + [ + "Ę", + "ėĕĕ" + ], + [ + "ē", + "ĘĖ" + ], + [ + "ėĖĖ", + "ĕĕ" + ], + [ + "Đ", + "Ē" + ], + [ + "ĕĖĕ", + "ĕĖĖ" + ], + [ + "Ĝ", + "ėė" + ], + [ + "Ě", + "Ď" + ], + [ + "ď", + "ĕĖ" + ], + [ + "ĕ", + "ėĕĖĖ" + ], + [ + "Ę", + "ęĖĖ" + ], + [ + "ď", + "ėĖ" + ], + [ + "Ě", + "ĘĖ" + ], + [ + "ę", + "Ĕĕ" + ], + [ + "Đ", + "ĕė" + ], + [ + "Ě", + "ĔĖ" + ], + [ + "ĕĖĕ", + "ĕĖ" + ], + [ + "ĖĔ", + "ėĕ" + ], + [ + "Ĕ", + "ĕē" + ], + [ + "ě", + "Ē" + ], + [ + "ĖĖ", + "Ėė" + ], + [ + "ėĘ", + "Ėė" + ], + [ + "ē", + "ĝ" + ], + [ + "ē", + "ĖĖĖĖ" + ], + [ + "ėĖė", + "ėĖ" + ], + [ + "Ē", + "ĔĖ" + ], + [ + "ė", + "#" + ], + [ + "Ĕ", + "ēĖĖ" + ], + [ + "ĕ", + "Ĉ" + ], + [ + "Ē", + "ĘĖ" + ], + [ + "ĘĖĕ", + "ėĖĖ" + ], + [ + "Ę", + "ėĖĖĖĖ" + ], + [ + "ėĖĕ", + "Ėė" + ], + [ + "ĕė", + "ĕĖ" + ], + [ + "ĖėĖĖ", + "ĘĖĖ" + ], + [ + "ĕ", + "#" + ], + [ + "Ĕ", + "ĖĕĖĖ" + ], + [ + "Ĝ", + "ĕė" + ], + [ + "đ", + "ę" + ], + [ + "ĕĖĕ", + "Ėė" + ], + [ + "ĕĔ", + "ėĖ" + ], + [ + "ĔĔ", + "Ėĕ" + ], + [ + "đ", + "ď" + ], + [ + "Ğ", + "Ėĕ" + ], + [ + "ĖĘ", + "ĖĖĖĖ" + ], + [ + "ĕĖĖ", + "ĕĕ" + ], + [ + "ėĘ", + "ĕĕ" + ], + [ + "ėĔ", + "Ėė" + ], + [ + "Ę", + "ā" + ], + [ + "ĖĕĖĖ", + "ĔĖĖ" + ], + [ + "Ę", + "!" + ], + [ + "Ě", + "Ă" + ], + [ + "ĕĖĕ", + "ėĖ" + ], + [ + "ĕē", + "Ėĕ" + ], + [ + "Ě", + "ĝ" + ], + [ + "Ĝ", + "ĖĔ" + ], + [ + "ę", + "ėĖĖ" + ], + [ + "Ě", + "Ęĕ" + ], + [ + "ĖĚ", + "Ėĕ" + ], + [ + "ė", + "Ĉ" + ], + [ + "ē", + "Ę" + ], + [ + "ėĖĕ", + "ĕĖ" + ], + [ + "ėĖĕ", + "ėĖ" + ], + [ + "ē", + "ĔĖ" + ], + [ + "ęĖĖ", + "ĘĖĖ" + ], + [ + "Đ", + "Đ" + ], + [ + "Ē", + "ėĘ" + ], + [ + "Ė", + "$" + ], + [ + "ėė", + "ėė" + ], + [ + "ĕĔ", + "ĕĖ" + ], + [ + "Ē", + "Ĝ" + ], + [ + "ē", + "ă" + ], + [ + "ě", + "ě" + ], + [ + "ėĖĕĖĖ", + "Ė" + ], + [ + "Ĕĕ", + "ĖĖĖĖ" + ], + [ + "ĝ", + "ĕĕ" + ], + [ + "đ", + "Ď" + ], + [ + "ę", + "ėĖė" + ], + [ + "ĕĘ", + "Ėė" + ], + [ + "ě", + "ėĔ" + ], + [ + "ėĕ", + "ĖėĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖ", + "Ė" + ], + [ + "Ę", + "ĔĔ" + ], + [ + "Ġ", + "ĖĖ" + ], + [ + "ę", + "ğ" + ], + [ + "ĕĕ", + "ėĖ" + ], + [ + "ę", + "ĕĘ" + ], + [ + "đ", + "ĖĔ" + ], + [ + "ĕĖė", + "ėĖ" + ], + [ + "ĕĖė", + "Ėė" + ], + [ + "Ĕ", + "!" + ], + [ + "Ė", + "ć" + ], + [ + "ě", + "ĕĔ" + ], + [ + "ē", + "ĕĔ" + ], + [ + "ĝ", + "ėĕ" + ], + [ + "ĔĔ", + "Ėė" + ], + [ + "Ĕ", + "Ęĕ" + ], + [ + "ē", + "ėĔ" + ], + [ + "Ėĕ", + "ĕ" + ], + [ + "ėė", + "ĖĕĖĖ" + ], + [ + "Ĕ", + "ėē" + ], + [ + "Ę", + "ėē" + ], + [ + "ĕĖĖ", + "ėĖ" + ], + [ + "ēĖĖ", + "ĔĖĖ" + ], + [ + "ę", + "ėĘ" + ], + [ + "ĕĖĖ", + "ĕĖ" + ], + [ + "ē", + "Ğ" + ], + [ + "ėĖĖ", + "ĕĖ" + ], + [ + "Ĕ", + "ėę" + ], + [ + "Ĝ", + "ę" + ], + [ + "ĖĖĖĖ", + "Ėė" + ], + [ + "Ėĕ", + "ĖĕĖĖ" + ], + [ + "ĕĕ", + "ėĖĖ" + ], + [ + "Ĕ", + "Ĉ" + ], + [ + "Ėę", + "ėĕ" + ], + [ + "Ėĕ", + "ė" + ], + [ + "Ę", + "ĖėĖĖ" + ], + [ + "ě", + "đ" + ], + [ + "ĖĒ", + "Ėė" + ], + [ + "Ę", + "ėę" + ], + [ + "Ĝ", + "Ě" + ], + [ + "Ėę", + "ĕĕ" + ], + [ + "ėĖė", + "ĕĖ" + ], + [ + "Ď", + "Ėė" + ], + [ + "Ĕ", + "Ęė" + ], + [ + "ė", + "ć" + ], + [ + "ě", + "ď" + ], + [ + "Đ", + "ď" + ], + [ + "ĕ", + "ęĖĕ" + ], + [ + "ĕĔ", + "ėė" + ], + [ + "ĘĖĖ", + "ĔĖĖ" + ], + [ + "ĕ", + "đĖĖ" + ], + [ + "ĕĕ", + "ĖĕĖĖ" + ], + [ + "ėė", + "ĖĔ" + ], + [ + "ē", + "ğ" + ], + [ + "đ", + "ĕĘ" + ], + [ + "Ėĕ", + "ĖėĖĖ" + ], + [ + "ĕĕ", + "ĕĖ" + ], + [ + "ĕ", + "$" + ], + [ + "ĘĖĕ", + "ĕĖĖ" + ], + [ + "đ", + "ėĘ" + ], + [ + "ĕĕ", + "ĖėĖĖ" + ], + [ + "ėė", + "ėĕ" + ], + [ + "ě", + "Ď" + ], + [ + "ě", + "ĔĖ" + ], + [ + "ĕė", + "ĖĕĖĖ" + ], + [ + "Đ", + "ĖĘ" + ], + [ + "Ě", + "Ğ" + ], + [ + "ėė", + "ĕĖĖ" + ], + [ + "Ę", + "Ĉ" + ], + [ + "ė", + "$" + ], + [ + "Ď", + "ėĖ" + ], + [ + "Ę", + "\"" + ], + [ + "ĖĔĖĖ", + "ĕĖĖ" + ], + [ + "Ē", + "Ė" + ], + [ + "ėĕ", + "ĖĕĖĖ" + ], + [ + "ĕĔ", + "ĖĖĖĖ" + ], + [ + "Đ", + "đ" + ], + [ + "ėĖĖ", + "ĔĖĖ" + ], + [ + "Ē", + "ĝ" + ], + [ + "ĕ", + "ć" + ], + [ + "ĕĕĖĖ", + "Ė" + ], + [ + "ĕĔ", + "ĕĕ" + ], + [ + "ď", + "ėė" + ], + [ + "Ě", + "Ĕĕ" + ], + [ + "ě", + "Ĝ" + ], + [ + "ğ", + "ĖĖ" + ], + [ + "ėė", + "ĕĖ" + ], + [ + "ĖĒ", + "Ėĕ" + ], + [ + "đ", + "ĔĖ" + ], + [ + "ėĖė", + "Ėė" + ], + [ + "ėė", + "ĖėĖĖ" + ], + [ + "ĕĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ĖĔ", + "ĕė" + ], + [ + "Ē", + "ėĔ" + ], + [ + "Ĕ", + "ĘĖĖ" + ], + [ + "ĕĘ", + "ĕĖ" + ], + [ + "ĕē", + "ėĖ" + ], + [ + "ĕĕ", + "ĕĖĖ" + ], + [ + "Ėĕ", + "ĖĘ" + ], + [ + "ę", + "ĔĔ" + ], + [ + "ĕĕĖĖ", + "ĕĖĖ" + ], + [ + "ĕĖĕ", + "ėĖĖ" + ], + [ + "ď", + "ĕė" + ], + [ + "Ĕ", + "ėĕĕ" + ], + [ + "ě", + "ĖĘ" + ], + [ + "đ", + "ĘĖ" + ], + [ + "ę", + "Ėē" + ], + [ + "ĝ", + "ĕė" + ], + [ + "Ėė", + "ĖėĖĖ" + ], + [ + "ē", + "ĕĖĖ" + ], + [ + "ė", + "ĚĖĕ" + ], + [ + "ĕē", + "ĕĖ" + ], + [ + "Ĕ", + "ĘĖĕ" + ], + [ + "Ę", + "#" + ], + [ + "ę", + "ĘĖĖ" + ], + [ + "ėĘ", + "ėĖ" + ], + [ + "ėĕ", + "ėĖĖ" + ], + [ + "Ęĕ", + "ĖĖĖĖ" + ], + [ + "Ě", + "ĔĔ" + ], + [ + "Ę", + "ĔĖĖ" + ], + [ + "ē", + "ĕĖĕ" + ], + [ + "Ėė", + "ĖĔ" + ], + [ + "Ę", + "Ĕė" + ], + [ + "Ğ", + "ĕĖ" + ], + [ + "ē", + "ā" + ], + [ + "Ė", + "%" + ], + [ + "Ėē", + "ėė" + ], + [ + "ē", + "Ĕĕ" + ], + [ + "ėė", + "ĖĘ" + ], + [ + "ę", + "ă" + ], + [ + "ĕĘ", + "ėĖ" + ], + [ + "ėĕ", + "ėĖ" + ], + [ + "ē", + "Ėę" + ], + [ + "ĖĒ", + "ėĖ" + ], + [ + "Đ", + "Ď" + ], + [ + "Ĕ", + "ĒĖĖ" + ], + [ + "ėĘ", + "ĕĖ" + ], + [ + "Ėē", + "ĕĕ" + ], + [ + "Ę", + "ęĖĕ" + ], + [ + "ě", + "Đ" + ], + [ + "ėė", + "ėĖ" + ], + [ + "Ĕ", + "ĘĘ" + ], + [ + "ėĖĖ", + "ėĖ" + ], + [ + "ė", + "ĒĖĖ" + ], + [ + "ĖĘĖĖ", + "ėĖĖ" + ], + [ + "ě", + "Ęĕ" + ], + [ + "ď", + "ėĕ" + ], + [ + "Ė", + "Ć" + ], + [ + "ĝ", + "ėė" + ], + [ + "ėę", + "Ėė" + ], + [ + "đ", + "ě" + ], + [ + "ĖĖĖĖĖĖ", + "Ėĕ" + ], + [ + "ę", + "ĕĖė" + ], + [ + "Ę", + "ĔĖė" + ], + [ + "Ĕ", + "\"" + ], + [ + "ėĖĖ", + "ėĖĕ" + ], + [ + "ėĕĖĖ", + "Ė" + ], + [ + "ęĖĖ", + "ĕĖĖ" + ], + [ + "ę", + "Ĕė" + ], + [ + "ę", + "ėĖĕ" + ], + [ + "ď", + "ĕĕ" + ], + [ + "Ē", + "ĕĔ" + ], + [ + "Ď", + "Ėĕ" + ], + [ + "ē", + "Ęė" + ], + [ + "ēĖĖ", + "ėĖĖ" + ], + [ + "Ē", + "Ęė" + ], + [ + "Ę", + "ć" + ], + [ + "Ē", + "Ğ" + ], + [ + "Ėĕ", + "ėĕ" + ], + [ + "ē", + "Ėē" + ], + [ + "ĕĖė", + "ĕĖĖ" + ], + [ + "ě", + "ĘĖ" + ], + [ + "ĖĚ", + "ĕĖ" + ], + [ + "Ė", + "Ą" + ], + [ + "ĖĐ", + "ĖĖ" + ], + [ + "ĕ", + "ĚĖĖ" + ], + [ + "ėĘ", + "ĖĖĖĖ" + ], + [ + "ėĘ", + "ėĕ" + ], + [ + "Ğ", + "ĕĕ" + ], + [ + "ĕĒ", + "Ėė" + ], + [ + "ď", + "Đ" + ], + [ + "Ĕ", + "ēĖė" + ], + [ + "ĕ", + "Ć" + ], + [ + "ę", + "ėĕĕ" + ], + [ + "ē", + "Ĕė" + ], + [ + "Ėĕ", + "ĖĔ" + ], + [ + "Ėė", + "ĖĖĖĖĖ" + ], + [ + "Ğ", + "Ėė" + ], + [ + "ėĔ", + "ĕĕ" + ], + [ + "ę", + "Ĉ" + ], + [ + "ĕė", + "ĖėĖĖ" + ], + [ + "ď", + "ď" + ], + [ + "Ē", + "Ĕĕ" + ], + [ + "ėĖĖ", + "ėė" + ], + [ + "ĔĔ", + "Ė" + ], + [ + "Ę", + "ĚĖĖ" + ], + [ + "Ĝ", + "ĕĔ" + ], + [ + "ę", + "Ėę" + ], + [ + "ĕ", + "Ą" + ], + [ + "ĘĖĖ", + "ėĖĕ" + ], + [ + "đ", + "Ă" + ], + [ + "Ĕ", + "ėĖĖĖĖ" + ], + [ + "Ĕĕ", + "ĖĕĖĖ" + ], + [ + "ėē", + "Ėė" + ], + [ + "Ĕė", + "ĖĕĖĖ" + ], + [ + "đ", + "Ĝ" + ], + [ + "ė", + "%" + ], + [ + "Ğ", + "ėĖ" + ], + [ + "ĕĕ", + "ĖĘ" + ], + [ + "Ď", + "ĕĖ" + ], + [ + "ĕ", + "%" + ], + [ + "ė", + "Ć" + ], + [ + "ĕĖĕĖĖ", + "Ė" + ], + [ + "ē", + "Ęĕ" + ], + [ + "ę", + "ĖĕĖĖ" + ], + [ + "ę", + "!" + ], + [ + "ĖėĖĖ", + "ĔĖĖ" + ], + [ + "ĔĖė", + "ĕĖĖ" + ], + [ + "Đ", + "ĖĔ" + ], + [ + "Ė", + "," + ], + [ + "ĖĒ", + "ĕĖ" + ], + [ + "ē", + "ĔĘ" + ], + [ + "ĕĕ", + "ĖĔ" + ], + [ + "Đ", + "ĕĘ" + ], + [ + "Ē", + "Ęĕ" + ], + [ + "Ě", + "ğ" + ], + [ + "ě", + "ĔĔ" + ], + [ + "Ě", + "ĕĘ" + ], + [ + "Ě", + "ėĘ" + ], + [ + "Ėĕ", + "ĔĖĖ" + ], + [ + "ĕĕ", + "ėĕ" + ], + [ + "ĕĖĕ", + "ĔĖĖ" + ], + [ + "ĖĚ", + "ėĖ" + ], + [ + "Ē", + "Ĕė" + ], + [ + "ĖĜ", + "ĖĖ" + ], + [ + "ē", + "ĕĖė" + ], + [ + "ė", + "ą" + ], + [ + "ĖĘ", + "ėė" + ], + [ + "Ě", + "Ė" + ], + [ + "Ęĕ", + "ĖėĖĖ" + ], + [ + "Đ", + "ĘĖ" + ], + [ + "ėę", + "ĕĖ" + ], + [ + "ď", + "đ" + ], + [ + "ĝ", + "ĖĔ" + ], + [ + "ĖĕĖĖ", + "Ėĕ" + ], + [ + "ĕĔ", + "ėĕ" + ], + [ + "ĕē", + "ėė" + ], + [ + "Ĝ", + "ėĔ" + ], + [ + "Ē", + "ĔĘ" + ], + [ + "ė", + "ěĖĖ" + ], + [ + "Ę", + "$" + ], + [ + "Ėē", + "ėĕ" + ], + [ + "Ď", + "Đ" + ], + [ + "Ęė", + "ĖėĖĖ" + ], + [ + "Ē", + "Ėę" + ], + [ + "ėĖĕ", + "ĕĕ" + ], + [ + "Ė", + "&" + ], + [ + "ĔĖĖ", + "ēĖĖ" + ], + [ + "ėę", + "ĕĕ" + ], + [ + "ĖĖĖĖ", + "ĘĖĖ" + ], + [ + "Ė", + "ą" + ], + [ + "ĕĖĕ", + "ĕĕ" + ], + [ + "Ĕ", + "ć" + ], + [ + "Ęė", + "ĖĕĖĖ" + ], + [ + "ĘĖĖ", + "ęĖĖ" + ], + [ + "Ĕĕ", + "ĖėĖĖ" + ], + [ + "Ĝ", + "ĔĖ" + ], + [ + "Ę", + "ĕĕĖĖ" + ], + [ + "Ĕ", + "#" + ], + [ + "ĕ", + "ą" + ], + [ + "Ę", + "ėĕĖĖ" + ], + [ + "ę", + "ĕĖĖ" + ], + [ + "Ĝ", + "Ē" + ], + [ + "ėĕĕ", + "Ėĕ" + ], + [ + "Ď", + "ėė" + ], + [ + "ē", + "Ĉ" + ], + [ + "ę", + "\"" + ], + [ + "ĖĚ", + "Ėė" + ], + [ + "ė", + "," + ], + [ + "ė", + "Ą" + ], + [ + "ĕĘ", + "Ė" + ], + [ + "ď", + "ĖĘ" + ], + [ + "đ", + "ėĔ" + ], + [ + "ėĕĕ", + "Ėė" + ], + [ + "ĖĔĖĖ", + "ĔĖĖ" + ], + [ + "Ĝ", + "ě" + ], + [ + "Ġ", + "Ėĕ" + ], + [ + "ĔĖĕ", + "ĕĖĖ" + ], + [ + "Ě", + "Ėē" + ], + [ + "ėĕ", + "ėĕ" + ], + [ + "ĘĖĖ", + "ĘĖĕ" + ], + [ + "ĘĘ", + "Ėė" + ], + [ + "ėėĖĖ", + "ėĖĖ" + ], + [ + "Đ", + "ėĘ" + ], + [ + "ĕ", + "," + ], + [ + "ėę", + "ėĖ" + ], + [ + "ĖĖ", + "ėĖ" + ], + [ + "Ď", + "ď" + ], + [ + "Ĕ", + "Ą" + ], + [ + "Ę", + "ĖĕĖĖĖ" + ], + [ + "ėē", + "Ėĕ" + ], + [ + "ĖėĖĕ", + "ĕĖĖ" + ], + [ + "ė", + "&" + ], + [ + "Ėě", + "Ėĕ" + ], + [ + "Ę", + "%" + ], + [ + "Ě", + "Ĕė" + ], + [ + "ē", + "ĔĖĖ" + ], + [ + "Ę", + "ĔĖĕ" + ], + [ + "ĖĕĖĖ", + "ĘĖĖ" + ], + [ + "ē", + "ėĖĕ" + ], + [ + "ĕĕĖĖ", + "ėĖĖ" + ], + [ + "Ē", + "ă" + ], + [ + "ĝ", + "Ě" + ], + [ + "Ėė", + "ĖĘ" + ], + [ + "ĖĕĖĖ", + "ė" + ], + [ + "ėĘ", + "Ė" + ], + [ + "Ĕĕ", + "ėė" + ], + [ + "ē", + "ĖėĖĖ" + ], + [ + "ĖėĖĖ", + "Ėĕ" + ], + [ + "ĕĖė", + "ĕĕ" + ], + [ + "ėĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ĖĔĖĖ", + "ėĖĖ" + ], + [ + "ę", + "ĕē" + ], + [ + "ėĔ", + "Ė" + ], + [ + "ĕĖė", + "ėĖĖ" + ], + [ + "ę", + "Ęė" + ], + [ + "ĕĖĖ", + "ĘĖĖ" + ], + [ + "Ėē", + "ĕė" + ], + [ + "ę", + "ĘĖĕ" + ], + [ + "ď", + "Ē" + ], + [ + "ėĖĕ", + "ĘĖĖ" + ], + [ + "ę", + "ā" + ], + [ + "ę", + "ĖĖĖĖĖ" + ], + [ + "ėĖĕ", + "ĔĖĖ" + ], + [ + "ĕĘ", + "ėė" + ], + [ + "Ę", + "," + ], + [ + "ě", + "Ĕĕ" + ], + [ + "ď", + "Ď" + ], + [ + "Ė", + "'" + ], + [ + "Ę", + "Ć" + ], + [ + "ę", + "ć" + ], + [ + "ĖĖĖĖ", + "ĔĖĖ" + ], + [ + "ėĔ", + "ĕĖ" + ], + [ + "ĕĘ", + "ĕĕ" + ], + [ + "Đ", + "ĔĖ" + ], + [ + "đ", + "ĕĔ" + ], + [ + "Ĝ", + "ĖĘ" + ], + [ + "ĕĖĕ", + "ėė" + ], + [ + "ĕĒ", + "Ėĕ" + ], + [ + "ď", + "ē" + ], + [ + "ĕĖĖ", + "ėė" + ], + [ + "ėĖĖ", + "ĖĕĖĖ" + ], + [ + "ĕē", + "ĕĕ" + ], + [ + "Ĕ", + "$" + ], + [ + "ęĖĕ", + "ėĖĖ" + ], + [ + "ė", + "'" + ], + [ + "ĕĔ", + "ĕė" + ], + [ + "ĔĖĖ", + "Ėĕ" + ], + [ + "Ĕ", + "ĘĖė" + ] + ] + } +} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_cyc_440000/tokenizer_config.json b/scenestreamer/tokenization/0305_fast_cyc_440000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..44b81ebde2224b4e3935b02872938beae622c37c --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_cyc_440000/tokenizer_config.json @@ -0,0 +1,8 @@ +{ + "added_tokens_decoder": {}, + "clean_up_tokenization_spaces": false, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "processor_class": "UniversalActionProcessor", + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/scenestreamer/tokenization/0305_fast_ped_4000000/delta_normalization_quantiles.json b/scenestreamer/tokenization/0305_fast_ped_4000000/delta_normalization_quantiles.json new file mode 100644 index 0000000000000000000000000000000000000000..61ffb978827f503b729b4b94fb20ee55e3d467ec --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_ped_4000000/delta_normalization_quantiles.json @@ -0,0 +1 @@ +{"q_lower": [-0.05709691323339939, -0.04747965559363365, -0.6622537469863893, -0.053571314886212355, -0.04730982825160027, -0.6465871715545655, -0.05305027216672897, -0.04703669250011444, -0.654660701751709, -0.0555944050475955, -0.048526575043797496, -0.6671109199523926, -0.060208022594451904, -0.051760585978627205, -0.6740370631217958], "q_upper": [0.056708107106387316, 0.2951338437199576, 0.6598919129371574, 0.05403243750333786, 0.2931670981645541, 0.6609598350524877, 0.05273184873163668, 0.2911098754405972, 0.6547465324401855, 0.054818840585648974, 0.2922115921974182, 0.6682721972465515, 0.06055043738335352, 0.2947467863559716, 0.6932802295684706]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_ped_4000000/error_mean.json b/scenestreamer/tokenization/0305_fast_ped_4000000/error_mean.json new file mode 100644 index 0000000000000000000000000000000000000000..428c8098c82be451ab039ce463288ccbed1c1616 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_ped_4000000/error_mean.json @@ -0,0 +1 @@ +[[0.0009200551834461406, 0.0005993730614175804, 0.02961482031974784], [0.0012037112044310564, 0.0014241823432694278, 0.03598430569527111], [0.0013748478273013951, 0.0026271077530224227, 0.04444853739405181], [0.0014320452151866628, 0.004539943947601046, 0.049008310165982644], [0.0017214882922979352, 0.007024024440325715, 0.05161147338950369]] \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_ped_4000000/processor_config.json b/scenestreamer/tokenization/0305_fast_ped_4000000/processor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..8d9a3f725377bee10763626aa8cf6ddb5ad8ac87 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_ped_4000000/processor_config.json @@ -0,0 +1,8 @@ +{ + "action_dim": null, + "min_token": -22, + "processor_class": "UniversalActionProcessor", + "scale": 10, + "time_horizon": null, + "vocab_size": 1024 +} diff --git a/scenestreamer/tokenization/0305_fast_ped_4000000/special_tokens_map.json b/scenestreamer/tokenization/0305_fast_ped_4000000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_ped_4000000/special_tokens_map.json @@ -0,0 +1 @@ +{} diff --git a/scenestreamer/tokenization/0305_fast_ped_4000000/tokenizer.json b/scenestreamer/tokenization/0305_fast_ped_4000000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbdd13a14c6212e0f0358f3ee2a84d5596ee805 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_ped_4000000/tokenizer.json @@ -0,0 +1,4847 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": false, + "vocab": { + "\u0000": 0, + "\u0001": 1, + "\u0002": 2, + "\u0003": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "Ā": 45, + "ā": 46, + "Ă": 47, + "ă": 48, + "Ą": 49, + "ą": 50, + "Ć": 51, + "ć": 52, + "Ĉ": 53, + "ĉ": 54, + "Ċ": 55, + "ċ": 56, + "Č": 57, + "č": 58, + "Ď": 59, + "ď": 60, + "Đ": 61, + "đ": 62, + "Ē": 63, + "ē": 64, + "Ĕ": 65, + "ĕ": 66, + "Ė": 67, + "ė": 68, + "Ę": 69, + "ę": 70, + "Ě": 71, + "ě": 72, + "Ĝ": 73, + "ĝ": 74, + "Ğ": 75, + "ğ": 76, + "Ġ": 77, + "ĖĖ": 78, + "ĕĖ": 79, + "ėĖ": 80, + "ĕĖĖ": 81, + "ėĖĖ": 82, + "ĖĖĖĖ": 83, + "ĘĖĖ": 84, + "ĔĖĖ": 85, + "ĘĖ": 86, + "ĔĖ": 87, + "ĖĖĖ": 88, + "ėĕĖ": 89, + "ĕĕĖ": 90, + "ĕėĖ": 91, + "ėėĖ": 92, + "ĕĕ": 93, + "ėė": 94, + "ęĖ": 95, + "ęĖĖ": 96, + "ēĖĖ": 97, + "ĖĖĖĖĖ": 98, + "ĘĕĖ": 99, + "ĕė": 100, + "ēĖ": 101, + "ĔĕĖ": 102, + "ĔėĖ": 103, + "ĘėĖ": 104, + "ĚĖ": 105, + "ĕĖĖĖĖ": 106, + "ėĕ": 107, + "ėĖĖĖĖ": 108, + "ĖėĖ": 109, + "ĖĕĖ": 110, + "ĒĖ": 111, + "ėĖĖĖĖĖ": 112, + "ĚĖĖ": 113, + "ĕĖĖĖĖĖ": 114, + "ęĕĖ": 115, + "ěĖ": 116, + "ĒĖĖ": 117, + "ēėĖ": 118, + "ēĕĖ": 119, + "ĕĔ": 120, + "ęėĖ": 121, + "ėĘ": 122, + "ėĖĖėĖĖ": 123, + "ĕĖĖĕĖĖ": 124, + "ĕĘ": 125, + "ėĔ": 126, + "đĖ": 127, + "ĕĘĖ": 128, + "ėĖĖĕĖĖ": 129, + "ĕĕĖĖ": 130, + "ėĘĖ": 131, + "ĜĖ": 132, + "ėĔĖ": 133, + "ĕĔĖ": 134, + "ėĖĖĖ": 135, + "ėĕĖĖ": 136, + "ĖĖĖĖĖĖ": 137, + "ĆĖ": 138, + "Ĕĕ": 139, + "ěĖĖ": 140, + "Ĕė": 141, + "ĚĕĖ": 142, + "ĕĖĖĖ": 143, + "Ęĕ": 144, + "ėĖĖĕĖ": 145, + "Ęė": 146, + "ĒėĖ": 147, + "ĕĖĖėĖĖ": 148, + "ĐĖ": 149, + "ĒĕĖ": 150, + "đĖĖ": 151, + "ėĖĖėĖ": 152, + "ĕē": 153, + "ĚėĖ": 154, + "ĔĘĖ": 155, + "ėę": 156, + "ĘĖĖĖĖĖ": 157, + "ĘĘĖ": 158, + "ĖĘĖ": 159, + "ĘĔĖ": 160, + "ĝĖ": 161, + "ĔĖĖĖĖĖ": 162, + "ĖĖĖĖĖĖĖĖ": 163, + "ĕę": 164, + "ĔĔĖ": 165, + "ćĖ": 166, + "ėē": 167, + "ėęĖ": 168, + "ĕęĖ": 169, + "ĕĖĖĕĖ": 170, + "ėĖĖĘĖĖ": 171, + "ĔĔ": 172, + "ĘĘ": 173, + "ďĖ": 174, + "ėĖĖĔĖĖ": 175, + "ĕĖĖėĖ": 176, + "ėĕĖĖĖĖ": 177, + "ėėĖĖ": 178, + "ĖĖĖĖĖĖĖ": 179, + "ĖęĖ": 180, + "ĕĖĖĔĖĖ": 181, + "ĕĖĖĘĖĖ": 182, + "ĘĔ": 183, + "ĕĒ": 184, + "ĖĔĖ": 185, + "ėĚ": 186, + "ĕėĖĖ": 187, + "ĕēĖ": 188, + "ĘĆ": 189, + "ĔĆ": 190, + "ĕĕĖĖĖĖ": 191, + "Ėć": 192, + "ĜĖĖ": 193, + "ĔĘ": 194, + "ĕėė": 195, + "ĕĔĖĖ": 196, + "ĕėĖĖĖĖ": 197, + "ěĕĖ": 198, + "ėēĖ": 199, + "ėĒ": 200, + "ėĕĕ": 201, + "ēĘĖ": 202, + "ĘĖĖėĖĖ": 203, + "ĖĆ": 204, + "ĘĖĖĕĖĖ": 205, + "ĕĚ": 206, + "ėĚĖ": 207, + "đėĖ": 208, + "đĕĖ": 209, + "ąĖ": 210, + "ĎĖ": 211, + "ėėĖĖĖĖ": 212, + "ĐĖĖ": 213, + "ęĔĖ": 214, + "ĕĆ": 215, + "ĕĚĖ": 216, + "ĞĖ": 217, + "ėĆ": 218, + "ęĘĖ": 219, + "Ęę": 220, + "Ĕē": 221, + "ĔĖĖĕĖĖ": 222, + "ėć": 223, + "ĖĚĖ": 224, + "ěėĖ": 225, + "ĔęĖ": 226, + "ēĔĖ": 227, + "ĘęĖ": 228, + "ĔĖĖėĖĖ": 229, + "ĕć": 230, + "Ęē": 231, + "ėĘĖĖ": 232, + "ėĕė": 233, + "ĕđ": 234, + "ėĕĖĕĖĖ": 235, + "ėĖĖĖĕĖ": 236, + "Ĕę": 237, + "ėĕĖėĖĖ": 238, + "ėě": 239, + "ĕĖĖĖĕĖ": 240, + "ĖĖĖĕĖ": 241, + "ėđ": 242, + "ėĖĖĖėĖ": 243, + "ĖēĖ": 244, + "Ęĕĕ": 245, + "ĕĕĖĕĖĖ": 246, + "Ėą": 247, + "ĘĖĖĘĖĖ": 248, + "ĖĖĖėĖ": 249, + "ĕĖĖĖėĖ": 250, + "ĕĕĖėĖĖ": 251, + ",Ė": 252, + "ĕě": 253, + "ĕĘĖĖ": 254, + "ęĖĖĖĖĖ": 255, + "Ĕėė": 256, + "ĕĐ": 257, + "ĔĖĖĔĖĖ": 258, + "ĔēĖ": 259, + "ĕėĖėĖĖ": 260, + "ėěĖ": 261, + "ĕĒĖ": 262, + "ĘĕĖĖĖĖ": 263, + "ėĔĖĖ": 264, + "ĕėĖĕĖĖ": 265, + "ĔĒ": 266, + "ėėĖėĖĖ": 267, + "Ĕĕė": 268, + "ĘĚ": 269, + "ĕěĖ": 270, + "ĕą": 271, + "ĘĚĖ": 272, + "ĘēĖ": 273, + "ēĖĖĖĖĖ": 274, + "ĘĒ": 275, + "ĖěĖ": 276, + "ėą": 277, + "ėėĖĕĖĖ": 278, + "Ęėė": 279, + "Ĕĕĕ": 280, + "ėĐ": 281, + "ĔĚĖ": 282, + "ęę": 283, + "Ęć": 284, + "ėĒĖ": 285, + "ĈĖ": 286, + "ėĜ": 287, + "Ęėĕ": 288, + "ĘĖĖĔĖĖ": 289, + "ēē": 290, + "Ĕć": 291, + "ğĖ": 292, + "ĘĖĖĖĖ": 293, + "ĒĘĖ": 294, + "Ęĕė": 295, + "ĝĖĖ": 296, + "ĔĖĖĖĖ": 297, + "ĕĖĖĖĖĖĖ": 298, + "ęē": 299, + "ĔĕĖĖĖĖ": 300, + "ĚĔĖ": 301, + "ĕď": 302, + "ĔėĖĖĖĖ": 303, + "ĕĕĕĖ": 304, + "ēĆ": 305, + "ĔĚ": 306, + "ėĖĖėĕĖ": 307, + "ęĆ": 308, + "ĔĖĖĘĖĖ": 309, + "ēęĖ": 310, + "ęĕĕ": 311, + "ėęĖĖ": 312, + "ĖĒĖ": 313, + "ĕĜ": 314, + "ėď": 315, + "ĘėĖĖĖĖ": 316, + "ĕēĖĖ": 317, + "ėĖĖĖĖĖĖ": 318, + "ĕĖĖĖĖĖĖĖ": 319, + "ĒĔĖ": 320, + "ĚĘĖ": 321, + "ĜĕĖ": 322, + "ėĖĖĕėĖ": 323, + "ĕĖĖėĕĖ": 324, + "ĕĖĖĖĖĖĖĖĖ": 325, + "ęęĖ": 326, + "ėĖĖĕĕĖ": 327, + "Ĕđ": 328, + "ĕĕėĖ": 329, + "ĐėĖ": 330, + "ĐĕĖ": 331, + "ĔĖė": 332, + "Ęđ": 333, + "ėĖĖĖĖĖĖĖĖ": 334, + "ėĕĖĖĕĖ": 335, + "ĕĖĖĕėĖ": 336, + "ėĖĖėėĖ": 337, + "ēėė": 338, + "ęėĕ": 339, + "Ĕėĕ": 340, + "ĕĖĖĕĕĖ": 341, + "ēę": 342, + "ēĕė": 343, + "ĖĖĕĖ": 344, + "Ęě": 345, + "ėėĕĖ": 346, + "ėĖĖęĖĖ": 347, + "ĕđĖ": 348, + "ĘĖĖĖĕĖ": 349, + "Ėē": 350, + "ĜėĖ": 351, + "ėĕĖĖėĖ": 352, + "ĕĎ": 353, + "ĕėĕĖ": 354, + "ėĖĖĖĖĖĖĖ": 355, + "ĘěĖ": 356, + "ĕĖė": 357, + "ďĖĖ": 358, + "ĕĖĖėėĖ": 359, + "Ėę": 360, + "ėĕĖĔĖĖ": 361, + "ĔěĖ": 362, + "ĘĖė": 363, + "ĕėėĖ": 364, + "ĕĕĖĖĕĖ": 365, + "ĕĕĖĔĖĖ": 366, + "ėĎ": 367, + "ėĖĖēĖĖ": 368, + "ėĕĖĘĖĖ": 369, + "Ĕą": 370, + "Ęą": 371, + "ĕĖĖēĖĖ": 372, + "ęĒ": 373, + "ĔĒĖ": 374, + "ėĝ": 375, + "ĕėĖĖĕĖ": 376, + "ĕĖĖęĖĖ": 377, + "ėĜĖ": 378, + "ėĖė": 379, + "ĕĜĖ": 380, + "ĖĜĖ": 381, + "ĖĆĖ": 382, + "ēĒ": 383, + "ĔĖĖĖĕĖ": 384, + "ĕĕĖĖėĖ": 385, + "Ĕě": 386, + "ĘĖĖĕĖ": 387, + "ĔĖĖĖ": 388, + "ĘĖĖĖėĖ": 389, + "ėėėĖ": 390, + "Ěē": 391, + "ĕėĖĖėĖ": 392, + "ĕĔĕĖ": 393, + "ĘĐ": 394, + "Ěę": 395, + "ĔĐ": 396, + "ĘĖĖĖ": 397, + "ĘĆĖ": 398, + "ėėĖĖĕĖ": 399, + "ēēĖ": 400, + "ĔĖĖĖėĖ": 401, + "ēć": 402, + "ėđĖ": 403, + "ĔĆĖ": 404, + "ęėė": 405, + "ĕĝ": 406, + "ėĖĖĘĕĖ": 407, + "ĕĕĖĘĖĖ": 408, + "ĘĒĖ": 409, + "ĖĒ": 410, + "Ěĕĕ": 411, + "ęć": 412, + "ēĕĕ": 413, + "ĠĖ": 414, + "ęĖĖĕĖĖ": 415, + "ĖđĖ": 416, + "ęĚ": 417, + "ĖćĖ": 418, + "ėĆĖ": 419, + "ĔĖĖĕĖ": 420, + "ĘĖĕ": 421, + "ęĖĖėĖĖ": 422, + "ĕĆĖ": 423, + "Ēē": 424, + "ĕėĖĘĖĖ": 425, + "ęēĖ": 426, + "ėėĖĖėĖ": 427, + "ēĚĖ": 428, + "ęĕė": 429, + "ĕĖĕĖ": 430, + "ĘĕĖĕĖĖ": 431, + "ėėĖĘĖĖ": 432, + "Ĕď": 433, + "ĕėĖĔĖĖ": 434, + "ėĕĕĖ": 435, + "ĘĜ": 436, + "ĖĖĖĖĕĖĖ": 437, + "ĕęĖĖ": 438, + "ęĚĖ": 439, + "ĘĖĖėĖ": 440, + "ĘĕĖėĖĖ": 441, + "Ęď": 442, + "ėėĖĔĖĖ": 443, + "ēĖĖĕĖĖ": 444, + "Ěėĕ": 445, + "ėĕėĖ": 446, + "ęĔ": 447, + "ėĖėĖ": 448, + "ėĕĖėĕĖ": 449, + "ĕĔėĖ": 450, + "ĖėĖĖĖĖ": 451, + "Ēėė": 452, + "đĘĖ": 453, + "ēĔ": 454, + "ėĞ": 455, + "ěĔĖ": 456, + "ĕĈ": 457, + "ėĈ": 458, + "ęđ": 459, + "ėĖĖĔėĖ": 460, + "ęĖĖĘĖĖ": 461, + "ĕĐĖ": 462, + "Ēĕė": 463, + "ĕĖĖĖĖĖĕĖĖ": 464, + "ėĖĕĖ": 465, + "ĖĈ": 466, + "ĖĚ": 467, + "ęĘ": 468, + "ĒĆ": 469, + "ĔĖĕ": 470, + "ėĖĖĔĕĖ": 471, + "ėĚĖĖ": 472, + "ĔĖĖėĖ": 473, + "ėĖĖĘėĖ": 474, + "ĖĕĖĖĖĖ": 475, + "ēėĕ": 476, + "ēđ": 477, + "ĘĕĖĘĖĖ": 478, + "ĕĞ": 479, + "Ēę": 480, + "ėĖĖĖĖĖĕĖĖ": 481, + "ĕĕė": 482, + "ēĚ": 483, + "ėćĖ": 484, + "ęĕĖĖĖĖ": 485, + "Ėđ": 486, + "ėĕĖĕėĖ": 487, + "ĕĖĖĖĖĖėĖĖ": 488, + "ĔĕĖĕĖĖ": 489, + "ēĖĖĔĖĖ": 490, + "ēĖė": 491, + "ĘĎ": 492, + "ęĖĖĔĖĖ": 493, + "ĔėĖĕĖĖ": 494, + "đĔĖ": 495, + "ĕćĖ": 496, + "ēĖĖėĖĖ": 497, + "ĔĜĖ": 498, + "ĘĜĖ": 499, + "ĘėĖĕĖĖ": 500, + "ĘĕĖĔĖĖ": 501, + "ĚĆ": 502, + "ĔĎ": 503, + "ĕĕĖĖĖ": 504, + "ĖĖĖĖėĖĖ": 505, + "ĚĒ": 506, + "ēĖĖĘĖĖ": 507, + "ĒęĖ": 508, + "ĘĕĖĖĕĖ": 509, + "ĕĒĖĖ": 510, + "ĔđĖ": 511, + "ĕĕĖĕĕĖ": 512, + "ěĘĖ": 513, + "ĘĖĖĖĖĖĖĖĖ": 514, + "ęě": 515, + "ėĖĖĖĖĖėĖĖ": 516, + "ėėė": 517, + "ĒĒ": 518, + "ēė": 519, + "ĕĖĕĖĖ": 520, + "ėĐĖ": 521, + "ĔĈ": 522, + "ĕėĖĕėĖ": 523, + "ĘĕĖĖ": 524, + "ĖĐĖ": 525, + "ĔĜ": 526, + "ėĕĖĕĕĖ": 527, + "ēĘ": 528, + "ĝĕĖ": 529, + "ĕĕĖĖĖĖĖ": 530, + "ĚĚ": 531, + "ĔėĖėĖĖ": 532, + "ĘĈ": 533, + "ēėĖĖĖĖ": 534, + "ĖĖėĖ": 535, + "ēą": 536, + "ĔĕĖĖ": 537, + "ĔĖĖĖĖĖĖĖĖ": 538, + "ęė": 539, + "ĘėĖėĖĖ": 540, + "ĖĖĖĖĕĖ": 541, + "ėĕĖėėĖ": 542, + "ĘĖĖęĖĖ": 543, + "ĖĖĖĖĖĖĖĖĖĖ": 544, + "ďėĖ": 545, + "ĚĖĖĖĖĖ": 546, + "ĔĕĖėĖĖ": 547, + "ĔĕĖĔĖĖ": 548, + "ĘĖĖĕĕĖ": 549, + "ĘĔĕ": 550, + "ĘđĖ": 551, + "ėĖĕ": 552, + "ďĕĖ": 553, + "ĕĝĖ": 554, + "ēĕĖĖĖĖ": 555, + "ĚęĖ": 556, + "ėēĖĖ": 557, + "ĘĖĖėėĖ": 558, + "Ēć": 559, + "ęą": 560, + "ėĝĖ": 561, + "ĘĕĖĖėĖ": 562, + "ĘĖĖėĕĖ": 563, + "ēěĖ": 564, + "ęĐ": 565, + "ĕĖĖĘĕĖ": 566, + "ĕėĖĖĖ": 567, + "ĘĖĖĕėĖ": 568, + "ĖĐ": 569, + "ĔėĖĖ": 570, + "ėĕĖĖĖ": 571, + "ĖĝĖ": 572, + "ėėĖėėĖ": 573, + "ęĖĖĖĕĖ": 574, + "ęėĖĖĖĖ": 575, + "ĔėĖĔĖĖ": 576, + "ĖĖĖĖĔĖĖ": 577, + "ĔĕĔ": 578, + "Ěć": 579, + "ĘėĖĖ": 580, + "ĖĖĖĖĖĕĖ": 581, + "ĔĕĖĖĕĖ": 582, + "ĕĕĖĕėĖ": 583, + "ęĖĖĕĖ": 584, + "Ěĕė": 585, + "ĖąĖ": 586, + "ĘĘĕ": 587, + "Ěėė": 588, + "ĞĖĖ": 589, + "ēĐ": 590, + "ēĒĖ": 591, + "ĘėĖĔĖĖ": 592, + "ęěĖ": 593, + "ĝėĖ": 594, + "ėĘĕĖ": 595, + "ĘĖĖēĖĖ": 596, + "ĖĘĖĖĖĖ": 597, + "ĕĕĖėėĖ": 598, + "Ėď": 599, + "ĒĖĖĖĖĖ": 600, + "ĕėĖėĕĖ": 601, + "ēĖĖĖĕĖ": 602, + "ĕĖĖĔėĖ": 603, + "ĖĎ": 604, + "ėėĕ": 605, + "ĔėĖĖĕĖ": 606, + "Ęĝ": 607, + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ": 608, + "ęĕ": 609, + "ėĘĖĖĖĖ": 610, + "ěĕĕ": 611, + "!Ė": 612, + "Ēĕĕ": 613, + "ĖĖĖĖėĖ": 614, + "ęĖė": 615, + "ĒēĖ": 616, + "ĘėĖĖĕĖ": 617, + "ęĖĕ": 618, + "ĖĖĖĖĖėĖ": 619, + "ĕĕĖėĕĖ": 620, + "ĘėĘ": 621, + "ĕĘĖĖĖĖ": 622, + "ėĕĖĖĖĖĖĖĖ": 623, + "ĘėĖĘĖĖ": 624, + "ęĖĖĖėĖ": 625, + "ĔĕĖĖėĖ": 626, + "Ēđ": 627, + "ēĕ": 628, + "ĕĖėĖ": 629, + "ēě": 630, + "ĕĕĕ": 631, + "ĔėĖĖėĖ": 632, + "Ěđ": 633, + "ėĖĖĖĖĖėĖ": 634, + "ĕĖėĖĖ": 635, + "ĕĖĖĔĕĖ": 636, + "ėėĖėĕĖ": 637, + "ĔĖĖēĖĖ": 638, + "ęĔĕ": 639, + "ĔĖĖĕėĖ": 640, + "ĘĕĘ": 641, + "ęĘĕ": 642, + "ėėĖĕĕĖ": 643, + "ĕĖĕ": 644, + "ĔėĘ": 645, + "ēĖĖĖėĖ": 646, + "Ėě": 647, + "ĔėĔ": 648, + "ĕďĖ": 649, + "ĔĖĖėĕĖ": 650, + "ęď": 651, + "ĔĔė": 652, + "ĕėĕ": 653, + "ėĔĖĖĖĖ": 654, + "ėėĖĕėĖ": 655, + "đėė": 656, + "ĎĖĖ": 657, + "ĚēĖ": 658, + "ĔĖĖĕĕĖ": 659, + "ĘĕĔ": 660, + "ĘėĖĖėĖ": 661, + "ěėĕ": 662, + "ęĈ": 663, + "ēď": 664, + "ĔĖĖęĖĖ": 665, + "ēĈ": 666, + "ĔėĖĘĖĖ": 667, + "ĕąĖ": 668, + "ĕĖĖĖĖĖėĖ": 669, + "ĔĕĖĘĖĖ": 670, + "ĔĕĘ": 671, + "ęĒĖ": 672, + "ęĖĖĖ": 673, + "ėğ": 674, + "ĒĚ": 675, + "đĕė": 676, + "ĔĐĖ": 677, + "ĕĔĖĖĖĖ": 678, + "Ĕĝ": 679, + "ēĖĖĖ": 680, + "ėĖĖĖĖĖĕĖ": 681, + "ĕĖĖĘėĖ": 682, + "ĖĔĖĖĖĖ": 683, + "ĖďĖ": 684, + "ĕėĖĕĕĖ": 685, + "ėĘėĖ": 686, + "ĕĕĕĕ": 687, + "ĖĖĖĖĖĖėĖ": 688, + "ęĜ": 689, + "ėĖĖėĖĖėĖĖ": 690, + "ĕĚĖĖ": 691, + "Ěě": 692, + "ĕĖĖĖĖĖĕĖ": 693, + "ėąĖ": 694, + "ĜĔĖ": 695, + "ĘėĔ": 696, + "ĕėĖėėĖ": 697, + "ėďĖ": 698, + "ĕĕĖĖĖĖĖĖĖ": 699, + "ĕēĕĖ": 700, + "đē": 701, + "Ēėĕ": 702, + "ĔĖĖėėĖ": 703, + "ĐĘĖ": 704, + "ĘĖĖĘĕĖ": 705, + "ęĖĖėĖ": 706, + "ėėėė": 707, + "ĒĚĖ": 708, + "ĘćĖ": 709, + "ĕĖĖĕĖĖĕĖĖ": 710, + "ėėĖĖĖ": 711, + "ěē": 712, + "ĚĐ": 713, + "ĖĖĖĖĖĖĖĖĖ": 714, + "ēĔė": 715, + "ĔĘė": 716, + "ęĎ": 717, + "đĒ": 718, + "ęĖĖęĖĖ": 719, + "ĕğ": 720, + "ĘĕĖėĕĖ": 721, + "ēĎ": 722, + "ĒĐ": 723, + "ĔćĖ": 724, + "ĕėĖĖĖĖĖĖĖ": 725, + "ėěĖĖ": 726, + "ĖĖĖĖĘĖĖ": 727, + "ĕĕĖĖĕĖĖ": 728, + "ĘĞ": 729, + "ěĒ": 730, + "ėėĕė": 731, + "đĆ": 732, + "ěę": 733, + "ĖĖĖĖĖĖĕĖ": 734, + "ēĆĖ": 735, + "ĕĖĖėĖĖĖĖĖ": 736, + "ĖĚĖĖ": 737, + "ēĖĕ": 738, + "ęĆĖ": 739, + "ėĖĖėĖĖĖĖĖ": 740, + "ėĕĖēĖĖ": 741, + "ĘĝĖ": 742, + "ĔĝĖ": 743, + "ėĕĖęĖĖ": 744, + "ĕĕĖēĖĖ": 745, + "+Ė": 746, + "ėĕĖĘĕĖ": 747, + "ĐĔĖ": 748, + "ĒĖė": 749, + "\"Ė": 750, + "Ēą": 751, + "ĚĘĕ": 752, + "ėĖĖė": 753, + "ĕėĕĕ": 754, + "ěĆ": 755, + "ėĖĖėĖĖĕĖĖ": 756, + "ĘĕĖĕėĖ": 757, + "ĜĘĖ": 758, + "ęĕĖĕĖĖ": 759, + "ěĚ": 760, + "ĘĐĖ": 761, + "ēĖĖēĖĖ": 762, + "ĚĚĖ": 763, + "ĕĕėė": 764, + "ęĕĖėĖĖ": 765, + "ēĜĖ": 766, + "ėėĖė": 767, + "ėĕĖĔėĖ": 768, + "ĕĖĖĕĖĖėĖĖ": 769, + "ēĘė": 770, + "ĕĘĕĖ": 771, + "ĒĈ": 772, + "ĔĔĕ": 773, + "ěđ": 774, + "Ěď": 775, + "ĔĖĖĔėĖ": 776, + "đć": 777, + "ĕēėĖ": 778, + "ēđĖ": 779, + "ĕĕĖęĖĖ": 780, + "ĘĖĖĘėĖ": 781, + "ėĖĖĕĖĖėĖĖ": 782, + "ĚĔĕ": 783, + "ĕĖĖĕĖĖĖĖĖ": 784, + "ĔĖĖĔĕĖ": 785, + "ĕĕĖĔĕĖ": 786, + "ěć": 787, + "ĕėĖęĖĖ": 788, + "ĕėĖēĖĖ": 789, + "Ěą": 790, + "ėĕėė": 791, + "ĘĔė": 792, + "ĕĎĖ": 793, + "ĚĈ": 794, + "ĔĖĖĘĕĖ": 795, + "đę": 796, + "ėĖĕĖĖ": 797, + "ĕĘėĖ": 798, + "Ēď": 799, + "ĖėĖĖ": 800, + "ĘĖĖĔĕĖ": 801, + "ėėĖęĖĖ": 802, + "ėĖĖĖĖĖĖĖĖĖĖĖ": 803, + "ĕėĕė": 804, + "ĕėĖĘĕĖ": 805, + "ĖĕĖĖ": 806, + "đęĖ": 807, + "ēĜ": 808, + "ėĖĖĕĖĖĖĖĖ": 809, + "ėėĖēĖĖ": 810, + "ēĖĖĕĖ": 811, + "ĘĖĖĔėĖ": 812, + "ĕėĖĔėĖ": 813, + "ėėĖĖĖĖĖĖĖ": 814, + "ěĕė": 815, + "ėĔĕĖ": 816, + "ěėė": 817, + "ĒĔė": 818, + "đđ": 819, + "ĕĖĖĖĖĖĖĖĖĖĖĖ": 820, + "ęĜĖ": 821, + "ĖęĖĖĖĖ": 822, + "Ėĕĕ": 823, + "ĔĞ": 824, + "ĔďĖ": 825, + "ĘĘė": 826, + "ėĖĖĚĖĖ": 827, + "ĚĎ": 828, + "ėėĖĔĕĖ": 829, + "ėėĕĕ": 830, + "ĖĖĖĖĖĖĖĖĖĖĖ": 831, + "ĕĕĖĘĕĖ": 832, + "ėęĕĖ": 833, + "ĒĖĖĕĖĖ": 834, + "ėĕĖĔĕĖ": 835, + "ĕĞĖ": 836, + "ėĖėĖĖ": 837, + "ĕĖĖėĖĖėĖĖ": 838, + "ėėĖĘĕĖ": 839, + "ėĕĖė": 840, + "ęĖĖēĖĖ": 841, + "ĕĕĖĔėĖ": 842, + "Ĝĕĕ": 843, + "ĖĎĖ": 844, + "ėĖĖĔĖ": 845, + "ĒĎ": 846, + "ĖėĖėĖĖ": 847, + "đĕĕ": 848, + "ĚĖĖĕĖĖ": 849, + "ĒĘė": 850, + "ĕĖĖĔĖ": 851, + "ĚĔ": 852, + "ėĞĖ": 853, + "ĕĕĕė": 854, + "ĕĕĖĘėĖ": 855, + "ĖĞĖ": 856, + "Ēě": 857, + "ėĒĖĖ": 858, + "ĞĕĖ": 859, + "ėĎĖ": 860, + "Ęğ": 861, + "ĔĘĕ": 862, + "ĚĖĖėĖĖ": 863, + "ēĖĖęĖĖ": 864, + "ėĖĖĘĖ": 865, + "Ĝē": 866, + "ĖėĖĕĖĖ": 867, + "ėėĖĘėĖ": 868, + "ėĕĖĘėĖ": 869, + "ēėĖĕĖĖ": 870, + "ĔĖĖĘėĖ": 871, + "ėęĖĖĖĖ": 872, + "ęĕĖĖĕĖ": 873, + "ĕđĖĖ": 874, + "ěĐ": 875, + "ěĖĖĕĖ": 876, + "ęėĖėĖĖ": 877, + "ėĔėĖ": 878, + "ęėĘ": 879, + "ĚĜ": 880, + "ėĖĖĖĕĖĖĖĖ": 881, + "ėėĖĔėĖ": 882, + "ĒĒĖ": 883, + "*Ė": 884, + "ėĖĖĕĖĖĕĖĖ": 885, + "ēĕĔ": 886, + "ēĕĖĕĖĖ": 887, + "ĕėĖĔĕĖ": 888, + "Đėė": 889, + "ĘĘĖĖ": 890, + "Ĝėĕ": 891, + "ĘĕĖęĖĖ": 892, + "ĘĕĖĕĕĖ": 893, + "ĔĔĖĖ": 894, + "ęĝ": 895, + "Đē": 896, + "Đĕė": 897, + "ęđĖ": 898, + "ĕęĖĖĖĖ": 899, + "đĚ": 900, + "ĜĒ": 901, + "ĎėĖ": 902, + "ĎĕĖ": 903, + "ėĖĖęĕĖ": 904, + "ĘĕĖēĖĖ": 905, + "ĄĖ": 906, + "ĖėĖĖĕĖ": 907, + "ĕėĖĘėĖ": 908, + "ĕĕĔĖ": 909, + "ğĖĖ": 910, + "ĞėĖ": 911, + "ĕĖĖĖĕĖĖĖĖ": 912, + "ĘĘĖĖĖĖ": 913, + "ĖĕĖĕĖĖ": 914, + "ĒěĖ": 915, + "ěēĖ": 916, + "ĚĖĕ": 917, + "ĖĔ": 918, + "ĔėĖėĕĖ": 919, + "ęėĖĕĖĖ": 920, + "ĖĜ": 921, + "ĕĕĖĕĖ": 922, + "ěě": 923, + "đēĖ": 924, + "ĔĘĖĖĖĖ": 925, + "ęĕĘ": 926, + "ĕĖĖĘĖ": 927, + "đėĕ": 928, + "ěęĖ": 929, + "ĔėĖĕėĖ": 930, + "ĚĘ": 931, + "ĕėĖĕĖ": 932, + "đĐ": 933, + "ęĔė": 934, + "Ė,Ė": 935, + "ėĖĖĒĖĖ": 936, + "ĕĕėĕ": 937, + "ĘėĖĕĕĖ": 938, + "ĝĔĖ": 939, + "ĕėĖėĖ": 940, + "ėĖĖĖėĖĖĖĖ": 941, + "ėėĖėĖ": 942, + "ěĘĕ": 943, + "ĘďĖ": 944, + "ėĘĕė": 945, + "ęĕĖĘĖĖ": 946, + "ęĕĖĔĖĖ": 947, + "ĚĖĖĘĖĖ": 948, + "ĖĕĖėĖĖ": 949, + "ĕĖĖėĖĖĕĖĖ": 950, + "ēėĔ": 951, + "ĖĘ": 952, + "ėęėĖ": 953, + "ĕĔĖĖĖĖĖ": 954, + "ĐĒ": 955, + "ėĕĖĕĖ": 956, + "ĘĔĖĖĖĖ": 957, + "ęĕĔ": 958, + "ĘĕĖĘĕĖ": 959, + "ĕĕĖėĖ": 960, + "ĚĒĖ": 961, + "ēėĖėĖĖ": 962, + "ēĔĕ": 963, + "#Ė": 964, + "ĒĖĖĔĖĖ": 965, + "ĘĔĖĖ": 966, + "ĔĕĖĕĕĖ": 967, + "ĚĕĖĖĖĖ": 968, + "ĘąĖ": 969, + "ĔĕĖėėĖ": 970, + "ēĕĖėĖĖ": 971, + "ĒĖĖėĖĖ": 972, + "ęĘė": 973, + "ĔĘĖĖ": 974, + "Ěė": 975, + "ęĕĖĖėĖ": 976, + "ėĕĖėĖ": 977, + "ĔėĖēĖĖ": 978, + ")Ė": 979, + "ēėĘ": 980, + "ĘĕĖėėĖ": 981, + "ėĘĖĖĕĖ": 982, + "ĒĔ": 983, + "đĈ": 984, + "ęĖĖėĕĖ": 985, + "ĔąĖ": 986, + "ĘėĖėėĖ": 987, + "ěĈ": 988, + "ďĘĖ": 989, + "ęĖĖĕĕĖ": 990, + "ĖēĖĖĖĖ": 991, + "ĔĕĖēĖĖ": 992, + "ĚĖė": 993, + "ĚĖĖĔĖĖ": 994, + "ĕĖĖĖėĖĖĖĖ": 995, + "ĚěĖ": 996, + "ĕĔĕĕ": 997, + "ĕĘĖĖĕĖ": 998, + "ēĐĖ": 999, + "ėĘėė": 1000, + "ĘĚĖĖ": 1001, + "ĕĖĖĖĖĖĘĖĖ": 1002, + "ėėĖĕĖ": 1003, + "ěď": 1004, + "ēĕĘ": 1005, + "ēėĖĔĖĖ": 1006, + "ęėĔ": 1007, + "ĘĕĖĔėĖ": 1008, + "đď": 1009, + "ėĖĖĖĖĖĔĖĖ": 1010, + "ėėĖĕ": 1011, + "đą": 1012, + "ėĕĖĕ": 1013, + "ėĖĖĕ": 1014, + "ĕĖĖĖĖĖĔĖĖ": 1015, + "ĖĘĖĖ": 1016, + "Ĝę": 1017, + "ĕ,Ė": 1018, + "ēėĖĖĕĖ": 1019, + "ĘėĖęĖĖ": 1020, + "ėėėĕ": 1021, + "ėĖĖĖĖĖĘĖĖ": 1022, + "ĔĕĖėĕĖ": 1023 + }, + "merges": [ + [ + "Ė", + "Ė" + ], + [ + "ĕ", + "Ė" + ], + [ + "ė", + "Ė" + ], + [ + "ĕ", + "ĖĖ" + ], + [ + "ė", + "ĖĖ" + ], + [ + "ĖĖ", + "ĖĖ" + ], + [ + "Ę", + "ĖĖ" + ], + [ + "Ĕ", + "ĖĖ" + ], + [ + "Ę", + "Ė" + ], + [ + "Ĕ", + "Ė" + ], + [ + "ĖĖ", + "Ė" + ], + [ + "ė", + "ĕĖ" + ], + [ + "ĕ", + "ĕĖ" + ], + [ + "ĕ", + "ėĖ" + ], + [ + "ė", + "ėĖ" + ], + [ + "ĕ", + "ĕ" + ], + [ + "ė", + "ė" + ], + [ + "ę", + "Ė" + ], + [ + "ę", + "ĖĖ" + ], + [ + "ē", + "ĖĖ" + ], + [ + "ĖĖĖĖ", + "Ė" + ], + [ + "Ę", + "ĕĖ" + ], + [ + "ĕ", + "ė" + ], + [ + "ē", + "Ė" + ], + [ + "Ĕ", + "ĕĖ" + ], + [ + "Ĕ", + "ėĖ" + ], + [ + "Ę", + "ėĖ" + ], + [ + "Ě", + "Ė" + ], + [ + "ĕĖĖ", + "ĖĖ" + ], + [ + "ė", + "ĕ" + ], + [ + "ėĖĖ", + "ĖĖ" + ], + [ + "Ė", + "ėĖ" + ], + [ + "Ė", + "ĕĖ" + ], + [ + "Ē", + "Ė" + ], + [ + "ėĖĖ", + "ĖĖĖ" + ], + [ + "Ě", + "ĖĖ" + ], + [ + "ĕĖĖ", + "ĖĖĖ" + ], + [ + "ę", + "ĕĖ" + ], + [ + "ě", + "Ė" + ], + [ + "Ē", + "ĖĖ" + ], + [ + "ē", + "ėĖ" + ], + [ + "ē", + "ĕĖ" + ], + [ + "ĕ", + "Ĕ" + ], + [ + "ę", + "ėĖ" + ], + [ + "ė", + "Ę" + ], + [ + "ėĖĖ", + "ėĖĖ" + ], + [ + "ĕĖĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "Ę" + ], + [ + "ė", + "Ĕ" + ], + [ + "đ", + "Ė" + ], + [ + "ĕ", + "ĘĖ" + ], + [ + "ėĖĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "ĕĖĖ" + ], + [ + "ė", + "ĘĖ" + ], + [ + "Ĝ", + "Ė" + ], + [ + "ė", + "ĔĖ" + ], + [ + "ĕ", + "ĔĖ" + ], + [ + "ėĖĖ", + "Ė" + ], + [ + "ė", + "ĕĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖ" + ], + [ + "Ć", + "Ė" + ], + [ + "Ĕ", + "ĕ" + ], + [ + "ě", + "ĖĖ" + ], + [ + "Ĕ", + "ė" + ], + [ + "Ě", + "ĕĖ" + ], + [ + "ĕĖĖ", + "Ė" + ], + [ + "Ę", + "ĕ" + ], + [ + "ėĖĖ", + "ĕĖ" + ], + [ + "Ę", + "ė" + ], + [ + "Ē", + "ėĖ" + ], + [ + "ĕĖĖ", + "ėĖĖ" + ], + [ + "Đ", + "Ė" + ], + [ + "Ē", + "ĕĖ" + ], + [ + "đ", + "ĖĖ" + ], + [ + "ėĖĖ", + "ėĖ" + ], + [ + "ĕ", + "ē" + ], + [ + "Ě", + "ėĖ" + ], + [ + "Ĕ", + "ĘĖ" + ], + [ + "ė", + "ę" + ], + [ + "Ę", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "ĘĖ" + ], + [ + "Ė", + "ĘĖ" + ], + [ + "Ę", + "ĔĖ" + ], + [ + "ĝ", + "Ė" + ], + [ + "Ĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "ĕ", + "ę" + ], + [ + "Ĕ", + "ĔĖ" + ], + [ + "ć", + "Ė" + ], + [ + "ė", + "ē" + ], + [ + "ė", + "ęĖ" + ], + [ + "ĕ", + "ęĖ" + ], + [ + "ĕĖĖ", + "ĕĖ" + ], + [ + "ėĖĖ", + "ĘĖĖ" + ], + [ + "Ĕ", + "Ĕ" + ], + [ + "Ę", + "Ę" + ], + [ + "ď", + "Ė" + ], + [ + "ėĖĖ", + "ĔĖĖ" + ], + [ + "ĕĖĖ", + "ėĖ" + ], + [ + "ė", + "ĕĖĖĖĖ" + ], + [ + "ė", + "ėĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖ" + ], + [ + "Ė", + "ęĖ" + ], + [ + "ĕĖĖ", + "ĔĖĖ" + ], + [ + "ĕĖĖ", + "ĘĖĖ" + ], + [ + "Ę", + "Ĕ" + ], + [ + "ĕ", + "Ē" + ], + [ + "Ė", + "ĔĖ" + ], + [ + "ė", + "Ě" + ], + [ + "ĕ", + "ėĖĖ" + ], + [ + "ĕ", + "ēĖ" + ], + [ + "Ę", + "Ć" + ], + [ + "Ĕ", + "Ć" + ], + [ + "ĕ", + "ĕĖĖĖĖ" + ], + [ + "Ė", + "ć" + ], + [ + "Ĝ", + "ĖĖ" + ], + [ + "Ĕ", + "Ę" + ], + [ + "ĕ", + "ėė" + ], + [ + "ĕ", + "ĔĖĖ" + ], + [ + "ĕ", + "ėĖĖĖĖ" + ], + [ + "ě", + "ĕĖ" + ], + [ + "ė", + "ēĖ" + ], + [ + "ė", + "Ē" + ], + [ + "ė", + "ĕĕ" + ], + [ + "ē", + "ĘĖ" + ], + [ + "ĘĖĖ", + "ėĖĖ" + ], + [ + "Ė", + "Ć" + ], + [ + "ĘĖĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "Ě" + ], + [ + "ė", + "ĚĖ" + ], + [ + "đ", + "ėĖ" + ], + [ + "đ", + "ĕĖ" + ], + [ + "ą", + "Ė" + ], + [ + "Ď", + "Ė" + ], + [ + "ė", + "ėĖĖĖĖ" + ], + [ + "Đ", + "ĖĖ" + ], + [ + "ę", + "ĔĖ" + ], + [ + "ĕ", + "Ć" + ], + [ + "ĕ", + "ĚĖ" + ], + [ + "Ğ", + "Ė" + ], + [ + "ė", + "Ć" + ], + [ + "ę", + "ĘĖ" + ], + [ + "Ę", + "ę" + ], + [ + "Ĕ", + "ē" + ], + [ + "ĔĖĖ", + "ĕĖĖ" + ], + [ + "ė", + "ć" + ], + [ + "Ė", + "ĚĖ" + ], + [ + "ě", + "ėĖ" + ], + [ + "Ĕ", + "ęĖ" + ], + [ + "ē", + "ĔĖ" + ], + [ + "Ę", + "ęĖ" + ], + [ + "ĔĖĖ", + "ėĖĖ" + ], + [ + "ĕ", + "ć" + ], + [ + "Ę", + "ē" + ], + [ + "ė", + "ĘĖĖ" + ], + [ + "ė", + "ĕė" + ], + [ + "ĕ", + "đ" + ], + [ + "ėĕĖ", + "ĕĖĖ" + ], + [ + "ėĖĖ", + "ĖĕĖ" + ], + [ + "Ĕ", + "ę" + ], + [ + "ėĕĖ", + "ėĖĖ" + ], + [ + "ė", + "ě" + ], + [ + "ĕĖĖ", + "ĖĕĖ" + ], + [ + "ĖĖĖ", + "ĕĖ" + ], + [ + "ė", + "đ" + ], + [ + "ėĖĖ", + "ĖėĖ" + ], + [ + "Ė", + "ēĖ" + ], + [ + "Ę", + "ĕĕ" + ], + [ + "ĕĕĖ", + "ĕĖĖ" + ], + [ + "Ė", + "ą" + ], + [ + "ĘĖĖ", + "ĘĖĖ" + ], + [ + "ĖĖĖ", + "ėĖ" + ], + [ + "ĕĖĖ", + "ĖėĖ" + ], + [ + "ĕĕĖ", + "ėĖĖ" + ], + [ + ",", + "Ė" + ], + [ + "ĕ", + "ě" + ], + [ + "ĕ", + "ĘĖĖ" + ], + [ + "ę", + "ĖĖĖĖĖ" + ], + [ + "Ĕ", + "ėė" + ], + [ + "ĕ", + "Đ" + ], + [ + "ĔĖĖ", + "ĔĖĖ" + ], + [ + "Ĕ", + "ēĖ" + ], + [ + "ĕėĖ", + "ėĖĖ" + ], + [ + "ė", + "ěĖ" + ], + [ + "ĕ", + "ĒĖ" + ], + [ + "Ę", + "ĕĖĖĖĖ" + ], + [ + "ė", + "ĔĖĖ" + ], + [ + "ĕėĖ", + "ĕĖĖ" + ], + [ + "Ĕ", + "Ē" + ], + [ + "ėėĖ", + "ėĖĖ" + ], + [ + "Ĕ", + "ĕė" + ], + [ + "Ę", + "Ě" + ], + [ + "ĕ", + "ěĖ" + ], + [ + "ĕ", + "ą" + ], + [ + "Ę", + "ĚĖ" + ], + [ + "Ę", + "ēĖ" + ], + [ + "ē", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "Ē" + ], + [ + "Ė", + "ěĖ" + ], + [ + "ė", + "ą" + ], + [ + "ėėĖ", + "ĕĖĖ" + ], + [ + "Ę", + "ėė" + ], + [ + "Ĕ", + "ĕĕ" + ], + [ + "ė", + "Đ" + ], + [ + "Ĕ", + "ĚĖ" + ], + [ + "ę", + "ę" + ], + [ + "Ę", + "ć" + ], + [ + "ė", + "ĒĖ" + ], + [ + "Ĉ", + "Ė" + ], + [ + "ė", + "Ĝ" + ], + [ + "Ę", + "ėĕ" + ], + [ + "ĘĖĖ", + "ĔĖĖ" + ], + [ + "ē", + "ē" + ], + [ + "Ĕ", + "ć" + ], + [ + "ğ", + "Ė" + ], + [ + "Ę", + "ĖĖĖĖ" + ], + [ + "Ē", + "ĘĖ" + ], + [ + "Ę", + "ĕė" + ], + [ + "ĝ", + "ĖĖ" + ], + [ + "Ĕ", + "ĖĖĖĖ" + ], + [ + "ĕĖĖ", + "ĖĖĖĖ" + ], + [ + "ę", + "ē" + ], + [ + "Ĕ", + "ĕĖĖĖĖ" + ], + [ + "Ě", + "ĔĖ" + ], + [ + "ĕ", + "ď" + ], + [ + "Ĕ", + "ėĖĖĖĖ" + ], + [ + "ĕ", + "ĕĕĖ" + ], + [ + "ē", + "Ć" + ], + [ + "Ĕ", + "Ě" + ], + [ + "ėĖĖ", + "ėĕĖ" + ], + [ + "ę", + "Ć" + ], + [ + "ĔĖĖ", + "ĘĖĖ" + ], + [ + "ē", + "ęĖ" + ], + [ + "ę", + "ĕĕ" + ], + [ + "ė", + "ęĖĖ" + ], + [ + "Ė", + "ĒĖ" + ], + [ + "ĕ", + "Ĝ" + ], + [ + "ė", + "ď" + ], + [ + "Ę", + "ėĖĖĖĖ" + ], + [ + "ĕ", + "ēĖĖ" + ], + [ + "ėĖĖ", + "ĖĖĖĖ" + ], + [ + "ĕĖĖ", + "ĖĖĖĖĖ" + ], + [ + "Ē", + "ĔĖ" + ], + [ + "Ě", + "ĘĖ" + ], + [ + "Ĝ", + "ĕĖ" + ], + [ + "ėĖĖ", + "ĕėĖ" + ], + [ + "ĕĖĖ", + "ėĕĖ" + ], + [ + "ĕĖĖ", + "ĖĖĖĖĖĖ" + ], + [ + "ę", + "ęĖ" + ], + [ + "ėĖĖ", + "ĕĕĖ" + ], + [ + "Ĕ", + "đ" + ], + [ + "ĕ", + "ĕėĖ" + ], + [ + "Đ", + "ėĖ" + ], + [ + "Đ", + "ĕĖ" + ], + [ + "ĔĖ", + "ė" + ], + [ + "Ę", + "đ" + ], + [ + "ėĖĖ", + "ĖĖĖĖĖĖ" + ], + [ + "ėĕĖĖ", + "ĕĖ" + ], + [ + "ĕĖĖ", + "ĕėĖ" + ], + [ + "ėĖĖ", + "ėėĖ" + ], + [ + "ē", + "ėė" + ], + [ + "ę", + "ėĕ" + ], + [ + "Ĕ", + "ėĕ" + ], + [ + "ĕĖĖ", + "ĕĕĖ" + ], + [ + "ē", + "ę" + ], + [ + "ē", + "ĕė" + ], + [ + "ĖĖ", + "ĕĖ" + ], + [ + "Ę", + "ě" + ], + [ + "ė", + "ėĕĖ" + ], + [ + "ėĖĖ", + "ęĖĖ" + ], + [ + "ĕ", + "đĖ" + ], + [ + "ĘĖĖ", + "ĖĕĖ" + ], + [ + "Ė", + "ē" + ], + [ + "Ĝ", + "ėĖ" + ], + [ + "ėĕĖĖ", + "ėĖ" + ], + [ + "ĕ", + "Ď" + ], + [ + "ĕ", + "ėĕĖ" + ], + [ + "ėĖĖ", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "ěĖ" + ], + [ + "ĕĖ", + "ė" + ], + [ + "ď", + "ĖĖ" + ], + [ + "ĕĖĖ", + "ėėĖ" + ], + [ + "Ė", + "ę" + ], + [ + "ėĕĖ", + "ĔĖĖ" + ], + [ + "Ĕ", + "ěĖ" + ], + [ + "ĘĖ", + "ė" + ], + [ + "ĕ", + "ėėĖ" + ], + [ + "ĕĕĖĖ", + "ĕĖ" + ], + [ + "ĕĕĖ", + "ĔĖĖ" + ], + [ + "ė", + "Ď" + ], + [ + "ėĖĖ", + "ēĖĖ" + ], + [ + "ėĕĖ", + "ĘĖĖ" + ], + [ + "Ĕ", + "ą" + ], + [ + "Ę", + "ą" + ], + [ + "ĕĖĖ", + "ēĖĖ" + ], + [ + "ę", + "Ē" + ], + [ + "Ĕ", + "ĒĖ" + ], + [ + "ė", + "ĝ" + ], + [ + "ĕ", + "ėĖĖĕĖ" + ], + [ + "ĕĖĖ", + "ęĖĖ" + ], + [ + "ė", + "ĜĖ" + ], + [ + "ėĖ", + "ė" + ], + [ + "ĕ", + "ĜĖ" + ], + [ + "Ė", + "ĜĖ" + ], + [ + "Ė", + "ĆĖ" + ], + [ + "ē", + "Ē" + ], + [ + "ĔĖĖ", + "ĖĕĖ" + ], + [ + "ĕĕĖĖ", + "ėĖ" + ], + [ + "Ĕ", + "ě" + ], + [ + "ĘĖĖ", + "ĕĖ" + ], + [ + "ĔĖĖ", + "Ė" + ], + [ + "ĘĖĖ", + "ĖėĖ" + ], + [ + "ė", + "ėėĖ" + ], + [ + "Ě", + "ē" + ], + [ + "ĕ", + "ėĖĖėĖ" + ], + [ + "ĕ", + "ĔĕĖ" + ], + [ + "Ę", + "Đ" + ], + [ + "Ě", + "ę" + ], + [ + "Ĕ", + "Đ" + ], + [ + "ĘĖĖ", + "Ė" + ], + [ + "Ę", + "ĆĖ" + ], + [ + "ė", + "ėĖĖĕĖ" + ], + [ + "ē", + "ēĖ" + ], + [ + "ĔĖĖ", + "ĖėĖ" + ], + [ + "ē", + "ć" + ], + [ + "ė", + "đĖ" + ], + [ + "Ĕ", + "ĆĖ" + ], + [ + "ę", + "ėė" + ], + [ + "ĕ", + "ĝ" + ], + [ + "ėĖĖ", + "ĘĕĖ" + ], + [ + "ĕĕĖ", + "ĘĖĖ" + ], + [ + "Ę", + "ĒĖ" + ], + [ + "Ė", + "Ē" + ], + [ + "Ě", + "ĕĕ" + ], + [ + "ę", + "ć" + ], + [ + "ē", + "ĕĕ" + ], + [ + "Ġ", + "Ė" + ], + [ + "ęĖĖ", + "ĕĖĖ" + ], + [ + "Ė", + "đĖ" + ], + [ + "ę", + "Ě" + ], + [ + "Ė", + "ćĖ" + ], + [ + "ė", + "ĆĖ" + ], + [ + "ĔĖĖ", + "ĕĖ" + ], + [ + "ĘĖ", + "ĕ" + ], + [ + "ęĖĖ", + "ėĖĖ" + ], + [ + "ĕ", + "ĆĖ" + ], + [ + "Ē", + "ē" + ], + [ + "ĕėĖ", + "ĘĖĖ" + ], + [ + "ę", + "ēĖ" + ], + [ + "ė", + "ėĖĖėĖ" + ], + [ + "ē", + "ĚĖ" + ], + [ + "ę", + "ĕė" + ], + [ + "ĕĖ", + "ĕĖ" + ], + [ + "ĘĕĖ", + "ĕĖĖ" + ], + [ + "ėėĖ", + "ĘĖĖ" + ], + [ + "Ĕ", + "ď" + ], + [ + "ĕėĖ", + "ĔĖĖ" + ], + [ + "ė", + "ĕĕĖ" + ], + [ + "Ę", + "Ĝ" + ], + [ + "ĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "ęĖĖ" + ], + [ + "ę", + "ĚĖ" + ], + [ + "ĘĖĖ", + "ėĖ" + ], + [ + "ĘĕĖ", + "ėĖĖ" + ], + [ + "Ę", + "ď" + ], + [ + "ėėĖ", + "ĔĖĖ" + ], + [ + "ēĖĖ", + "ĕĖĖ" + ], + [ + "Ě", + "ėĕ" + ], + [ + "ė", + "ĕėĖ" + ], + [ + "ę", + "Ĕ" + ], + [ + "ėĖ", + "ėĖ" + ], + [ + "ėĕĖ", + "ėĕĖ" + ], + [ + "ĕ", + "ĔėĖ" + ], + [ + "Ė", + "ėĖĖĖĖ" + ], + [ + "Ē", + "ėė" + ], + [ + "đ", + "ĘĖ" + ], + [ + "ē", + "Ĕ" + ], + [ + "ė", + "Ğ" + ], + [ + "ě", + "ĔĖ" + ], + [ + "ĕ", + "Ĉ" + ], + [ + "ė", + "Ĉ" + ], + [ + "ę", + "đ" + ], + [ + "ėĖĖ", + "ĔėĖ" + ], + [ + "ęĖĖ", + "ĘĖĖ" + ], + [ + "ĕ", + "ĐĖ" + ], + [ + "Ē", + "ĕė" + ], + [ + "ĕĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ėĖ", + "ĕĖ" + ], + [ + "Ė", + "Ĉ" + ], + [ + "Ė", + "Ě" + ], + [ + "ę", + "Ę" + ], + [ + "Ē", + "Ć" + ], + [ + "ĔĖ", + "ĕ" + ], + [ + "ėĖĖ", + "ĔĕĖ" + ], + [ + "ė", + "ĚĖĖ" + ], + [ + "ĔĖĖ", + "ėĖ" + ], + [ + "ėĖĖ", + "ĘėĖ" + ], + [ + "Ė", + "ĕĖĖĖĖ" + ], + [ + "ē", + "ėĕ" + ], + [ + "ē", + "đ" + ], + [ + "ĘĕĖ", + "ĘĖĖ" + ], + [ + "ĕ", + "Ğ" + ], + [ + "Ē", + "ę" + ], + [ + "ėĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ĕĕ", + "ė" + ], + [ + "ē", + "Ě" + ], + [ + "ė", + "ćĖ" + ], + [ + "ę", + "ĕĖĖĖĖ" + ], + [ + "Ė", + "đ" + ], + [ + "ėĕĖ", + "ĕėĖ" + ], + [ + "ĕĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĔĕĖ", + "ĕĖĖ" + ], + [ + "ēĖĖ", + "ĔĖĖ" + ], + [ + "ēĖ", + "ė" + ], + [ + "Ę", + "Ď" + ], + [ + "ęĖĖ", + "ĔĖĖ" + ], + [ + "ĔėĖ", + "ĕĖĖ" + ], + [ + "đ", + "ĔĖ" + ], + [ + "ĕ", + "ćĖ" + ], + [ + "ēĖĖ", + "ėĖĖ" + ], + [ + "Ĕ", + "ĜĖ" + ], + [ + "Ę", + "ĜĖ" + ], + [ + "ĘėĖ", + "ĕĖĖ" + ], + [ + "ĘĕĖ", + "ĔĖĖ" + ], + [ + "Ě", + "Ć" + ], + [ + "Ĕ", + "Ď" + ], + [ + "ĕĕĖĖ", + "Ė" + ], + [ + "ĖĖĖĖ", + "ėĖĖ" + ], + [ + "Ě", + "Ē" + ], + [ + "ēĖĖ", + "ĘĖĖ" + ], + [ + "Ē", + "ęĖ" + ], + [ + "Ę", + "ĕĖĖĕĖ" + ], + [ + "ĕ", + "ĒĖĖ" + ], + [ + "Ĕ", + "đĖ" + ], + [ + "ĕĕĖ", + "ĕĕĖ" + ], + [ + "ě", + "ĘĖ" + ], + [ + "Ę", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ě" + ], + [ + "ėĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ėė", + "ė" + ], + [ + "Ē", + "Ē" + ], + [ + "ē", + "ė" + ], + [ + "ĕĖ", + "ĕĖĖ" + ], + [ + "ė", + "ĐĖ" + ], + [ + "Ĕ", + "Ĉ" + ], + [ + "ĕėĖ", + "ĕėĖ" + ], + [ + "Ę", + "ĕĖĖ" + ], + [ + "Ė", + "ĐĖ" + ], + [ + "Ĕ", + "Ĝ" + ], + [ + "ėĕĖ", + "ĕĕĖ" + ], + [ + "ē", + "Ę" + ], + [ + "ĝ", + "ĕĖ" + ], + [ + "ĕ", + "ĕĖĖĖĖĖ" + ], + [ + "Ě", + "Ě" + ], + [ + "ĔėĖ", + "ėĖĖ" + ], + [ + "Ę", + "Ĉ" + ], + [ + "ē", + "ėĖĖĖĖ" + ], + [ + "ĖĖ", + "ėĖ" + ], + [ + "ē", + "ą" + ], + [ + "Ĕ", + "ĕĖĖ" + ], + [ + "Ĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ė" + ], + [ + "ĘėĖ", + "ėĖĖ" + ], + [ + "ĖĖĖĖ", + "ĕĖ" + ], + [ + "ėĕĖ", + "ėėĖ" + ], + [ + "ĘĖĖ", + "ęĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĖĖĖ" + ], + [ + "ď", + "ėĖ" + ], + [ + "Ě", + "ĖĖĖĖĖ" + ], + [ + "ĔĕĖ", + "ėĖĖ" + ], + [ + "ĔĕĖ", + "ĔĖĖ" + ], + [ + "ĘĖĖ", + "ĕĕĖ" + ], + [ + "Ę", + "Ĕĕ" + ], + [ + "Ę", + "đĖ" + ], + [ + "ėĖ", + "ĕ" + ], + [ + "ď", + "ĕĖ" + ], + [ + "ĕ", + "ĝĖ" + ], + [ + "ē", + "ĕĖĖĖĖ" + ], + [ + "Ě", + "ęĖ" + ], + [ + "ė", + "ēĖĖ" + ], + [ + "ĘĖĖ", + "ėėĖ" + ], + [ + "Ē", + "ć" + ], + [ + "ę", + "ą" + ], + [ + "ė", + "ĝĖ" + ], + [ + "Ę", + "ĕĖĖėĖ" + ], + [ + "ĘĖĖ", + "ėĕĖ" + ], + [ + "ē", + "ěĖ" + ], + [ + "ę", + "Đ" + ], + [ + "ĕĖĖ", + "ĘĕĖ" + ], + [ + "ĕ", + "ėĖĖĖ" + ], + [ + "ĘĖĖ", + "ĕėĖ" + ], + [ + "Ė", + "Đ" + ], + [ + "Ĕ", + "ėĖĖ" + ], + [ + "ėĕĖĖ", + "Ė" + ], + [ + "Ė", + "ĝĖ" + ], + [ + "ėėĖ", + "ėėĖ" + ], + [ + "ę", + "ĖĖĖĕĖ" + ], + [ + "ę", + "ėĖĖĖĖ" + ], + [ + "ĔėĖ", + "ĔĖĖ" + ], + [ + "ĖĖĖĖ", + "ĔĖĖ" + ], + [ + "Ĕ", + "ĕĔ" + ], + [ + "Ě", + "ć" + ], + [ + "Ę", + "ėĖĖ" + ], + [ + "ĖĖĖĖĖ", + "ĕĖ" + ], + [ + "Ĕ", + "ĕĖĖĕĖ" + ], + [ + "ĕĕĖ", + "ĕėĖ" + ], + [ + "ęĖĖ", + "ĕĖ" + ], + [ + "Ě", + "ĕė" + ], + [ + "Ė", + "ąĖ" + ], + [ + "Ę", + "Ęĕ" + ], + [ + "Ě", + "ėė" + ], + [ + "Ğ", + "ĖĖ" + ], + [ + "ē", + "Đ" + ], + [ + "ē", + "ĒĖ" + ], + [ + "ĘėĖ", + "ĔĖĖ" + ], + [ + "ę", + "ěĖ" + ], + [ + "ĝ", + "ėĖ" + ], + [ + "ė", + "ĘĕĖ" + ], + [ + "ĘĖĖ", + "ēĖĖ" + ], + [ + "Ė", + "ĘĖĖĖĖ" + ], + [ + "ĕĕĖ", + "ėėĖ" + ], + [ + "Ė", + "ď" + ], + [ + "Ē", + "ĖĖĖĖĖ" + ], + [ + "ĕėĖ", + "ėĕĖ" + ], + [ + "ē", + "ĖĖĖĕĖ" + ], + [ + "ĕĖĖ", + "ĔėĖ" + ], + [ + "Ė", + "Ď" + ], + [ + "ėė", + "ĕ" + ], + [ + "Ĕ", + "ėĖĖĕĖ" + ], + [ + "Ę", + "ĝ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ę", + "ĕ" + ], + [ + "ėĘ", + "ĖĖĖĖ" + ], + [ + "ě", + "ĕĕ" + ], + [ + "!", + "Ė" + ], + [ + "Ē", + "ĕĕ" + ], + [ + "ĖĖĖĖ", + "ėĖ" + ], + [ + "ęĖ", + "ė" + ], + [ + "Ē", + "ēĖ" + ], + [ + "Ę", + "ėĖĖĕĖ" + ], + [ + "ęĖ", + "ĕ" + ], + [ + "ĖĖĖĖĖ", + "ėĖ" + ], + [ + "ĕĕĖ", + "ėĕĖ" + ], + [ + "Ę", + "ėĘ" + ], + [ + "ĕĘ", + "ĖĖĖĖ" + ], + [ + "ėĕĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ĘėĖ", + "ĘĖĖ" + ], + [ + "ę", + "ĖĖĖėĖ" + ], + [ + "Ĕ", + "ĕĖĖėĖ" + ], + [ + "Ē", + "đ" + ], + [ + "ē", + "ĕ" + ], + [ + "ĕĖ", + "ėĖ" + ], + [ + "ē", + "ě" + ], + [ + "ĕĕ", + "ĕ" + ], + [ + "Ĕ", + "ėĖĖėĖ" + ], + [ + "Ě", + "đ" + ], + [ + "ėĖĖĖĖĖ", + "ėĖ" + ], + [ + "ĕĖ", + "ėĖĖ" + ], + [ + "ĕĖĖ", + "ĔĕĖ" + ], + [ + "ėėĖ", + "ėĕĖ" + ], + [ + "ĔĖĖ", + "ēĖĖ" + ], + [ + "ę", + "Ĕĕ" + ], + [ + "ĔĖĖ", + "ĕėĖ" + ], + [ + "Ę", + "ĕĘ" + ], + [ + "ę", + "Ęĕ" + ], + [ + "ėėĖ", + "ĕĕĖ" + ], + [ + "ĕĖ", + "ĕ" + ], + [ + "Ĕ", + "ėĘ" + ], + [ + "ē", + "ĖĖĖėĖ" + ], + [ + "Ė", + "ě" + ], + [ + "Ĕ", + "ėĔ" + ], + [ + "ĕ", + "ďĖ" + ], + [ + "ĔĖĖ", + "ėĕĖ" + ], + [ + "ę", + "ď" + ], + [ + "Ĕ", + "Ĕė" + ], + [ + "ĕė", + "ĕ" + ], + [ + "ėĔ", + "ĖĖĖĖ" + ], + [ + "ėėĖ", + "ĕėĖ" + ], + [ + "đ", + "ėė" + ], + [ + "Ď", + "ĖĖ" + ], + [ + "Ě", + "ēĖ" + ], + [ + "ĔĖĖ", + "ĕĕĖ" + ], + [ + "Ę", + "ĕĔ" + ], + [ + "Ę", + "ėĖĖėĖ" + ], + [ + "ě", + "ėĕ" + ], + [ + "ę", + "Ĉ" + ], + [ + "ē", + "ď" + ], + [ + "ĔĖĖ", + "ęĖĖ" + ], + [ + "ē", + "Ĉ" + ], + [ + "ĔėĖ", + "ĘĖĖ" + ], + [ + "ĕ", + "ąĖ" + ], + [ + "ĕĖĖĖĖĖ", + "ėĖ" + ], + [ + "ĔĕĖ", + "ĘĖĖ" + ], + [ + "Ĕ", + "ĕĘ" + ], + [ + "ę", + "ĒĖ" + ], + [ + "ę", + "ĖĖĖ" + ], + [ + "ė", + "ğ" + ], + [ + "Ē", + "Ě" + ], + [ + "đ", + "ĕė" + ], + [ + "Ĕ", + "ĐĖ" + ], + [ + "ĕĔ", + "ĖĖĖĖ" + ], + [ + "Ĕ", + "ĝ" + ], + [ + "ē", + "ĖĖĖ" + ], + [ + "ėĖĖĖĖĖ", + "ĕĖ" + ], + [ + "ĕĖĖ", + "ĘėĖ" + ], + [ + "Ė", + "ĔĖĖĖĖ" + ], + [ + "Ė", + "ďĖ" + ], + [ + "ĕėĖ", + "ĕĕĖ" + ], + [ + "ė", + "ĘėĖ" + ], + [ + "ĕĕ", + "ĕĕ" + ], + [ + "ĖĖĖĖĖĖ", + "ėĖ" + ], + [ + "ę", + "Ĝ" + ], + [ + "ėĖĖėĖĖ", + "ėĖĖ" + ], + [ + "ĕ", + "ĚĖĖ" + ], + [ + "Ě", + "ě" + ], + [ + "ĕĖĖĖĖĖ", + "ĕĖ" + ], + [ + "ė", + "ąĖ" + ], + [ + "Ĝ", + "ĔĖ" + ], + [ + "Ę", + "ėĔ" + ], + [ + "ĕėĖ", + "ėėĖ" + ], + [ + "ė", + "ďĖ" + ], + [ + "ĕĕĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ĕ", + "ēĕĖ" + ], + [ + "đ", + "ē" + ], + [ + "Ē", + "ėĕ" + ], + [ + "ĔĖĖ", + "ėėĖ" + ], + [ + "Đ", + "ĘĖ" + ], + [ + "ĘĖĖ", + "ĘĕĖ" + ], + [ + "ęĖĖ", + "ėĖ" + ], + [ + "ėė", + "ėė" + ], + [ + "Ē", + "ĚĖ" + ], + [ + "Ę", + "ćĖ" + ], + [ + "ĕĖĖĕĖĖ", + "ĕĖĖ" + ], + [ + "ė", + "ėĖĖĖ" + ], + [ + "ě", + "ē" + ], + [ + "Ě", + "Đ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ē", + "Ĕė" + ], + [ + "Ĕ", + "Ęė" + ], + [ + "ę", + "Ď" + ], + [ + "đ", + "Ē" + ], + [ + "ęĖĖ", + "ęĖĖ" + ], + [ + "ĕ", + "ğ" + ], + [ + "ĘĕĖ", + "ėĕĖ" + ], + [ + "ē", + "Ď" + ], + [ + "Ē", + "Đ" + ], + [ + "Ĕ", + "ćĖ" + ], + [ + "ĕėĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ė", + "ěĖĖ" + ], + [ + "ĖĖĖĖ", + "ĘĖĖ" + ], + [ + "ĕ", + "ĕĖĖĕĖĖ" + ], + [ + "Ę", + "Ğ" + ], + [ + "ě", + "Ē" + ], + [ + "ėė", + "ĕė" + ], + [ + "đ", + "Ć" + ], + [ + "ě", + "ę" + ], + [ + "ĖĖĖĖĖĖ", + "ĕĖ" + ], + [ + "ē", + "ĆĖ" + ], + [ + "ĕĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "Ė", + "ĚĖĖ" + ], + [ + "ēĖ", + "ĕ" + ], + [ + "ę", + "ĆĖ" + ], + [ + "ėĖĖ", + "ėĖĖĖĖĖ" + ], + [ + "ėĕĖ", + "ēĖĖ" + ], + [ + "Ę", + "ĝĖ" + ], + [ + "Ĕ", + "ĝĖ" + ], + [ + "ėĕĖ", + "ęĖĖ" + ], + [ + "ĕĕĖ", + "ēĖĖ" + ], + [ + "+", + "Ė" + ], + [ + "ėĕĖ", + "ĘĕĖ" + ], + [ + "Đ", + "ĔĖ" + ], + [ + "ĒĖ", + "ė" + ], + [ + "\"", + "Ė" + ], + [ + "Ē", + "ą" + ], + [ + "Ě", + "Ęĕ" + ], + [ + "ėĖĖ", + "ė" + ], + [ + "ĕė", + "ĕĕ" + ], + [ + "ě", + "Ć" + ], + [ + "ėĖĖėĖĖ", + "ĕĖĖ" + ], + [ + "ĘĕĖ", + "ĕėĖ" + ], + [ + "Ĝ", + "ĘĖ" + ], + [ + "ęĕĖ", + "ĕĖĖ" + ], + [ + "ě", + "Ě" + ], + [ + "Ę", + "ĐĖ" + ], + [ + "ēĖĖ", + "ēĖĖ" + ], + [ + "Ě", + "ĚĖ" + ], + [ + "ĕĕ", + "ėė" + ], + [ + "ęĕĖ", + "ėĖĖ" + ], + [ + "ē", + "ĜĖ" + ], + [ + "ėėĖ", + "ė" + ], + [ + "ėĕĖ", + "ĔėĖ" + ], + [ + "ĕĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "ē", + "Ęė" + ], + [ + "ĕ", + "ĘĕĖ" + ], + [ + "Ē", + "Ĉ" + ], + [ + "Ĕ", + "Ĕĕ" + ], + [ + "ě", + "đ" + ], + [ + "Ě", + "ď" + ], + [ + "ĔĖĖ", + "ĔėĖ" + ], + [ + "đ", + "ć" + ], + [ + "ĕ", + "ēėĖ" + ], + [ + "ē", + "đĖ" + ], + [ + "ĕĕĖ", + "ęĖĖ" + ], + [ + "ĘĖĖ", + "ĘėĖ" + ], + [ + "ėĖĖĕĖĖ", + "ėĖĖ" + ], + [ + "Ě", + "Ĕĕ" + ], + [ + "ĕĖĖ", + "ĕĖĖĖĖĖ" + ], + [ + "ĔĖĖ", + "ĔĕĖ" + ], + [ + "ĕĕĖ", + "ĔĕĖ" + ], + [ + "ě", + "ć" + ], + [ + "ĕėĖ", + "ęĖĖ" + ], + [ + "ĕėĖ", + "ēĖĖ" + ], + [ + "Ě", + "ą" + ], + [ + "ėĕ", + "ėė" + ], + [ + "Ę", + "Ĕė" + ], + [ + "ĕ", + "ĎĖ" + ], + [ + "Ě", + "Ĉ" + ], + [ + "ĔĖĖ", + "ĘĕĖ" + ], + [ + "đ", + "ę" + ], + [ + "ėĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "ĘėĖ" + ], + [ + "Ē", + "ď" + ], + [ + "Ė", + "ėĖĖ" + ], + [ + "ĘĖĖ", + "ĔĕĖ" + ], + [ + "ėėĖ", + "ęĖĖ" + ], + [ + "ėĖĖĖĖĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ĕė", + "ĕė" + ], + [ + "ĕėĖ", + "ĘĕĖ" + ], + [ + "Ė", + "ĕĖĖ" + ], + [ + "đ", + "ęĖ" + ], + [ + "ē", + "Ĝ" + ], + [ + "ėĖĖ", + "ĕĖĖĖĖĖ" + ], + [ + "ėėĖ", + "ēĖĖ" + ], + [ + "ēĖĖ", + "ĕĖ" + ], + [ + "ĘĖĖ", + "ĔėĖ" + ], + [ + "ĕėĖ", + "ĔėĖ" + ], + [ + "ėėĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ě", + "ĕė" + ], + [ + "ė", + "ĔĕĖ" + ], + [ + "ě", + "ėė" + ], + [ + "Ē", + "Ĕė" + ], + [ + "đ", + "đ" + ], + [ + "ĕĖĖĖĖĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ę", + "ĜĖ" + ], + [ + "Ėę", + "ĖĖĖĖ" + ], + [ + "Ė", + "ĕĕ" + ], + [ + "Ĕ", + "Ğ" + ], + [ + "Ĕ", + "ďĖ" + ], + [ + "Ę", + "Ęė" + ], + [ + "ėĖĖ", + "ĚĖĖ" + ], + [ + "Ě", + "Ď" + ], + [ + "ėėĖ", + "ĔĕĖ" + ], + [ + "ėė", + "ĕĕ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖ" + ], + [ + "ĕĕĖ", + "ĘĕĖ" + ], + [ + "ė", + "ęĕĖ" + ], + [ + "ĒĖĖ", + "ĕĖĖ" + ], + [ + "ėĕĖ", + "ĔĕĖ" + ], + [ + "ĕ", + "ĞĖ" + ], + [ + "ėĖ", + "ėĖĖ" + ], + [ + "ĕĖĖ", + "ėĖĖėĖĖ" + ], + [ + "ėėĖ", + "ĘĕĖ" + ], + [ + "ėĕĖ", + "ė" + ], + [ + "ęĖĖ", + "ēĖĖ" + ], + [ + "ĕĕĖ", + "ĔėĖ" + ], + [ + "Ĝ", + "ĕĕ" + ], + [ + "Ė", + "ĎĖ" + ], + [ + "ėĖĖ", + "ĔĖ" + ], + [ + "Ē", + "Ď" + ], + [ + "ĖėĖ", + "ėĖĖ" + ], + [ + "đ", + "ĕĕ" + ], + [ + "ĚĖĖ", + "ĕĖĖ" + ], + [ + "Ē", + "Ęė" + ], + [ + "ĕĖĖ", + "ĔĖ" + ], + [ + "Ě", + "Ĕ" + ], + [ + "ė", + "ĞĖ" + ], + [ + "ĕĕ", + "ĕė" + ], + [ + "ĕĕĖ", + "ĘėĖ" + ], + [ + "Ė", + "ĞĖ" + ], + [ + "Ē", + "ě" + ], + [ + "ė", + "ĒĖĖ" + ], + [ + "Ğ", + "ĕĖ" + ], + [ + "ė", + "ĎĖ" + ], + [ + "Ę", + "ğ" + ], + [ + "Ĕ", + "Ęĕ" + ], + [ + "ĚĖĖ", + "ėĖĖ" + ], + [ + "ēĖĖ", + "ęĖĖ" + ], + [ + "ėĖĖ", + "ĘĖ" + ], + [ + "Ĝ", + "ē" + ], + [ + "ĖėĖ", + "ĕĖĖ" + ], + [ + "ėėĖ", + "ĘėĖ" + ], + [ + "ėĕĖ", + "ĘėĖ" + ], + [ + "ēėĖ", + "ĕĖĖ" + ], + [ + "ĔĖĖ", + "ĘėĖ" + ], + [ + "ėę", + "ĖĖĖĖ" + ], + [ + "ę", + "ĕĖĖĕĖ" + ], + [ + "ĕ", + "đĖĖ" + ], + [ + "ě", + "Đ" + ], + [ + "ěĖĖ", + "ĕĖ" + ], + [ + "ęėĖ", + "ėĖĖ" + ], + [ + "ė", + "ĔėĖ" + ], + [ + "ę", + "ėĘ" + ], + [ + "Ě", + "Ĝ" + ], + [ + "ėĖĖĖ", + "ĕĖĖĖĖ" + ], + [ + "ėėĖ", + "ĔėĖ" + ], + [ + "Ē", + "ĒĖ" + ], + [ + "*", + "Ė" + ], + [ + "ėĖĖ", + "ĕĖĖĕĖĖ" + ], + [ + "ē", + "ĕĔ" + ], + [ + "ēĕĖ", + "ĕĖĖ" + ], + [ + "ĕėĖ", + "ĔĕĖ" + ], + [ + "Đ", + "ėė" + ], + [ + "Ę", + "ĘĖĖ" + ], + [ + "Ĝ", + "ėĕ" + ], + [ + "ĘĕĖ", + "ęĖĖ" + ], + [ + "ĘĕĖ", + "ĕĕĖ" + ], + [ + "Ĕ", + "ĔĖĖ" + ], + [ + "ę", + "ĝ" + ], + [ + "Đ", + "ē" + ], + [ + "Đ", + "ĕė" + ], + [ + "ę", + "đĖ" + ], + [ + "ĕę", + "ĖĖĖĖ" + ], + [ + "đ", + "Ě" + ], + [ + "Ĝ", + "Ē" + ], + [ + "Ď", + "ėĖ" + ], + [ + "Ď", + "ĕĖ" + ], + [ + "ėĖĖ", + "ęĕĖ" + ], + [ + "ĘĕĖ", + "ēĖĖ" + ], + [ + "Ą", + "Ė" + ], + [ + "Ė", + "ėĖĖĕĖ" + ], + [ + "ĕėĖ", + "ĘėĖ" + ], + [ + "ĕĕ", + "ĔĖ" + ], + [ + "ğ", + "ĖĖ" + ], + [ + "Ğ", + "ėĖ" + ], + [ + "ĕĖĖĖ", + "ĕĖĖĖĖ" + ], + [ + "ĘĘ", + "ĖĖĖĖ" + ], + [ + "ĖĕĖ", + "ĕĖĖ" + ], + [ + "Ē", + "ěĖ" + ], + [ + "ě", + "ēĖ" + ], + [ + "ĚĖ", + "ĕ" + ], + [ + "Ė", + "Ĕ" + ], + [ + "ĔėĖ", + "ėĕĖ" + ], + [ + "ęėĖ", + "ĕĖĖ" + ], + [ + "Ė", + "Ĝ" + ], + [ + "ĕĕĖ", + "ĕĖ" + ], + [ + "ě", + "ě" + ], + [ + "đ", + "ēĖ" + ], + [ + "ĔĘ", + "ĖĖĖĖ" + ], + [ + "ę", + "ĕĘ" + ], + [ + "ĕĖĖ", + "ĘĖ" + ], + [ + "đ", + "ėĕ" + ], + [ + "ě", + "ęĖ" + ], + [ + "ĔėĖ", + "ĕėĖ" + ], + [ + "Ě", + "Ę" + ], + [ + "ĕėĖ", + "ĕĖ" + ], + [ + "đ", + "Đ" + ], + [ + "ę", + "Ĕė" + ], + [ + "Ė", + ",Ė" + ], + [ + "ėĖĖ", + "ĒĖĖ" + ], + [ + "ĕĕ", + "ėĕ" + ], + [ + "ĘėĖ", + "ĕĕĖ" + ], + [ + "ĝ", + "ĔĖ" + ], + [ + "ĕėĖ", + "ėĖ" + ], + [ + "ėĖĖĖ", + "ėĖĖĖĖ" + ], + [ + "ėėĖ", + "ėĖ" + ], + [ + "ě", + "Ęĕ" + ], + [ + "Ę", + "ďĖ" + ], + [ + "ėĘ", + "ĕė" + ], + [ + "ęĕĖ", + "ĘĖĖ" + ], + [ + "ęĕĖ", + "ĔĖĖ" + ], + [ + "ĚĖĖ", + "ĘĖĖ" + ], + [ + "ĖĕĖ", + "ėĖĖ" + ], + [ + "ĕĖĖ", + "ėĖĖĕĖĖ" + ], + [ + "ē", + "ėĔ" + ], + [ + "Ė", + "Ę" + ], + [ + "ė", + "ęėĖ" + ], + [ + "ĕĔ", + "ĖĖĖĖĖ" + ], + [ + "Đ", + "Ē" + ], + [ + "ėĕĖ", + "ĕĖ" + ], + [ + "ĘĔ", + "ĖĖĖĖ" + ], + [ + "ę", + "ĕĔ" + ], + [ + "ĘĕĖ", + "ĘĕĖ" + ], + [ + "ĕĕĖ", + "ėĖ" + ], + [ + "Ě", + "ĒĖ" + ], + [ + "ēėĖ", + "ėĖĖ" + ], + [ + "ē", + "Ĕĕ" + ], + [ + "#", + "Ė" + ], + [ + "ĒĖĖ", + "ĔĖĖ" + ], + [ + "Ę", + "ĔĖĖ" + ], + [ + "ĔĕĖ", + "ĕĕĖ" + ], + [ + "Ě", + "ĕĖĖĖĖ" + ], + [ + "Ę", + "ąĖ" + ], + [ + "ĔĕĖ", + "ėėĖ" + ], + [ + "ēĕĖ", + "ėĖĖ" + ], + [ + "ĒĖĖ", + "ėĖĖ" + ], + [ + "ę", + "Ęė" + ], + [ + "Ĕ", + "ĘĖĖ" + ], + [ + "Ě", + "ė" + ], + [ + "ę", + "ĕĖĖėĖ" + ], + [ + "ėĕĖ", + "ėĖ" + ], + [ + "ĔėĖ", + "ēĖĖ" + ], + [ + ")", + "Ė" + ], + [ + "ē", + "ėĘ" + ], + [ + "ĘĕĖ", + "ėėĖ" + ], + [ + "ėĘĖĖ", + "ĕĖ" + ], + [ + "Ē", + "Ĕ" + ], + [ + "đ", + "Ĉ" + ], + [ + "ęĖĖ", + "ėĕĖ" + ], + [ + "Ĕ", + "ąĖ" + ], + [ + "ĘėĖ", + "ėėĖ" + ], + [ + "ě", + "Ĉ" + ], + [ + "ď", + "ĘĖ" + ], + [ + "ęĖĖ", + "ĕĕĖ" + ], + [ + "Ėē", + "ĖĖĖĖ" + ], + [ + "ĔĕĖ", + "ēĖĖ" + ], + [ + "ĚĖ", + "ė" + ], + [ + "ĚĖĖ", + "ĔĖĖ" + ], + [ + "ĕĖĖĖ", + "ėĖĖĖĖ" + ], + [ + "Ě", + "ěĖ" + ], + [ + "ĕĔ", + "ĕĕ" + ], + [ + "ĕĘĖĖ", + "ĕĖ" + ], + [ + "ē", + "ĐĖ" + ], + [ + "ėĘ", + "ėė" + ], + [ + "Ę", + "ĚĖĖ" + ], + [ + "ĕĖĖĖĖĖ", + "ĘĖĖ" + ], + [ + "ėėĖ", + "ĕĖ" + ], + [ + "ě", + "ď" + ], + [ + "ē", + "ĕĘ" + ], + [ + "ēėĖ", + "ĔĖĖ" + ], + [ + "ę", + "ėĔ" + ], + [ + "ĘĕĖ", + "ĔėĖ" + ], + [ + "đ", + "ď" + ], + [ + "ėĖĖĖĖĖ", + "ĔĖĖ" + ], + [ + "ėėĖ", + "ĕ" + ], + [ + "đ", + "ą" + ], + [ + "ėĕĖ", + "ĕ" + ], + [ + "ėĖĖ", + "ĕ" + ], + [ + "ĕĖĖĖĖĖ", + "ĔĖĖ" + ], + [ + "Ė", + "ĘĖĖ" + ], + [ + "Ĝ", + "ę" + ], + [ + "ĕ", + ",Ė" + ], + [ + "ē", + "ėĖĖĕĖ" + ], + [ + "ĘėĖ", + "ęĖĖ" + ], + [ + "ėė", + "ėĕ" + ], + [ + "ėĖĖĖĖĖ", + "ĘĖĖ" + ], + [ + "ĔĕĖ", + "ėĕĖ" + ] + ] + } +} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_ped_4000000/tokenizer_config.json b/scenestreamer/tokenization/0305_fast_ped_4000000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..44b81ebde2224b4e3935b02872938beae622c37c --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_ped_4000000/tokenizer_config.json @@ -0,0 +1,8 @@ +{ + "added_tokens_decoder": {}, + "clean_up_tokenization_spaces": false, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "processor_class": "UniversalActionProcessor", + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/scenestreamer/tokenization/0305_fast_veh_5000000/delta_normalization_quantiles.json b/scenestreamer/tokenization/0305_fast_veh_5000000/delta_normalization_quantiles.json new file mode 100644 index 0000000000000000000000000000000000000000..06522a1cb8585b3ca6017a7c5db6359fd8cf1396 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_veh_5000000/delta_normalization_quantiles.json @@ -0,0 +1 @@ +{"q_lower": [-0.05596816554665565, -0.02602481350302696, -0.04004049301147461, -0.07802368082106113, -0.02550339575856924, -0.0398712158203125, -0.10218419015407562, -0.02493762530386448, -0.039841651916503906, -0.1266891425848007, -0.025245580431073906, -0.03985453128814698, -0.15155693650245666, -0.02577203346416354, -0.039835453033447266], "q_upper": [0.05516141653060913, 2.2003298497199997, 0.038505794405937155, 0.07299125269055365, 2.2003488397598217, 0.038251579403877245, 0.09229004696011536, 2.199215903282165, 0.03821897506713867, 0.11211258292198178, 2.1999488830566385, 0.038151146173477146, 0.13218084573745714, 2.20082770347595, 0.0382561767101286]} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_veh_5000000/error_mean.json b/scenestreamer/tokenization/0305_fast_veh_5000000/error_mean.json new file mode 100644 index 0000000000000000000000000000000000000000..7ab663440e86a3f4d843e5f19b2f58e6cbdf7c7c --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_veh_5000000/error_mean.json @@ -0,0 +1 @@ +[[0.0004786132866771012, 0.004060047930190917, 0.005571852738386794], [0.0008418189217362506, 0.014933188340568017, 0.006157294823575788], [0.0008461391505910589, 0.03096727067823814, 0.007005410314550139], [0.0009346813160780386, 0.05354768548491225, 0.01063678851399226], [0.0013970259123356357, 0.08213329261006687, 0.011346835764575368]] \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_veh_5000000/processor_config.json b/scenestreamer/tokenization/0305_fast_veh_5000000/processor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..8d9a3f725377bee10763626aa8cf6ddb5ad8ac87 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_veh_5000000/processor_config.json @@ -0,0 +1,8 @@ +{ + "action_dim": null, + "min_token": -22, + "processor_class": "UniversalActionProcessor", + "scale": 10, + "time_horizon": null, + "vocab_size": 1024 +} diff --git a/scenestreamer/tokenization/0305_fast_veh_5000000/special_tokens_map.json b/scenestreamer/tokenization/0305_fast_veh_5000000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_veh_5000000/special_tokens_map.json @@ -0,0 +1 @@ +{} diff --git a/scenestreamer/tokenization/0305_fast_veh_5000000/tokenizer.json b/scenestreamer/tokenization/0305_fast_veh_5000000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..8799da32c09eea864953d60d6d5cdcbbc80e359f --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_veh_5000000/tokenizer.json @@ -0,0 +1,4847 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": false, + "vocab": { + "\u0000": 0, + "\u0001": 1, + "\u0002": 2, + "\u0003": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "Ā": 45, + "ā": 46, + "Ă": 47, + "ă": 48, + "Ą": 49, + "ą": 50, + "Ć": 51, + "ć": 52, + "Ĉ": 53, + "ĉ": 54, + "Ċ": 55, + "ċ": 56, + "Č": 57, + "č": 58, + "Ď": 59, + "ď": 60, + "Đ": 61, + "đ": 62, + "Ē": 63, + "ē": 64, + "Ĕ": 65, + "ĕ": 66, + "Ė": 67, + "ė": 68, + "Ę": 69, + "ę": 70, + "Ě": 71, + "ě": 72, + "Ĝ": 73, + "ĝ": 74, + "Ğ": 75, + "ğ": 76, + "Ġ": 77, + "ĖĖ": 78, + "ĖĖĖĖ": 79, + "ĖĖĖĖĖĖĖĖ": 80, + "Ėĕ": 81, + "Ėė": 82, + "ĖĖĖĖĖ": 83, + "ėĀ": 84, + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ": 85, + "ĖĖĕ": 86, + "ėĀĖĖĖĖĖĖĖĖĖĖĖĖĖ": 87, + "ĖĖė": 88, + "ėĕ": 89, + "ĖĖĖ": 90, + "ĖĔ": 91, + "ĖĘ": 92, + "ĖĖĖĕ": 93, + "ĖĖĖė": 94, + "ėĖĕ": 95, + "ĕĕ": 96, + "Ėėĕ": 97, + "ĖĖĔ": 98, + "Ėĕĕ": 99, + "ėė": 100, + "ėĖė": 101, + "ėĖĖĖĖĖĖĖĖ": 102, + "ĖĖĖĖĕ": 103, + "ĖĖĘ": 104, + "ĖĖĖĖĖĕ": 105, + "ĖĖĖĖĖĖĖĖĖĖĖ": 106, + "ĖĖĖĖė": 107, + "Ėē": 108, + "ĖĖĖĖĖė": 109, + "ĖĖĕĖĖĕ": 110, + "Ėę": 111, + "ėĔ": 112, + "ėĖĖė": 113, + "ėā": 114, + "ėĖĖ": 115, + "ĕĖė": 116, + "ĖĕĖĖĕ": 117, + "ėĖĖĖĖĖĖĖĖĖĖĖĖ": 118, + "Ęĕ": 119, + "ĖĖĖĖĖĖ": 120, + "ĖĖĕĖĖ": 121, + "ĖĖėĖĖ": 122, + "ĕĖĕ": 123, + "ėĖĖĖĖ": 124, + "ĖĒ": 125, + "ĘĖĕ": 126, + "ėĖĖĕ": 127, + "ĖĖĖĖĖĖĖĖĖĖĖĖ": 128, + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ": 129, + "ėĂ": 130, + "ĖĚ": 131, + "ĖĖĖĖĖĖĖĖĖ": 132, + "ėĖĔ": 133, + "ĖĖĕĖė": 134, + "ĔĖė": 135, + "ĖĖĖĖĖĖĖė": 136, + "ĖĖėĖĕ": 137, + "Ĕĕ": 138, + "ĖĖĖĖĖĖĖĕ": 139, + "ėĘ": 140, + "ĘĀ": 141, + "ĕĖĖ": 142, + "ĖĀ": 143, + "ėē": 144, + "ėă": 145, + "ĖĖē": 146, + "ĘĖė": 147, + "ĖĖėĖė": 148, + "ėĖĘ": 149, + "ėĕĖĖĖĖĖĖĖĖĖĖĖ": 150, + "ėĖ": 151, + "ėĄ": 152, + "ėĕĖĖĕ": 153, + "ėĖĖĖė": 154, + "ĖĕĖĖė": 155, + "ĖĖę": 156, + "ĖĕĖė": 157, + "ĖĖĖĖĖĖĖĖĖĖ": 158, + "ĕĖĘ": 159, + "ėĕĖė": 160, + "ĕĔ": 161, + "ėą": 162, + "Ėđ": 163, + "ėĕĕ": 164, + "ĖĖĕĖĕ": 165, + "ĔĖĕ": 166, + "ėĖĖĖĕ": 167, + "ėę": 168, + "Ėě": 169, + "ėĒ": 170, + "ĖĖėĖĖė": 171, + "ėĆ": 172, + "ĖėĕĖė": 173, + "ĕĖĔ": 174, + "ėć": 175, + "ĕĖĖĕ": 176, + "ĔĖĘ": 177, + "ĖĕĖĖĖĖĖĖĖĖ": 178, + "ėĈ": 179, + "ĖĖĖĖĖĖĕ": 180, + "Ėā": 181, + "ĖĕĖĖĖ": 182, + "ėđ": 183, + "ĖėĖĖĖ": 184, + "ĖĐ": 185, + "ĖėĕĖĖ": 186, + "ėĐ": 187, + "ĘĔ": 188, + "Ęā": 189, + "ėĎ": 190, + "ėď": 191, + "ĖĜ": 192, + "ėėĖĕ": 193, + "ĘĖĔ": 194, + "ėĕĖĕ": 195, + "ĖėĖĖė": 196, + "ĘĖĖ": 197, + "ĖĖĖĖĖĖė": 198, + "ėĚ": 199, + "ĔĖĖ": 200, + "ĖĕĖĕ": 201, + "ėĖē": 202, + "ėĕĖĖĖĖĖĖĖĖ": 203, + "ĖĕĕĖė": 204, + "Ėď": 205, + "Ęē": 206, + "ėĖĖĖĖĖė": 207, + "ĖĖĕĖĖė": 208, + "ęĕ": 209, + "ĖĖĕĖĖĖĖĖ": 210, + "ĖĕĖĖĖĖĖĕ": 211, + "ĖĕĕĖĖ": 212, + "ėėĕ": 213, + "ėĖĖĖĖĖ": 214, + "ĖĖĕĕ": 215, + "ĖĖėĖĖĖĖĖ": 216, + "ĖĕĖĖĕĖĖĕ": 217, + "ĖĖĖėĖĖ": 218, + "ēĕ": 219, + "ĖĖĖĖĖĕĖĖ": 220, + "ĖĖĖĕĖĖ": 221, + "ėĖĖĘ": 222, + "ĖėĕĖĕ": 223, + "ĖĕĕĖĕ": 224, + "ĖĎ": 225, + "Ėĝ": 226, + "ĘĘ": 227, + "ĖĂ": 228, + "ĖĖėĕ": 229, + "ėě": 230, + "ęĖĕ": 231, + "ėĖĖĖĖĕ": 232, + "ĔĖĔ": 233, + "ĀĖĖĖĖĖĖĖĖĖĖĖĖ": 234, + "ĖĖĖĖĕĕ": 235, + "ĖĖĕĖĖĖ": 236, + "Ėă": 237, + "ēĖė": 238, + "ėĖę": 239, + "ĖĖĖĖĖėĖĖ": 240, + "ĘĖĖė": 241, + "ĘĂ": 242, + "ėĀėĕ": 243, + "ĘĒ": 244, + "ėĖĖĖ": 245, + "ĖĖĒ": 246, + "ĖĖĔĖė": 247, + ",ĖĖĖĖĖĖĖĖĖĖĖĖ": 248, + "ĖĄ": 249, + "ĖĕėĖĕ": 250, + "ĘĖĘ": 251, + "Ęę": 252, + "ėĜ": 253, + "ĕē": 254, + "ĔĖĖĕ": 255, + "ėĖĖĖĖĖĕ": 256, + "ĖĖėĖĖĖ": 257, + "ĖĞ": 258, + "ĖĖĕĖĖĕĖĖ": 259, + "ėėĖė": 260, + "ėĀĖĕ": 261, + "Ęă": 262, + "Ęđ": 263, + "Ėą": 264, + "ĕĕĕ": 265, + "ėĕĖĖĕĖĖĕ": 266, + "ėĖĖĖĖė": 267, + "ĖĖĚ": 268, + "ĖĖĖėĖĕ": 269, + "ĖĈ": 270, + "ĘĚ": 271, + "ĖĆ": 272, + "ĔĔ": 273, + "Ėć": 274, + "ĘĄ": 275, + "ĘĐ": 276, + "ėĝ": 277, + "ĖĖėĖĖĕ": 278, + "ĖĖĖĕĖė": 279, + "ėĕĖĖĖĖĖĕ": 280, + "ęĖė": 281, + "Ęď": 282, + "ĘĎ": 283, + "ĕĖę": 284, + "ĖĖĔĖĖ": 285, + "ĖėĖĕ": 286, + "Ęą": 287, + "ĘĖĖĕ": 288, + "ĖĔĖĖĕ": 289, + "ĖĖĖėĖė": 290, + "Ėğ": 291, + "ĘĖĖĘ": 292, + "ĖĔĕ": 293, + "ĖĕĔ": 294, + "ėĞ": 295, + "Ęě": 296, + "ĘĆ": 297, + "ĖĕĖĖĕĖĖĖĖĖĖĖĖ": 298, + "ĖĖĖĖėĖĖĖ": 299, + "ĘĈ": 300, + "ĕĖĖė": 301, + "Ęć": 302, + "ĖĖĖĖĕĖĖĖ": 303, + "ĖėĖĖĕ": 304, + "ēĖĘ": 305, + "ėĖĖėĖĖ": 306, + "Ęĕĕ": 307, + "ėāĖĖĖĖĖĖĖĖĖĖĖĖĖ": 308, + "ĖĖĖĖĔ": 309, + "ėğ": 310, + "ėĀĖĕĖĖĖĖĖĖĖĖĖĖĖ": 311, + "ĔĖę": 312, + "ĖĖĖĖĖĖĖĖĕ": 313, + "ĖĕĖĖĖĖĕ": 314, + "ėĖėĕ": 315, + "ĖĖĖĖėĕ": 316, + "ĖĕĖĖĖĖĖė": 317, + "ĖĖėĖĖėĖĖ": 318, + "ĖĖĖĔ": 319, + "ĘĜ": 320, + "ĕĖėĕ": 321, + "ĕĖē": 322, + "ĕĀ": 323, + "ęĀ": 324, + "ėĖĒ": 325, + "ĕĒ": 326, + "ĕę": 327, + "ĖėĖė": 328, + "ēĖĕ": 329, + "ĖĖĘĖĕ": 330, + "ĘĖē": 331, + "Ęĝ": 332, + "ĕĖĕĕ": 333, + "ĖėĖĖ": 334, + "ėĖĖĔ": 335, + "ėĕĖĖĕĖĖĖĖĖĖĖĖ": 336, + "ĖĖĖĖėė": 337, + "ĖĖĖĕĖĕ": 338, + "ėĖĖĖĖĖĖĖĖĖĖ": 339, + "ĖĖĖĖĖĖĖėĖĖĖ": 340, + "ėĖĖėĖĕ": 341, + "ĕĘ": 342, + "ĕĕĖė": 343, + "ĖĖĖĖėĖĕ": 344, + "ėĀėĖĖĖĖĖĖĖĖĖĖĖĖ": 345, + "ĕĚ": 346, + "ĖĕĖĕĕ": 347, + "ĖĖĖĖĘ": 348, + "Ė!": 349, + "ĖĖĖĖĖĖĖĕĖĖĖ": 350, + "ĖĕĖĖ": 351, + "ėėĖĖė": 352, + "ĖĔĖė": 353, + "ĘĞ": 354, + "ĖĕĖėĕ": 355, + "ėĕĖĕĕ": 356, + "ĘĕĖė": 357, + "ęĖĔ": 358, + "ĔĖėĕ": 359, + "ėĕĖĖĖĖĖė": 360, + "ĖĖĘĖĖ": 361, + "ėĕĖĖĖĖė": 362, + "ėĕĖėĕ": 363, + "ĖĖĔĖĕ": 364, + "ė!": 365, + "ĕĖĖĖĖĖĖĖĖ": 366, + "ĘėĖĕ": 367, + "ĖĖĖĖĖĖėĖĖ": 368, + "ėĀĖĖĖĖĖĖĖĖ": 369, + "ĖĖĖĖĖĕĖĖĖ": 370, + "ėĖĖėĖė": 371, + "ĖĖĖĖĖėĖĖĖ": 372, + "ėĖĚ": 373, + "Ęğ": 374, + "ĖĖĖĕĖĖĖĖĖ": 375, + "ė\"": 376, + "ėĘĖĕ": 377, + "ĖĖĖĖĖĖĕĖĖ": 378, + "ėĕĖĖĖĖĕ": 379, + "ĕĖ": 380, + "ĕě": 381, + "ėĖĕĖĖĖ": 382, + "ėĕĖĖė": 383, + "ĖĖĖĖĖĖĖĖė": 384, + "Ĕē": 385, + "ĖĖĘĖė": 386, + "ĘĖĖĖĖĖĖĖĖĖĖĖĖ": 387, + "ĕđ": 388, + "Ė\"": 389, + "ĖĖĖĖėĖĖė": 390, + "ėĖĕĖĖĕ": 391, + "ėĔĖė": 392, + "ĖĖđ": 393, + "ĖĖĕĖĖĕĖĖĖĖĖ": 394, + "ėĔĖĖĕ": 395, + "ĖĖĖĖĖĔ": 396, + "Ė,": 397, + "ė,": 398, + "ĔĖē": 399, + "ĖĖĖĘ": 400, + "ė#": 401, + "ėĖĕĕ": 402, + "ĖĖĔĖĖĕ": 403, + "ĕĕĖĖĕ": 404, + "ĖĖĖėĖĖĖĖĖ": 405, + "Ę!": 406, + "ė$": 407, + "ęĔ": 408, + "ĘĖę": 409, + "ĕĜ": 410, + "ėĖĕėĖĕ": 411, + "ĖĖě": 412, + "ĘĖĖĖĖĖĖĖĖ": 413, + "ėĀėĖĖĖĖĖĖĖĖ": 414, + "ĖĕĖĖĖĖė": 415, + "Ė#": 416, + "ęĚ": 417, + "ĔĖĖĔ": 418, + "ėĖėĖĖĖ": 419, + "ĖĖĖĖĕĖĖĕ": 420, + "ĕĐ": 421, + "ęĒ": 422, + "Ęėė": 423, + "ĔĖĕĕ": 424, + "ėāĖĕĖĖĖĖĖĖĖĖĖĖĖ": 425, + "ĖĖėĖĖĖĖĖĖĖĖĖĖ": 426, + "ėĂĖĖĖĖĖĖĖĖĖĖĖĖĖ": 427, + "ėĀėĕĖĖĖĖĖĖĖĖĖĖĖ": 428, + "ęĖĘ": 429, + "ėĖĕĖĖė": 430, + "ĖĕĖĖĖĖĖĕĖĖĖĖĖ": 431, + "ĖĖĖĕĕ": 432, + "Ę\"": 433, + "ėĖėĖĖė": 434, + "ĕĕĖĖĖĖĖĖĖĖĖĖĖ": 435, + "ėĕėĖĕ": 436, + "ēĖĔ": 437, + "ĖĘĕ": 438, + "ėĘĕ": 439, + "Ę,": 440, + "ĕĖĖĖĖĖĖĖĖĖĖĖĖ": 441, + "ėĖĕĖĖĖĖĖĖĖĖĖĖ": 442, + "Ė$": 443, + "ēĔ": 444, + "ĖĔĖĖĔ": 445, + "Ęėĕ": 446, + "ęĖĖ": 447, + "ĕĝ": 448, + "ĖĖĖĖĖėĖĕ": 449, + "ĖĖĖĕĖĖĕ": 450, + "ęę": 451, + "ēĖĖ": 452, + "ėĕĖĖĕĖė": 453, + "ėĖėĖĖĕ": 454, + "ĕā": 455, + "ĒĖė": 456, + "ĕĖĚ": 457, + "Ę$": 458, + "ĖĕĖĖĕĖĖĕĖĖĖĖĖ": 459, + "ĖĖĖĖĖĖĖĖĖĖėĖĖ": 460, + "ė%": 461, + "ėĖĖĕĖĖĕ": 462, + "ēĖę": 463, + "ĖĖĕĖĖĖĖĖĖĖĖĖĖ": 464, + "ęā": 465, + "Ę#": 466, + "ĕĕĖĕ": 467, + "ĖėĕĖĖĖĖĖ": 468, + "ĖĖĖĖĖĖĖėĖĖĖĖĖ": 469, + "ĕď": 470, + "ęě": 471, + "ĖėĕĖĘ": 472, + "ėĖĖĖĖĖĖĖĖĖėĖĖ": 473, + "ėĕĕĖĕ": 474, + "ĖĖĕĖĖĖĖĖĖĖĖĖ": 475, + "ĕĖĖĖĖĖĖĖĖĖĖĖ": 476, + "ėāėĖĖĖĖĖĖĖĖĖĖĖĖ": 477, + "ĕ,": 478, + "ėĖđ": 479, + "ĖĖĖĖĖĕĖė": 480, + "ĚĖĕ": 481, + "ĖĕĖĖĖĖĖĖĖĖĕĖĖ": 482, + "ĕĖĒ": 483, + "ęđ": 484, + "ĖĖĕĖėĖĖĖ": 485, + "ĕĞ": 486, + "ĖĖėĖĕĖĖĖ": 487, + "Ė%": 488, + "ĕĖĖĖ": 489, + "ė&": 490, + "ĖĖĖĖĖĕĖĖĖĖĖ": 491, + "ĕĎ": 492, + "ĖĕĖĖĔ": 493, + "ĖĖėĖĖĖĖĖĖ": 494, + "ėĖĖĖĖĖĖėĖĖĖĖĖ": 495, + "ĖĕĖĖĖĖĖĖĖĖĖĖĕ": 496, + "ėĖėėĖė": 497, + "ĖĖĖĖėĖĖĕ": 498, + "ęē": 499, + "ėĖėĖĖĖĖĖĖĖĖĖĖ": 500, + "ĖĖĖĖĖĖĖĕĖĖĖĖĖ": 501, + "ĖĕĕĖĖĖĖĖ": 502, + "ĖĖĖĖĖĖĖĖĖĖĕĖĖ": 503, + "ĔĒ": 504, + "ėĀĖĖĖĕ": 505, + "ėĖĖĖĖĖĖė": 506, + "ĕėė": 507, + "ĖĕĖĖĖĖĖĖĖĖėĖĖ": 508, + "ęĜ": 509, + "ĔĖĖė": 510, + "ėĕĖĘ": 511, + "Ĕĕĕ": 512, + "ėĖĖĖĖĖĖĖĖĖĖĖ": 513, + "ĕğ": 514, + "ėĖĖĖĖĖĖĖĖĖĕĖĖ": 515, + "ĖĖĖĖĕĖĖė": 516, + "ĖĔĖĖė": 517, + "ĖĖĖĖĖĘ": 518, + "ėĖĖĖĖĖĖĕ": 519, + "ęĐ": 520, + "ėĀĖĖĖė": 521, + "ėĀėĖĖĕ": 522, + "ĖĖĖĖĖĖĖĖĖĖĖĖĕ": 523, + "ĖĕĕĖĘ": 524, + "ęĂ": 525, + "ėĖėėĖĕ": 526, + "ė'": 527, + "ĔĖĚ": 528, + "ęă": 529, + "Ę%": 530, + "ĖĕĖĖĕĖė": 531, + "ėĕĕĕ": 532, + "ėăĖĖĖĖĖĖĖĖĖĖĖĖĖ": 533, + "ėĖĖĖĖĖĖĕĖĖĖĖĖ": 534, + "ĘĖĖĖė": 535, + "Ę)": 536, + "ė)": 537, + "ĘĖĒ": 538, + "ĖĖĕĖĖĕĖė": 539, + "ėĕĖĖĖĖĖĖĖĖĕĖĖ": 540, + "ėĕĖĖĖĖĖĕĖĖĖĖĖ": 541, + "ĖĖėĖĖĖĖĖĖĖĖĖ": 542, + "ėĖě": 543, + "ėĕĖĖĔ": 544, + "Ė&": 545, + "ĖĖĘĖĖė": 546, + "ęď": 547, + "ė(": 548, + "ĕĖĖĔ": 549, + "ĖĔĖĕ": 550, + "ĕĖėĖĖĕ": 551, + "ĚĖė": 552, + "ęĖē": 553, + "ęĝ": 554, + "ęĄ": 555, + "ėĕĖĖĕĖĖĕĖĖĖĖĖ": 556, + "ĖĔĖėĕ": 557, + "ėĖĖėĖĖĖĖĖ": 558, + "ĖĖĖĖėĖĖĖĖĖĖĖĖ": 559, + "ėĕĔ": 560, + "ĕĖĖĖĖ": 561, + "ęĎ": 562, + "ĖĕĖĖĖĖĖĖĖĖĖĖė": 563, + "ĖĖĖĖĖėĖė": 564, + "Ę&": 565, + "ėĖĖĖĖĖĖĖĖĖĖĖė": 566, + "ėĔĖĕ": 567, + "ėĀėĖĖė": 568, + "ėĖĕėĖĖ": 569, + "ĕă": 570, + "ĕ!": 571, + "ĕĂ": 572, + "Ė'": 573, + "ĖĖĖĖĖĖĕĖĖĖĖĖĖ": 574, + "ĖĖĖĖĖĖĖĖĖĖĖĖė": 575, + "ėĕĖĖĖĖĖĖĖĖėĖĖ": 576, + "ĘĕĖĖĕ": 577, + "ĖĕĕĖĔ": 578, + "ĕĄ": 579, + "ēĖĖĕ": 580, + "ęĞ": 581, + "ėĔĕ": 582, + "ĘĖĖĖĕ": 583, + "ĖĖĖĖĖĖĖĖĖė": 584, + "Ēĕ": 585, + "ĖĘĖĕ": 586, + "ėĕėė": 587, + "ėĖėėĖĖ": 588, + "ėĕĖĖĖĖĖĖĖĖĖĖė": 589, + "ėāĖĕ": 590, + "ėĖėĖė": 591, + "Ěĕ": 592, + "ėāėĕĖĖĖĖĖĖĖĖĖĖĖ": 593, + "ėĖĖĖĖĖĖĖĖĖĖĖĕ": 594, + "ĖĖĖĕĖĖĕĖĖ": 595, + "ĖĖĕĖĖĕĖĖĕĖĖ": 596, + "ĒĖĕ": 597, + "ĘĕĖĕ": 598, + "Ę(": 599, + "ėĖĖĖėĖĖĖĖĖĖĖĖ": 600, + "ĖĘĖĖė": 601, + "ĕ\"": 602, + "ĕĈ": 603, + "ĕĖĖĖĖĖ": 604, + "ęĕĕ": 605, + "ĖĖĖĖĖĕĖĕ": 606, + "ĖĖĖĖĖĖĖ": 607, + "ėĖĖĖėĖė": 608, + "ęą": 609, + "ęğ": 610, + "ĖĖĖĖĖĖĖĖĖĕ": 611, + "Ĕ,": 612, + "ĕć": 613, + "ĒĖĘ": 614, + "ĖĖĕĖĕĖĖĖ": 615, + "ėĖĖĕĖĖ": 616, + "ĔĚ": 617, + "Ę'": 618, + "ĕĖĖĖė": 619, + "ę,": 620, + "ęĖĖė": 621, + "ĖĖĖėĖĖėĖĖ": 622, + "ĕą": 623, + "ĖĖĐ": 624, + "ĖĖĔĖĖė": 625, + "ĕĆ": 626, + "ēē": 627, + "Ė(": 628, + "ęĈ": 629, + "ėĖĕĖė": 630, + "ĖĖĘĖĖĕ": 631, + "ėĔĖėĕ": 632, + "ĘĕĖĖĖĖĖĖĖĖĖĖĖ": 633, + "ĕĖėĖĖė": 634, + "ėĕĖĖĖĖĖĖĖĖĖĖĕ": 635, + "Ĕėė": 636, + "ĖĖēĖė": 637, + "ĘėĖė": 638, + "ĖĕĕĖĖĖĖĖĖĖĖĖĖ": 639, + "ėĄĖĖĖĖĖĖĖĖĖĖĖĖĖ": 640, + "Ĕě": 641, + "ėĀĖĕĖė": 642, + "ĖĔĕĖĕ": 643, + "ĕĖėĖĖĖĖĖĖĖĖĖ": 644, + "ęć": 645, + "ĖĘĖĖĕ": 646, + "ĖĖĖĖĖėĖĖĖĖĖ": 647, + "Ĕđ": 648, + "ĖĖĖĖĖĖĖėĖĖ": 649, + "ęĆ": 650, + "ĖĖĖĖĖĖĖĖėĖĖ": 651, + "ĖĖĕĖĖĕĖĕ": 652, + "ĖĖėĖėĖĖĖ": 653, + "ĖĔėĖĕ": 654, + "ĖĕĖĖĖĖĖĖ": 655, + "ĖĖĖĖĖĖĖĖĖėĖĖĖ": 656, + "ĘĖĚ": 657, + "Ėĕĕĕ": 658, + "ĕĖėĖĖĖ": 659, + "ĖĖĖĖĖĖĖĖĕĖĖ": 660, + "ĖĕėĖĖĖĖĖĖĖĖĖĖ": 661, + "ĕ#": 662, + "ĔĖĒ": 663, + "ĖĖĖĖĖĖĖĖĖĕĖĖĖ": 664, + "ĖĖĜ": 665, + "ĕė": 666, + "ėĔĖĖĔ": 667, + "ĖėĕĖĖĕĖĖ": 668, + "ėĀĖĕĖĖĕ": 669, + "ĕėĖĕ": 670, + "ĖĖĖĖĕĖĖĖĖĖĖĖĖ": 671, + "ėĖĖėĖĖė": 672, + "ėĖĖĖĕĖĖĖĖĖĖĖĖ": 673, + "ĖĖĖĖĕĕĖĖ": 674, + "ėĖĖĕĖė": 675, + "ĖĖĖĖėĖė": 676, + "ĖĖĖĖĖĖĖĕĖĖ": 677, + "ĖĕĕĖĖĖĖĕ": 678, + "ėĕĖĔ": 679, + "Ė)": 680, + "ĕ$": 681, + "ĘĖĖę": 682, + "Ėĕėė": 683, + "ĕĖě": 684, + "ėĖĐ": 685, + "ĖĖėĖĕėĖĖ": 686, + "Ĕę": 687, + "ēĖē": 688, + "ĚĖĔ": 689, + "ĖĖĖĖĖėĖĖĕ": 690, + "ĖĖĖėĖĖĕ": 691, + "ĖĖĖĖėĕĖĖ": 692, + "ĘĖĖĔ": 693, + "ĖėĕĖĔ": 694, + "ėĖĖĖĖĖĖĖĖėĖĖĖ": 695, + "ĖĕĖĘ": 696, + ",Ď": 697, + "ĀĐ": 698, + "ĖĕĕĖĖĕĖĖ": 699, + "ėĕĕĖĖĖĖĖĖĖĖĖĖ": 700, + "ĖĖĔĖĖĕĖĖ": 701, + "ĖĖĖĖĖĕĖĖĕ": 702, + "ę!": 703, + "ĖĖĖėĖĖė": 704, + "ėĖĖĖĖĖĖĖĖĕĖĖĖ": 705, + "ĘĘĖĕ": 706, + "ėĕĖĕĖĖė": 707, + "ĖĕĕĖĖĕ": 708, + "ėĕėĖĖĖĖĖĖĖĖĖĖ": 709, + "ĖĖĖĕĖĖė": 710, + "ĘĖ": 711, + "ĖĖĖėĕ": 712, + "ėąĖĖĖĖĖĖĖĖĖĖĖĖĖ": 713, + "ĖĕĖĖĕĖĖĕĖĖĕĖĖ": 714, + "ĘĖĖĖĖ": 715, + "ĔĀ": 716, + "ĖĖĖĖĖĕĖĖė": 717, + "ĖėĔ": 718, + "ĔĜ": 719, + "ĖĖĖĖėėĖĖ": 720, + "ėĕĖĖĖĖĖĖĖė": 721, + "ęĖĖĕ": 722, + "Āđ": 723, + "ėĕĖĖĖĖĖĖĖĕ": 724, + "ėĖĖĖĖĖėĖĖĖĖĖĖ": 725, + "ėĘĖė": 726, + "ĖĖĖĖĖėĖĖė": 727, + "ę\"": 728, + "ĕ%": 729, + "ēĖĖĔ": 730, + "ęĖę": 731, + "ėėĖĔ": 732, + "ėĖĖĕĕ": 733, + "ĖĖĖĖėĖĖĖĖĖ": 734, + "ĖĖĘĖĖĘ": 735, + "ėĂėĖĖĖĖĖĖĖĖĖĖĖĖ": 736, + "ĖĖĖĖĕĖĖĖĖĖĖ": 737, + "ĘĔĖė": 738, + "ĕĖĕĖĖė": 739, + "ĔĐ": 740, + "ėĖĖĕĖĖė": 741, + "ėĂĖĕĖĖĖĖĖĖĖĖĖĖĖ": 742, + "ĖĖĕĖĖĖĖĕ": 743, + "ėĀėĕĖė": 744, + "Āď": 745, + "ĖĖĖĕĖĘ": 746, + "ėėĔ": 747, + "ĘĕĖĘ": 748, + "ĕĖđ": 749, + "ēĖĚ": 750, + "ėĀĕĕ": 751, + "ēĒ": 752, + "ĖĕĖĖĖĖĖĖĖė": 753, + "ĖĕĖĖĖĖĖĖĖĕ": 754, + "ĖĖĖĖĖĖėĖĕ": 755, + "ęĘ": 756, + "ėĖĕėĖė": 757, + "ĘėĖĖė": 758, + "ĖĕėĖĖĖĖ": 759, + "ĖĖĖėĖĔ": 760, + "ĖĖĖĖĕėĖĖ": 761, + "ĖĖĕĖĖĕĖĖĖ": 762, + "Ĕĝ": 763, + "ĖĕėĖĖ": 764, + "ĕ&": 765, + "ėėĕĖė": 766, + "ėĖĖėĖĖėĖĖ": 767, + "ėėĖĖĕ": 768, + "ĀĒ": 769, + "ĕĖĖĖĕ": 770, + "ę#": 771, + "ĘĖĕĕ": 772, + "ėĖĜ": 773, + "ĖĕĖĖĖĖĖĕĖĖĕĖĖ": 774, + "ĖĖĖĖĖĖĖĖĖĖĕ": 775, + "ĖĕĖĖĕĖĖĖ": 776, + "ĖĖĔĖĘ": 777, + "ĖėĕĖėĖĖĖ": 778, + "ėėĕĕ": 779, + "ėĀėĕĖĖĕ": 780, + "ĔĖ": 781, + "ĖĖėĖĖĖĖĕ": 782, + "ĘĔĕ": 783, + "ĖĖĖĖĖĖĖėĖĖėĖĖ": 784, + "ėĖĖĖĖĖĖĖĕ": 785, + "Ėĕē": 786, + "ėĆĖĖĖĖĖĖĖĖĖĖĖĖĖ": 787, + "ĚĚ": 788, + "ĔĞ": 789, + "ĕ'": 790, + "ĖĖĖĖĖĖĕĖė": 791, + "ĔĕĖĕ": 792, + "ę$": 793, + "ėĖĖĖĖĖĖĖė": 794, + "ĖĖĖĕĖĖĖĖĖĖĖĖĖ": 795, + "ĔĕĖė": 796, + "ĖĖĖĖĖĕĖĖĕĖĖ": 797, + "ė*": 798, + "ĘĖđ": 799, + "ėĕėĕ": 800, + "ĖĕėĖė": 801, + "ĖĖėėĖĕ": 802, + "Ęė": 803, + "ĖĕĖėĖĖĕ": 804, + "ĀĎ": 805, + "ĔĖě": 806, + "ėĖĖĖĖĖĕĖĖĖĖĖĖ": 807, + "ĕĖĕĖĖĖ": 808, + "ĖĖĕĖėĖĖĕ": 809, + "ėĕĖĖĖĖĖĖĖėĖĖĖ": 810, + "ē,": 811, + "ėĖĖĖĖĖĖėĖĖėĖĖ": 812, + "ēĕĕ": 813, + "ĖĖĖĖėĖĖĖĖĖĖ": 814, + "ĖĖėĖĖėĖĖĖĖĖ": 815, + "ĕĔĖė": 816, + "Ĕď": 817, + "Ěě": 818, + "ĖĕĖĔ": 819, + "ĖĖĔĖĖĔ": 820, + "ĖĖĖĖĖĖėĖĖĖĖĖĖ": 821, + "ĖĕĖĖĖĖĖĖĖėĖĖĖ": 822, + "ĘĔĖĕ": 823, + "ĖĕĖĖĖĖĖėĖĖĖĖĖ": 824, + "Ĕğ": 825, + "ĚĀ": 826, + "ĖēĖė": 827, + "ĘĖĖėĖĖ": 828, + "ėĖĕĖĖĖĖĖĖĖĖĖ": 829, + "ĖĖĖĖĖĖĖĖĖĖė": 830, + "ĒĖę": 831, + "ęĖĖĘ": 832, + "ĘĖĖėĖĕ": 833, + "ĒĖĔ": 834, + "ĖĖėĖėėĖĖ": 835, + "ėĖĖĕĖĖĖĖĖĖĖĖĖ": 836, + "ĖĖĕĖėĖĖė": 837, + "ėėĖ": 838, + "ėĕĖĖĕĖĖĕĖĖĕĖĖ": 839, + "ėĕĖĖĖĖĖĖĖĕĖĖĖ": 840, + "ĖĕĖĖĖĖĖĖĖĕĖĖĖ": 841, + "ĖĖĖĖĖĖĕĖĖĖ": 842, + "ėăėĖĖĖĖĖĖĖĖĖĖĖĖ": 843, + "ėĖĖĖĕĖė": 844, + "ęĕĖė": 845, + "ėēĖė": 846, + "ėĕĖĖĖĖ": 847, + "Ěđ": 848, + "ĖĘĖė": 849, + "Ė*": 850, + "ĕĖĖĖĖĕ": 851, + "ĕ(": 852, + "ĖĖĘĖĖĖ": 853, + "ĖĖėĖĖėĖĖĖ": 854, + "ĘĖĖĖĖĖ": 855, + "ėĕĖ": 856, + "ĔĖĖĕĖĖ": 857, + "ėĕėĖĖĖĖ": 858, + "ĖĕĖĖĖĖĕĖĖĖĖĖĖ": 859, + "ĖėĕĖĖĕ": 860, + "ĖĖĖĖĖĖėĖĖĖ": 861, + "ĚĒ": 862, + "ĖĖęĖĕ": 863, + "ĕĕĕĕ": 864, + "ĔĖĖē": 865, + "ĔĘ": 866, + "ĖĖď": 867, + "ĖĔĖĕĕ": 868, + "ę%": 869, + "ĖĖĔĖĖĖ": 870, + "ĖĖĝ": 871, + "ĕĖĖĘ": 872, + "ĕĖĖĖĖĖĕ": 873, + "ĖĕĖĖĖĖĕĕ": 874, + "ęĖĒ": 875, + "ėćĖĖĖĖĖĖĖĖĖĖĖĖĖ": 876, + "ĕėĘ": 877, + "ĔĎ": 878, + "ĖėĖĔ": 879, + "ĖĕĖėĖĖĖĖĖĖĖĖĖ": 880, + "ėĕĖĖ": 881, + "ėĖĖĖėĖĖė": 882, + "Ě,": 883, + "ĚĄ": 884, + "ĖĖĖĖėĖĖėĖĖĖĖĖ": 885, + "ęėĕ": 886, + "ĖĘĖĖĖ": 887, + "ėėĖĖ": 888, + "ĖĕĖĖĕĖĖĖĖĖ": 889, + "Ěă": 890, + "ėĕĖĖĖĖĖėĖĖĖĖĖ": 891, + "ėėĖĖĘ": 892, + "ĖĔĖĖĕĖĖĕ": 893, + "ĚĖĘ": 894, + "ėėĖĘ": 895, + "ĕĔĖĖĕ": 896, + "ĖĔĖĖĖ": 897, + "ėĖĖĖėĖĕ": 898, + "ĕ)": 899, + "ĘĖĖėĖė": 900, + "ĚĜ": 901, + "ēđ": 902, + "ĖĖĖĕĖėĖĖĖ": 903, + "ĖĖĖĖĕĖė": 904, + ",ď": 905, + "Ĕ!": 906, + "ĘĖĕėĖĕ": 907, + "ĖĖĖĖĖĖĖĕĖĖĕĖĖ": 908, + "ėĂėĕĖĖĖĖĖĖĖĖĖĖĖ": 909, + "ĖĕĖĖĕĖĖė": 910, + "ĖĖĖĖĖĖėĖĖĖĖĖ": 911, + "ĚĐ": 912, + "ĕĖĜ": 913, + "ĘĖĖĖĖĖė": 914, + "ėĖď": 915, + "ĖėėĖĖ": 916, + "ĕĕĖĘ": 917, + "ĖĖĕĖĖĖĖĖĖ": 918, + "ĕĖĖĕĖĖĖĖĖ": 919, + "ĖėĖĖĘ": 920, + "ėĖĖĕĖĕ": 921, + "ėĖĖĘĖĕ": 922, + "ėĖĖĖĖĖĖĖĖĖ": 923, + "ėĕĖĖĖĖėĖĖĖĖĖĖ": 924, + "ėĖĖĖėĖĖėĖĖĖĖĖ": 925, + "ėĀĘĕ": 926, + "ĔėĘ": 927, + "ĖĖĖĖĖėĖĖėĖĖ": 928, + "ĔĕĖĖĖĖĖĖĖĖĖĖĖ": 929, + ",Ĉ": 930, + "ēĖėĕ": 931, + "ĖėėĖė": 932, + "ĔĖĖĕĖė": 933, + "ĖĖĕĖĖėĖĖĖ": 934, + "ĕĕĖĕĕ": 935, + "ėĈĖĖĖĖĖĖĖĖĖĖĖĖĖ": 936, + "ėĔĖĘ": 937, + "ėėėė": 938, + "ėĖĖĖĖĖĖĕĖĖĕĖĖ": 939, + "ę&": 940, + "ĖĖĔĕ": 941, + "ėĕĖėĖĖė": 942, + "ĖĕĖĕĖĖė": 943, + "ĔĖĖĖĖĖ": 944, + "ėĕĖĖĖĖĖĕĖĖĕĖĖ": 945, + "ĔĖđ": 946, + "ĖĕĖĖĕĖĖĖĖĖĕĖĖ": 947, + "ĖĖĖėĖĘ": 948, + "ėĄėĖĖĖĖĖĖĖĖĖĖĖĖ": 949, + "Ĕ\"": 950, + "ĖĖĔĖĖĖĖĖ": 951, + "ēě": 952, + "ĕĕĖ": 953, + "ĖĖĖĖĖĖėĖė": 954, + "ĖēĖĖĕ": 955, + "đĖė": 956, + "ĖĖĖĕĖĖĖĖĕ": 957, + "ĔĖĖĖĖĖĖĖĖĖĖĖĖ": 958, + "ėĖĔĖĖĕ": 959, + "ĖĖĖėĖĕĖĖĖ": 960, + "Ėĕėĕ": 961, + "ėĕĖėĖĖĖĖĖĖĖĖĖ": 962, + "ėĖĕĕĖė": 963, + "ėĀĖĕĖĕ": 964, + "ĕĖĕĖĖĖĖĖĖĖĖĖ": 965, + "ĔĕĖĖĕ": 966, + "ėĔĖĕĕ": 967, + "ĖĖėĖĖėĖĖėĖĖ": 968, + "Ĕā": 969, + "ĘĖě": 970, + "ĘėĔ": 971, + "Ěą": 972, + "ĖĖĖĖĕĖĖĖĖĖ": 973, + "ĖĕĖĕĖĖĕ": 974, + "Ĕ#": 975, + "ėĖĖĖĖĖėĖĖ": 976, + "ĖėĖĖėĖĕ": 977, + "ĕĕĔ": 978, + "ė+": 979, + "ĖĕĖĖėĖĖĖ": 980, + "Ěď": 981, + "ĕĕĖėĕ": 982, + "ĕĕĖĖė": 983, + "ĔĖėĖĖĖ": 984, + "ĕĖĖĕĖĖ": 985, + "ĚĖē": 986, + "ĕĕĘ": 987, + "Ěā": 988, + "ĖĖėėĖĖ": 989, + "ėăĖĕĖĖĖĖĖĖĖĖĖĖĖ": 990, + "ĖĖĖĕĖĔ": 991, + "ęėė": 992, + "Ěĝ": 993, + "ĖĖĕĕĖĖ": 994, + "ęĖĖĖĖĖĖĖĖĖĖĖĖ": 995, + "ĔėĖĕ": 996, + "ėĕē": 997, + "ĖĖĕĕĖė": 998, + "ėĕĖĕĖĖĕ": 999, + "ĖĖėĕĖė": 1000, + "ĘĘĖė": 1001, + "ęĖĖę": 1002, + "ĖĖĕėĖĖ": 1003, + "ĒĒ": 1004, + "ėĀĖĖĖĖĖĖĕ": 1005, + "ėąėĖĖĖĖĖĖĖĖĖĖĖĖ": 1006, + "ĖĖėĖĖĖĕ": 1007, + "ĖĖĖĖĖĖĕĖĕ": 1008, + "ĕĔĕ": 1009, + "ęėĖĕ": 1010, + "ěĖĕ": 1011, + "Ĕ$": 1012, + "ĕĖĖĕĖĖĖĖĖĖĖĖ": 1013, + "ĖĖĖĖĕĕĖė": 1014, + "ĖĕĕĖĖĖĖė": 1015, + "ę'": 1016, + "Ę*": 1017, + "ĚĎ": 1018, + "ĖĔĖĘ": 1019, + "ĖĖĖĖĖĖĕĖĖĖĖĖ": 1020, + "Ė+": 1021, + "ĀĈ": 1022, + "ĖĖėĕĖĖ": 1023 + }, + "merges": [ + [ + "Ė", + "Ė" + ], + [ + "ĖĖ", + "ĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "Ė", + "ĕ" + ], + [ + "Ė", + "ė" + ], + [ + "ĖĖĖĖ", + "Ė" + ], + [ + "ė", + "Ā" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖ", + "ĕ" + ], + [ + "ėĀ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖ", + "ė" + ], + [ + "ė", + "ĕ" + ], + [ + "ĖĖ", + "Ė" + ], + [ + "Ė", + "Ĕ" + ], + [ + "Ė", + "Ę" + ], + [ + "ĖĖ", + "Ėĕ" + ], + [ + "ĖĖ", + "Ėė" + ], + [ + "ė", + "Ėĕ" + ], + [ + "ĕ", + "ĕ" + ], + [ + "Ėė", + "ĕ" + ], + [ + "ĖĖ", + "Ĕ" + ], + [ + "Ėĕ", + "ĕ" + ], + [ + "ė", + "ė" + ], + [ + "ė", + "Ėė" + ], + [ + "ė", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ĕ" + ], + [ + "ĖĖ", + "Ę" + ], + [ + "ĖĖĖĖ", + "Ėĕ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ė" + ], + [ + "Ė", + "ē" + ], + [ + "ĖĖĖĖ", + "Ėė" + ], + [ + "ĖĖĕ", + "ĖĖĕ" + ], + [ + "Ė", + "ę" + ], + [ + "ė", + "Ĕ" + ], + [ + "ė", + "ĖĖė" + ], + [ + "ė", + "ā" + ], + [ + "ė", + "ĖĖ" + ], + [ + "ĕ", + "Ėė" + ], + [ + "Ėĕ", + "ĖĖĕ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "Ę", + "ĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖ" + ], + [ + "ĖĖĕ", + "ĖĖ" + ], + [ + "ĖĖė", + "ĖĖ" + ], + [ + "ĕ", + "Ėĕ" + ], + [ + "ė", + "ĖĖĖĖ" + ], + [ + "Ė", + "Ē" + ], + [ + "Ę", + "Ėĕ" + ], + [ + "ė", + "ĖĖĕ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ė", + "Ă" + ], + [ + "Ė", + "Ě" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "Ė" + ], + [ + "ė", + "ĖĔ" + ], + [ + "ĖĖĕ", + "Ėė" + ], + [ + "Ĕ", + "Ėė" + ], + [ + "ĖĖĖĖ", + "ĖĖĖė" + ], + [ + "ĖĖė", + "Ėĕ" + ], + [ + "Ĕ", + "ĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖĕ" + ], + [ + "ė", + "Ę" + ], + [ + "Ę", + "Ā" + ], + [ + "ĕ", + "ĖĖ" + ], + [ + "Ė", + "Ā" + ], + [ + "ė", + "ē" + ], + [ + "ė", + "ă" + ], + [ + "ĖĖ", + "ē" + ], + [ + "Ę", + "Ėė" + ], + [ + "ĖĖė", + "Ėė" + ], + [ + "ė", + "ĖĘ" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ė", + "Ė" + ], + [ + "ė", + "Ą" + ], + [ + "ėĕ", + "ĖĖĕ" + ], + [ + "ė", + "ĖĖĖė" + ], + [ + "Ėĕ", + "ĖĖė" + ], + [ + "ĖĖ", + "ę" + ], + [ + "Ėĕ", + "Ėė" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖ" + ], + [ + "ĕ", + "ĖĘ" + ], + [ + "ėĕ", + "Ėė" + ], + [ + "ĕ", + "Ĕ" + ], + [ + "ė", + "ą" + ], + [ + "Ė", + "đ" + ], + [ + "ėĕ", + "ĕ" + ], + [ + "ĖĖĕ", + "Ėĕ" + ], + [ + "Ĕ", + "Ėĕ" + ], + [ + "ė", + "ĖĖĖĕ" + ], + [ + "ė", + "ę" + ], + [ + "Ė", + "ě" + ], + [ + "ė", + "Ē" + ], + [ + "ĖĖė", + "ĖĖė" + ], + [ + "ė", + "Ć" + ], + [ + "Ėėĕ", + "Ėė" + ], + [ + "ĕ", + "ĖĔ" + ], + [ + "ė", + "ć" + ], + [ + "ĕ", + "ĖĖĕ" + ], + [ + "Ĕ", + "ĖĘ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ė", + "Ĉ" + ], + [ + "ĖĖĖĖ", + "ĖĖĕ" + ], + [ + "Ė", + "ā" + ], + [ + "Ėĕ", + "ĖĖĖ" + ], + [ + "ė", + "đ" + ], + [ + "Ėė", + "ĖĖĖ" + ], + [ + "Ė", + "Đ" + ], + [ + "Ėėĕ", + "ĖĖ" + ], + [ + "ė", + "Đ" + ], + [ + "Ę", + "Ĕ" + ], + [ + "Ę", + "ā" + ], + [ + "ė", + "Ď" + ], + [ + "ė", + "ď" + ], + [ + "Ė", + "Ĝ" + ], + [ + "ė", + "ėĖĕ" + ], + [ + "Ę", + "ĖĔ" + ], + [ + "ėĕ", + "Ėĕ" + ], + [ + "Ėė", + "ĖĖė" + ], + [ + "Ę", + "ĖĖ" + ], + [ + "ĖĖĖĖ", + "ĖĖė" + ], + [ + "ė", + "Ě" + ], + [ + "Ĕ", + "ĖĖ" + ], + [ + "Ėĕ", + "Ėĕ" + ], + [ + "ė", + "Ėē" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "Ėĕĕ", + "Ėė" + ], + [ + "Ė", + "ď" + ], + [ + "Ę", + "ē" + ], + [ + "ė", + "ĖĖĖĖĖė" + ], + [ + "ĖĖĕ", + "ĖĖė" + ], + [ + "ę", + "ĕ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĕ" + ], + [ + "Ėĕĕ", + "ĖĖ" + ], + [ + "ė", + "ėĕ" + ], + [ + "ė", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĕ", + "ĕ" + ], + [ + "ĖĖė", + "ĖĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĕĖĖĕ" + ], + [ + "ĖĖĖė", + "ĖĖ" + ], + [ + "ē", + "ĕ" + ], + [ + "ĖĖĖĖĖĕ", + "ĖĖ" + ], + [ + "ĖĖĖĕ", + "ĖĖ" + ], + [ + "ė", + "ĖĖĘ" + ], + [ + "Ėėĕ", + "Ėĕ" + ], + [ + "Ėĕĕ", + "Ėĕ" + ], + [ + "Ė", + "Ď" + ], + [ + "Ė", + "ĝ" + ], + [ + "Ę", + "Ę" + ], + [ + "Ė", + "Ă" + ], + [ + "ĖĖė", + "ĕ" + ], + [ + "ė", + "ě" + ], + [ + "ę", + "Ėĕ" + ], + [ + "ė", + "ĖĖĖĖĕ" + ], + [ + "Ĕ", + "ĖĔ" + ], + [ + "Ā", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ĕĕ" + ], + [ + "ĖĖĕ", + "ĖĖĖ" + ], + [ + "Ė", + "ă" + ], + [ + "ē", + "Ėė" + ], + [ + "ė", + "Ėę" + ], + [ + "ĖĖĖĖĖė", + "ĖĖ" + ], + [ + "Ę", + "ĖĖė" + ], + [ + "Ę", + "Ă" + ], + [ + "ėĀ", + "ėĕ" + ], + [ + "Ę", + "Ē" + ], + [ + "ė", + "ĖĖĖ" + ], + [ + "ĖĖ", + "Ē" + ], + [ + "ĖĖĔ", + "Ėė" + ], + [ + ",", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ė", + "Ą" + ], + [ + "Ėĕ", + "ėĖĕ" + ], + [ + "Ę", + "ĖĘ" + ], + [ + "Ę", + "ę" + ], + [ + "ė", + "Ĝ" + ], + [ + "ĕ", + "ē" + ], + [ + "Ĕ", + "ĖĖĕ" + ], + [ + "ė", + "ĖĖĖĖĖĕ" + ], + [ + "ĖĖė", + "ĖĖĖ" + ], + [ + "Ė", + "Ğ" + ], + [ + "ĖĖĕĖĖĕ", + "ĖĖ" + ], + [ + "ėė", + "Ėė" + ], + [ + "ėĀ", + "Ėĕ" + ], + [ + "Ę", + "ă" + ], + [ + "Ę", + "đ" + ], + [ + "Ė", + "ą" + ], + [ + "ĕĕ", + "ĕ" + ], + [ + "ėĕ", + "ĖĖĕĖĖĕ" + ], + [ + "ė", + "ĖĖĖĖė" + ], + [ + "ĖĖ", + "Ě" + ], + [ + "ĖĖĖė", + "Ėĕ" + ], + [ + "Ė", + "Ĉ" + ], + [ + "Ę", + "Ě" + ], + [ + "Ė", + "Ć" + ], + [ + "Ĕ", + "Ĕ" + ], + [ + "Ė", + "ć" + ], + [ + "Ę", + "Ą" + ], + [ + "Ę", + "Đ" + ], + [ + "ė", + "ĝ" + ], + [ + "ĖĖė", + "ĖĖĕ" + ], + [ + "ĖĖĖĕ", + "Ėė" + ], + [ + "ėĕ", + "ĖĖĖĖĖĕ" + ], + [ + "ę", + "Ėė" + ], + [ + "Ę", + "ď" + ], + [ + "Ę", + "Ď" + ], + [ + "ĕ", + "Ėę" + ], + [ + "ĖĖĔ", + "ĖĖ" + ], + [ + "Ėė", + "Ėĕ" + ], + [ + "Ę", + "ą" + ], + [ + "Ę", + "ĖĖĕ" + ], + [ + "ĖĔ", + "ĖĖĕ" + ], + [ + "ĖĖĖė", + "Ėė" + ], + [ + "Ė", + "ğ" + ], + [ + "Ę", + "ĖĖĘ" + ], + [ + "ĖĔ", + "ĕ" + ], + [ + "Ėĕ", + "Ĕ" + ], + [ + "ė", + "Ğ" + ], + [ + "Ę", + "ě" + ], + [ + "Ę", + "Ć" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖė", + "ĖĖĖ" + ], + [ + "Ę", + "Ĉ" + ], + [ + "ĕ", + "ĖĖė" + ], + [ + "Ę", + "ć" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖ" + ], + [ + "Ėė", + "ĖĖĕ" + ], + [ + "ē", + "ĖĘ" + ], + [ + "ėĖĖė", + "ĖĖ" + ], + [ + "Ę", + "ĕĕ" + ], + [ + "ėā", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "Ĕ" + ], + [ + "ė", + "ğ" + ], + [ + "ėĀ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "Ėę" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĕ" + ], + [ + "Ėĕ", + "ĖĖĖĖĕ" + ], + [ + "ė", + "Ėėĕ" + ], + [ + "ĖĖĖĖ", + "ėĕ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖė" + ], + [ + "ĖĖė", + "ĖĖėĖĖ" + ], + [ + "ĖĖĖ", + "Ĕ" + ], + [ + "Ę", + "Ĝ" + ], + [ + "ĕ", + "Ėėĕ" + ], + [ + "ĕ", + "Ėē" + ], + [ + "ĕ", + "Ā" + ], + [ + "ę", + "Ā" + ], + [ + "ė", + "ĖĒ" + ], + [ + "ĕ", + "Ē" + ], + [ + "ĕ", + "ę" + ], + [ + "Ėė", + "Ėė" + ], + [ + "ē", + "Ėĕ" + ], + [ + "ĖĖĘ", + "Ėĕ" + ], + [ + "Ę", + "Ėē" + ], + [ + "Ę", + "ĝ" + ], + [ + "ĕ", + "Ėĕĕ" + ], + [ + "Ėė", + "ĖĖ" + ], + [ + "ė", + "ĖĖĔ" + ], + [ + "ėĕĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ėė" + ], + [ + "ĖĖĖĕ", + "Ėĕ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖ" + ], + [ + "ĖĖĖĖĖĖĖė", + "ĖĖĖ" + ], + [ + "ėĖĖė", + "Ėĕ" + ], + [ + "ĕ", + "Ę" + ], + [ + "ĕĕ", + "Ėė" + ], + [ + "ĖĖĖĖ", + "ėĖĕ" + ], + [ + "ėĀ", + "ėĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "Ě" + ], + [ + "Ėĕ", + "Ėĕĕ" + ], + [ + "ĖĖĖĖ", + "Ę" + ], + [ + "Ė", + "!" + ], + [ + "ĖĖĖĖĖĖĖĕ", + "ĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖ" + ], + [ + "ėė", + "ĖĖė" + ], + [ + "ĖĔ", + "Ėė" + ], + [ + "Ę", + "Ğ" + ], + [ + "Ėĕ", + "Ėėĕ" + ], + [ + "ėĕ", + "Ėĕĕ" + ], + [ + "Ę", + "ĕĖė" + ], + [ + "ę", + "ĖĔ" + ], + [ + "Ĕ", + "Ėėĕ" + ], + [ + "ėĕ", + "ĖĖĖĖĖė" + ], + [ + "ĖĖĘ", + "ĖĖ" + ], + [ + "ėĕ", + "ĖĖĖĖė" + ], + [ + "ėĕ", + "Ėėĕ" + ], + [ + "ĖĖĔ", + "Ėĕ" + ], + [ + "ė", + "!" + ], + [ + "ĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "Ę", + "ėĖĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖėĖĖ" + ], + [ + "ėĀ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĕ", + "ĖĖĖ" + ], + [ + "ėĖĖė", + "Ėė" + ], + [ + "ĖĖĖĖĖė", + "ĖĖĖ" + ], + [ + "ė", + "ĖĚ" + ], + [ + "Ę", + "ğ" + ], + [ + "ĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ė", + "\"" + ], + [ + "ė", + "ĘĖĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖĕĖĖ" + ], + [ + "ėĕ", + "ĖĖĖĖĕ" + ], + [ + "ĕ", + "Ė" + ], + [ + "ĕ", + "ě" + ], + [ + "ėĖĕ", + "ĖĖĖ" + ], + [ + "ėĕ", + "ĖĖė" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ė" + ], + [ + "Ĕ", + "ē" + ], + [ + "ĖĖĘ", + "Ėė" + ], + [ + "Ę", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "đ" + ], + [ + "Ė", + "\"" + ], + [ + "ĖĖĖĖė", + "ĖĖė" + ], + [ + "ėĖĕ", + "ĖĖĕ" + ], + [ + "ėĔ", + "Ėė" + ], + [ + "ĖĖ", + "đ" + ], + [ + "ĖĖĕĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ėĔ", + "ĖĖĕ" + ], + [ + "ĖĖĖĖĖ", + "Ĕ" + ], + [ + "Ė", + "," + ], + [ + "ė", + "," + ], + [ + "Ĕ", + "Ėē" + ], + [ + "ĖĖĖ", + "Ę" + ], + [ + "ė", + "#" + ], + [ + "ėĖĕ", + "ĕ" + ], + [ + "ĖĖĔ", + "ĖĖĕ" + ], + [ + "ĕĕ", + "ĖĖĕ" + ], + [ + "ĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "Ę", + "!" + ], + [ + "ė", + "$" + ], + [ + "ę", + "Ĕ" + ], + [ + "Ę", + "Ėę" + ], + [ + "ĕ", + "Ĝ" + ], + [ + "ėĖĕ", + "ėĖĕ" + ], + [ + "ĖĖ", + "ě" + ], + [ + "Ę", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ėĀ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖė" + ], + [ + "Ė", + "#" + ], + [ + "ę", + "Ě" + ], + [ + "Ĕ", + "ĖĖĔ" + ], + [ + "ėĖė", + "ĖĖĖ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĕ" + ], + [ + "ĕ", + "Đ" + ], + [ + "ę", + "Ē" + ], + [ + "Ę", + "ėė" + ], + [ + "Ĕ", + "Ėĕĕ" + ], + [ + "ėā", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖė", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĂ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĀ", + "ėĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ĖĘ" + ], + [ + "ėĖĕ", + "ĖĖė" + ], + [ + "ĖĕĖĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĕ", + "ĕ" + ], + [ + "Ę", + "\"" + ], + [ + "ėĖė", + "ĖĖė" + ], + [ + "ĕĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĕ", + "ėĖĕ" + ], + [ + "ē", + "ĖĔ" + ], + [ + "ĖĘ", + "ĕ" + ], + [ + "ė", + "Ęĕ" + ], + [ + "Ę", + "," + ], + [ + "ĕ", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĕ", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ė", + "$" + ], + [ + "ē", + "Ĕ" + ], + [ + "ĖĔ", + "ĖĖĔ" + ], + [ + "Ę", + "ėĕ" + ], + [ + "ę", + "ĖĖ" + ], + [ + "ĕ", + "ĝ" + ], + [ + "ĖĖĖĖĖė", + "Ėĕ" + ], + [ + "ĖĖĖĕ", + "ĖĖĕ" + ], + [ + "ę", + "ę" + ], + [ + "ē", + "ĖĖ" + ], + [ + "ėĕ", + "ĖĖĕĖė" + ], + [ + "ėĖė", + "ĖĖĕ" + ], + [ + "ĕ", + "ā" + ], + [ + "Ē", + "Ėė" + ], + [ + "ĕ", + "ĖĚ" + ], + [ + "Ę", + "$" + ], + [ + "ĖĕĖĖĕĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖėĖĖ" + ], + [ + "ė", + "%" + ], + [ + "ė", + "ĖĖĕĖĖĕ" + ], + [ + "ē", + "Ėę" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ā" + ], + [ + "Ę", + "#" + ], + [ + "ĕĕ", + "Ėĕ" + ], + [ + "Ėėĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "ĕ", + "ď" + ], + [ + "ę", + "ě" + ], + [ + "Ėėĕ", + "ĖĘ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖėĖĖ" + ], + [ + "ėĕ", + "ĕĖĕ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėā", + "ėĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "," + ], + [ + "ėĖ", + "đ" + ], + [ + "ĖĖĖĖĖĕ", + "Ėė" + ], + [ + "Ě", + "Ėĕ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ĕ", + "ĖĒ" + ], + [ + "ę", + "đ" + ], + [ + "ĖĖĕĖė", + "ĖĖĖ" + ], + [ + "ĕ", + "Ğ" + ], + [ + "ĖĖėĖĕ", + "ĖĖĖ" + ], + [ + "Ė", + "%" + ], + [ + "ĕ", + "ĖĖĖ" + ], + [ + "ė", + "&" + ], + [ + "ĖĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ĕ", + "Ď" + ], + [ + "Ėĕ", + "ĖĖĔ" + ], + [ + "ĖĖė", + "ĖĖĖĖĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĖĖėĖĖĖĖĖ" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ĖĖĕ" + ], + [ + "ėĖė", + "ėĖė" + ], + [ + "ĖĖĖĖė", + "ĖĖĕ" + ], + [ + "ę", + "ē" + ], + [ + "ėĖė", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĕĖĖ" + ], + [ + "Ĕ", + "Ē" + ], + [ + "ėĀ", + "ĖĖĖĕ" + ], + [ + "ėĖĖĖĖ", + "ĖĖė" + ], + [ + "ĕ", + "ėė" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ę", + "Ĝ" + ], + [ + "Ĕ", + "ĖĖė" + ], + [ + "ėĕ", + "ĖĘ" + ], + [ + "Ĕ", + "ĕĕ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖĖ" + ], + [ + "ĕ", + "ğ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĕĖĖ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖė" + ], + [ + "ĖĔ", + "ĖĖė" + ], + [ + "ĖĖĖĖĖ", + "Ę" + ], + [ + "ėĖĖĖĖ", + "ĖĖĕ" + ], + [ + "ę", + "Đ" + ], + [ + "ėĀ", + "ĖĖĖė" + ], + [ + "ėĀ", + "ėĖĖĕ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖĕ" + ], + [ + "Ėĕĕ", + "ĖĘ" + ], + [ + "ę", + "Ă" + ], + [ + "ėĖė", + "ėĖĕ" + ], + [ + "ė", + "'" + ], + [ + "Ĕ", + "ĖĚ" + ], + [ + "ę", + "ă" + ], + [ + "Ę", + "%" + ], + [ + "ĖĕĖĖĕ", + "Ėė" + ], + [ + "ėĕ", + "ĕĕ" + ], + [ + "ėă", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĖĖĕĖĖĖĖĖ" + ], + [ + "Ę", + "ĖĖĖė" + ], + [ + "Ę", + ")" + ], + [ + "ė", + ")" + ], + [ + "Ę", + "ĖĒ" + ], + [ + "ĖĖĕĖĖĕ", + "Ėė" + ], + [ + "ėĕĖĖĖĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "ėĕĖĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĖė", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖ", + "ě" + ], + [ + "ėĕ", + "ĖĖĔ" + ], + [ + "Ė", + "&" + ], + [ + "ĖĖĘ", + "ĖĖė" + ], + [ + "ę", + "ď" + ], + [ + "ė", + "(" + ], + [ + "ĕ", + "ĖĖĔ" + ], + [ + "ĖĔ", + "Ėĕ" + ], + [ + "ĕĖė", + "ĖĖĕ" + ], + [ + "Ě", + "Ėė" + ], + [ + "ę", + "Ėē" + ], + [ + "ę", + "ĝ" + ], + [ + "ę", + "Ą" + ], + [ + "ėĕĖĖĕĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "ĖĔ", + "Ėėĕ" + ], + [ + "ėĖĖė", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĖĖ", + "ėĖĖĖĖĖĖĖĖ" + ], + [ + "ėĕ", + "Ĕ" + ], + [ + "ĕ", + "ĖĖĖĖ" + ], + [ + "ę", + "Ď" + ], + [ + "ĖĕĖĖĖĖĖĖĖĖ", + "ĖĖė" + ], + [ + "ĖĖĖĖĖė", + "Ėė" + ], + [ + "Ę", + "&" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖĖė" + ], + [ + "ėĔ", + "Ėĕ" + ], + [ + "ėĀ", + "ėĖĖė" + ], + [ + "ėĖĕ", + "ėĖĖ" + ], + [ + "ĕ", + "ă" + ], + [ + "ĕ", + "!" + ], + [ + "ĕ", + "Ă" + ], + [ + "Ė", + "'" + ], + [ + "ĖĖĖĖĖĖĕ", + "ĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĖĖė" + ], + [ + "ėĕĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "Ęĕ", + "ĖĖĕ" + ], + [ + "Ėĕĕ", + "ĖĔ" + ], + [ + "ĕ", + "Ą" + ], + [ + "ē", + "ĖĖĕ" + ], + [ + "ę", + "Ğ" + ], + [ + "ėĔ", + "ĕ" + ], + [ + "Ę", + "ĖĖĖĕ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "Ėė" + ], + [ + "Ē", + "ĕ" + ], + [ + "ĖĘ", + "Ėĕ" + ], + [ + "ėĕ", + "ėė" + ], + [ + "ėĖė", + "ėĖĖ" + ], + [ + "ėĕĖĖĖĖĖĖĖĖ", + "ĖĖė" + ], + [ + "ėā", + "Ėĕ" + ], + [ + "ėĖė", + "Ėė" + ], + [ + "Ě", + "ĕ" + ], + [ + "ėā", + "ėĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĖĖĖĕ" + ], + [ + "ĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĕĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "Ē", + "Ėĕ" + ], + [ + "Ęĕ", + "Ėĕ" + ], + [ + "Ę", + "(" + ], + [ + "ėĖĖĖė", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĘ", + "ĖĖė" + ], + [ + "ĕ", + "\"" + ], + [ + "ĕ", + "Ĉ" + ], + [ + "ĕ", + "ĖĖĖĖĖ" + ], + [ + "ę", + "ĕĕ" + ], + [ + "ĖĖĖĖĖĕ", + "Ėĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖĖ" + ], + [ + "ėĖĖĖė", + "Ėė" + ], + [ + "ę", + "ą" + ], + [ + "ę", + "ğ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "Ėĕ" + ], + [ + "Ĕ", + "," + ], + [ + "ĕ", + "ć" + ], + [ + "Ē", + "ĖĘ" + ], + [ + "ĖĖĕĖĕ", + "ĖĖĖ" + ], + [ + "ė", + "ĖĖĕĖĖ" + ], + [ + "Ĕ", + "Ě" + ], + [ + "Ę", + "'" + ], + [ + "ĕ", + "ĖĖĖė" + ], + [ + "ę", + "," + ], + [ + "ę", + "ĖĖė" + ], + [ + "ĖĖĖė", + "ĖĖėĖĖ" + ], + [ + "ĕ", + "ą" + ], + [ + "ĖĖ", + "Đ" + ], + [ + "ĖĖĔ", + "ĖĖė" + ], + [ + "ĕ", + "Ć" + ], + [ + "ē", + "ē" + ], + [ + "Ė", + "(" + ], + [ + "ę", + "Ĉ" + ], + [ + "ėĖĕ", + "Ėė" + ], + [ + "ĖĖĘ", + "ĖĖĕ" + ], + [ + "ėĔ", + "Ėėĕ" + ], + [ + "Ęĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕĖė", + "ĖĖė" + ], + [ + "ėĕĖĖĖĖĖĖĖĖ", + "ĖĖĕ" + ], + [ + "Ĕ", + "ėė" + ], + [ + "ĖĖē", + "Ėė" + ], + [ + "Ę", + "ėĖė" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĄ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "ě" + ], + [ + "ėĀ", + "ĖĕĖė" + ], + [ + "ĖĔ", + "ĕĖĕ" + ], + [ + "ĕĖė", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ę", + "ć" + ], + [ + "ĖĘ", + "ĖĖĕ" + ], + [ + "ĖĖĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "Ĕ", + "đ" + ], + [ + "ĖĖĖĖĖĖĖė", + "ĖĖ" + ], + [ + "ę", + "Ć" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ėĖĖ" + ], + [ + "ĖĖĕĖĖĕ", + "Ėĕ" + ], + [ + "ĖĖėĖė", + "ĖĖĖ" + ], + [ + "ĖĔ", + "ėĖĕ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖėĖĖĖ" + ], + [ + "Ę", + "ĖĚ" + ], + [ + "Ėĕ", + "ĕĕ" + ], + [ + "ĕĖė", + "ĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĕĖĖ" + ], + [ + "Ėĕ", + "ėĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "#" + ], + [ + "Ĕ", + "ĖĒ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĕĖĖĖ" + ], + [ + "ĖĖ", + "Ĝ" + ], + [ + "ĕ", + "ė" + ], + [ + "ėĔ", + "ĖĖĔ" + ], + [ + "Ėėĕ", + "ĖĖĕĖĖ" + ], + [ + "ėĀ", + "ĖĕĖĖĕ" + ], + [ + "ĕ", + "ėĖĕ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖė", + "ĖĖė" + ], + [ + "ėĖĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĕĕ", + "ĖĖ" + ], + [ + "ėĖĖĕ", + "Ėė" + ], + [ + "ĖĖĖĖ", + "ėĖė" + ], + [ + "ĖĖĖĖĖĖĖĕ", + "ĖĖ" + ], + [ + "Ėĕĕ", + "ĖĖĖĖĕ" + ], + [ + "ėĕ", + "ĖĔ" + ], + [ + "Ė", + ")" + ], + [ + "ĕ", + "$" + ], + [ + "Ę", + "ĖĖę" + ], + [ + "Ėĕ", + "ėė" + ], + [ + "ĕ", + "Ėě" + ], + [ + "ėĖ", + "Đ" + ], + [ + "ĖĖėĖĕ", + "ėĖĖ" + ], + [ + "Ĕ", + "ę" + ], + [ + "ē", + "Ėē" + ], + [ + "Ě", + "ĖĔ" + ], + [ + "ĖĖĖĖĖė", + "ĖĖĕ" + ], + [ + "ĖĖĖė", + "ĖĖĕ" + ], + [ + "ĖĖĖĖėĕ", + "ĖĖ" + ], + [ + "Ę", + "ĖĖĔ" + ], + [ + "Ėėĕ", + "ĖĔ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ėĖĖĖ" + ], + [ + "Ėĕ", + "ĖĘ" + ], + [ + ",", + "Ď" + ], + [ + "Ā", + "Đ" + ], + [ + "Ėĕĕ", + "ĖĖĕĖĖ" + ], + [ + "ėĕĕ", + "ĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĔ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĕ", + "ĖĖĕ" + ], + [ + "ę", + "!" + ], + [ + "ĖĖĖė", + "ĖĖė" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "ĕĖĖĖ" + ], + [ + "Ę", + "ĘĖĕ" + ], + [ + "ėĕ", + "ĖĕĖĖė" + ], + [ + "Ėĕĕ", + "ĖĖĕ" + ], + [ + "ėĕ", + "ėĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĕ", + "ĖĖė" + ], + [ + "Ę", + "Ė" + ], + [ + "ĖĖĖė", + "ĕ" + ], + [ + "ėą", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĕĖĖĕĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "Ę", + "ĖĖĖĖ" + ], + [ + "Ĕ", + "Ā" + ], + [ + "ĖĖĖĖĖĕ", + "ĖĖė" + ], + [ + "Ėė", + "Ĕ" + ], + [ + "Ĕ", + "Ĝ" + ], + [ + "ĖĖĖĖėė", + "ĖĖ" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖė" + ], + [ + "ę", + "ĖĖĕ" + ], + [ + "Ā", + "đ" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖĕ" + ], + [ + "ėĖĖĖĖĖė", + "ĖĖĖĖĖĖ" + ], + [ + "ėĘ", + "Ėė" + ], + [ + "ĖĖĖĖĖė", + "ĖĖė" + ], + [ + "ę", + "\"" + ], + [ + "ĕ", + "%" + ], + [ + "ē", + "ĖĖĔ" + ], + [ + "ę", + "Ėę" + ], + [ + "ėė", + "ĖĔ" + ], + [ + "ėĖĖĕ", + "ĕ" + ], + [ + "ĖĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "ĖĖĘ", + "ĖĖĘ" + ], + [ + "ėĂ", + "ėĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖĖ" + ], + [ + "Ę", + "ĔĖė" + ], + [ + "ĕĖĕ", + "ĖĖė" + ], + [ + "Ĕ", + "Đ" + ], + [ + "ėĖĖĕ", + "ĖĖė" + ], + [ + "ėĂ", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĕ" + ], + [ + "ėĀ", + "ėĕĖė" + ], + [ + "Ā", + "ď" + ], + [ + "ĖĖĖĕ", + "ĖĘ" + ], + [ + "ėė", + "Ĕ" + ], + [ + "Ęĕ", + "ĖĘ" + ], + [ + "ĕ", + "Ėđ" + ], + [ + "ē", + "ĖĚ" + ], + [ + "ėĀ", + "ĕĕ" + ], + [ + "ē", + "Ē" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖė" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖėĖĕ" + ], + [ + "ę", + "Ę" + ], + [ + "ėĖĕ", + "ėĖė" + ], + [ + "Ę", + "ėĖĖė" + ], + [ + "Ėĕ", + "ėĖĖĖĖ" + ], + [ + "ĖĖĖė", + "ĖĔ" + ], + [ + "ĖĖĖĖĕ", + "ėĖĖ" + ], + [ + "ĖĖĕĖĖĕ", + "ĖĖĖ" + ], + [ + "Ĕ", + "ĝ" + ], + [ + "Ėĕ", + "ėĖĖ" + ], + [ + "ĕ", + "&" + ], + [ + "ė", + "ėĕĖė" + ], + [ + "ėĖĖė", + "ĖĖėĖĖ" + ], + [ + "ėė", + "ĖĖĕ" + ], + [ + "Ā", + "Ē" + ], + [ + "ĕ", + "ĖĖĖĕ" + ], + [ + "ę", + "#" + ], + [ + "Ę", + "Ėĕĕ" + ], + [ + "ėĖ", + "Ĝ" + ], + [ + "ĖĕĖĖĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖĕ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖ" + ], + [ + "ĖĖĔ", + "ĖĘ" + ], + [ + "ĖėĕĖė", + "ĖĖĖ" + ], + [ + "ė", + "ėĕĕ" + ], + [ + "ėĀ", + "ėĕĖĖĕ" + ], + [ + "Ĕ", + "Ė" + ], + [ + "ĖĖė", + "ĖĖĖĖĕ" + ], + [ + "Ę", + "Ĕĕ" + ], + [ + "ĖĖĖĖĖĖĖė", + "ĖĖėĖĖ" + ], + [ + "ėĖĖĖĖ", + "ĖĖĖĕ" + ], + [ + "Ėĕ", + "ē" + ], + [ + "ėĆ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ě", + "Ě" + ], + [ + "Ĕ", + "Ğ" + ], + [ + "ĕ", + "'" + ], + [ + "ĖĖĖĖ", + "ĖĖĕĖė" + ], + [ + "Ĕ", + "ĕĖĕ" + ], + [ + "ę", + "$" + ], + [ + "ėĖĖĖĖ", + "ĖĖĖė" + ], + [ + "ĖĖĖĕ", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "ĕĖė" + ], + [ + "ĖĖĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ė", + "*" + ], + [ + "Ę", + "Ėđ" + ], + [ + "ėĕ", + "ėĕ" + ], + [ + "Ėĕ", + "ėĖė" + ], + [ + "ĖĖė", + "ėĖĕ" + ], + [ + "Ę", + "ė" + ], + [ + "ĖĕĖė", + "ĖĖĕ" + ], + [ + "Ā", + "Ď" + ], + [ + "Ĕ", + "Ėě" + ], + [ + "ėĖĖĖĖĖĕ", + "ĖĖĖĖĖĖ" + ], + [ + "ĕĖĕ", + "ĖĖĖ" + ], + [ + "ĖĖĕĖė", + "ĖĖĕ" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖėĖĖĖ" + ], + [ + "ē", + "," + ], + [ + "ėĖĖĖĖ", + "ĖĖėĖĖėĖĖ" + ], + [ + "ē", + "ĕĕ" + ], + [ + "ĖĖĖĖė", + "ĖĖĖĖĖĖ" + ], + [ + "ĖĖėĖĖė", + "ĖĖĖĖĖ" + ], + [ + "ĕ", + "ĔĖė" + ], + [ + "Ĕ", + "ď" + ], + [ + "Ě", + "ě" + ], + [ + "Ėĕ", + "ĖĔ" + ], + [ + "ĖĖĔ", + "ĖĖĔ" + ], + [ + "ĖĖĖĖĖĖė", + "ĖĖĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖėĖĖĖ" + ], + [ + "Ę", + "ĔĖĕ" + ], + [ + "ĖĕĖĖĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "Ĕ", + "ğ" + ], + [ + "Ě", + "Ā" + ], + [ + "Ėē", + "Ėė" + ], + [ + "Ę", + "ĖĖėĖĖ" + ], + [ + "ėĖĕ", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĖĖ", + "ĖĖė" + ], + [ + "Ē", + "Ėę" + ], + [ + "ę", + "ĖĖĘ" + ], + [ + "Ę", + "ĖĖėĖĕ" + ], + [ + "Ē", + "ĖĔ" + ], + [ + "ĖĖėĖė", + "ėĖĖ" + ], + [ + "ėĖĖĕ", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĕĖė", + "ĖĖė" + ], + [ + "ėė", + "Ė" + ], + [ + "ėĕĖĖĕĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ėĕ", + "ĖĖĖĖĖĖĖĕĖĖĖ" + ], + [ + "Ėĕ", + "ĖĖĖĖĖĖĖĕĖĖĖ" + ], + [ + "ĖĖĖĖĖĖĕ", + "ĖĖĖ" + ], + [ + "ėă", + "ėĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĖĖĕ", + "Ėė" + ], + [ + "ę", + "ĕĖė" + ], + [ + "ėē", + "Ėė" + ], + [ + "ėĕ", + "ĖĖĖĖ" + ], + [ + "Ě", + "đ" + ], + [ + "ĖĘ", + "Ėė" + ], + [ + "Ė", + "*" + ], + [ + "ĕ", + "ĖĖĖĖĕ" + ], + [ + "ĕ", + "(" + ], + [ + "ĖĖĘ", + "ĖĖĖ" + ], + [ + "ĖĖėĖĖė", + "ĖĖĖ" + ], + [ + "Ę", + "ĖĖĖĖĖ" + ], + [ + "ėĕ", + "Ė" + ], + [ + "Ĕ", + "ĖĖĕĖĖ" + ], + [ + "ėĕ", + "ėĖĖĖĖ" + ], + [ + "ĖĕĖĖĖĖĕ", + "ĖĖĖĖĖĖ" + ], + [ + "Ėėĕ", + "ĖĖĕ" + ], + [ + "ĖĖĖĖĖĖė", + "ĖĖĖ" + ], + [ + "Ě", + "Ē" + ], + [ + "ĖĖę", + "Ėĕ" + ], + [ + "ĕĕ", + "ĕĕ" + ], + [ + "Ĕ", + "ĖĖē" + ], + [ + "Ĕ", + "Ę" + ], + [ + "ĖĖ", + "ď" + ], + [ + "ĖĔ", + "Ėĕĕ" + ], + [ + "ę", + "%" + ], + [ + "ĖĖĔ", + "ĖĖĖ" + ], + [ + "ĖĖ", + "ĝ" + ], + [ + "ĕ", + "ĖĖĘ" + ], + [ + "ĕ", + "ĖĖĖĖĖĕ" + ], + [ + "Ėĕ", + "ĖĖĖĖĕĕ" + ], + [ + "ę", + "ĖĒ" + ], + [ + "ėć", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĕ", + "ėĘ" + ], + [ + "Ĕ", + "Ď" + ], + [ + "Ėė", + "ĖĔ" + ], + [ + "ĖĕĖė", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĕ", + "ĖĖ" + ], + [ + "ėĖĖĖė", + "ĖĖė" + ], + [ + "Ě", + "," + ], + [ + "Ě", + "Ą" + ], + [ + "ĖĖĖĖė", + "ĖĖėĖĖĖĖĖ" + ], + [ + "ę", + "ėĕ" + ], + [ + "ĖĘ", + "ĖĖĖ" + ], + [ + "ėė", + "ĖĖ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ě", + "ă" + ], + [ + "ėĕĖĖĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "ėė", + "ĖĖĘ" + ], + [ + "ĖĔ", + "ĖĖĕĖĖĕ" + ], + [ + "Ě", + "ĖĘ" + ], + [ + "ėė", + "ĖĘ" + ], + [ + "ĕĔ", + "ĖĖĕ" + ], + [ + "ĖĔ", + "ĖĖĖ" + ], + [ + "ėĖĖĖė", + "Ėĕ" + ], + [ + "ĕ", + ")" + ], + [ + "Ę", + "ĖĖėĖė" + ], + [ + "Ě", + "Ĝ" + ], + [ + "ē", + "đ" + ], + [ + "ĖĖĖĕ", + "ĖėĖĖĖ" + ], + [ + "ĖĖĖĖĕ", + "Ėė" + ], + [ + ",", + "ď" + ], + [ + "Ĕ", + "!" + ], + [ + "ĘĖĕ", + "ėĖĕ" + ], + [ + "ĖĖĖĖĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "ėĂ", + "ėĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖė" + ], + [ + "ĖĖĖĖĖĖė", + "ĖĖĖĖĖ" + ], + [ + "Ě", + "Đ" + ], + [ + "ĕ", + "ĖĜ" + ], + [ + "Ę", + "ĖĖĖĖĖė" + ], + [ + "ėĖ", + "ď" + ], + [ + "Ėė", + "ėĖĖ" + ], + [ + "ĕĕ", + "ĖĘ" + ], + [ + "ĖĖĕ", + "ĖĖĖĖĖĖ" + ], + [ + "ĕĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ėė", + "ĖĖĘ" + ], + [ + "ėĖĖĕ", + "Ėĕ" + ], + [ + "ėĖĖĘ", + "Ėĕ" + ], + [ + "ėĖĖĖĖĖĖĖĖ", + "Ė" + ], + [ + "ėĕĖĖĖĖė", + "ĖĖĖĖĖĖ" + ], + [ + "ėĖĖĖė", + "ĖĖėĖĖĖĖĖ" + ], + [ + "ėĀ", + "Ęĕ" + ], + [ + "Ĕ", + "ėĘ" + ], + [ + "ĖĖĖĖĖė", + "ĖĖėĖĖ" + ], + [ + "Ĕĕ", + "ĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + ",", + "Ĉ" + ], + [ + "ē", + "Ėėĕ" + ], + [ + "Ėė", + "ėĖė" + ], + [ + "Ĕ", + "ĖĖĕĖė" + ], + [ + "ĖĖĕĖĖė", + "ĖĖĖ" + ], + [ + "ĕĕ", + "Ėĕĕ" + ], + [ + "ėĈ", + "ĖĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĔ", + "ĖĘ" + ], + [ + "ėė", + "ėė" + ], + [ + "ėĖĖĖĖ", + "ĖĖĕĖĖĕĖĖ" + ], + [ + "ę", + "&" + ], + [ + "ĖĖĔ", + "ĕ" + ], + [ + "ėĕĖė", + "ĖĖė" + ], + [ + "Ėĕ", + "ĖĕĖĖė" + ], + [ + "Ĕ", + "ĖĖĖĖĖ" + ], + [ + "ėĕĖĖĖĖĖĕ", + "ĖĖĕĖĖ" + ], + [ + "Ĕ", + "Ėđ" + ], + [ + "ĖĕĖĖĕ", + "ĖĖĖĖĖĕĖĖ" + ], + [ + "ĖĖĖė", + "ĖĘ" + ], + [ + "ėĄ", + "ėĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "\"" + ], + [ + "ĖĖĔ", + "ĖĖĖĖĖ" + ], + [ + "ē", + "ě" + ], + [ + "ĕĕ", + "Ė" + ], + [ + "ĖĖĖĖ", + "ĖĖėĖė" + ], + [ + "Ėē", + "ĖĖĕ" + ], + [ + "đ", + "Ėė" + ], + [ + "ĖĖĖĕ", + "ĖĖĖĖĕ" + ], + [ + "Ĕ", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĔ", + "ĖĖĕ" + ], + [ + "ĖĖĖė", + "ĖĕĖĖĖ" + ], + [ + "Ėĕ", + "ėĕ" + ], + [ + "ėĕĖė", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "ėĖĕ", + "ĕĖė" + ], + [ + "ėĀ", + "ĖĕĖĕ" + ], + [ + "ĕĖĕ", + "ĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕĕ", + "ĖĖĕ" + ], + [ + "ėĔ", + "Ėĕĕ" + ], + [ + "ĖĖėĖĖė", + "ĖĖėĖĖ" + ], + [ + "Ĕ", + "ā" + ], + [ + "Ę", + "Ėě" + ], + [ + "Ę", + "ėĔ" + ], + [ + "Ě", + "ą" + ], + [ + "ĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ėĕ", + "ĖĕĖĖĕ" + ], + [ + "Ĕ", + "#" + ], + [ + "ėĖĖĖĖĖė", + "ĖĖ" + ], + [ + "Ėė", + "ĖĖėĖĕ" + ], + [ + "ĕĕ", + "Ĕ" + ], + [ + "ė", + "+" + ], + [ + "ĖĕĖĖė", + "ĖĖĖ" + ], + [ + "Ě", + "ď" + ], + [ + "ĕĕ", + "Ėėĕ" + ], + [ + "ĕĕ", + "ĖĖė" + ], + [ + "ĔĖė", + "ĖĖĖ" + ], + [ + "ĕ", + "ĖĖĕĖĖ" + ], + [ + "Ě", + "Ėē" + ], + [ + "ĕĕ", + "Ę" + ], + [ + "Ě", + "ā" + ], + [ + "ĖĖė", + "ėĖĖ" + ], + [ + "ėă", + "ĖĕĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĕ", + "ĖĔ" + ], + [ + "ę", + "ėė" + ], + [ + "Ě", + "ĝ" + ], + [ + "ĖĖĕ", + "ĕĖĖ" + ], + [ + "ę", + "ĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "Ĕ", + "ėĖĕ" + ], + [ + "ėĕ", + "ē" + ], + [ + "ĖĖĕ", + "ĕĖė" + ], + [ + "ėĕ", + "ĖĕĖĖĕ" + ], + [ + "ĖĖė", + "ĕĖė" + ], + [ + "Ę", + "ĘĖė" + ], + [ + "ę", + "ĖĖę" + ], + [ + "ĖĖĕ", + "ėĖĖ" + ], + [ + "Ē", + "Ē" + ], + [ + "ėĀ", + "ĖĖĖĖĖĖĕ" + ], + [ + "ėą", + "ėĖĖĖĖĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖė", + "ĖĖĖĕ" + ], + [ + "ĖĖĖĖ", + "ĖĖĕĖĕ" + ], + [ + "ĕ", + "Ĕĕ" + ], + [ + "ę", + "ėĖĕ" + ], + [ + "ě", + "Ėĕ" + ], + [ + "Ĕ", + "$" + ], + [ + "ĕĖĖĕ", + "ĖĖĖĖĖĖĖĖ" + ], + [ + "ĖĖĖĖĕĕ", + "Ėė" + ], + [ + "Ėĕĕ", + "ĖĖĖĖė" + ], + [ + "ę", + "'" + ], + [ + "Ę", + "*" + ], + [ + "Ě", + "Ď" + ], + [ + "ĖĔ", + "ĖĘ" + ], + [ + "ĖĖĖĖĖĖĕ", + "ĖĖĖĖĖ" + ], + [ + "Ė", + "+" + ], + [ + "Ā", + "Ĉ" + ], + [ + "ĖĖė", + "ĕĖĖ" + ] + ] + } +} \ No newline at end of file diff --git a/scenestreamer/tokenization/0305_fast_veh_5000000/tokenizer_config.json b/scenestreamer/tokenization/0305_fast_veh_5000000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..44b81ebde2224b4e3935b02872938beae622c37c --- /dev/null +++ b/scenestreamer/tokenization/0305_fast_veh_5000000/tokenizer_config.json @@ -0,0 +1,8 @@ +{ + "added_tokens_decoder": {}, + "clean_up_tokenization_spaces": false, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "processor_class": "UniversalActionProcessor", + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/scenestreamer/tokenization/__init__.py b/scenestreamer/tokenization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6ef260ba6fb51ad1904f09faf855423eab24e3 --- /dev/null +++ b/scenestreamer/tokenization/__init__.py @@ -0,0 +1,36 @@ +from scenestreamer.tokenization.biycle_tokenizer import BicycleModelTokenizerFixed0124 +from scenestreamer.tokenization.diffusion_tokenizer import DiffusionTokenizer, SPECIAL_INVALID, SPECIAL_START, SPECIAL_VALID +from scenestreamer.tokenization.motion_tokenizers import DeltaDeltaTokenizer, START_ACTION, END_ACTION, DeltaTokenizer + + +def get_tokenizer(config): + if config.USE_DIFFUSION: + from scenestreamer.tokenization.diffusion_tokenizer import DiffusionTokenizer + return DiffusionTokenizer(config) + + if config.TOKENIZATION.TOKENIZATION_METHOD == "delta": + return DeltaTokenizer(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "delta_delta": + return DeltaDeltaTokenizer(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "precomputed_delta_delta": + return PrecomputedDeltaDeltaTokenizer(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "bicycle": + raise ValueError() + return BicycleModelTokenizer(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "bicycle_noavg": + raise ValueError() + return BicycleModelTokenizerNoAVG(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "bicycle_interpolated": + return BicycleModelInterpolatedTokenizer(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "BicycleModelTokenizerFixed0124": + return BicycleModelTokenizerFixed0124(config) + elif config.TOKENIZATION.TOKENIZATION_METHOD == "fast": + from scenestreamer.tokenization.fast_tokenizer import FastTokenizer + return FastTokenizer(config) + else: + raise ValueError("Unknown tokenizer: {}".format(config.TOKENIZATION.TOKENIZATION_METHOD)) + + +def get_action_dim(config): + t = get_tokenizer(config) + return t.num_actions diff --git a/scenestreamer/tokenization/biycle_tokenizer.py b/scenestreamer/tokenization/biycle_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbac4d7374dc28bf4cae312af30e3f40ad06257 --- /dev/null +++ b/scenestreamer/tokenization/biycle_tokenizer.py @@ -0,0 +1,1300 @@ +import numpy as np +import torch + +from scenestreamer.tokenization.motion_tokenizers import ( + DeltaDeltaTokenizer, + get_relative_velocity, + START_ACTION, + END_ACTION, + BaseTokenizer, + interpolate, + interpolate_heading, +) +from scenestreamer.utils import rotate +from scenestreamer.utils import utils + + +class BicycleModelTokenizerFixed0124(BaseTokenizer): + ACC_MAX = 10 # m/s2 + YAW_RATE_MAX = np.pi / 2 # Just set to < 90 deg otherwise the tan() function will be too large. + + def __init__(self, config): + super().__init__(config) + self.config = config + self.bin_centers = None + assert self.config.DELTA_POS_IS_VELOCITY + + ACC_MAX = self.ACC_MAX + YAW_RATE_MAX = self.YAW_RATE_MAX + print("BicycleModelTokenizer: ACC_MAX: ", ACC_MAX, "YAW_RATE_MAX: ", YAW_RATE_MAX) + + self.x_max = ACC_MAX + self.x_min = -self.x_max + self.y_max = YAW_RATE_MAX + self.y_min = -self.y_max + # assert self.y_max < np.pi / 2 + + self.num_bins = config.TOKENIZATION.NUM_BINS + # assert self.num_bins == 33 + self.num_actions = self.num_bins**2 + + self.acceleration_bins = torch.linspace(self.x_min, self.x_max, self.num_bins) + self.steering_bins = torch.linspace(self.y_min, self.y_max, self.num_bins) + + self.default_action = self.num_bins**2 // 2 + + a_grid, delta_grid = torch.meshgrid(self.acceleration_bins, self.steering_bins, indexing='ij') + a_grid = a_grid.flatten() # .to(device) # Shape: (num_bins^2,) + delta_grid = delta_grid.flatten() # .to(device) # Shape: (num_bins^2,) + + self.a_grid_flat = a_grid + self.delta_grid_flat = delta_grid + self.bin_centers_flat = torch.stack([a_grid, delta_grid], dim=-1).cpu().numpy() # Shape: (num_bins^2, 2) + + self.use_type_specific_bins = False + + num_bins = self.num_bins + # Create coordinate grid centered at (0,0) + y, x = np.ogrid[-(num_bins // 2):(num_bins + 1) // 2, -(num_bins // 2):(num_bins + 1) // 2] + # Calculate the distance from the center + dist_from_center = np.sqrt(x**2 + y**2) + # Normalize distances so that the center is -1 and edges are 0 + max_distance = dist_from_center.max() + min_val = 1e-5 + normalized_dist = ((dist_from_center / max_distance) - 1) * min_val + + # Flatten to get a (num_bins^2,) vector + self.noise = torch.from_numpy(normalized_dist.ravel()).reshape(1, num_bins * num_bins, 1) + + def tokenize(self, data_dict, backward_prediction=False, **kwargs): + """ + + Args: + data_dict: Input data + + Returns: + Discretized action in an int array with shape (num time steps for actions, num agents). + """ + + if backward_prediction: + return self._tokenize_backward_prediction(data_dict, **kwargs) + + # TODO: Hardcoded here... + if self.config.GPT_STYLE: + start_step = 0 + else: + start_step = 2 + + # ===== Hole Filling ===== + data_dict = self.hole_filling(data_dict) + + # ===== Get initial data ===== + # If we don't clone here, the following hole-filling code will overwrite raw data. + agent_pos = data_dict["decoder/agent_position"] # .clone() + agent_heading = data_dict["decoder/agent_heading"] # .clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] # .clone() + agent_velocity = data_dict["decoder/agent_velocity"] # .clone() + agent_shape = data_dict["decoder/current_agent_shape"] # .clone() + agent_type = data_dict["decoder/agent_type"] # .clone() + B, T_full, N, _ = agent_pos.shape + # assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos_full = agent_pos.clone() + agent_heading_full = agent_heading.clone() + agent_velocity_full = agent_velocity.clone() + agent_valid_mask_full = agent_valid_mask.clone() + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + # assert T_chunks == 19 + + # ===== Build up some variables ===== + current_pos = agent_pos[:, start_step:start_step + 1, ..., :2] + current_heading = agent_heading[:, start_step:start_step + 1] + current_vel = agent_velocity[:, start_step:start_step + 1, ..., :2] + current_valid_mask = agent_valid_mask[:, start_step:start_step + 1] + + init_pos = current_pos.clone() + init_heading = current_heading.clone() + init_vel = current_vel.clone() + init_valid_mask = current_valid_mask.clone() + + assert self.config.DELTA_POS_IS_VELOCITY + init_delta = get_relative_velocity(current_vel, current_heading) + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + target_action = [] + target_action_valid_mask = [] + reconstruction_list = [] + relative_delta_pos_list = [] + pos = [] + heading = [] + vel = [] + + # ===== Loop to reconstruct the scenario ===== + tokenization_state = None + for next_step in range(start_step + 1, T_chunks): + res = self._tokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_vel=current_vel, + current_valid_mask=current_valid_mask, + next_pos=agent_pos[:, next_step:next_step + 1, ..., :2], # (B, 1, N, 2) + next_heading=agent_heading[:, next_step:next_step + 1], # (B, 1, N) + next_valid_mask=agent_valid_mask[:, next_step:next_step + 1], # (B, 1, N) + next_velocity=agent_velocity[:, next_step:next_step + 1, ..., :2], # (B, 1, N, 2) + bin_centers=bin_centers, + add_noise=False, + topk=self.config.TOKENIZATION.NOISE_TOPK, + agent_shape=agent_shape, + agent_type=agent_type, + dt=self.dt, + tokenization_state=tokenization_state, + agent_pos_full=agent_pos_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + 1], + agent_heading_full=agent_heading_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + 1], + agent_velocity_full=agent_velocity_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + 1], + agent_valid_mask_full=agent_valid_mask_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + + 1], + ) + tokenization_state = res + + best_action = res["action"] + recon_next_pos = res["pos"] + recon_next_heading = res["heading"] + recon_next_vel = res["vel"] + recon_next_valid_mask = res["mask"] + recon_next_delta_pos = res["delta_pos"] # The input delta for next step. + + best_action = best_action.reshape(B, 1, N) + + # ===== Process the target action/valid mask ===== + target_action_valid_mask.append(recon_next_valid_mask.clone()) + target_action.append(best_action) + + # Some debug asserts + assert (best_action[recon_next_valid_mask] >= 0).all() + assert (best_action[~recon_next_valid_mask] == -1).all() + + # ===== Process the "current_xxx" for next step ===== + if self.config.GPT_STYLE: + assert self.config.TOKENIZATION.ALLOW_SKIP_STEP + if self.config.TOKENIZATION.ALLOW_SKIP_STEP: + # Use the next valid mask as the valid mask for next step. + # In contrast, if this flag is False, then we will use "next valid mask & if it's not removed" for next + # step. + next_valid_mask = agent_valid_mask[:, next_step:next_step + 1] + newly_added = torch.logical_and(~recon_next_valid_mask, next_valid_mask) + if newly_added.any(): + recon_next_pos[newly_added] = agent_pos[:, next_step:next_step + 1, ..., :2][newly_added] + recon_next_heading[newly_added] = agent_heading[:, next_step:next_step + 1][newly_added] + recon_next_vel[newly_added] = agent_velocity[:, next_step:next_step + 1, ..., :2][newly_added] + + recon_next_delta_pos[newly_added] = get_relative_velocity( + vel=agent_velocity[:, next_step:next_step + 1, ..., :2][newly_added], + heading=agent_heading[:, next_step:next_step + 1][newly_added], + ) + recon_next_valid_mask[newly_added] = next_valid_mask[newly_added] + + relative_delta_pos_list.append(recon_next_delta_pos) + current_vel = recon_next_vel + current_heading = recon_next_heading + current_pos = recon_next_pos + current_valid_mask = recon_next_valid_mask + pos.append(current_pos.clone()) + heading.append(current_heading.clone()) + vel.append(current_vel.clone()) + + # ===== Postprocess and prepare the "start action" ===== + # In GPT style, some agents will be added in the middle of the scene. + # So we need to find out when they are in and add a start action before that step. + # In non-GPT style, we only need to prepare the start action for the first step. + target_actions = torch.cat(target_action, dim=1) # (B, T_skipped, N) + target_action_valid_mask = torch.cat(target_action_valid_mask, dim=1) # (B, T_skipped, N) + relative_delta_pos_list = torch.cat(relative_delta_pos_list, dim=1) # (B, T_skipped, N) + pos = torch.cat(pos, dim=1) + heading = torch.cat(heading, dim=1) + vel = torch.cat(vel, dim=1) + + pos = torch.cat([init_pos, pos], dim=1) + heading = torch.cat([init_heading, heading], dim=1) + vel = torch.cat([init_vel, vel], dim=1) + relative_delta_pos_list = torch.cat([init_delta, relative_delta_pos_list], dim=1) + + # If not in back prediction, what will be: + # 1. The first tokens in input_actions? START_ACTION + # 2. The last tokens in input_actions? Just the tokens at t=18 (t=85~90) + # 3. The first tokens in target_actions? The tokens at t=0 (t=0~5) for GPT and t=2 otherwise. + # 4. The last tokens in target_actions? All -1 because there is no GT for t=19 (t=90~95) + assert self.config.GPT_STYLE + # Search for the first step that has newly added agents + assert start_step == 0 + already_tokenized = init_valid_mask.clone() + start_action = torch.full_like(target_actions[:, :1], -1) + start_action[init_valid_mask] = START_ACTION + assert target_actions.shape[1] == T_chunks - 1 + input_action = torch.cat([start_action, target_actions], dim=1) + input_action_valid_mask = torch.cat([init_valid_mask, target_action_valid_mask], dim=1) + for next_step in range(start_step + 1, T_chunks): + next_valid_mask = agent_valid_mask[:, next_step:next_step + 1] + is_newly_added = torch.logical_and(~already_tokenized, next_valid_mask) + if is_newly_added.any(): + input_action[:, next_step:next_step + 1][is_newly_added] = START_ACTION + input_action_valid_mask[:, next_step:next_step + 1][is_newly_added] = \ + next_valid_mask[is_newly_added] + already_tokenized = torch.logical_or(already_tokenized, is_newly_added) + + target_actions = torch.cat([target_actions, target_actions.new_full((B, 1, N), -1)], dim=1) + target_action_valid_mask = torch.cat( + [target_action_valid_mask, target_action_valid_mask.new_zeros((B, 1, N))], dim=1 + ) + data_dict["in_backward_prediction"] = False + assert (agent_valid_mask[:, start_step:] >= target_action_valid_mask).all() + assert (agent_valid_mask[:, start_step + 1:] >= target_action_valid_mask[:, :-1]).all() + assert (agent_valid_mask[:, start_step:] >= input_action_valid_mask).all() + + # # Some debug asserts for backward prediction: + # assert (target_actions[:, :-1] == flipped_target_actions[:, :-1].flip(dims=[1])).all() + # minp = (input_action * (input_action != START_ACTION)) + # minp = minp * (input_action != -1) + # mfinp = (flipped_input_action * (flipped_input_action != END_ACTION)) + # mfinp = mfinp * (flipped_input_action != -1) + # assert (minp[:, 1:] == mfinp[:, 1:].flip(dims=[1])).all() + # assert (pos == flipped_pos.flip(dims=[1])).all() + # assert (heading == flipped_heading.flip(dims=[1])).all() + # assert (vel == flipped_vel.flip(dims=[1])).all() + + data_dict["decoder/target_action"] = target_actions + data_dict["decoder/target_action_valid_mask"] = target_action_valid_mask + data_dict["decoder/input_action"] = input_action + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + data_dict["decoder/modeled_agent_delta"] = relative_delta_pos_list + data_dict["decoder/modeled_agent_position"] = pos + data_dict["decoder/modeled_agent_heading"] = heading + data_dict["decoder/modeled_agent_velocity"] = vel + + # Debug: + # pos_diff = (pos - agent_pos[..., :2]).norm(dim=-1).numpy() + # heading_diff = utils.wrap_to_pi(heading - agent_heading).abs().numpy() + # vel_diff = (vel - agent_velocity[..., :2]).norm(dim=-1).numpy() + + # All input actions should be >0 + assert (input_action[input_action_valid_mask] >= 0).all() + assert (target_actions[target_action_valid_mask] >= 0).all() + assert (input_action[~input_action_valid_mask] == -1).all() + assert (target_actions[~target_action_valid_mask] == -1).all() + + return data_dict, {"reconstruction_list": reconstruction_list} + + def detokenize( + self, + data_dict, + interpolation=True, + detokenizing_gt=False, + backward_prediction=False, + flip_wrong_heading=False, + autoregressive_start_step=2, + **kwargs, + ): # actions, current_pos, current_vel, current_heading): + + if backward_prediction: + return self._detokenize_backward_prediction( + data_dict, interpolation=interpolation, detokenizing_gt=detokenizing_gt, **kwargs + ) + + # TODO: Hardcoded here... + assert self.config.GPT_STYLE + start_step = 0 + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"].clone() + agent_heading = data_dict["decoder/agent_heading"].clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() + agent_velocity = data_dict["decoder/agent_velocity"].clone() + agent_shape = data_dict["decoder/current_agent_shape"].clone() + agent_type = data_dict["decoder/agent_type"].clone() + if detokenizing_gt: + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + input_mask = data_dict["decoder/input_action_valid_mask"] + B, T_full, N, _ = agent_pos.shape + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps].clone() + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + # T_chunks = agent_pos.shape[1] + + # ===== Prepare some variables ===== + action = data_dict["decoder/output_action"] + T_actions = action.shape[1] + T_generated_chunks = T_actions + start_step + + current_pos = agent_pos[:, start_step:start_step + 1, ..., :2].clone() + current_heading = agent_heading[:, start_step:start_step + 1].clone() + current_vel = agent_velocity[:, start_step:start_step + 1, ..., :2].clone() + current_valid_mask = agent_valid_mask[:, start_step:start_step + 1].clone() + + if detokenizing_gt: + # Merge input mask with target mask + input_mask = input_mask & target_action_valid_mask + + reconstructed_pos_list = [current_pos.clone()] + reconstructed_heading_list = [current_heading.clone()] + reconstructed_vel_list = [current_vel.clone()] + + already_interpolated = False + reconstructed_pos_full_list = [current_pos.clone()] + reconstructed_heading_full_list = [current_heading.clone()] + reconstructed_vel_full_list = [current_vel.clone()] + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + kwargs["detokenization_state"] = None + + for curr_step in range(T_generated_chunks): + if curr_step < start_step: + next_pos = agent_pos[:, curr_step + 1:curr_step + 2, ..., :2] + next_heading = agent_heading[:, curr_step + 1:curr_step + 2] + next_vel = agent_velocity[:, curr_step + 1:curr_step + 2, ..., :2] + next_valid_mask = agent_valid_mask[:, curr_step + 1:curr_step + 2] + + else: + # We assume that starting from start_step, the agent valid mask will not change. + action_step = curr_step - start_step + action_valid_mask_step = input_mask[:, action_step:action_step + 1] + + act = action[:, action_step:action_step + 1] + assert (act[action_valid_mask_step] != -1).all() + res = self._detokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=action_valid_mask_step, + current_vel=current_vel, + action=act, + agent_shape=agent_shape, + agent_type=agent_type, + bin_centers=bin_centers, + dt=self.dt, + flip_wrong_heading=flip_wrong_heading, + **kwargs + ) + kwargs["detokenization_state"] = res + + next_pos, next_heading, next_vel = res["pos"], res["heading"], res["vel"] + assert "delta_pos" in res + next_pos = next_pos.reshape(B, 1, N, 2) + next_heading = next_heading.reshape(B, 1, N) + next_vel = next_vel.reshape(B, 1, N, 2) + next_valid_mask = current_valid_mask + + # ===== A special case: fill in the info for the agents added in next step ===== + # ===== Another special case: if you are detokenizing the raw tokenized data, you need to fill in + # the info for the agents added in the next step. ===== + if (curr_step < autoregressive_start_step) or (detokenizing_gt and curr_step < T_generated_chunks - 1): + # Fill in the initial states of newly added agents + action_valid_mask_next_step = input_mask[:, action_step + 1:action_step + 2] + newly_added = torch.logical_and(~action_valid_mask_step, action_valid_mask_next_step) + next_pos[newly_added] = agent_pos[:, curr_step + 1:curr_step + 2, ..., :2][newly_added] + next_heading[newly_added] = agent_heading[:, curr_step + 1:curr_step + 2][newly_added] + next_vel[newly_added] = agent_velocity[:, curr_step + 1:curr_step + 2, ..., :2][newly_added] + next_valid_mask[newly_added] = action_valid_mask_next_step[newly_added] + if "reconstructed_position" in res: + # If some agents are added in the next step, the "last step" in reconstructed chunk + # aka the 5-th step in the chunk should be replaced by the GT states. + res["reconstructed_position"][-1][newly_added] = agent_pos[:, curr_step + 1:curr_step + 2, + ..., :2][newly_added] + res["reconstructed_heading"][-1][newly_added] = agent_heading[:, curr_step + 1:curr_step + + 2][newly_added] + res["reconstructed_velocity"][-1][newly_added] = agent_velocity[:, curr_step + 1:curr_step + 2, + ..., :2][newly_added] + + if "reconstructed_position" in res: + already_interpolated = True + reconstructed_pos_full_list.extend(res["reconstructed_position"]) + reconstructed_heading_full_list.extend(res["reconstructed_heading"]) + reconstructed_vel_full_list.extend(res["reconstructed_velocity"]) + + current_pos = next_pos + current_heading = next_heading + current_vel = next_vel + current_valid_mask = next_valid_mask + + reconstructed_pos_list.append(current_pos.clone()) + reconstructed_heading_list.append(current_heading.clone()) + reconstructed_vel_list.append(current_vel.clone()) + + reconstructed_pos = torch.cat(reconstructed_pos_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_list, dim=1) + + # Every input token has it's own position (before the action). + # As we have 19 tokens, and the last one token will lead us to a new place, + # So it's totally 20 positions. + assert reconstructed_pos.shape[1] == T_generated_chunks + 1 + assert input_mask.shape[1] == T_generated_chunks - start_step + + # Interpolation + if interpolation: + + if already_interpolated: + reconstructed_pos = torch.cat(reconstructed_pos_full_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_full_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_full_list, dim=1) + + else: + + new_reconstructed_pos = interpolate(reconstructed_pos, self.num_skipped_steps, remove_first_step=False) + assert (new_reconstructed_pos[:, ::5] == reconstructed_pos).all() + reconstructed_pos = new_reconstructed_pos + + reconstructed_heading = interpolate_heading( + reconstructed_heading, self.num_skipped_steps, remove_first_step=False + ) + reconstructed_vel = interpolate(reconstructed_vel, self.num_skipped_steps, remove_first_step=False) + + input_mask_augmented = torch.cat([agent_valid_mask[:, :start_step], input_mask], dim=1) + assert input_mask_augmented.shape[1] == T_generated_chunks + valid = input_mask_augmented + valid = valid.reshape(B, -1, 1, N).expand(-1, -1, self.num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + else: + reconstructed_valid_mask = input_mask + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + + return data_dict + + def _tokenize_backward_prediction(self, data_dict, **kwargs): + start_step = 0 + + # ===== Hole Filling ===== + data_dict = self.hole_filling(data_dict) + + # ===== Get initial data ===== + # If we don't clone here, the following hole-filling code will overwrite raw data. + agent_pos = data_dict["decoder/agent_position"] # .clone() + agent_heading = data_dict["decoder/agent_heading"] # .clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] # .clone() + agent_velocity = data_dict["decoder/agent_velocity"] # .clone() + agent_shape = data_dict["decoder/current_agent_shape"] # .clone() + agent_type = data_dict["decoder/agent_type"] # .clone() + B, T_full, N, _ = agent_pos.shape + # assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + # assert T_chunks == 19 + + # ===== Build up some variables ===== + current_pos = agent_pos[:, -1:, ..., :2] + current_heading = agent_heading[:, -1:] + current_vel = agent_velocity[:, -1:, ..., :2] + current_valid_mask = agent_valid_mask[:, -1:] + + init_pos = current_pos.clone() + init_heading = current_heading.clone() + init_vel = current_vel.clone() + init_valid_mask = current_valid_mask.clone() + + init_delta = get_relative_velocity(current_vel, current_heading) + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + target_action = [] + target_action_valid_mask = [] + reconstruction_list = [] + relative_delta_pos_list = [] + pos = [] + heading = [] + vel = [] + + # ===== Loop to reconstruct the scenario ===== + for backward_next_step in range(1, T_chunks): + # backward_next_step = 1, ..., 18 + + forward_next_step = T_chunks - backward_next_step - 1 + # forward_next_step = 17, ..., 0 + + res = self._tokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_vel=current_vel, + current_valid_mask=current_valid_mask, + next_pos=agent_pos[:, forward_next_step:forward_next_step + 1, ..., :2], # (B, 1, N, 2) + next_heading=agent_heading[:, forward_next_step:forward_next_step + 1], # (B, 1, N) + next_valid_mask=agent_valid_mask[:, forward_next_step:forward_next_step + 1], # (B, 1, N) + next_velocity=agent_velocity[:, forward_next_step:forward_next_step + 1, ..., :2], # (B, 1, N, 2) + bin_centers=bin_centers, + add_noise=False, + topk=self.config.TOKENIZATION.NOISE_TOPK, + agent_shape=agent_shape, + agent_type=agent_type, + dt=-self.dt, + **kwargs + ) + + best_action = res["action"] + recon_next_pos = res["pos"] + recon_next_heading = res["heading"] + recon_next_vel = res["vel"] + recon_next_valid_mask = res["mask"] + recon_next_delta_pos = res["delta_pos"] # The input delta for next step. + + best_action = best_action.reshape(B, 1, N) + + # ===== Process the target action/valid mask ===== + target_action_valid_mask.append(recon_next_valid_mask.clone()) + target_action.append(best_action) + + # Some debug asserts + assert (best_action[recon_next_valid_mask] >= 0).all() + assert (best_action[~recon_next_valid_mask] == -1).all() + + # ===== Process the "current_xxx" for next step ===== + # Use the next valid mask as the valid mask for next step. + # In contrast, if this flag is False, then we will use "next valid mask & if it's not removed" for next + # step. + next_valid_mask = agent_valid_mask[:, forward_next_step:forward_next_step + 1] + newly_added = torch.logical_and(~recon_next_valid_mask, next_valid_mask) + if newly_added.any(): + recon_next_pos[newly_added] = agent_pos[:, forward_next_step:forward_next_step + 1, + ..., :2][newly_added] + recon_next_heading[newly_added] = agent_heading[:, forward_next_step:forward_next_step + 1][newly_added] + recon_next_vel[newly_added] = agent_velocity[:, forward_next_step:forward_next_step + 1, + ..., :2][newly_added] + recon_next_delta_pos[newly_added] = get_relative_velocity( + vel=agent_velocity[:, forward_next_step:forward_next_step + 1, ..., :2][newly_added], + heading=agent_heading[:, forward_next_step:forward_next_step + 1][newly_added], + ) + recon_next_valid_mask[newly_added] = next_valid_mask[newly_added] + + relative_delta_pos_list.append(recon_next_delta_pos) + current_vel = recon_next_vel + current_heading = recon_next_heading + current_pos = recon_next_pos + current_valid_mask = recon_next_valid_mask + pos.append(current_pos.clone()) + heading.append(current_heading.clone()) + vel.append(current_vel.clone()) + + # ===== Postprocess and prepare the "start action" ===== + # In GPT style, some agents will be added in the middle of the scene. + # So we need to find out when they are in and add a start action before that step. + # In non-GPT style, we only need to prepare the start action for the first step. + target_actions = torch.cat(target_action, dim=1) # (B, T_skipped, N) + target_action_valid_mask = torch.cat(target_action_valid_mask, dim=1) # (B, T_skipped, N) + relative_delta_pos_list = torch.cat(relative_delta_pos_list, dim=1) # (B, T_skipped, N) + pos = torch.cat(pos, dim=1) + heading = torch.cat(heading, dim=1) + vel = torch.cat(vel, dim=1) + + pos = torch.cat([init_pos, pos], dim=1) + heading = torch.cat([init_heading, heading], dim=1) + vel = torch.cat([init_vel, vel], dim=1) + relative_delta_pos_list = torch.cat([init_delta, relative_delta_pos_list], dim=1) + + # Search for the first step that has newly added agents + assert start_step == 0 + already_tokenized = init_valid_mask.clone() + start_action = torch.full_like(target_actions[:, :1], -1) + start_action[init_valid_mask] = END_ACTION + assert target_actions.shape[1] == T_chunks - 1 + input_action = torch.cat([start_action, target_actions], dim=1) + input_action_valid_mask = torch.cat([init_valid_mask, target_action_valid_mask], dim=1) + for backward_next_step in range(1, T_chunks): + forward_next_step = T_chunks - backward_next_step - 1 + next_valid_mask = agent_valid_mask[:, forward_next_step:forward_next_step + 1] + is_newly_added = torch.logical_and(~already_tokenized, next_valid_mask) + if is_newly_added.any(): + input_action[:, backward_next_step:backward_next_step + 1][is_newly_added] = END_ACTION + input_action_valid_mask[:, backward_next_step:backward_next_step + 1][is_newly_added] = \ + next_valid_mask[is_newly_added] + already_tokenized = torch.logical_or(already_tokenized, is_newly_added) + + target_actions = torch.cat([target_actions, target_actions.new_full((B, 1, N), -1)], dim=1) + target_action_valid_mask = torch.cat( + [target_action_valid_mask, target_action_valid_mask.new_zeros((B, 1, N))], dim=1 + ) + data_dict["in_backward_prediction"] = True + + flipped_agent_valid_mask = agent_valid_mask.flip(dims=[1]) + assert (flipped_agent_valid_mask[:, start_step:] >= target_action_valid_mask).all() + assert (flipped_agent_valid_mask[:, start_step + 1:] >= target_action_valid_mask[:, :-1]).all() + assert (flipped_agent_valid_mask[:, start_step:] >= input_action_valid_mask).all() + + data_dict["decoder/target_action"] = target_actions + data_dict["decoder/target_action_valid_mask"] = target_action_valid_mask + data_dict["decoder/input_action"] = input_action + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + data_dict["decoder/modeled_agent_delta"] = relative_delta_pos_list + data_dict["decoder/modeled_agent_position"] = pos + data_dict["decoder/modeled_agent_heading"] = heading + data_dict["decoder/modeled_agent_velocity"] = vel + + # All input actions should be >0 + assert (input_action[input_action_valid_mask] >= 0).all() + assert (target_actions[target_action_valid_mask] >= 0).all() + assert (input_action[~input_action_valid_mask] == -1).all() + assert (target_actions[~target_action_valid_mask] == -1).all() + + return data_dict, {"reconstruction_list": reconstruction_list} + + def _tokenize_a_step_backward( + self, *, current_pos, current_heading, current_valid_mask, current_vel, next_pos, next_heading, next_valid_mask, + next_vel, add_noise, agent_shape, dt, **kwargs + ): + assert dt > 0 + + device = current_pos.device + B, T, N, _ = current_pos.shape + assert T == 1 + + # Prepare grid of accelerations and steering angles for batch processing + a_grid = self.a_grid_flat.to(device) # Shape: (num_bins^2,) + delta_grid = self.delta_grid_flat.to(device) # Shape: (num_bins^2,) + + num_candidates = a_grid.shape[0] + a_grid_exp = a_grid.view(1, -1, 1).expand(B, num_candidates, N) # (B, num_candidates, N, 1) + yaw_rate = delta_grid.view(1, -1, 1).expand(B, num_candidates, N) # (B, num_candidates, N, 1) + + # Repeat current states for batch computation + current_pos_exp = current_pos.expand(B, num_candidates, N, 2) # (B, num_candidates, N, 2) + current_heading_exp = current_heading.expand(B, num_candidates, N) # (B, num_candidates, N) + next_heading_exp = next_heading.expand(B, num_candidates, N) # (B, num_candidates, N) + + # Next speed: + next_speed = next_vel.norm(dim=-1).expand(B, num_candidates, N) + current_speed_candidate = next_speed - a_grid_exp * dt + average_speed = (current_speed_candidate + next_speed) / 2 + + # wheel_base = agent_shape[..., 0].reshape(B, 1, N) + # yaw_rate = (average_speed / wheel_base) * torch.tan(delta_grid_exp) + delta_theta = yaw_rate * dt + current_heading_candidate = utils.wrap_to_pi(next_heading_exp - delta_theta) + average_heading = utils.wrap_to_pi(utils.average_heading(current_heading_candidate, next_heading_exp)) + + average_velocity_candidate = rotate(average_speed, torch.zeros_like(average_speed), angle=average_heading) + current_velocity_candidate = rotate( + current_speed_candidate, torch.zeros_like(current_speed_candidate), angle=current_heading_candidate + ) + + current_pos_reconstructed = next_pos - average_velocity_candidate * dt + current_pos_reconstructed = current_pos_reconstructed.expand(B, num_candidates, N, 2) + + contour = utils.cal_polygon_contour_torch( + x=current_pos_reconstructed[..., 0], + y=current_pos_reconstructed[..., 1], + theta=current_heading_candidate, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + + gt_contour = utils.cal_polygon_contour_torch( + x=current_pos_exp[..., 0], + y=current_pos_exp[..., 1], + theta=current_heading_exp, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + + error_pos = torch.norm(contour - gt_contour, dim=-1).mean(-1) + error = error_pos # + error_heading + + error = error + self.noise + + if add_noise: + # Get top-k actions based on the error + candidates = error.topk(5, largest=False, dim=1).indices + best_action = torch.gather( + candidates, index=torch.randint(0, 5, size=(B, 1, N)).to(candidates.device), dim=1 + ).squeeze(1) + raise ValueError() + + else: + # Pick the best bin with the least error: + min_result = error.min(dim=1) + best_action = min_result.indices + + # Update reconstructed position and velocity according to the best action: + ind = best_action.reshape(B, 1, N, 1).expand(B, 1, N, 2).clone() + mask = ind == -1 + ind[mask] = self.default_action # Workaround the gather can't handle -1 + reconstructed_pos = torch.gather(current_pos_reconstructed, index=ind, dim=1) + reconstructed_vel = torch.gather(current_velocity_candidate, index=ind, dim=1) + + ind = best_action.reshape(B, 1, N).clone() + mask = ind == -1 + ind[mask] = self.default_action # Workaround the gather can't handle -1 + reconstructed_heading = torch.gather(current_heading_candidate, index=ind, dim=1) + + valid_mask = current_valid_mask & next_valid_mask + assert current_pos.shape == reconstructed_pos.shape + best_action = best_action.reshape(B, 1, N) + best_action[~valid_mask] = -1 + reconstructed_pos[~valid_mask] = 0 + reconstructed_vel[~valid_mask] = 0 + reconstructed_heading[~valid_mask] = 0 + assert (best_action[valid_mask] >= 0).all() + assert (best_action[~valid_mask] == -1).all() + assert self.num_bins == 33 + + # Just return the relative velocity. + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + relative_delta_pos[mask] = 0 + + # AID = 0 + # masked_best_action = best_action * valid_mask + # delta = torch.gather(delta_grid_exp, index=masked_best_action.reshape(B, 1, N), dim=1) + # acc = torch.gather(a_grid_exp, index=masked_best_action.reshape(B, 1, N), dim=1) + # action = best_action[0,0,AID] + # average_velocity_candidate = average_velocity_candidate[0,action,AID] + # current_velocity_candidate = current_velocity_candidate[0,action,AID] + # print( + # f"[TOK] AID{AID} CUR POS: {current_pos[0, 0, AID].cpu().numpy()}, " + # f"RECON POS: {reconstructed_pos[0, 0, AID].cpu().numpy()}, " + # f"NEXT POS: {next_pos[0, 0, AID].cpu().numpy()}, " + # f"Action: {best_action[0, 0, AID].cpu().numpy()}, " + # f"ACC: {acc[0, 0, AID].cpu().numpy()}, " + # f"STEER: {delta[0, 0, AID].cpu().numpy()}, " + # # f"CUR VEL: {current_vel[0, 0, AID].norm(dim=-1).cpu().numpy()}, " + # f"RECON VEL: {reconstructed_vel[0, 0, AID].norm(dim=-1).cpu().numpy()}, " + # f"VALID: {valid_mask[0, 0, AID].cpu().numpy()}", + # f"CUR VEL: {current_velocity_candidate.cpu().numpy()}", + # f"AVG VEL: {average_velocity_candidate.cpu().numpy()}", + # ) + + return dict( + action=best_action, + pos=reconstructed_pos, + heading=reconstructed_heading, + vel=reconstructed_vel, + mask=valid_mask, + delta_pos=relative_delta_pos + ) + + def _tokenize_a_step( + self, *, current_pos, current_heading, current_valid_mask, current_vel, next_pos, next_heading, next_valid_mask, + add_noise, agent_shape, dt, **kwargs + ): + + if dt < 0: + # TODO: This is a trick to handle the backward prediction. We flip current/next states and dt. + # Might cause confusion to other users. + return self._tokenize_a_step_backward( + current_pos=next_pos, + current_heading=next_heading, + current_valid_mask=next_valid_mask, + current_vel=None, + next_vel=current_vel, + next_pos=current_pos, + next_heading=current_heading, + next_valid_mask=current_valid_mask, + add_noise=add_noise, + agent_shape=agent_shape, + dt=-dt, + **kwargs + # current_pos, current_heading, current_valid_mask, current_vel, next_pos, next_heading, + # next_valid_mask, + # add_noise, agent_shape, -dt, **kwargs + ) + + device = current_pos.device + B, T, N, _ = current_pos.shape + assert T == 1 + + # Prepare grid of accelerations and steering angles for batch processing + a_grid = self.a_grid_flat.to(device) # Shape: (num_bins^2,) + delta_grid = self.delta_grid_flat.to(device) # Shape: (num_bins^2,) + + num_candidates = a_grid.shape[0] + a_grid_exp = a_grid.view(1, -1, 1).expand(B, num_candidates, N) # (B, num_candidates, N, 1) + delta_grid_exp = delta_grid.view(1, -1, 1).expand(B, num_candidates, N) # (B, num_candidates, N, 1) + + # Repeat current states for batch computation + current_pos_exp = current_pos.expand(B, num_candidates, N, 2) # (B, num_candidates, N, 2) + current_heading_exp = current_heading.expand(B, num_candidates, N) # (B, num_candidates, N) + + # Current speed in local frame: + current_speed = current_vel.norm(dim=-1) # (B, num_candidates, N) + current_speed = current_speed.expand(B, num_candidates, N) # (B, num_candidates, N, 2) + next_speed_candidate = current_speed + a_grid_exp * self.dt + average_speed = (current_speed + next_speed_candidate) / 2 + + # wheel_base = agent_shape[..., 0].reshape(B, 1, N) + # yaw_rate = (average_speed / wheel_base) * torch.tan(delta_grid_exp) + yaw_rate = delta_grid_exp + delta_theta = yaw_rate * self.dt + next_heading_candidate = utils.wrap_to_pi(current_heading_exp + delta_theta) + average_heading = utils.wrap_to_pi(utils.average_heading(next_heading_candidate, current_heading_exp)) + + # Rotate velocity vector to update both v_x and v_y + next_velocity_candidate = rotate( + next_speed_candidate, torch.zeros_like(next_speed_candidate), angle=next_heading_candidate + ) # (B, num_candidates, N, 2) + + average_next_velocity = rotate(average_speed, torch.zeros_like(average_speed), angle=average_heading) + next_pos_candidate = current_pos_exp + average_next_velocity * self.dt + + no_displacement_mask = None + invalid_next_pos = None + + contour = utils.cal_polygon_contour_torch( + x=next_pos_candidate[..., 0], + y=next_pos_candidate[..., 1], + theta=next_heading_candidate, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + + gt_contour = utils.cal_polygon_contour_torch( + x=next_pos[..., 0], + y=next_pos[..., 1], + theta=next_heading, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + + error_pos = torch.norm(contour - gt_contour, dim=-1).mean(-1) + error = error_pos # + error_heading + + # Add the very small noise to break the tie! + if self.noise.device != error.device: + self.noise = self.noise.to(error.device) + error = error + self.noise + + if add_noise: + # Get top-k actions based on the error + TOPK = 5 + candidates = error.topk(TOPK, largest=False, dim=1).indices + best_action = torch.gather( + candidates, index=torch.randint(0, TOPK, size=(B, 1, N)).to(candidates.device), dim=1 + ).squeeze(1) + raise ValueError() + else: + # Pick the best bin with the least error: + min_result = error.min(dim=1) + best_action = min_result.indices + + if no_displacement_mask is not None: + best_action[no_displacement_mask.squeeze(1)] = self.default_action + + if invalid_next_pos is not None: + best_action[invalid_next_pos.squeeze(1)] = self.default_action + + # Update reconstructed position and velocity according to the best action: + ind = best_action.reshape(B, 1, N, 1).expand(B, 1, N, 2).clone() + mask = ind == -1 + ind[mask] = 0 + reconstructed_pos = torch.gather(next_pos_candidate, index=ind, dim=1) + reconstructed_vel = torch.gather(next_velocity_candidate, index=ind, dim=1) + + ind = best_action.reshape(B, 1, N).clone() + mask = ind == -1 + ind[mask] = 0 + reconstructed_heading = torch.gather(next_heading_candidate, index=ind, dim=1) + + valid_mask = current_valid_mask & next_valid_mask + assert current_pos.shape == reconstructed_pos.shape + + if invalid_next_pos is not None: + reconstructed_heading[invalid_next_pos] = current_heading[invalid_next_pos] + + if no_displacement_mask is not None: + reconstructed_heading[no_displacement_mask] = current_heading[no_displacement_mask] + + best_action = best_action.reshape(B, 1, N) + best_action[~valid_mask] = -1 + reconstructed_pos[~valid_mask] = 0 + reconstructed_vel[~valid_mask] = 0 + reconstructed_heading[~valid_mask] = 0 + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + relative_delta_pos[~valid_mask] = 0 + + assert (best_action[valid_mask] >= 0).all() + assert (best_action[~valid_mask] == -1).all() + + # AID = 0 + # masked_best_action = best_action * valid_mask + # delta = torch.gather(delta_grid_exp, index=masked_best_action.reshape(B, 1, N), dim=1) + # acc = torch.gather(a_grid_exp, index=masked_best_action.reshape(B, 1, N), dim=1) + # print( + # f"[TOK] AID{AID} CUR POS: {current_pos[0, 0, AID].cpu().numpy()}, " + # f"RECON POS: {reconstructed_pos[0, 0, AID].cpu().numpy()}, " + # f"GT POS: {next_pos[0, 0, AID].cpu().numpy()}, " + # f"Action: {best_action[0, 0, AID].cpu().numpy()}, " + # f"ACC: {acc[0, 0, AID].cpu().numpy()}, " + # f"STEER: {delta[0, 0, AID].cpu().numpy()}, " + # f"CUR VEL: {current_vel[0, 0, AID].norm(dim=-1).cpu().numpy()}, " + # f"RECON VEL: {reconstructed_vel[0, 0, AID].norm(dim=-1).cpu().numpy()}, " + # f"VALID: {valid_mask[0, 0, AID].cpu().numpy()}" + # ) + + return dict( + action=best_action, + pos=reconstructed_pos, + heading=reconstructed_heading, + vel=reconstructed_vel, + mask=valid_mask, + delta_pos=relative_delta_pos + ) + + def _detokenize_backward_prediction( + self, + data_dict, + interpolation=True, + detokenizing_gt=False, + flip_wrong_heading=False, + ): # actions, current_pos, current_vel, current_heading): + """ + Compared to the non-gpt style, this function dynamically adds new agents into the scene. + A very interesting point here is we can't start with 'current position' in the data. + Because the model is predicting according to the first few tokens, which already have some errors. + """ + # TODO: Hardcoded here... + assert self.config.GPT_STYLE + start_step = 0 + # autoregressive_start_step = 2 + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"].clone() + agent_heading = data_dict["decoder/agent_heading"].clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() + agent_velocity = data_dict["decoder/agent_velocity"].clone() + agent_shape = data_dict["decoder/current_agent_shape"].clone() + agent_type = data_dict["decoder/agent_type"].clone() + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + input_mask = data_dict["decoder/input_action_valid_mask"] + B, T_full, N, _ = agent_pos.shape + assert T_full == 91 # TODO: hardcoded + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 # TODO: hardcoded + + # ===== Prepare some variables ===== + action = data_dict["decoder/output_action"] + T_actions = action.shape[1] + if T_actions + start_step != T_chunks: + print( + "WARNING: The number of actions is not consistent with the number of raw data chunks! You have {} actions, start step is {} and the number of chunks is {}." + .format(T_actions, start_step, T_chunks) + ) + T_generated_chunks = T_actions + start_step + + current_pos = agent_pos[:, -1:, ..., :2] + current_heading = agent_heading[:, -1:] + current_vel = agent_velocity[:, -1:, ..., :2] + current_valid_mask = agent_valid_mask[:, -1:] + + if detokenizing_gt: + # Merge input mask with target mask + input_mask = input_mask & target_action_valid_mask + + reconstructed_pos_list = [current_pos.clone()] + reconstructed_heading_list = [current_heading.clone()] + reconstructed_vel_list = [current_vel.clone()] + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + for curr_backward_step in range(T_generated_chunks): + # curr_backward_step = 0, 1, ..., 18 + + curr_forward_step = T_chunks - curr_backward_step - 1 + # curr_forward_step = 18, 17, ..., 0 + + next_forward_step = curr_forward_step - 1 + # next_forward_step = 17, 16, ..., -1 + + action_valid_mask_step = input_mask[:, curr_backward_step:curr_backward_step + 1] + act = action[:, curr_backward_step:curr_backward_step + 1] + assert (act[action_valid_mask_step] != -1).all() + res = self._detokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=action_valid_mask_step, + current_vel=current_vel, + action=act, + agent_shape=agent_shape, + agent_type=agent_type, + bin_centers=bin_centers, + dt=-self.dt, + flip_wrong_heading=flip_wrong_heading, + ) + next_pos, next_heading, next_vel = res["pos"], res["heading"], res["vel"] + next_pos = next_pos.reshape(B, 1, N, 2) + next_heading = next_heading.reshape(B, 1, N) + next_vel = next_vel.reshape(B, 1, N, 2) + next_valid_mask = current_valid_mask + + # if detokenizing_gt and curr_backward_step < T_generated_chunks - 1: + # TODO: Here the detokenizing_gt is ignored and we always add new agents in. + if curr_backward_step < T_generated_chunks - 1: + # Fill in the initial states of newly added agents + action_valid_mask_next_step = input_mask[:, curr_backward_step + 1:curr_backward_step + 2] + newly_added = torch.logical_and(~action_valid_mask_step, action_valid_mask_next_step) + next_pos[newly_added] = agent_pos[:, next_forward_step:next_forward_step + 1, ..., :2][newly_added] + next_heading[newly_added] = agent_heading[:, next_forward_step:next_forward_step + 1][newly_added] + next_vel[newly_added] = agent_velocity[:, next_forward_step:next_forward_step + 1, ..., :2][newly_added] + next_valid_mask[newly_added] = action_valid_mask_next_step[newly_added] + + current_pos = next_pos + current_heading = next_heading + current_vel = next_vel + current_valid_mask = next_valid_mask + + reconstructed_pos_list.append(current_pos.clone()) + reconstructed_heading_list.append(current_heading.clone()) + reconstructed_vel_list.append(current_vel.clone()) + + reconstructed_pos = torch.cat(reconstructed_pos_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_list, dim=1) + + # Every input token has it's own position (before the action). + # As we have 19 tokens, and the last one token will lead us to a new place, + # So it's totally 20 positions. + assert reconstructed_pos.shape[1] == T_generated_chunks + 1 + assert input_mask.shape[1] == T_generated_chunks - start_step + + # TODO: Not sure if we should return flipped data or not. + reconstructed_pos = reconstructed_pos.flip(dims=[1]) + reconstructed_heading = reconstructed_heading.flip(dims=[1]) + reconstructed_vel = reconstructed_vel.flip(dims=[1]) + input_mask = input_mask.flip(dims=[1]) + + # Interpolation + if interpolation: + new_reconstructed_pos = interpolate(reconstructed_pos, self.num_skipped_steps, remove_first_step=False) + assert (new_reconstructed_pos[:, ::5] == reconstructed_pos).all() + reconstructed_pos = new_reconstructed_pos + + reconstructed_heading = interpolate_heading( + reconstructed_heading, self.num_skipped_steps, remove_first_step=False + ) + reconstructed_vel = interpolate(reconstructed_vel, self.num_skipped_steps, remove_first_step=False) + + # input_mask_augmented = torch.cat([agent_valid_mask[:, :start_step], input_mask], dim=1) + input_mask_augmented = input_mask + assert input_mask_augmented.shape[1] == T_generated_chunks + valid = input_mask_augmented + valid = valid.reshape(B, -1, 1, N).expand(-1, -1, self.num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + else: + reconstructed_valid_mask = input_mask + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + + return data_dict + + def _detokenize_a_step_backward( + self, *, current_pos, current_heading, current_valid_mask, current_vel, action, agent_shape, dt, **kwargs + ): + assert dt > 0 + + assert action.ndim == 3 + B, T_action, N = action.shape + assert T_action == 1 + # Retrieve acceleration and steering angle based on decoded bins + if self.acceleration_bins.device != action.device: + self.acceleration_bins = self.acceleration_bins.to(action.device) + self.steering_bins = self.steering_bins.to(action.device) + + action_expanded = action.reshape(B, T_action, N, 1).expand(B, T_action, N, 1).clone() + # TODO: This line is wrong. Some invalid action will have non-neg1 value. + mask = (action_expanded == -1) | (action_expanded == START_ACTION) | (action_expanded == END_ACTION) + action_expanded[mask] = 0 + + acceleration_bins = self.acceleration_bins.reshape(1, 1, 1, -1).expand(B, T_action, N, -1) + steering_bins = self.steering_bins.reshape(1, 1, 1, -1).expand(B, T_action, N, -1) + + # Decode the action into acceleration and steering angle bins + best_a_idx = action_expanded // self.num_bins + best_delta_idx = action_expanded % self.num_bins + + best_acceleration = torch.gather(acceleration_bins, index=best_a_idx, axis=3).squeeze(-1) + best_steering = torch.gather(steering_bins, index=best_delta_idx, axis=3).squeeze(-1) + + # Next speed: + next_heading_exp = current_heading + next_pos = current_pos + next_vel = current_vel + next_speed = next_vel.norm(dim=-1) + current_speed_candidate = next_speed - best_acceleration * dt + average_speed = (current_speed_candidate + next_speed) / 2 + + # wheel_base = agent_shape[..., 0].reshape(B, 1, N) + # yaw_rate = (average_speed / wheel_base) * torch.tan(best_steering) + yaw_rate = best_steering + delta_theta = yaw_rate * dt + current_heading_candidate = utils.wrap_to_pi(next_heading_exp - delta_theta) + average_heading = utils.wrap_to_pi(utils.average_heading(current_heading_candidate, next_heading_exp)) + + reconstructed_vel = rotate( + current_speed_candidate, torch.zeros_like(current_speed_candidate), angle=current_heading_candidate + ) + average_velocity = rotate(average_speed, torch.zeros_like(average_speed), angle=average_heading) + + reconstructed_pos = next_pos - average_velocity * dt + reconstructed_heading = current_heading_candidate + + # Masking + valid_mask = current_valid_mask.reshape(B, 1, N, 1).expand(B, 1, N, 2) + reconstructed_pos[~valid_mask] = 0 + reconstructed_vel[~valid_mask] = 0 + reconstructed_heading[~valid_mask[..., 0]] = 0 + + # AID = 0 + # print( + # f"[DETOK] AID{AID} PRED POS: {reconstructed_pos[0,0,AID].cpu().numpy()}, " + # f"PRED HEAD: {reconstructed_heading[0,0,AID]}, " + # f"PRED VEL: {reconstructed_vel[0,0,AID].norm(dim=-1).cpu().numpy()}, " + # f"SPEED: {next_speed[0,0,AID]:.4f}, " + # f"NEXT POS: {current_pos[0,0,AID].cpu().numpy()}, " + # f"NEXT HEAD: {current_heading[0,0,AID]:.4f}, " + # f"VALID: {valid_mask[0,0,AID].cpu().numpy()}" + # f"ACTION: {action[0,0,AID].cpu().numpy()}, " + # ) + + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + relative_delta_pos[~valid_mask] = 0 + + return dict( + pos=reconstructed_pos, heading=reconstructed_heading, vel=reconstructed_vel, delta_pos=relative_delta_pos + ) + + def _detokenize_a_step( + self, + *, + current_pos, + current_heading, + current_valid_mask, + current_vel, + action, + agent_shape=None, + dt=None, + **kwargs + ): + if dt is not None and dt < 0: + # TODO: This is a trick to handle the backward prediction. We flip current/next states and dt. + # Might cause confusion to other users. + return self._detokenize_a_step_backward( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + current_vel=current_vel, + action=action, + agent_shape=agent_shape, + dt=-dt, + **kwargs + ) + action = action.clone() + + assert action.ndim == 3 + B, T_action, N = action.shape + assert T_action == 1 + # Retrieve acceleration and steering angle based on decoded bins + if self.acceleration_bins.device != action.device: + self.acceleration_bins = self.acceleration_bins.to(action.device) + self.steering_bins = self.steering_bins.to(action.device) + + mask = (action == -1) | (action == START_ACTION) | (action == END_ACTION) | (~current_valid_mask) + action[mask] = 0 + action_expanded = action.reshape(B, T_action, N, 1).expand(B, T_action, N, 1).clone() + + acceleration_bins = self.acceleration_bins.reshape(1, 1, 1, -1).expand(B, T_action, N, -1) + steering_bins = self.steering_bins.reshape(1, 1, 1, -1).expand(B, T_action, N, -1) + + # Decode the action into acceleration and steering angle bins + best_a_idx = action_expanded // self.num_bins + best_delta_idx = action_expanded % self.num_bins + + best_acceleration = torch.gather(acceleration_bins, index=best_a_idx, axis=3).squeeze(-1) + best_steering = torch.gather(steering_bins, index=best_delta_idx, axis=3).squeeze(-1) + + # Update velocity components + current_speed = current_vel.norm(dim=-1) # (B, N) + next_speed = current_speed + best_acceleration.reshape_as(current_speed) * self.dt + average_speed = (current_speed + next_speed) / 2 + + # Compute yaw rate and resulting change in heading + # wheelbase = agent_shape[..., 0].reshape(B, 1, N) # shape = Length, Width, Height + # yaw_rate = (average_speed / wheelbase) * torch.tan(best_steering.squeeze(-1)) + yaw_rate = best_steering.reshape_as(current_heading) + delta_theta = yaw_rate * self.dt + next_heading = utils.wrap_to_pi(current_heading + delta_theta) + reconstructed_heading = next_heading.reshape(B, 1, N) + average_heading = utils.wrap_to_pi(utils.average_heading(current_heading, next_heading)) + + average_velocity = rotate(average_speed, torch.zeros_like(average_speed), angle=average_heading) + next_velocity = rotate(next_speed, torch.zeros_like(next_speed), angle=next_heading) + reconstructed_vel = next_velocity.reshape(B, 1, N, 2) + + next_pos = current_pos + average_velocity * self.dt + reconstructed_pos = next_pos.reshape(B, 1, N, 2) + + # Masking + valid_mask = current_valid_mask.reshape(B, 1, N, 1).expand(B, 1, N, 2) + reconstructed_pos[~valid_mask] = 0 + reconstructed_vel[~valid_mask] = 0 + reconstructed_heading[~valid_mask[..., 0]] = 0 + + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + relative_delta_pos[~valid_mask] = 0 + + return dict( + pos=reconstructed_pos, heading=reconstructed_heading, vel=reconstructed_vel, delta_pos=relative_delta_pos + ) diff --git a/scenestreamer/tokenization/delta.py b/scenestreamer/tokenization/delta.py new file mode 100644 index 0000000000000000000000000000000000000000..fcca57ae265d1232bde41df11d2502ff27dfb446 --- /dev/null +++ b/scenestreamer/tokenization/delta.py @@ -0,0 +1,602 @@ +# from scenestreamer.tokenization.motion_tokenizers import BaseTokenizer + +import argparse +import datetime +import os +import pathlib +from typing import * + +import lightning.pytorch as pl +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import wandb +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger +from scipy.cluster.vq import kmeans2 + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.models.layers import common_layers +# from scenestreamer.models.motionlm_lightning import MotionLMLightning +# from scenestreamer.tokenization.tokenizers import rotate +from scenestreamer.utils import global_config, cfg_from_yaml_file, REPO_ROOT, get_time_str +from scenestreamer.utils import lr_schedule +# from scenestreamer.tokenization.tokenizers import DeltaTokenizer, DeltaDeltaTokenizer +from scenestreamer.utils import rotate, unwrap, wrap_to_pi + + +def compute_3d_translation(data_dict, num_skipped_steps, offset=0): + future_pos = data_dict["encoder/agent_position"] + future_heading = data_dict["encoder/agent_heading"] + future_valid_mask = data_dict["encoder/agent_valid_mask"] + + assert offset < num_skipped_steps + + future_pos = future_pos[:, offset:] + future_heading = future_heading[:, offset:] + future_valid_mask = future_valid_mask[:, offset:] + + current_pos = future_pos[:, 0] + current_heading = future_heading[:, 0] + current_valid_mask = future_valid_mask[:, 0] + + # T_action = future_pos.shape[1] + + B, T, N, _ = future_pos.shape + + # T_action = T + + # future_pos_sliced = future_pos[:, num_skipped_steps - 1::self.num_skipped_steps] + # assert future_pos_sliced.shape[1] == T_action + # + # future_heading_sliced = future_heading[:, self.num_skipped_steps - 1::self.num_skipped_steps] + # assert future_heading_sliced.shape[1] == T_action + # + # future_valid_mask = future_valid_mask[:, self.num_skipped_steps - 1::self.num_skipped_steps] + # assert future_valid_mask.shape[1] == T_action + + reconstructed_pos = current_pos[..., :2].clone().reshape(B, 1, N, 2) + reconstructed_heading = current_heading.clone().reshape(B, 1, N) + reconstructed_valid_mask = current_valid_mask.clone().reshape(B, 1, N) + + target_action = [] + target_action_valid_mask = [] + + delta_heading = [] + + # reconstruction_error = [] # For stats + + current_t = 0 + + while True: + current_t += num_skipped_steps + + if current_t + num_skipped_steps > T: + break + + # Real position at this step: + real_pos = future_pos[:, current_t:current_t + num_skipped_steps, ..., :2] # (1, N, 2) + real_heading = future_heading[:, current_t:current_t + num_skipped_steps] # (1, N, 2) + + # Update valid mask + real_valid_mask = future_valid_mask[:, current_t:current_t + num_skipped_steps] + real_valid_mask = real_valid_mask.all(dim=1, keepdims=True) + reconstructed_valid_mask = torch.logical_and(reconstructed_valid_mask, real_valid_mask) + assert reconstructed_valid_mask.shape == (B, 1, N) + target_action_valid_mask.append(reconstructed_valid_mask) + + abs_delta = real_pos - reconstructed_pos + y_axis_in_relative_coord = reconstructed_heading.repeat(1, num_skipped_steps, 1) + x_axis_in_relative_coord = y_axis_in_relative_coord - np.pi / 2 + candidate_pos = rotate(abs_delta[..., 0], abs_delta[..., 1], -x_axis_in_relative_coord) + + target_action.append(candidate_pos) + delta_heading.append(wrap_to_pi(real_heading - reconstructed_heading)) + + new_reconstructed_pos = real_pos[:, -1:] + reconstructed_pos = new_reconstructed_pos + reconstructed_heading = real_heading[:, -1:] + + target_actions = torch.stack(target_action, dim=1) # (B, T_skipped, N) + delta_heading = torch.stack(delta_heading, dim=1) # (B, T_skipped, N) + + deltas = torch.concat([target_actions, delta_heading[..., None]], dim=-1) + deltas = deltas.swapaxes(-3, -2) + + assert deltas.ndim == 5 + + target_action_valid_mask = torch.concatenate(target_action_valid_mask, dim=1) # (B, T_skipped, N) + + return deltas, target_action_valid_mask + + +class VectorQuantizer(nn.Module): + """ + PZH: From huggingface + + + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float = 0.25, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + # legacy: bool = True, + legacy: bool = False, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + self.register_buffer('data_initialized', torch.zeros(1)) + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.FloatTensor, disable=False) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]: + # reshape z -> (batch, height, width, channel) and flatten + # z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # PZH: https://github.com/karpathy/deep-vector-quantization/blob/c3c026a1ccea369bc892ad6dde5e6d6cd5a508a4/dvq/model/quantize.py + # DeepMind def does not do this but I find I have to... ;\ + if self.training and self.data_initialized.item() == 0: + print('running kmeans!!') # data driven initialization for the embeddings + rp = torch.randperm(z_flattened.size(0)) + kd = kmeans2(z_flattened[rp[:20000]].data.cpu().numpy(), self.n_e, minit='points') + self.embedding.weight.data.copy_(torch.from_numpy(kd[0])) + self.data_initialized.fill_(1) + # TODO: this won't work in multi-GPU setups + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean((z_q - z.detach())**2) + else: + loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) + + # preserve gradients + z_q: torch.FloatTensor = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + if disable: + return z, loss, (perplexity, min_encodings, min_encoding_indices) + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + # return z, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.FloatTensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class RelationEncoder(nn.Module): + def __init__(self, d_model=128, num_layers=2): # , num_heads=4): + super().__init__() + self.d_model = d_model + self.proj = common_layers.build_mlps( + c_in=15, + mlp_channels=[d_model] * num_layers, + ret_before_act=True, + ) + + def forward(self, diff, mask, batch_dict): + B, T, N, _ = diff.shape + x = diff[mask] + x = self.proj(x) + x = unwrap(x, mask) + return x, mask + + +class RelationDecoder(nn.Module): + def __init__(self, d_model=128, num_layers=2): + super(RelationDecoder, self).__init__() + self.d_model = d_model + self.prediction_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model] * num_layers + [15], ret_before_act=True + ) + + def forward(self, latent, mask, batch_dict): + B, T, N, _ = latent.shape + x = latent[mask] + x = unwrap(self.prediction_head(x), mask) + return x + + +class DeltaVAE(pl.LightningModule): + def __init__(self, config): + super().__init__() + self.config = config + d_model = 512 + self.enc = RelationEncoder(num_layers=4, d_model=d_model) + self.dec = RelationDecoder(num_layers=4, d_model=d_model) + self.quantizer = VectorQuantizer(1024, d_model) + self.save_hyperparameters() + + def forward(self, batch_dict): + with torch.no_grad(): + data, mask = compute_3d_translation( + batch_dict, num_skipped_steps=5, offset=np.random.randint(0, 5) + ) # , num_samples=None) + data = data.flatten(-2, -1) + latent, mask = self.enc(data, mask, batch_dict) + z, quant_loss, (perplexity, min_encodings, min_encoding_indices) = self.quantizer(latent, disable=False) + # emask = get_mask(mask) + # count = emask.sum(-1, keepdims=True) + # count = torch.masked_fill(count, count == 0, 1) + # target = (data * emask[..., None]).sum(-2) / count + return { + "output": self.dec(z, mask=mask, batch_dict=batch_dict), + "target": data, + # "rel_matrix": data, + "quant_loss": quant_loss, + # "dist": posterior, + "data": batch_dict, + "valid_mask": mask, + "quant_idxs": min_encoding_indices, + } + + def get_loss(self, data_dict): + output_logit = data_dict["output"] + + # target_action = data_dict["target"] + target_action = data_dict["target"] + mask = data_dict["valid_mask"] # (B, N) + + # Masking + output_logit = output_logit[mask] + target_action = target_action[mask] + + mse = nn.functional.mse_loss(input=output_logit, target=target_action) + loss = (mse * 1 + data_dict["quant_loss"] * 10) + + # output_logit_scaled = output_logit.clone() + # target_action_scaled = target_action.clone() + # + # recon_rel_matrix = pairwise_relative_diff(data_dict["output"]) + # rel_matrix = data_dict["rel_matrix"] + # emask = get_mask(mask) + # # recon_loss1 = nn.functional.l1_loss(input=recon_rel_matrix[emask], target=rel_matrix[emask]) + # recon_loss2 = nn.functional.l1_loss(input=-recon_rel_matrix[emask], target=rel_matrix[emask]) + + # debugging: cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + # encodings = F.one_hot(data_dict["quant_idxs"][data_dict["valid_mask"].flatten()], self.reltok.quantizer.n_e).float().reshape(-1, self.reltok.quantizer.n_e) + # flat_mask = get_mask(data_dict["valid_mask"]).flatten() + flat_mask = mask.flatten() + encodings = F.one_hot(data_dict["quant_idxs"][flat_mask], + self.quantizer.n_e).float().reshape(-1, self.quantizer.n_e) + + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + # self.log('val_perplexity', perplexity, prog_bar=True) + # self.log('val_cluster_use', cluster_use, prog_bar=True) + + # scaled_mse = nn.functional.mse_loss(input=output_logit_scaled, target=target_action_scaled) + # scaled_norm = (output_logit_scaled[..., :1] - target_action_scaled[..., :1]).norm(dim=-1).mean() + + loss_stat = { + # "recon/loss1": recon_loss1, + # "recon/loss2": recon_loss2, + "loss/total_loss": loss, + "loss/mse": mse, + "mse": mse, + "perplexity": perplexity, + "cluster_use": cluster_use, + # "scaled_mse": scaled_mse, + # "scaled_norm": scaled_norm, + "loss/quant_loss": data_dict["quant_loss"], # ["codebook_loss"], + # "loss/commitment_loss": data_dict["quant_loss"]["commitment_loss"], + "output/output_mean": output_logit.mean(), + "output/output_max": output_logit.max(), + "output/output_min": output_logit.min(), + "output/target_mean": target_action.mean(), + "output/target_max": target_action.max(), + "output/target_min": target_action.min(), + "quant/quant_idxs_mean": data_dict["quant_idxs"][flat_mask].float().mean(), + "quant/quant_idxs_max": data_dict["quant_idxs"][flat_mask].float().max(), + "quant/quant_idxs_min": data_dict["quant_idxs"][flat_mask].float().min(), + } + try: + loss_stat["lr"] = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] + except RuntimeError: + # When debugging, the model might not be attached to a trainer. + pass + return loss, loss_stat + + def training_step(self, data_dict, batch_idx): + data_dict = self(data_dict) + loss, loss_stat = self.get_loss(data_dict) + self.log_dict( + {f"train/{k}": float(v) + for k, v in loss_stat.items()}, + batch_size=data_dict["data"]["encoder/agent_feature"].shape[0], + # on_epoch=True, + prog_bar=True, + ) + self.log('monitoring_step', float(self.global_step)) + return loss + + def configure_optimizers(self): + """Required by Lightning.""" + opt_cfg = self.config.OPTIMIZATION + optimizer = torch.optim.AdamW( + self.parameters(), + lr=opt_cfg.get("LR"), + weight_decay=opt_cfg.get('WEIGHT_DECAY', 0), + betas=(0.9, 0.95), + eps=1e-5 + ) + scheduler = lr_schedule.get_cosine_schedule_with_warmup( + optimizer=optimizer, + # num_warmup_steps=opt_cfg.WARMUP_STEPS, + # num_training_steps=opt_cfg.TRAINING_STEPS, + num_warmup_steps=200, # TODO + num_training_steps=opt_cfg.TRAINING_STEPS, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step" + }, + } + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='arg parser') + + # Experiment + parser.add_argument( + '--cfg_file', + type=str, + default="cfgs/motion_debug.yaml", + help='The config file path, relative to the repo root.' + ) + parser.add_argument('--exp_name', type=str, default='train_reltok', help='Experiment name.') + parser.add_argument('--ckpt', type=str, default=None, help='Path to pretrained checkpoint.') + parser.add_argument('--log_dir', type=str, default=None, help='Path to store all logs/ckpts/files.') + parser.add_argument('--debug', action='store_true', default=False, help='Whether to quickly set debug config.') + parser.add_argument('--eval', action='store_true', default=False, help='Whether to evaluate the model.') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--wandb', action='store_true', default=False, help='Whether to use wandb logging.') + + # Training + parser.add_argument('--batch_size', type=int, default=20, required=False, help='Batch size for training.') + parser.add_argument( + '--prefetch_factor', type=int, default=2, required=False, help='Datamodule prefetch factor for training.' + ) + parser.add_argument( + '--limit_train_batches', + type=int, + default=-1, + required=False, + help='Number of validation steps in each iteration.' + ) + parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for dataloader.') + parser.add_argument('--epochs', type=int, default=None, required=False, help='Number of epochs for training.') + + # Validation + parser.add_argument('--val_batch_size', type=int, default=6, required=False, help='Batch size for validation.') + parser.add_argument( + '--val_num_workers', type=int, default=4, help='Number of workers for dataloader in validation.' + ) + parser.add_argument( + '--num_sanity_val_steps', + type=int, + default=20, + required=False, + help='Number of validation steps before first training epoch.' + ) + parser.add_argument( + '--limit_val_batches', + type=int, + default=-1, + required=False, + help='Number of validation steps in each iteration. Default to whole validation dataset.' + ) + + args = parser.parse_args() + + pl.seed_everything(args.seed) + print("Everything is seeded to: ", args.seed) + + # Set up config + cfg_file = REPO_ROOT / args.cfg_file + config = cfg_from_yaml_file(cfg_file, global_config) + exp_name = args.exp_name + max_epochs = args.epochs #or config.OPTIMIZATION.NUM_EPOCHS + batch_size = args.batch_size + val_batch_size = args.val_batch_size + num_workers = args.num_workers + val_num_workers = args.val_num_workers + log_dir = args.log_dir or None + if log_dir is not None: + log_dir = pathlib.Path(log_dir) + + # Setup wandb logger + trial_id = get_time_str() + name = "{}_{}".format(exp_name, trial_id) + if log_dir: + save_dir = log_dir / "lightning_logs" + else: + save_dir = os.path.join(REPO_ROOT, "lightning_logs") + if args.wandb and not args.eval: + with open(os.path.abspath(os.path.expanduser("~/wandb_api_key_file.txt")), "rt") as fp: + api_key = fp.readline().strip() + wandb.login(key=api_key) + logger = WandbLogger( + name=name, + save_dir=save_dir, + id=name, + project="scenestreamer", + log_model=True, + group=exp_name, + ) + else: + logger = TensorBoardLogger(save_dir=save_dir, name=exp_name) + + # Set up trainer arguments + callbacks = [ + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + monitor="monitoring_step", + every_n_epochs=1, + save_last=True, + auto_insert_metric_name=True, + mode="max", + save_top_k=-1, + save_on_train_epoch_end=True, + ), + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + train_time_interval=datetime.timedelta(minutes=30), + auto_insert_metric_name=True, + save_on_train_epoch_end=True, + every_n_train_steps=None, + every_n_epochs=None, + ) + ] + trainer_kwargs = dict( + num_sanity_val_steps=args.num_sanity_val_steps, + limit_val_batches=args.limit_val_batches if args.limit_val_batches > 0 else None, + limit_train_batches=args.limit_train_batches if args.limit_train_batches > 0 else None, + gradient_clip_val=config.OPTIMIZATION.GRAD_NORM_CLIP, + max_epochs=max_epochs, + callbacks=callbacks, + logger=logger, + accelerator="auto", + devices="auto", + log_every_n_steps=2, + # strategy='ddp_find_unused_parameters_true' + ) + if args.debug: + # from lightning.pytorch.profilers import PyTorchProfiler + # profiler = PyTorchProfiler(filename="profile") + trainer_kwargs.update( + num_sanity_val_steps=0, + # profiler=profiler, + detect_anomaly=True, + limit_val_batches=2, + limit_train_batches=2, + log_every_n_steps=1, + ) + num_workers = 0 + val_num_workers = 0 + datamodule = SceneStreamerDataModule( + config, + train_batch_size=batch_size, + train_num_workers=num_workers, + train_prefetch_factor=args.prefetch_factor, + val_batch_size=val_batch_size, + val_num_workers=val_num_workers, + val_prefetch_factor=args.prefetch_factor, + ) + if torch.cuda.device_count() > 1: + trainer_kwargs["strategy"] = 'ddp' + # trainer_kwargs["strategy"] = 'ddp_find_unused_parameters_true' + if log_dir: + trainer_kwargs["default_root_dir"] = log_dir + + # Set up trainer + trainer = pl.Trainer(**trainer_kwargs) + + # Set up model + ckpt_path = args.ckpt + if ckpt_path is not None: + ckpt_path = os.path.join(REPO_ROOT, ckpt_path) + assert os.path.isfile(ckpt_path), ckpt_path + assert ckpt_path.endswith(".ckpt"), ckpt_path + print("==============================") + print("Loading checkpoint: ", ckpt_path) + print("==============================") + + model = DeltaVAE(config=config) + + if args.eval: + trainer.validate(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + else: + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) diff --git a/scenestreamer/tokenization/diffusion_tokenizer.py b/scenestreamer/tokenization/diffusion_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..933eebcd8b26473339280d087f3f07d41aa5227b --- /dev/null +++ b/scenestreamer/tokenization/diffusion_tokenizer.py @@ -0,0 +1,530 @@ +import numpy as np +import torch + +from scenestreamer.tokenization.motion_tokenizers import DeltaDeltaTokenizer, BaseTokenizer, get_relative_velocity, STEPS_PER_SECOND +from scenestreamer.utils import utils +import json + +SPECIAL_INVALID = 0 +SPECIAL_VALID = 1 +SPECIAL_START = 2 +SPECIAL_MASKED = 3 + +cyc_std = [ + 0.19134897, 0.28934157, 0.52083063, 0.19307183, 0.4897204, 0.617495, 0.19566636, 0.6996683, 0.6919608, 0.21254513, + 0.93830705, 0.7470725, 0.23419037, 1.1697725, 0.80575556 +] +cyc_mean = [ + 0.0007314072, 0.38584468, -0.0005979259, 0.0015717002, 0.77193886, -0.0023218163, 0.002200475, 1.1587898, + -0.0010374307, 0.0025438545, 1.5457957, -0.0017486577, 0.002413424, 1.9331585, -0.0021125625 +] + +# Just put it here in case need it. It's output 20 values including Z axis. +# output_std +# [0.19134897, 0.28934157, 0.023779996, 0.52083063, 0.19307183, 0.4897204, 0.028927919, 0.617495, 0.19566636, 0.6996683, 0.034308836, 0.6919608, 0.21254513, 0.93830705, 0.03995999, 0.7470725, 0.23419037, 1.1697725, 0.046037905, 0.80575556] +# output_mean +# [0.0007314072, 0.38584468, 6.2437284e-05, -0.0005979259, 0.0015717002, 0.77193886, 0.00020317949, -0.0023218163, 0.002200475, 1.1587898, 0.00037277717, -0.0010374307, 0.0025438545, 1.5457957, 0.00043893827, -0.0017486577, 0.002413424, 1.9331585, 0.0006425823, -0.0021125625] + +# PED: +# output_std +ped_std = [ + 0.04235751, 0.074799694, 0.68056464, 0.045664273, 0.13535422, 0.7549149, 0.049251433, 0.19858274, 0.81499416, + 0.05508655, 0.26276276, 0.86660916, 0.06212374, 0.32706332, 0.9120234 +] +# output_mean +ped_mean = [ + 1.4271492e-05, 0.09508697, -0.00029620097, 8.967689e-05, 0.18981859, -0.00076825253, 0.00014018304, 0.2842135, + -0.000643687, 0.00017854733, 0.37860408, -0.0008555841, 0.00028553946, 0.4729763, -0.0006006232 +] +# PED 20: +# output_std +# [0.04235751, 0.074799694, 0.015911372, 0.68056464, 0.045664273, 0.13535422, 0.019296514, 0.7549149, 0.049251433, 0.19858274, 0.022285515, 0.81499416, 0.05508655, 0.26276276, 0.025021093, 0.86660916, 0.06212374, 0.32706332, 0.02777967, 0.9120234] +# output_mean +# [1.4271492e-05, 0.09508697, -2.9287117e-05, -0.00029620097, 8.967689e-05, 0.18981859, 1.3563565e-05, -0.00076825253, 0.00014018304, 0.2842135, 2.268245e-05, -0.000643687, 0.00017854733, 0.37860408, 5.332513e-05, -0.0008555841, 0.00028553946, 0.4729763, 8.546651e-05, -0.0006006232] + +# VEH: +# output_std +veh_std = [ + 0.044777874, 0.5883173, 0.48556346, 0.06600924, 1.1707116, 0.5484921, 0.081244536, 1.7253877, 0.59927756, 0.1037226, + 2.3370876, 0.6422892, 0.13490802, 2.777383, 0.6803117 +] +# output_mean +veh_mean = [ + -0.0006106305, 0.54284567, -0.00054906023, -0.0014649398, 1.0862343, -0.0009819986, -0.0025806348, 1.6136851, + -0.0012528845, -0.003927998, 2.174847, -0.0012167478, -0.0055386806, 2.6897037, -0.0012875189 +] + +# VEH 20: +# output_std +# [0.044777874, 0.5883173, 0.013346198, 0.48556346, 0.06600924, 1.1707116, 0.020044344, 0.5484921, 0.081244536, 1.7253877, 0.02681499, 0.59927756, 0.1037226, 2.3370876, 0.03373221, 0.6422892, 0.13490802, 2.777383, 0.040490452, 0.6803117] +# output_mean +# [-0.0006106305, 0.54284567, 0.00023764589, -0.00054906023, -0.0014649398, 1.0862343, 0.00047408746, -0.0009819986, -0.0025806348, 1.6136851, 0.0007028891, -0.0012528845, -0.003927998, 2.174847, 0.00094223575, -0.0012167478, -0.0055386806, 2.6897037, 0.0011787575, -0.0012875189] + + +class DiffusionTokenizer(DeltaDeltaTokenizer): + + # ped_std = torch.from_numpy(np.asarray(ped_std)).float().reshape(1, 1, 1, 15) + # ped_mean = torch.from_numpy(np.asarray(ped_mean)).float().reshape(1, 1, 1, 15) + # cyc_std = torch.from_numpy(np.asarray(cyc_std)).float().reshape(1, 1, 1, 15) + # cyc_mean = torch.from_numpy(np.asarray(cyc_mean)).float().reshape(1, 1, 1, 15) + # veh_std = torch.from_numpy(np.asarray(veh_std)).float().reshape(1, 1, 1, 15) + # veh_mean = torch.from_numpy(np.asarray(veh_mean)).float().reshape(1, 1, 1, 15) + + def __init__(self, config): + BaseTokenizer.__init__(self, config) + + # self.dt = (1 / STEPS_PER_SECOND) * 1 + + self.should_standardize = config.TOKENIZATION.SHOULD_STANDARDIZE + + with open(utils.REPO_ROOT / "scenestreamer" / "tokenization" / "motion_stats_FORMAL.json", "r") as f: + motion_stats = json.load(f) + self.veh_mean = torch.tensor(motion_stats["1"]["mean"]).reshape(1, 1, 1, 15) + self.veh_std = torch.tensor(motion_stats["1"]["std"]).reshape(1, 1, 1, 15) + self.ped_mean = torch.tensor(motion_stats["2"]["mean"]).reshape(1, 1, 1, 15) + self.ped_std = torch.tensor(motion_stats["2"]["std"]).reshape(1, 1, 1, 15) + self.cyc_mean = torch.tensor(motion_stats["3"]["mean"]).reshape(1, 1, 1, 15) + self.cyc_std = torch.tensor(motion_stats["3"]["std"]).reshape(1, 1, 1, 15) + + self.use_delta = False + self.use_delta_delta = True + + assert not (self.use_delta_delta is True and self.use_delta is True) + + def _get_stat(self, motion, agent_type): + B, T, N, D = motion.shape + if motion.device != self.ped_std.device: + self.ped_std = self.ped_std.to(motion.device) + self.ped_mean = self.ped_mean.to(motion.device) + self.cyc_std = self.cyc_std.to(motion.device) + self.cyc_mean = self.cyc_mean.to(motion.device) + self.veh_std = self.veh_std.to(motion.device) + self.veh_mean = self.veh_mean.to(motion.device) + + agent_type = agent_type.unsqueeze(-1) + + mean = torch.zeros_like(motion) + mean = torch.where(agent_type == 1, self.veh_mean.expand(B, T, N, -1), mean) + mean = torch.where(agent_type == 2, self.ped_mean.expand(B, T, N, -1), mean) + mean = torch.where(agent_type == 3, self.cyc_mean.expand(B, T, N, -1), mean) + + std = torch.ones_like(motion) + std = torch.where(agent_type == 1, self.veh_std.expand(B, T, N, -1), std) + std = torch.where(agent_type == 2, self.ped_std.expand(B, T, N, -1), std) + std = torch.where(agent_type == 3, self.cyc_std.expand(B, T, N, -1), std) + + # TODO: Just nullify the STD. + std = std.fill_(10) + + return mean, std + + def standardize(self, motion, agent_type, valid_mask): + if not self.should_standardize: + motion[~valid_mask] = 0 + return motion + + if self.use_delta or self.use_delta_delta: + # TODO: Reconsider this. + motion[~valid_mask] = 0 + return motion + + assert motion.ndim == 4 + B, T, N, D = motion.shape + if agent_type.ndim == 2: + agent_type = agent_type.unsqueeze(1).expand(-1, T, -1) + assert agent_type.ndim == 3 + assert D == 15 + assert valid_mask.shape == (B, T, N) + + # Do a hack here... Do not use mean-std normalization for heading to avoid wierd thing. + heading = motion.reshape(B, T, N, 5, 3)[:, :, :, :, -1] + normalized_heading = utils.wrap_to_pi(heading) / np.pi + + mean, std = self._get_stat(motion, agent_type) + motion = (motion - mean) / std + motion[~valid_mask] = 0 + + motion = motion.reshape(B, T, N, 5, 3) + motion[:, :, :, :, -1] = normalized_heading + motion = motion.reshape(B, T, N, D) + + motion[~valid_mask] = 0 + return motion + + def unstandardize(self, motion, agent_type, valid_mask): + if not self.should_standardize: + motion[~valid_mask] = 0 + return motion + + if self.use_delta or self.use_delta_delta: + # TODO: Reconsider this. + motion[~valid_mask] = 0 + return motion + + assert motion.ndim == 4 + B, T, N, D = motion.shape + if agent_type.ndim == 2: + agent_type = agent_type.unsqueeze(1).expand(-1, T, -1) + assert agent_type.ndim == 3 + assert D == 15 + assert valid_mask.shape == (B, T, N) + mean, std = self._get_stat(motion, agent_type) + + # Do a hack here... Do not use mean-std normalization for heading to avoid wierd thing. + heading = motion.reshape(B, T, N, 5, 3)[:, :, :, :, -1] + unnormalized_heading = utils.wrap_to_pi(heading * np.pi) + + motion = motion * std + mean + + motion = motion.reshape(B, T, N, 5, 3) + motion[:, :, :, :, -1] = unnormalized_heading + motion = motion.reshape(B, T, N, D) + + motion[~valid_mask] = 0 + return motion + + def tokenize(self, data_dict, **kwargs): + + agent_pos = data_dict["decoder/agent_position"].clone() + agent_heading = data_dict["decoder/agent_heading"].clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() + agent_velocity = data_dict["decoder/agent_velocity"].clone() + + # Do the slicing + B, T_full, N, _ = agent_pos.shape + # TODO: hardcoded + assert T_full == 91 + assert agent_pos.ndim == 4 + # Note: do [::5] slicing will keep 0, 5, 10 (the current step!), 15, ... + + # ===== Hole filling ===== + data_dict = self.hole_filling(data_dict) + + def unfold(t): + """This function transforms the tensor from (B, T, N, 2) to (B, T', N, 6, ..) where T' = T // 5""" + assert t.shape[1] == T_full + t = t.unfold(dimension=1, size=6, step=5) + if t.ndim == 5: + t = t.permute(0, 1, 4, 2, 3) + elif t.ndim == 4: + t = t.permute(0, 1, 3, 2) + else: + raise ValueError + return t + + agent_pos = unfold(agent_pos) + agent_heading = unfold(agent_heading) + agent_valid_mask = unfold(agent_valid_mask) + agent_velocity = unfold(agent_velocity) + + T_action = T_full // self.num_skipped_steps + assert T_action == agent_pos.shape[1] == 18 + + # input_valid_mask = agent_valid_mask[:, :, 0, :] # If the first step is valid, then is a valid input. + + start_valid_mask = agent_valid_mask[:, :1, 0].clone() + + target_valid_mask = agent_valid_mask.all(dim=2) # If all steps are valid, then is a valid target. + input_valid_mask = torch.cat([start_valid_mask, target_valid_mask], dim=1) + target_valid_mask = torch.cat([target_valid_mask, target_valid_mask.new_zeros(B, 1, N)], dim=1) + + pos = torch.cat([agent_pos[:, :, 0], agent_pos[:, -1:, -1]], dim=1) + heading = utils.wrap_to_pi(torch.cat([agent_heading[:, :, 0], agent_heading[:, -1:, -1]], dim=1)) + vel = torch.cat([agent_velocity[:, :, 0], agent_velocity[:, -1:, -1]], dim=1) + + relative_pos = agent_pos[:, :, :] - agent_pos[:, :, :1] + relative_heading = utils.wrap_to_pi(agent_heading[:, :, :] - agent_heading[:, :, :1]) + + def transform_to_relative_pos(x, h): + # first rotate the absolute coordinate to the ego vehicle's coordinate + assert x.ndim == 5 + assert h.ndim == 4 + + # If we consider X to be the heading direction, then the following code is OK. + # relative_pos = rotate(x=x[..., 0], y=x[..., 1], angle=-h, assert_shape=False) + + # However, because we need to do standardization, we need to follow strictly to the definition of + # coordinate system to align with the stats from the dataset. + local_y_wrt_global_x = h + local_x_wrt_global_x = local_y_wrt_global_x - np.pi / 2 + relative_pos = utils.rotate(x=x[..., 0], y=x[..., 1], angle=-local_x_wrt_global_x, assert_shape=False) + + return relative_pos + + each_step_heading = agent_heading[:, :, 0] + rotated_pos = transform_to_relative_pos(relative_pos, each_step_heading.unsqueeze(2).expand(-1, -1, 6, -1)) + rotated_pos = rotated_pos.permute((0, 1, 3, 2, 4)) + + relative_heading = relative_heading.permute((0, 1, 3, 2)).unsqueeze(-1) + + if self.use_delta: + # Compute the delta in the -2 dim. + relative_heading = relative_heading[..., 1:, :] - relative_heading[..., :-1, :] + rotated_pos = rotated_pos[..., 1:, :] - rotated_pos[..., :-1, :] + + elif self.use_delta_delta: + # print(111) + # + # # Should not use relative_heading and rotated_pos here. + # pred_vel_change_list = [] + # pred_heading_change_list = [] + # + old_v = agent_velocity.permute((0, 1, 3, 2, 4)) + old_h = agent_heading.permute((0, 1, 3, 2)) + old_p = agent_pos.permute((0, 1, 3, 2, 4)) + + dt = (1 / STEPS_PER_SECOND) + + # Compute velocity from position differences + # velocity = (old_p[..., 1:, :2] - old_p[..., :-1, :2]) / dt # Shape: (B, T, N, 5, 2) + # velocity = torch.cat([velocity, old_v[..., -1:, :2],], dim=-2) # Pad for shape consistency + velocity = old_v + old_v[~target_valid_mask[:, :-1]] = 0 + + # Compute speed magnitude + speed = torch.norm(velocity, dim=-1) # Shape: (B, T, N, 6) + + # Compute acceleration as the change in speed + acceleration = (speed[..., 1:] - speed[..., :-1]) / dt + acceleration[~target_valid_mask[:, :-1]] = 0 + + # Compute yaw rate as the change in heading + yaw_rate = utils.wrap_to_pi(old_h[..., 1:] - old_h[..., :-1]) / dt + yaw_rate[~target_valid_mask[:, :-1]] = 0 + + rotated_pos = acceleration + relative_heading = yaw_rate #.unsqueeze(-1) + + # for i in range(5): + # dpos0 = old_p[..., i+1, :2] - old_p[..., i, :2] + # v1 = dpos0 / self.dt + # v0 = old_v[..., i, :] + # dv0 = v1 - v0 + # h0 = old_h[..., i] + # pred_vel_change = utils.rotate(dv0[..., 0], dv0[..., 1], -h0) + # pred_heading_change = old_h[..., i+1] - h0 + # pred_vel_change_list.append(pred_vel_change) + # pred_heading_change_list.append(pred_heading_change) + # pred_vel_change = torch.stack(pred_vel_change_list, dim=3) + # pred_heading_change = torch.stack(pred_heading_change_list, dim=3) + + # Compute the delta in the -2 dim. + # relative_heading = relative_heading[..., 1:, :] - relative_heading[..., :-1, :] + # rotated_pos = rotated_pos[..., 1:, :] - rotated_pos[..., :-1, :] + + else: + relative_heading = relative_heading[..., 1:, :] + rotated_pos = rotated_pos[..., 1:, :] + + target_motion = torch.stack([rotated_pos, relative_heading], dim=-1) + target_motion = target_motion.reshape(B, T_action, N, -1) + + target_agent_motion = torch.cat( + [target_motion, target_motion.new_zeros(B, 1, N, target_motion.shape[-1])], dim=1 + ) + target_agent_motion = self.standardize( + target_agent_motion, agent_type=data_dict["decoder/agent_type"], valid_mask=target_valid_mask + ) + + # Special token (total 5 options): + # 0: invalid token + # 1: valid token + # 2: just started. + # 3: masked. + # 4: not used. + input_special_token = torch.full( + (B, T_action + 1, N), SPECIAL_INVALID, dtype=torch.int64, device=start_valid_mask.device + ) + + # Fill in MASKED for step between START and last VALID: + cumsum = input_valid_mask.cumsum(dim=1) + max_cumsum = cumsum.max(dim=1, keepdim=True).values + input_special_token[cumsum < max_cumsum] = SPECIAL_MASKED + + already_started = input_valid_mask[:, :1].clone() + input_special_token[input_valid_mask] = SPECIAL_VALID + input_special_token[:, :1][input_valid_mask[:, :1]] = SPECIAL_START + for step in range(1, T_action): + newly_added = (~already_started) & input_valid_mask[:, step:step + 1] + input_special_token[:, step:step + 1][newly_added] = SPECIAL_START + already_started = torch.logical_or(already_started, newly_added) + + # Allow the model to know the existence of masked agents. + # is_masked = input_special_token == SPECIAL_MASKED + # pos[is_masked] = 0 + # heading[is_masked] = 0 + # vel[is_masked] = 0 + + input_agent_motion = torch.cat( + [target_motion.new_zeros(B, 1, N, target_motion.shape[-1]), target_motion], dim=1 + ) + # input_agent_motion[is_masked] = 0 + input_valid_mask[input_special_token == SPECIAL_MASKED] = 1 + + data_dict["decoder/input_action"] = input_special_token + data_dict["decoder/modeled_agent_position"] = pos + data_dict["decoder/modeled_agent_heading"] = heading + data_dict["decoder/modeled_agent_velocity"] = vel + data_dict["decoder/modeled_agent_delta"] = get_relative_velocity(vel=vel, heading=heading) + + data_dict["decoder/target_agent_motion"] = target_agent_motion + data_dict["decoder/target_action_valid_mask"] = target_valid_mask + data_dict["decoder/input_agent_motion"] = input_agent_motion + data_dict["decoder/input_action_valid_mask"] = input_valid_mask + + # All input actions should be >0 + # This assertion won't hold because we introduce MASKED token. + # assert (input_special_token[~input_valid_mask] == 0).all() + + return data_dict, {} + + def _detokenize_a_step( + self, *, current_pos, current_heading, current_valid_mask, current_vel, action, agent_type, **kwargs + ): + B, _, N, _ = action.shape + + # Do the unstandardization here! + action = self.unstandardize(action, agent_type=agent_type, valid_mask=current_valid_mask) + + action = action.reshape(B, 1, N, 5, -1) + + # DEBUG + # action = action.fill_(0) + + if self.use_delta_delta: + + acceleration = action[..., 0] + yaw_rate = action[..., 1] + dt = (1 / STEPS_PER_SECOND) + + # Compute speed from acceleration (cumulative sum) + initial_speed = torch.norm(current_vel, dim=-1, keepdim=True) # Shape: (B, T, N, 1) + speed = torch.cat([initial_speed, acceleration * dt], dim=-1) # Add initial speed + speed = torch.cumsum(speed, dim=-1) # Integrate over time + speed = speed[..., 1:] + + # Compute heading from yaw rate (cumulative sum) + reconstructed_h = torch.cumsum(torch.cat([current_heading.unsqueeze(-1), yaw_rate * dt], dim=-1), dim=-1) + reconstructed_h = utils.wrap_to_pi(reconstructed_h) + reconstructed_h = reconstructed_h[..., 1:] + + # Compute velocity in the global frame + # local_y_wrt_global_x = reconstructed_h + # local_x_wrt_global_x = local_y_wrt_global_x - np.pi / 2 + + velocity_x = speed * torch.cos(reconstructed_h) + velocity_y = speed * torch.sin(reconstructed_h) + velocity = torch.stack([velocity_x, velocity_y], dim=-1) # Shape: (B, T, N, 6, 2) + + # Compute position from velocity (cumulative sum) + delta_p = velocity * dt # Displacement + reconstructed_p = torch.cumsum(torch.cat([current_pos.unsqueeze(-2), delta_p], dim=-2), dim=-2) + reconstructed_p = reconstructed_p[..., 1:, :] + reconstructed_velocity = velocity + + delta_pos = get_relative_velocity(vel=velocity[..., -1, :], heading=reconstructed_h[..., -1]) + + reconstructed_p[~current_valid_mask] = 0 + reconstructed_h[~current_valid_mask] = 0 + delta_pos[~current_valid_mask] = 0 + + AID = 1 + b = 0 + print( + "CUR POS: {}, CUR HEA: {}, POS: {}, HEAD: {}, Speed: {} ".format( + current_pos[b, 0, AID].cpu().numpy(), + current_heading[b, 0, AID], + reconstructed_p[b, 0, AID].cpu().numpy(), + reconstructed_h[b, 0, AID], + current_vel[b, 0, AID].norm(dim=-1), + # reconstructed_vel[0, 0, AID].norm(dim=-1).cpu().numpy(), + # reconstructed_vel.norm(dim=-1)[0, 0, AID], + # unrotated_delta_vel[0, 0, AID].cpu().numpy(), + # current_vel[0, 0, AID].norm(dim=-1) + ) + ) + + return dict( + pos=reconstructed_p[:, :, :, -1], + heading=reconstructed_h[:, :, :, -1], + reconstructed_pos=reconstructed_p, + reconstructed_heading=reconstructed_h, + delta_pos=delta_pos, + vel=reconstructed_velocity[:, :, :, -1], + ) + + if self.use_delta: + + agent_pos_change = action[..., 0:2] # .squeeze(1) + agent_heading_change = action[..., 2:3] # .squeeze(1) + + # Use cumsum on -2 dim to get the per-step delta. + agent_pos_change = agent_pos_change.cumsum(dim=-2) + agent_heading_change = agent_heading_change.cumsum(dim=-2) + + else: + + agent_pos_change = action[..., 0:2] # .squeeze(1) + agent_heading_change = action[..., 2:3] # .squeeze(1) + + # Since we use strict coordinate system, we need to rotate the agent_pos_change back to the global coordinate. + local_y_wrt_global_x = current_heading + local_x_wrt_global_x = local_y_wrt_global_x - np.pi / 2 + + rotated_agent_pos_change = utils.rotate( + x=agent_pos_change[..., 0], + y=agent_pos_change[..., 1], + angle=local_x_wrt_global_x.reshape(B, 1, N, 1).expand(-1, -1, -1, 5) + ) + reconstructed_pos = current_pos.reshape(B, 1, N, 1, 2) + rotated_agent_pos_change + reconstructed_heading = utils.wrap_to_pi(current_heading.reshape(B, 1, N, 1) + agent_heading_change.squeeze(-1)) + + delta_pos = get_relative_velocity( + vel=rotated_agent_pos_change[..., -1, :], heading=agent_heading_change[..., -1, 0] + ) + + # AID = 3 + # print( + # "CUR POS: {}, CUR HEA: {}, POS: {}, HEAD: {}, ".format( + # current_pos[-1, 0, AID].cpu().numpy(), + # current_heading[-1, 0, AID], + # reconstructed_pos[-1, 0, AID].cpu().numpy(), + # reconstructed_heading[-1, 0, AID], + # # reconstructed_vel[0, 0, AID].norm(dim=-1).cpu().numpy(), + # # reconstructed_vel.norm(dim=-1)[0, 0, AID], + # # unrotated_delta_vel[0, 0, AID].cpu().numpy(), + # # current_vel[0, 0, AID].norm(dim=-1) + # ) + # ) + + # TODO: Need to fix reconstructed_velocity + + return dict( + pos=reconstructed_pos[:, :, :, -1], + heading=reconstructed_heading[:, :, :, -1], + reconstructed_pos=reconstructed_pos, + reconstructed_heading=reconstructed_heading, + vel=reconstructed_velocity[:, :, :, -1], + delta_pos=delta_pos, + ) + + def detokenize(self, data_dict, **kwargs): + assert "decoder/reconstructed_position" in data_dict + assert "decoder/reconstructed_heading" in data_dict + assert "decoder/reconstructed_valid_mask" in data_dict + pos = data_dict["decoder/reconstructed_position"] + B, T, N, _ = pos.shape + data_dict["decoder/output_score"] = pos.new_zeros(size=(B, N)) + + # pred = pos[:, :91] + # gt = data_dict["decoder/agent_position"][..., :2] + # m = data_dict["decoder/reconstructed_valid_mask"][:, :91] + # m2 = data_dict["decoder/target_action_valid_mask"] + # # Expand m2: + # m2 = m2.unsqueeze(2).expand(-1, -1, 5, -1).reshape(B, -1, N)[:, :91] + # ade = torch.norm(pred - gt, dim=-1) * m + # # m3 = m2.all(dim=1, keepdim=True).expand(-1, 91, -1) + # ade[~m2] = 1000 + # adenp = ade.cpu().numpy() + + return data_dict diff --git a/scenestreamer/tokenization/fast_tokenizer.py b/scenestreamer/tokenization/fast_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..582800b3e2c0218a6009a3e0aa64c4e2dd9a15cb --- /dev/null +++ b/scenestreamer/tokenization/fast_tokenizer.py @@ -0,0 +1,1047 @@ +import json +import logging +from typing import ClassVar + +import numpy as np +import torch +from scipy.fft import dct +from scipy.fft import idct +from tokenizers import ByteLevelBPETokenizer +from tokenizers.trainers import BpeTrainer +from transformers import PreTrainedTokenizerFast +from transformers.processing_utils import ProcessorMixin + +from scenestreamer.tokenization.biycle_tokenizer import BicycleModelTokenizerFixed0124 +from scenestreamer.tokenization.motion_tokenizers import get_relative_velocity, \ + START_ACTION as MOTION_START_ACTION +from scenestreamer.utils import utils + + +class UniversalActionProcessor(ProcessorMixin): + """ + Copied from: https://huggingface.co/physical-intelligence/fast/blob/main/processing_action_tokenizer.py + """ + attributes: ClassVar[list[str]] = ["bpe_tokenizer"] + bpe_tokenizer_class: str = "AutoTokenizer" + + def __init__( + self, + bpe_tokenizer: PreTrainedTokenizerFast, + scale: float = 10, + vocab_size: int = 1024, + min_token: int = 0, + *, + action_dim: int | None = None, + time_horizon: int | None = None, + ): + self.scale = scale + self.vocab_size = vocab_size + self.min_token = min_token + assert min_token != 0 + # Action horizon and dimension needed during decoding. These can be specified + # in three ways (in order of priority): + # 1. passed in as kwargs to decode() + # 2. in the constructor + # 3. cached from the last time decode() was called + self.time_horizon = time_horizon + self.action_dim = action_dim + self.called_time_horizon = time_horizon + self.called_action_dim = action_dim + + super().__init__(bpe_tokenizer) + + def __call__(self, action_chunk: np.array) -> np.array: + assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]" + if action_chunk.ndim == 2: + action_chunk = action_chunk[None, ...] + + # Cache the time horizon and action dimension for decoding + self.called_time_horizon = action_chunk.shape[-2] + self.called_action_dim = action_chunk.shape[-1] + + dct_coeff = dct(action_chunk, axis=1, norm="ortho") + dct_coeff = np.around(dct_coeff * self.scale) + + # if dct_coeff.max() > 22: + # print("MAX dct_coeff", dct_coeff.max(), "MIN dct_coeff", dct_coeff.min()) + + tokens = [] + for elem in dct_coeff: + token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int))) + tokens.append(self.bpe_tokenizer(token_str)["input_ids"]) + return tokens + + def decode( + self, + tokens: list[list[int]], + *, + time_horizon: int | None = None, + action_dim: int | None = None, + ) -> np.array: + self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon + self.action_dim = action_dim or self.action_dim or self.called_action_dim + + # Cache the time horizon and action dimension for the next call + self.called_time_horizon = self.time_horizon + self.called_action_dim = self.action_dim + + assert ( + self.time_horizon is not None and self.action_dim is not None + ), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + + decoded_actions = [] + error_rate = [] + + for token in tokens: + + try: + decoded_tokens = self.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) + assert ( + decoded_dct_coeff.shape == ( + self.time_horizon, + self.action_dim, + ) + ), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + error_rate.append(0) + except Exception as e: + + # PZH NOTE: remove error message + # print(f"Error decoding tokens: {e}") + # print(f"Tokens: {token}") + error_rate.append(1) + decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) + decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho")) + assert len(error_rate) == len(decoded_actions) + return np.stack(decoded_actions), np.stack(error_rate) + + @classmethod + def fit( + cls, + action_data: list[np.array], + scale: float = 10, + vocab_size: int = 1024, + *, + time_horizon: int | None = None, + action_dim: int | None = None, + ) -> "UniversalActionProcessor": + # Run DCT over all inputs + dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data] + + # Quantize and find min token + max_token = int(np.around(np.concatenate(dct_tokens) * scale).max()) + min_token = int(np.around(np.concatenate(dct_tokens) * scale).min()) + min_vocab_size = max_token - min_token + + assert ( + min_vocab_size <= vocab_size + ), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}" + if min_vocab_size + 100 > vocab_size: + logging.warning( + f"Initial alphabet size {min_vocab_size} is almost as large as the vocab" + f"size {vocab_size}, consider increasing vocab size" + ) + + # Make token iterator for BPE training + def _token_iter(): + for tokens in dct_tokens: + rounded_tokens = np.around(tokens * scale) - min_token + rounded_tokens = rounded_tokens.astype(int) + string = "".join(map(chr, rounded_tokens)) + yield string + + # Train BPE tokenizer + bpe = ByteLevelBPETokenizer() + + # Set up the entire range of possible tokens as the initial alphabet + alphabet = [chr(i) for i in range(max_token - min_token + 1)] + trainer = BpeTrainer( + vocab_size=vocab_size, + min_frequency=2, + show_progress=True, + special_tokens=[], + initial_alphabet=alphabet, + max_token_length=10000, + ) + + # Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator() + # because it doesn't support custom alphabets) + bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) + + return cls( + PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False), + scale=scale, + vocab_size=vocab_size, + min_token=min_token, + time_horizon=time_horizon, + action_dim=action_dim, + ) + + +def normalize_actions(data, lower_percentile=1, upper_percentile=99, predefined_quantiles=None): + """ + Applies quantile normalization to each of the 15 features in the data. + The data is assumed to have shape (N, 5, 3), corresponding to 5 time steps and 3 action dimensions, + which yields a total of 15 features. The normalization maps the lower quantile to -1 and the upper + quantile to 1 for each feature. + + Parameters: + data (np.ndarray): Input array of shape (N, 5, 3). + lower_percentile (float): Lower percentile (default 1) used if quantiles are not predefined. + upper_percentile (float): Upper percentile (default 99) used if quantiles are not predefined. + predefined_quantiles (dict or None): If provided, must have keys 'q_lower' and 'q_upper', each a numpy + array of shape (15,) with quantile values for each feature. + + Returns: + tuple: A tuple containing: + - normalized_data (np.ndarray): Array of the same shape as data, with values in [-1, 1]. + - quantiles (dict): Dictionary with keys 'q_lower' and 'q_upper' used for normalization. + """ + normalized_data = np.empty_like(data) + + if predefined_quantiles is None: + q_lower_arr = np.empty(15) + q_upper_arr = np.empty(15) + # Reshape to (-1, 15) so that each column corresponds to one feature. + data_reshaped = data.reshape(-1, 15) + + for i in range(15): + values = data_reshaped[:, i] + q_lower = np.percentile(values, lower_percentile) + q_upper = np.percentile(values, upper_percentile) + q_lower_arr[i] = q_lower + q_upper_arr[i] = q_upper + + # Compute normalization for this feature. + if q_upper == q_lower: + normalized_feature = np.clip(data[..., i // 3, i % 3], -1, 1) + else: + scale = 2.0 / (q_upper - q_lower) + normalized_feature = (data[..., i // 3, i % 3] - q_lower) * scale - 1 + normalized_feature = np.clip(normalized_feature, -1, 1) + normalized_data[..., i // 3, i % 3] = normalized_feature + + quantiles = {'q_lower': q_lower_arr, 'q_upper': q_upper_arr} + else: + # Use the provided predefined quantiles. + q_lower_arr = predefined_quantiles['q_lower'] + q_upper_arr = predefined_quantiles['q_upper'] + + for i in range(15): + q_lower = q_lower_arr[i] + q_upper = q_upper_arr[i] + if q_upper == q_lower: + normalized_feature = np.clip(data[..., i // 3, i % 3], -1, 1) + else: + scale = 2.0 / (q_upper - q_lower) + normalized_feature = (data[..., i // 3, i % 3] - q_lower) * scale - 1 + normalized_feature = np.clip(normalized_feature, -1, 1) + normalized_data[..., i // 3, i % 3] = normalized_feature + + quantiles = predefined_quantiles + + return normalized_data, quantiles + + +def denormalize_actions(normalized_data, quantiles): + """ + Reverses the quantile normalization for each of the 15 features. + + Parameters: + normalized_data (np.ndarray): Normalized data array of shape (N, 5, 3) with values in [-1, 1]. + quantiles (dict): A dictionary with keys 'q_lower' and 'q_upper', each an array of shape (15,) + containing the quantile values for each feature. + + Returns: + np.ndarray: Denormalized data array of the same shape as normalized_data. + """ + denormalized_data = np.empty_like(normalized_data) + q_lower_arr = quantiles['q_lower'] + q_upper_arr = quantiles['q_upper'] + + for i in range(15): + q_lower = q_lower_arr[i] + q_upper = q_upper_arr[i] + if q_upper == q_lower: + denorm_feature = normalized_data[..., i // 3, i % 3] + else: + scale = 2.0 / (q_upper - q_lower) + denorm_feature = (normalized_data[..., i // 3, i % 3] + 1) / scale + q_lower + denormalized_data[..., i // 3, i % 3] = denorm_feature + + return denormalized_data + + +def get_norm_info(path): + with open(path, "r") as f: + norm_info = json.load(f) + norm_info = {k: np.asarray(v) for k, v in norm_info.items()} + return norm_info + + +class FastTokenizer(BicycleModelTokenizerFixed0124): + def __init__(self, config): + BicycleModelTokenizerFixed0124.__init__(self, config) + + from scenestreamer.utils import REPO_ROOT + # import numpy as np + + self.use_type_specific_bins = False + + # cyc_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True).from_pretrained( + # REPO_ROOT / "scenestreamer/tokenization/0305_fast_cyc_440000" + # ) + # cyc_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_cyc_440000/delta_normalization_quantiles.json") + # + # ped_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True).from_pretrained( + # REPO_ROOT / "scenestreamer/tokenization/0305_fast_ped_4000000" + # ) + # ped_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_ped_4000000/delta_normalization_quantiles.json") + # + # veh_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True).from_pretrained( + # REPO_ROOT / "scenestreamer/tokenization/0305_fast_veh_5000000" + # ) + # veh_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_veh_5000000/delta_normalization_quantiles.json") + + # self.fast_tokenizers = { + # "cyc": cyc_tokenizer, + # "ped": ped_tokenizer, + # "veh": veh_tokenizer + # } + + all_tokenizer = UniversalActionProcessor.from_pretrained( + REPO_ROOT / "scenestreamer/tokenization/0305_fast_all", time_horizon=5, action_dim=3 + ) + cyc_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_all/norm_info_cyc.json") + ped_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_all/norm_info_ped.json") + veh_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_all/norm_info_veh.json") + self.fast_tokenizers = {"cyc": all_tokenizer, "ped": all_tokenizer, "veh": all_tokenizer} + self.fast_tokenizer = all_tokenizer + + self.norm_infos = {"cyc": cyc_norm_info, "ped": ped_norm_info, "veh": veh_norm_info} + + # + # self.num_actions = len(all_trajs) + # + # self.all_trajs = torch.from_numpy(all_trajs).float() + # self.bin_centers = self.all_trajs[:, -1].reshape(1, self.num_actions, 1, 2) + # + # self.config = config + # self.all_heading = torch.from_numpy(all_head).float() + # + # self.default_action = 0 # We set action 0 to be all zeros. + # self.add_noise = config.TOKENIZATION.ADD_NOISE + + self.num_actions = 1024 + 3 + + # def get_motion_feature(self): + # # m = torch.from_numpy(self.bin_centers_flat) + # m = self.all_trajs[:, -1] # (1025, 2) + # dist = m.norm(p=2, dim=-1).unsqueeze(-1) + # heading = self.all_heading[:, -1] + # return torch.cat([m, dist, heading], dim=-1) + + def tokenize(self, data_dict, backward_prediction=False, **kwargs): + """ + + Args: + data_dict: Input data + + Returns: + Discretized action in an int array with shape (num time steps for actions, num agents). + """ + + if backward_prediction: + raise ValueError("FastTokenizer does not support backward prediction.") + return self._tokenize_backward_prediction(data_dict, **kwargs) + + # TODO: Hardcoded here... + assert self.config.GPT_STYLE + start_step = 0 + + # ===== Hole Filling ===== + data_dict = self.hole_filling(data_dict) + + # ===== Get initial data ===== + # If we don't clone here, the following hole-filling code will overwrite raw data. + agent_pos = data_dict["decoder/agent_position"] # .clone() + agent_heading = data_dict["decoder/agent_heading"] # .clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] # .clone() + agent_velocity = data_dict["decoder/agent_velocity"] # .clone() + agent_shape = data_dict["decoder/current_agent_shape"] # .clone() + agent_type = data_dict["decoder/agent_type"] # .clone() + B, T_full, N, _ = agent_pos.shape + # assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + # agent_pos_full = agent_pos.clone() + # agent_heading_full = agent_heading.clone() + # agent_velocity_full = agent_velocity.clone() + # agent_valid_mask_full = agent_valid_mask.clone() + + agent_pos_chunk = agent_pos.unfold(dimension=1, size=6, step=5).swapaxes(-1, -2) + agent_heading_chunk = agent_heading.unfold(dimension=1, size=6, step=5) + agent_velocity_chunk = agent_velocity.unfold(dimension=1, size=6, step=5).swapaxes(-1, -2) + + agent_valid_mask_chunk_full = agent_valid_mask.unfold(dimension=1, size=6, step=5) + agent_valid_mask_chunk = agent_valid_mask_chunk_full.all(dim=-1) + # This will hold: agent_pos_chunk[0, 1:, :, 0] == agent_pos[0, :-1, :, 5] + + agent_pos = agent_pos[:, ::self.num_skipped_steps] # [:, :-1] + agent_heading = agent_heading[:, ::self.num_skipped_steps] # [:, :-1] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] # [:, :-1] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] # [:, :-1] + + agent_valid_mask_chunk = torch.cat([agent_valid_mask_chunk, agent_valid_mask[:, -1:]], dim=1) + + # agent_valid_mask_chunk_all = agent_valid_mask_chunk.all(dim=-1) + # # Add final step + # agent_valid_mask_chunk_all = torch.cat([agent_valid_mask_chunk_all, agent_valid_mask_chunk_all.new_zeros((B, 1, N))], dim=1) + # agent_valid_mask_chunk = torch.logical_and(agent_valid_mask, agent_valid_mask_chunk_all) + # assert agent_valid_mask.shape == agent_valid_mask_chunk_all.shape + + # T_chunks = agent_pos.shape[1] + # assert T_chunks == 19 + T_chunks = agent_pos.shape[1] + + # ===== Build up some variables ===== + current_pos = agent_pos[:, start_step:start_step + 1, ..., :2] + current_heading = agent_heading[:, start_step:start_step + 1] + current_vel = agent_velocity[:, start_step:start_step + 1, ..., :2] + current_valid_mask = agent_valid_mask[:, start_step:start_step + 1] + + init_pos = current_pos.clone() + init_heading = current_heading.clone() + init_vel = current_vel.clone() + init_valid_mask = current_valid_mask.clone() + + assert self.config.DELTA_POS_IS_VELOCITY + init_delta = get_relative_velocity(current_vel, current_heading) + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + target_action = [] + target_action_valid_mask = [] + reconstruction_list = [] + relative_delta_pos_list = [] + pos = [] + heading = [] + vel = [] + + # ===== Loop to reconstruct the scenario ===== + tokenization_state = None + for next_step in range(start_step + 1, T_chunks): + res = self._tokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_vel=current_vel, + current_valid_mask=current_valid_mask, + next_pos=agent_pos[:, next_step:next_step + 1, ..., :2], # (B, 1, N, 2) + next_heading=agent_heading[:, next_step:next_step + 1], # (B, 1, N) + next_valid_mask=agent_valid_mask_chunk[:, next_step - 1:next_step], # (B, 1, N) + next_velocity=agent_velocity[:, next_step:next_step + 1, ..., :2], # (B, 1, N, 2) + bin_centers=bin_centers, + add_noise=False, + topk=self.config.TOKENIZATION.NOISE_TOPK, + agent_shape=agent_shape, + agent_type=agent_type, + dt=self.dt, + tokenization_state=tokenization_state, + agent_pos_full=agent_pos_chunk[:, next_step - 1], + agent_heading_full=agent_heading_chunk[:, next_step - 1], + agent_velocity_full=agent_velocity_chunk[:, next_step - 1], + # agent_valid_mask_full=agent_valid_mask_full[:, (next_step - 1) * + # self.num_skipped_steps:next_step * self.num_skipped_steps + + # 1], + ) + tokenization_state = res + + # best_action = res["action"] + recon_next_pos = res["pos"] + recon_next_heading = res["heading"] + recon_next_vel = res["vel"] + recon_next_valid_mask = res["mask"] + recon_next_delta_pos = res["delta_pos"] # The input delta for next step. + + # best_action = best_action.reshape(B, 1, N) + + # ===== Process the target action/valid mask ===== + target_action_valid_mask.append(recon_next_valid_mask.clone()) + target_action.append(res["action"]) + + # Some debug asserts + # assert (best_action[recon_next_valid_mask] >= 0).all() + # assert (best_action[~recon_next_valid_mask] == -1).all() + + # ===== Process the "current_xxx" for next step ===== + if self.config.GPT_STYLE: + assert self.config.TOKENIZATION.ALLOW_SKIP_STEP + if self.config.TOKENIZATION.ALLOW_SKIP_STEP: + # Use the next valid mask as the valid mask for next step. + # In contrast, if this flag is False, then we will use "next valid mask & if it's not removed" for next + # step. + next_valid_mask = agent_valid_mask[:, next_step:next_step + 1] + newly_added = torch.logical_and(~recon_next_valid_mask, next_valid_mask) + if newly_added.any(): + assert not (agent_pos[:, next_step:next_step + 1, ..., :2][newly_added] == 0.0).all(-1).any() + recon_next_pos[newly_added] = agent_pos[:, next_step:next_step + 1, ..., :2][newly_added] + recon_next_heading[newly_added] = agent_heading[:, next_step:next_step + 1][newly_added] + recon_next_vel[newly_added] = agent_velocity[:, next_step:next_step + 1, ..., :2][newly_added] + + recon_next_delta_pos[newly_added] = get_relative_velocity( + vel=agent_velocity[:, next_step:next_step + 1, ..., :2][newly_added], + heading=agent_heading[:, next_step:next_step + 1][newly_added], + ) + recon_next_valid_mask[newly_added] = next_valid_mask[newly_added] + + relative_delta_pos_list.append(recon_next_delta_pos) + current_vel = recon_next_vel + current_heading = recon_next_heading + current_pos = recon_next_pos + current_valid_mask = recon_next_valid_mask + pos.append(current_pos.clone()) + heading.append(current_heading.clone()) + vel.append(current_vel.clone()) + + # ===== Postprocess and prepare the "start action" ===== + # In GPT style, some agents will be added in the middle of the scene. + # So we need to find out when they are in and add a start action before that step. + # In non-GPT style, we only need to prepare the start action for the first step. + max_token_len = 0 + for step_tokens in target_action: + max_token_len = max(max_token_len, max([len(v) for v in step_tokens])) + target_actions = torch.full((B, T_chunks - 1, N, max_token_len), -1, dtype=torch.long) + assert B == 1 + for i, step_tokens in enumerate(target_action): + for j, tokens in enumerate(step_tokens): + target_actions[0, i, j, :len(tokens)] = torch.from_numpy(np.asarray(tokens)) + + target_action_valid_mask = torch.cat(target_action_valid_mask, dim=1) # (B, T_skipped, N) + relative_delta_pos_list = torch.cat(relative_delta_pos_list, dim=1) # (B, T_skipped, N) + pos = torch.cat(pos, dim=1) + heading = torch.cat(heading, dim=1) + vel = torch.cat(vel, dim=1) + + pos = torch.cat([init_pos, pos], dim=1) + heading = torch.cat([init_heading, heading], dim=1) + vel = torch.cat([init_vel, vel], dim=1) + relative_delta_pos_list = torch.cat([init_delta, relative_delta_pos_list], dim=1) + + # If not in back prediction, what will be: + # 1. The first tokens in input_actions? START_ACTION + # 2. The last tokens in input_actions? Just the tokens at t=18 (t=85~90) + # 3. The first tokens in target_actions? The tokens at t=0 (t=0~5) for GPT and t=2 otherwise. + # 4. The last tokens in target_actions? All -1 because there is no GT for t=19 (t=90~95) + assert self.config.GPT_STYLE + # Search for the first step that has newly added agents + assert start_step == 0 + already_tokenized = init_valid_mask.clone() + start_action = torch.full_like(target_actions[:, :1], -1) + start_action[init_valid_mask] = MOTION_START_ACTION + assert target_actions.shape[1] == T_chunks - 1 + input_action = torch.cat([start_action, target_actions], dim=1) + input_action_valid_mask = torch.cat([init_valid_mask, target_action_valid_mask], dim=1) + for next_step in range(start_step + 1, T_chunks): + next_valid_mask = agent_valid_mask[:, next_step:next_step + 1] + is_newly_added = torch.logical_and(~already_tokenized, next_valid_mask) + if is_newly_added.any(): + input_action[:, next_step:next_step + 1][is_newly_added] = MOTION_START_ACTION + input_action_valid_mask[:, next_step:next_step + 1][is_newly_added] = \ + next_valid_mask[is_newly_added] + already_tokenized = torch.logical_or(already_tokenized, is_newly_added) + + target_actions = torch.cat([target_actions, target_actions.new_full((B, 1, N, max_token_len), -1)], dim=1) + target_action_valid_mask = torch.cat( + [target_action_valid_mask, target_action_valid_mask.new_zeros((B, 1, N))], dim=1 + ) + data_dict["in_backward_prediction"] = False + assert (agent_valid_mask[:, start_step:] >= target_action_valid_mask).all() + assert (agent_valid_mask[:, start_step + 1:] >= target_action_valid_mask[:, :-1]).all() + assert (agent_valid_mask[:, start_step:] >= input_action_valid_mask).all() + + data_dict["decoder/target_action"] = target_actions + data_dict["decoder/target_action_valid_mask"] = target_action_valid_mask + data_dict["decoder/input_action"] = input_action + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + data_dict["decoder/modeled_agent_delta"] = relative_delta_pos_list + data_dict["decoder/modeled_agent_position"] = pos + data_dict["decoder/modeled_agent_heading"] = heading + data_dict["decoder/modeled_agent_velocity"] = vel + + # Debug: + # pos_diff = (pos - agent_pos[..., :2]).norm(dim=-1).numpy() + # heading_diff = utils.wrap_to_pi(heading - agent_heading).abs().numpy() + # vel_diff = (vel - agent_velocity[..., :2]).norm(dim=-1).numpy() + + # All input actions should be >0 + assert (input_action[input_action_valid_mask] >= 0).any(-1).all() + assert (target_actions[target_action_valid_mask] >= 0).any(-1).all() + assert (input_action[~input_action_valid_mask] == -1).all(-1).all() + assert (target_actions[~target_action_valid_mask] == -1).all(-1).all() + + return data_dict, {"reconstruction_list": reconstruction_list} + + def _tokenize_a_step( + self, *, current_pos, current_heading, current_valid_mask, current_vel, next_pos, next_heading, next_valid_mask, + add_noise, agent_shape, dt, agent_pos_full, agent_heading_full, agent_velocity_full, agent_type, next_velocity, + **kwargs + ): + if dt < 0: + raise ValueError("FastTokenizer does not support backward prediction.") + + B, _, N, _ = current_pos.shape + + # Change shape. Input: (B, N, 6, 3) + assert agent_pos_full.ndim == 4 + agent_pos_full = agent_pos_full[..., :2] # (B, N, 6, 2) + assert agent_heading_full.ndim == 3 + + valid_mask = torch.logical_and(current_valid_mask, next_valid_mask) + + # Rotate + + static_error = ( + ((agent_pos_full[:, :, 0] - current_pos.reshape(B, N, 2)).norm(dim=-1)) * current_valid_mask.reshape(B, N) + ).sum() / current_valid_mask.reshape(B, N).sum() + # print("Static error: ", static_error) + + agent_pos_full_rotated = agent_pos_full - current_pos.reshape(B, N, 1, 2) + agent_pos_full_rotated = utils.rotate( + agent_pos_full_rotated[..., 0], agent_pos_full_rotated[..., 1], + -current_heading.reshape(B, N, 1).expand(-1, -1, 6) + ) + agent_heading_full = utils.wrap_to_pi(agent_heading_full - current_heading.reshape(B, N, 1)) + + # Stack + chunk = torch.cat([agent_pos_full_rotated, agent_heading_full[..., None]], dim=-1) # (B, N, 6, 3) + + # Mask + chunk = chunk.masked_fill_(~valid_mask.reshape(B, N, 1, 1).expand(-1, -1, 6, 3), 0) + + # Compute delta + chunk_delta = chunk[:, :, 1:] - chunk[:, :, :-1] # (B, N, 5, 3) + + # Swap x and y + chunk_delta = chunk_delta[..., [1, 0, 2]] + + # Normalize + is_ped = agent_type == 2 + assert B == 1 + tokenized_chunk = [None] * N + if is_ped.any(): + chunk_ped, _ = normalize_actions(chunk_delta[is_ped].numpy(), predefined_quantiles=self.norm_infos["ped"]) + chunk_ped = torch.from_numpy(chunk_ped).float() + chunk_ped = self.fast_tokenizers["ped"](chunk_ped) + count = 0 + for i in range(N): + if is_ped[0, i]: + tokenized_chunk[i] = chunk_ped[count] + count += 1 + is_cyc = agent_type == 3 + if is_cyc.any(): + chunk_cyc, _ = normalize_actions(chunk_delta[is_cyc].numpy(), predefined_quantiles=self.norm_infos["cyc"]) + chunk_cyc = torch.from_numpy(chunk_cyc).float() + chunk_cyc = self.fast_tokenizers["cyc"](chunk_cyc) + count = 0 + for i in range(N): + if is_cyc[0, i]: + tokenized_chunk[i] = chunk_cyc[count] + count += 1 + + is_veh = ~(agent_type != 1) + chunk_veh, _ = normalize_actions(chunk_delta[is_veh].numpy(), predefined_quantiles=self.norm_infos["veh"]) + chunk_veh = torch.from_numpy(chunk_veh).float() + chunk_veh = self.fast_tokenizers["veh"](chunk_veh) + count = 0 + for i in range(N): + if is_veh[0, i]: + tokenized_chunk[i] = chunk_veh[count] + count += 1 + + detok = self._detokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=current_valid_mask, + action=tokenized_chunk, + agent_type=agent_type + ) + recon_pos = detok["pos"] + recon_heading = detok["heading"] + recon_vel = detok["vel"] + recon_delta = detok["delta_pos"] + + # error_rate = detok["error_rate_full"] + + # recon_pos = next_pos.clone() + # recon_heading = next_heading.clone() + # recon_vel = next_velocity.clone() + # recon_delta = get_relative_velocity(recon_vel, recon_heading) + + recon_pos[~valid_mask] = 0 + recon_heading[~valid_mask] = 0 + recon_vel[~valid_mask] = 0 + recon_delta[~valid_mask] = 0 + + for i, v in enumerate(valid_mask[0, 0]): + if not v: + tokenized_chunk[i] = [] + + # error = ((recon_pos[0,0,:] - next_pos[0,0,:]).norm(dim=-1) * valid_mask[0,0,:]) + # error_max = error.max().item() + # error_argmax = error.argmax().item() + # AID = error_argmax + # if error_max > 5: + # print("CUR {}, recon Pos {}, gt pos {}, error {}, valid mask {}".format( + # current_pos[0,0,AID], + # recon_pos[0,0,AID], + # next_pos[0,0,AID], + # (recon_pos[0,0,AID] - next_pos[0,0,AID]).norm(), + # valid_mask[0,0,AID] + # )) + + return dict( + action=tokenized_chunk, + pos=recon_pos, + heading=recon_heading, + vel=recon_vel, + mask=valid_mask, + delta_pos=recon_delta, + ) + + def detokenize( + self, + data_dict, + interpolation=True, + detokenizing_gt=False, + backward_prediction=False, + flip_wrong_heading=False, + autoregressive_start_step=2, + **kwargs, + ): # actions, current_pos, current_vel, current_heading): + # TODO: Hardcoded here... + assert self.config.GPT_STYLE + start_step = 0 + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"].clone() + agent_heading = data_dict["decoder/agent_heading"].clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() + agent_velocity = data_dict["decoder/agent_velocity"].clone() + agent_shape = data_dict["decoder/current_agent_shape"].clone() + agent_type = data_dict["decoder/agent_type"].clone() + if detokenizing_gt: + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + input_mask = data_dict["decoder/input_action_valid_mask"] + B, T_full, N, _ = agent_pos.shape + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps].clone() + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + # T_chunks = agent_pos.shape[1] + + # ===== Prepare some variables ===== + action = data_dict["decoder/output_action"] + T_actions = action.shape[1] + T_generated_chunks = T_actions + start_step + + current_pos = agent_pos[:, start_step:start_step + 1, ..., :2].clone() + current_heading = agent_heading[:, start_step:start_step + 1].clone() + current_vel = agent_velocity[:, start_step:start_step + 1, ..., :2].clone() + current_valid_mask = agent_valid_mask[:, start_step:start_step + 1].clone() + + if detokenizing_gt: + # Merge input mask with target mask + input_mask = input_mask & target_action_valid_mask + + reconstructed_pos_list = [current_pos.clone()] + reconstructed_heading_list = [current_heading.clone()] + reconstructed_vel_list = [current_vel.clone()] + + already_interpolated = False + reconstructed_pos_full_list = [current_pos.clone()] + reconstructed_heading_full_list = [current_heading.clone()] + reconstructed_vel_full_list = [current_vel.clone()] + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + kwargs["detokenization_state"] = None + + for curr_step in range(T_generated_chunks): + + # We assume that starting from start_step, the agent valid mask will not change. + action_step = curr_step - start_step + action_valid_mask_step = input_mask[:, action_step:action_step + 1] + + act = action[:, action_step:action_step + 1] + assert (act[action_valid_mask_step] != -1).any(-1).all() + res = self._detokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=action_valid_mask_step, + current_vel=current_vel, + action=act, + agent_shape=agent_shape, + agent_type=agent_type, + bin_centers=bin_centers, + dt=self.dt, + flip_wrong_heading=flip_wrong_heading, + **kwargs + ) + kwargs["detokenization_state"] = res + + next_pos, next_heading, next_vel = res["pos"], res["heading"], res["vel"] + assert "delta_pos" in res + next_pos = next_pos.reshape(B, 1, N, 2) + next_heading = next_heading.reshape(B, 1, N) + next_vel = next_vel.reshape(B, 1, N, 2) + next_valid_mask = current_valid_mask + + # ===== A special case: fill in the info for the agents added in next step ===== + # ===== Another special case: if you are detokenizing the raw tokenized data, you need to fill in + # the info for the agents added in the next step. ===== + if (curr_step < autoregressive_start_step) or (detokenizing_gt and curr_step < T_generated_chunks - 1): + # Fill in the initial states of newly added agents + action_valid_mask_next_step = input_mask[:, action_step + 1:action_step + 2] + newly_added = torch.logical_and(~action_valid_mask_step, action_valid_mask_next_step) + if newly_added.any(): + next_pos[newly_added] = agent_pos[:, curr_step + 1:curr_step + 2, ..., :2][newly_added] + next_heading[newly_added] = agent_heading[:, curr_step + 1:curr_step + 2][newly_added] + next_vel[newly_added] = agent_velocity[:, curr_step + 1:curr_step + 2, ..., :2][newly_added] + next_valid_mask[newly_added] = action_valid_mask_next_step[newly_added] + if "reconstructed_position" in res: + # If some agents are added in the next step, the "last step" in reconstructed chunk + # aka the 5-th step in the chunk should be replaced by the GT states. + assert (agent_pos[:, curr_step + 1:curr_step + 2, ..., :2][newly_added][..., 0] != 0).all() + res["reconstructed_position"][-1][newly_added] = agent_pos[:, curr_step + 1:curr_step + 2, + ..., :2][newly_added] + res["reconstructed_heading"][-1][newly_added] = agent_heading[:, curr_step + 1:curr_step + + 2][newly_added] + res["reconstructed_velocity"][-1][newly_added] = agent_velocity[:, curr_step + 1:curr_step + 2, + ..., :2][newly_added] + + if "reconstructed_position" in res: + already_interpolated = True + reconstructed_pos_full_list.extend(res["reconstructed_position"]) + reconstructed_heading_full_list.extend(res["reconstructed_heading"]) + reconstructed_vel_full_list.extend(res["reconstructed_velocity"]) + + current_pos = next_pos + current_heading = next_heading + current_vel = next_vel + current_valid_mask = next_valid_mask + + reconstructed_pos_list.append(current_pos.clone()) + reconstructed_heading_list.append(current_heading.clone()) + reconstructed_vel_list.append(current_vel.clone()) + + reconstructed_pos = torch.cat(reconstructed_pos_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_list, dim=1) + + # Every input token has it's own position (before the action). + # As we have 19 tokens, and the last one token will lead us to a new place, + # So it's totally 20 positions. + assert reconstructed_pos.shape[1] == T_generated_chunks + 1 + assert input_mask.shape[1] == T_generated_chunks - start_step + + # Interpolation + reconstructed_pos = torch.cat(reconstructed_pos_full_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_full_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_full_list, dim=1) + + input_mask_augmented = torch.cat([agent_valid_mask[:, :start_step], input_mask], dim=1) + assert input_mask_augmented.shape[1] == T_generated_chunks + valid = input_mask_augmented + valid = valid.reshape(B, -1, 1, N).expand(-1, -1, self.num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + + return data_dict + + def _detokenize_a_step(self, *, current_pos, current_heading, current_valid_mask, action, agent_type, **kwargs): + # assert action.ndim == 3 + # B, T_action, N = action.shape + + device = current_pos.device + + # assert T_action == 1 + if isinstance(action, list): + tokens = action + B, _, N, _ = current_pos.shape + + else: + B, _, N, _ = action.shape + + assert action.max() < self.fast_tokenizer.vocab_size + + assert action.ndim == 4 + assert action.shape[1] == 1 + action = action.squeeze(1) # (B, N, max_num_tokens) + + # Convert action to list + action = action.reshape(-1, action.shape[-1]) + tokens = [v[v != -1] for v in action] + # tokens = [(v if len(v) > 0 else tokens[0].new_full((1,), 0)) for v in tokens] + + # Process pedestrian + chunk_full = torch.zeros((B * N, 5, 3)).to(device) + error_rate_full = np.zeros((B * N,)) + is_ped = (agent_type == 2) & current_valid_mask.reshape(B, N) + is_ped = is_ped.reshape(-1) + if is_ped.any(): + chunk_ped, error_rate_ped = self.fast_tokenizers["ped"].decode([v for i, v in enumerate(tokens) if is_ped[i]]) + + if isinstance(is_ped, torch.Tensor): + error_rate_full[is_ped.cpu().numpy()] = error_rate_ped + else: + error_rate_full[is_ped] = error_rate_ped + + chunk_ped = denormalize_actions(chunk_ped, quantiles=self.norm_infos["ped"]) + chunk_ped = torch.from_numpy(chunk_ped).float().to(device) + count = 0 + for i in range(len(tokens)): + if is_ped[i]: + chunk_full[i] = chunk_ped[count] + count += 1 + + is_cyc = (agent_type == 3) & current_valid_mask.reshape(B, N) + is_cyc = is_cyc.reshape(-1) + if is_cyc.any(): + chunk_cyc, error_rate_cyc = self.fast_tokenizers["cyc"].decode([v for i, v in enumerate(tokens) if is_cyc[i]]) + + if isinstance(is_cyc, torch.Tensor): + error_rate_full[is_cyc.cpu().numpy()] = error_rate_cyc + else: + error_rate_full[is_cyc] = error_rate_cyc + + chunk_cyc = denormalize_actions(chunk_cyc, quantiles=self.norm_infos["cyc"]) + chunk_cyc = torch.from_numpy(chunk_cyc).float().to(device) + count = 0 + for i in range(len(tokens)): + if is_cyc[i]: + chunk_full[i] = chunk_cyc[count] + count += 1 + + is_vel = (agent_type == 1) & current_valid_mask.reshape(B, N) + is_vel = is_vel.reshape(-1) + if is_vel.any(): + chunk_veh, error_rate_veh = self.fast_tokenizers["veh"].decode([v for i, v in enumerate(tokens) if is_vel[i]]) + + if isinstance(is_vel, torch.Tensor): + error_rate_full[is_vel.cpu().numpy()] = error_rate_veh + else: + error_rate_full[is_vel] = error_rate_veh + + chunk_veh = denormalize_actions(chunk_veh, quantiles=self.norm_infos["veh"]) + chunk_veh = torch.from_numpy(chunk_veh).float().to(device) + count = 0 + for i in range(len(tokens)): + if is_vel[i]: + chunk_full[i] = chunk_veh[count] + count += 1 + + # Reshape back + chunk_full = chunk_full.reshape(B, N, 5, 3) + + # Cumsum + chunk_full = chunk_full.cumsum(dim=-2) + + # Swap x and y + chunk_full = chunk_full[..., [1, 0, 2]] + + # Rotate + chunk_pos = utils.rotate( + chunk_full[..., 0], chunk_full[..., 1], + current_heading.reshape(B, N, 1).expand(-1, -1, 5) + ) + chunk_head = utils.wrap_to_pi(chunk_full[..., 2] + current_heading.reshape(B, N, 1).expand(-1, -1, 5)) + + # Translation + chunk_pos = chunk_pos + current_pos.reshape(B, N, 1, 2).expand(-1, -1, 5, 2) + + # Mask + chunk_pos = chunk_pos.masked_fill_(~current_valid_mask.reshape(B, N, 1, 1).expand(-1, -1, 5, 2), 0) + chunk_head = chunk_head.masked_fill_(~current_valid_mask.reshape(B, N, 1).expand(-1, -1, 5), 0) + + # Get output + reconstructed_pos = chunk_pos[:, :, -1].reshape(B, 1, N, 2) + reconstructed_heading = chunk_head[:, :, -1].reshape(B, 1, N) + + reconstructed_vel = (reconstructed_pos - current_pos) / self.dt + + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + + chunk_pos_6steps = torch.cat([current_pos.reshape(B, N, 1, 2), chunk_pos], dim=-2) + recon_vel = (chunk_pos_6steps[:, :, 1:] - chunk_pos_6steps[:, :, :-1]) / (self.dt / self.num_skipped_steps) + + # AID = 0 + # print( + # "POS: {}, HEAD: {}, valid {}".format( + # reconstructed_pos[0, 0, AID].cpu().numpy(), + # reconstructed_heading[0, 0, AID], + # current_valid_mask[0, 0, AID] + # # reconstructed_vel[0, 0, AID].norm(dim=-1).cpu().numpy(), + # # reconstructed_vel.norm(dim=-1)[0, 0, AID], + # # unrotated_delta_vel[0, 0, AID].cpu().numpy(), + # # reconstructed_vel[0, 0, AID].norm(dim=-1) + # ) + # ) + + return dict( + pos=reconstructed_pos, + heading=reconstructed_heading, + vel=reconstructed_vel, + delta_pos=relative_delta_pos, + reconstructed_position=[chunk_pos[:, :, i].unsqueeze(1) for i in range(5)], # (B, N, 5, 2) + reconstructed_heading=[chunk_head[:, :, i].unsqueeze(1) for i in range(5)], # (B, N, 5) + reconstructed_velocity=[recon_vel[:, :, i].unsqueeze(1) for i in range(5)], # (B, N, 5, 2) + error_rate_full=error_rate_full, + ) + + +if __name__ == '__main__': + from scenestreamer.utils import REPO_ROOT + all_tokenizer = UniversalActionProcessor.from_pretrained( + REPO_ROOT / "scenestreamer/tokenization/0305_fast_all", time_horizon=5, action_dim=3 + ) + cyc_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_all/norm_info_cyc.json") + ped_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_all/norm_info_ped.json") + veh_norm_info = get_norm_info(REPO_ROOT / "scenestreamer/tokenization/0305_fast_all/norm_info_veh.json") + + import numpy as np + + bs = 11 + chunk_delta = np.zeros((bs, 5, 3)) + chunk_normalized, _ = normalize_actions(chunk_delta, predefined_quantiles=veh_norm_info) + chunk_tokenized = all_tokenizer(chunk_normalized) + print(chunk_tokenized) + + chunk_detokenized, detok_error_rate = all_tokenizer.decode(chunk_tokenized) + chunk_denormlized = denormalize_actions(chunk_detokenized, quantiles=veh_norm_info) + print(chunk_denormlized) + + error = np.sqrt(np.square(chunk_delta - chunk_denormlized).sum(axis=-1)).mean() + print(error) diff --git a/scenestreamer/tokenization/gen_tokenizers.py b/scenestreamer/tokenization/gen_tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d677270d863807da1c6248def17f8b1959069a --- /dev/null +++ b/scenestreamer/tokenization/gen_tokenizers.py @@ -0,0 +1,469 @@ +import copy +from dataclasses import dataclass +from typing import Union, Dict + +import numpy as np +import torch + + +@dataclass +class Tokens: + ids: Union[torch.Tensor, np.ndarray] + mask: Union[torch.Tensor, np.ndarray] + + # Say now we are creating the causal mask for the n-th object: cm[n-1]. What should be the correct value? + # In most cases the causal mask offset should be n. + # That is the n-th rows of the causal mask looks like: causal_mask[n-1] = [1,...,1, 0,...,0] where + # causal_mask[n-1][:n] = 1 + # This is the very classic causal mask in the decoder of the transformer. + # However, in some cases, we might have some objects that are updated at the same time. In this case, we + # don't want to create the dependency between the objects, aka the later object are dependent to the state + # of the object that is updated earlier. So we need to set the causal mask offset to be the length of the + # sequence. In this cause the casual_mask_offset[:n] = n, where n is the number of objects. + causal_mask_offset: Union[torch.Tensor, np.ndarray] + length: int + + @classmethod + def create(cls, ids, mask, causal_mask_offset, length=None, use_numpy=None, device=None): + if isinstance(ids, np.ndarray): + use_numpy = True + elif isinstance(ids, torch.Tensor): + use_numpy = False + elif use_numpy is None: + raise ValueError("use_numpy must be specified when the ids is not a numpy array or a torch tensor.") + + if use_numpy: + ids: np.ndarray = np.asarray(ids, dtype=int) + mask: np.ndarray = np.asarray(mask, dtype=bool) + causal_mask_offset: np.ndarray = np.asarray(causal_mask_offset, dtype=int) + else: + ids: torch.Tensor = torch.as_tensor(ids, device=device).int() + mask: torch.Tensor = torch.as_tensor(mask, device=device).bool() + causal_mask_offset: torch.Tensor = torch.as_tensor(causal_mask_offset, device=device).int() + if length is None: + assert ids.ndim == 1, ids.shape + length = len(ids) + causal_mask_offset[causal_mask_offset == -1] = length + assert causal_mask_offset.min() > 0, causal_mask_offset.min() + assert ids.shape == mask.shape == causal_mask_offset.shape, (ids.shape, mask.shape, causal_mask_offset.shape) + assert causal_mask_offset.max() <= length, (causal_mask_offset.max(), length) + + return cls(ids=ids, mask=mask, causal_mask_offset=causal_mask_offset, length=length) + + @classmethod + def concatenate(cls, group_list, axis=-1): + if isinstance(group_list[0].ids, torch.Tensor): + cat = torch.cat + else: + cat = np.concatenate + data = cat([t.ids for t in group_list], axis=axis) + mask = cat([t.mask for t in group_list], axis=axis) + + current_causal_mask_offset = 0 + causal_mask_offsets = [] + for t in group_list: + causal_mask_offsets.append(t.causal_mask_offset + current_causal_mask_offset) + current_causal_mask_offset += len(t) + causal_mask_offset = cat(causal_mask_offsets, axis=axis) + + length = sum(len(t) for t in group_list) + + return cls.create(data, mask, causal_mask_offset, length=length) + + def __len__(self): + return self.length + + def to_tensor(self, batch_size, device): + assert self.ids.ndim == 1 + if isinstance(self.ids, np.ndarray): + self.ids = torch.as_tensor(self.ids, device=device).unsqueeze(0) + self.mask = torch.as_tensor(self.mask, device=device).unsqueeze(0) + self.causal_mask_offset = torch.as_tensor(self.causal_mask_offset, device=device).unsqueeze(0) + return Tokens.create( + self.ids.repeat(batch_size, 1), + self.mask.repeat(batch_size, 1), + self.causal_mask_offset.repeat(batch_size, 1), + self.length, + use_numpy=False + ) + + def unbatch(self): + assert self.ids.ndim == 2 + return [ + Tokens.create(self.ids[i], self.mask[i], self.causal_mask_offset[i], self.length, use_numpy=False) + for i in range(len(self.ids)) + ] + + @staticmethod + def block_causal_mask_offset(number_of_objects, batch_size, device): + return torch.empty((batch_size, number_of_objects), dtype=int, device=device).fill_(number_of_objects) + + # def add_step(self, step: int): + # if isinstance(self.ids, torch.Tensor): + # self.step = torch.zeros_like(self.ids).fill_(step) + # else: + # self.step = step + # return self + + +def translate_id(ids, min, max, allow_invalid=False, reverse=False): + if reverse: + ids = ids - min + return ids + + assert isinstance(min, int) + assert isinstance(max, int) + + if isinstance(ids, int): + if allow_invalid: + if ids == -1: + return -1 + else: + pass + else: + ids = -min - 1 + ids = ids + min + assert ids < max + assert ids >= min + return ids + + else: + assert isinstance(ids, (np.ndarray, torch.Tensor)) + if allow_invalid: + ids = copy.deepcopy(ids) + ids[ids == -1] = -min - 1 + else: + ids[ids == -1] = max - min - 1 # Set it to the maximum value + ids = ids + min + if allow_invalid: + max_val = ids[ids != -1].max() + min_val = ids[ids != -1].min() + else: + max_val = ids.max() + min_val = ids.min() + assert max_val < max + assert min_val >= min + return ids + + +def in_range(ids, min, max): + if isinstance(ids, torch.Tensor): + ret1 = ids >= min + ret2 = ids < max + return torch.logical_and(ret1, ret2) + elif isinstance(ids, np.ndarray): + ret1 = ids >= min + ret2 = ids < max + return np.logical_and(ret1, ret2) + else: + return min <= ids < max + + +class GenTokenizer: + NUM_OPERATIONS = 8 # TODO: Need to be changed to 8 for scenestreamer. + NUM_ACTIONS = 169 # TODO: from config + NUM_NOOP = 1 + + STEP_START = 0 + STEP_END = 1 + UPDATE_START = 2 + UPDATE_END = 3 + ADD_START = 4 + ADD_END = 5 + REMOVE_START = 6 + REMOVE_END = 7 + + @classmethod + def get_num_actions(cls, config): + return ( + cls.NUM_OPERATIONS + cls.NUM_ACTIONS + cls.NUM_NOOP + config.PREPROCESSING.MAX_MAP_FEATURES + + config.PREPROCESSING.MAX_AGENTS + ) + + @classmethod + def get_agent_id_range(cls, config): + return ( + cls.NUM_OPERATIONS + config.PREPROCESSING.MAX_MAP_FEATURES, + cls.NUM_OPERATIONS + config.PREPROCESSING.MAX_MAP_FEATURES + config.PREPROCESSING.MAX_AGENTS + ) + + @classmethod + def get_action_id_range(cls, config): + return ( + cls.NUM_OPERATIONS + config.PREPROCESSING.MAX_MAP_FEATURES + config.PREPROCESSING.MAX_AGENTS, + cls.NUM_OPERATIONS + config.PREPROCESSING.MAX_MAP_FEATURES + config.PREPROCESSING.MAX_AGENTS + + cls.NUM_ACTIONS + cls.NUM_NOOP + ) + + @classmethod + def get_map_id_range(cls, config): + return (cls.NUM_OPERATIONS, cls.NUM_OPERATIONS + config.PREPROCESSING.MAX_MAP_FEATURES) + + @classmethod + def get_agent_id(cls, agent_id, config, allow_invalid=True, reverse=False): + return translate_id(agent_id, *cls.get_agent_id_range(config), allow_invalid=allow_invalid, reverse=reverse) + + @classmethod + def get_action_id(cls, action_id, config, allow_invalid=False, reverse=False): + # action_id = cls.add_invalid(action_id, allow_invalid) + return translate_id(action_id, *cls.get_action_id_range(config), allow_invalid=allow_invalid, reverse=reverse) + + @classmethod + def get_map_id(cls, map_id, config, allow_invalid=False, reverse=False): + # action_id = cls.add_invalid(action_id, allow_invalid) + return translate_id(map_id, *cls.get_map_id_range(config), allow_invalid=allow_invalid, reverse=reverse) + + @classmethod + def is_action_tokens(cls, tokens, config): + return in_range(tokens, *cls.get_action_id_range(config)) + + @classmethod + def is_agent_tokens(cls, tokens, config): + return in_range(tokens, *cls.get_agent_id_range(config)) + + @classmethod + def get_step_start_tokens(cls): + """ + [UPDATE_START, ] + """ + return Tokens.create([cls.STEP_START], [True], [1], use_numpy=True) + + @classmethod + def get_step_end_tokens(cls): + """ + [UPDATE_START, ] + """ + return Tokens.create([cls.STEP_END], [True], [1], use_numpy=True) + + @classmethod + def get_update_start_tokens(cls): + """ + [UPDATE_START, ] + """ + return Tokens.create([cls.UPDATE_START], [True], [1], use_numpy=True) + + @classmethod + def get_update_pre_tokens(cls, agent_id, valid_mask, config): + """ + Generate ids for: [(agent_id x N)] + """ + return Tokens.create(cls.get_agent_id(agent_id, config), valid_mask, [len(agent_id)] * len(agent_id)) + + @classmethod + def get_update_end_tokens(cls, use_numpy=True, device=None): + """ + Generate ids for: [UPDATE_END, ] + """ + return Tokens.create([cls.UPDATE_END], [True], [1], use_numpy=use_numpy, device=device) + + @classmethod + def get_update_operation(cls, agent_id, action_id, valid_mask, config): + """ + Compared to the GenTokenizer, we remove those invalid agents here to save some tokens. + """ + assert agent_id.shape == action_id.shape == valid_mask.shape + + valid_action_id = action_id[valid_mask] + valid_agent_id = agent_id[valid_mask] + N = len(valid_action_id) + + action_tokens = Tokens.create(cls.get_action_id(valid_action_id, config), [True] * N, [N] * N) + tokens = Tokens.concatenate( + [ + cls.get_update_start_tokens(), + cls.get_update_pre_tokens(valid_agent_id, [True] * N, config), action_tokens, + cls.get_update_end_tokens() + ] + ) + + # Update operation is more "grounded" compared to ADD and REMOVE. + # because you don't need to predict whether you have finished your operation. + # when you see "UPDATE_END", you should determine whether the next is "REMOVE_START" or "STEP_END". + should_predict = [False] + [True] * N + [False] * N + [True] + is_gt = [True] + [False] * N + [True] * N + [False] + return tokens, should_predict, is_gt + + @classmethod + def get_token_names(cls, tokens: Tokens, config: Dict): + def _get_names_for_a_sequence(tokens): + out = [] + if isinstance(tokens, Tokens): + data = tokens.ids + mask = tokens.mask + else: + data = tokens + mask = None + for i in range(len(tokens)): + if mask is not None and bool(mask[i]) is False: + out.append("INVALID") + continue + if data[i] == cls.STEP_START: + out.append("STEP_START") + elif data[i] == cls.STEP_END: + out.append("STEP_END") + elif data[i] == cls.UPDATE_START: + out.append("UPDATE_START") + elif data[i] == cls.UPDATE_END: + out.append("UPDATE_END") + # elif ids.ids[i] == cls.REMOVE_OBJECT: + # out.append("REMOVE_OBJECT") + elif cls.is_agent_tokens(data[i], config): + out.append("AGENT_{}".format(cls.get_agent_id(data[i], config, reverse=True))) + elif cls.is_action_tokens(data[i], config): + out.append("ACTION_{}".format(cls.get_action_id(data[i], config, reverse=True))) + elif data[i] == -1: + out.append("INVALID") + else: + raise ValueError("Invalid token id: {}".format(data[i])) + return out + + if isinstance(tokens, Tokens): + data = tokens.ids + else: + data = tokens + + if data.ndim == 1: + return _get_names_for_a_sequence(tokens) + + elif data.ndim == 2: + if isinstance(tokens, Tokens): + tokens = tokens.unbatch() + else: + tokens = list(tokens) + return [_get_names_for_a_sequence(t) for t in tokens] + + else: + raise ValueError() + + +class SceneStreamerTokenizer(GenTokenizer): + @classmethod + def get_add_start_tokens(cls): + """ + [ADD_START, ] + """ + return Tokens.create([cls.ADD_START], [True], [1], use_numpy=True) + + @classmethod + def get_add_end_tokens(cls): + """ + [ADD_END, ] + """ + return Tokens.create([cls.ADD_END], [True], [1], use_numpy=True) + + @classmethod + def get_remove_start_tokens(cls): + """ + [REMOVE_START, ] + """ + return Tokens.create([cls.REMOVE_START], [True], [1], use_numpy=True) + + @classmethod + def get_remove_end_tokens(cls): + """ + [REMOVE_END, ] + """ + return Tokens.create([cls.REMOVE_END], [True], [1], use_numpy=True) + + @classmethod + def get_add_operation(cls, agent_id, valid_mask, last_valid_mask, config): + if last_valid_mask is None: + # All valid objects should be added + new_agent_id = agent_id[valid_mask] + num_new_agents = len(new_agent_id) + else: + # Only the newly added objects should be added + new_agent_id = agent_id[valid_mask & ~last_valid_mask] + num_new_agents = len(new_agent_id) + if num_new_agents == 0: + return None, None, None + agent_tokens = Tokens.create( + cls.get_agent_id(new_agent_id, config), [True] * num_new_agents, list(range(1, 1 + num_new_agents)) + ) + # TODO: Each agent token should follow a map token. + # TODO: They should be interleaved. + tokens = Tokens.concatenate([cls.get_add_start_tokens(), agent_tokens, cls.get_add_end_tokens()]) + + # when see ADD_START, you should start making prediction. You should also predict whether ADD_END. + # when you saw ADD_END, you must predict UPDATE_START. + should_predict = [True] + [True] * num_new_agents + [True] + is_gt = [True] + [True] * num_new_agents + [True] + + return tokens, should_predict, is_gt + + @classmethod + def get_remove_operation(cls, agent_id, valid_mask, last_valid_mask, next_valid_mask, config): + if last_valid_mask is None: + return None, None, None + removed_agent_id = agent_id[last_valid_mask & ~valid_mask] + num_removed_agents = len(removed_agent_id) + if num_removed_agents == 0: + return None, None, None + agent_tokens = Tokens.create( + cls.get_agent_id(removed_agent_id, config), [True] * num_removed_agents, + list(range(1, 1 + num_removed_agents)) + ) + tokens = Tokens.concatenate([cls.get_remove_start_tokens(), agent_tokens, cls.get_remove_end_tokens()]) + + # when you see REMOVE_START, you should start making prediction. You should also predict whether REMOVE_END. + should_predict = [True] + [True] * num_removed_agents + [True] + is_gt = [True] + [True] * num_removed_agents + [True] + return tokens, should_predict, is_gt + + @classmethod + def get_token_names(cls, tokens: Tokens, config: Dict): + def _get_names_for_a_sequence(tokens): + out = [] + if isinstance(tokens, Tokens): + data = tokens.ids + mask = tokens.mask + else: + data = tokens + mask = None + for i in range(len(tokens)): + if mask is not None and bool(mask[i]) is False: + out.append("INVALID") + continue + if data[i] == cls.STEP_START: + out.append("STEP_START") + elif data[i] == cls.STEP_END: + out.append("STEP_END") + elif data[i] == cls.UPDATE_START: + out.append("UPDATE_START") + elif data[i] == cls.UPDATE_END: + out.append("UPDATE_END") + elif data[i] == cls.ADD_START: + out.append("ADD_START") + elif data[i] == cls.ADD_END: + out.append("ADD_END") + elif data[i] == cls.REMOVE_START: + out.append("REMOVE_START") + elif data[i] == cls.REMOVE_END: + out.append("REMOVE_END") + elif cls.is_agent_tokens(data[i], config): + out.append("AGENT_{}".format(cls.get_agent_id(data[i], config, reverse=True))) + elif cls.is_action_tokens(data[i], config): + out.append("ACTION_{}".format(cls.get_action_id(data[i], config, reverse=True))) + elif data[i] == -1: + out.append("INVALID") + else: + raise ValueError() + return out + + if isinstance(tokens, Tokens): + data = tokens.ids + else: + data = tokens + + if data.ndim == 1: + return _get_names_for_a_sequence(tokens) + + elif data.ndim == 2: + if isinstance(tokens, Tokens): + tokens = tokens.unbatch() + else: + tokens = list(tokens) + return [_get_names_for_a_sequence(t) for t in tokens] + + else: + raise ValueError() diff --git a/scenestreamer/tokenization/motion_stats_FORMAL.json b/scenestreamer/tokenization/motion_stats_FORMAL.json new file mode 100644 index 0000000000000000000000000000000000000000..9c8900106a1eb3111856a0cf9b05efe6a8bf9772 --- /dev/null +++ b/scenestreamer/tokenization/motion_stats_FORMAL.json @@ -0,0 +1,164 @@ +{ + "1": { + "mean": [ + -0.0003308499018940778, + 0.2672094444930485, + 0.00015538629172154734, + -0.0008176972447234687, + 0.5350654866619728, + 0.0002942970587200157, + -0.0014659175129346465, + 0.803500919984849, + 0.00043831233628253605, + -0.0022543672358130377, + 1.0725385406831256, + 0.0005798511557248863, + -0.0032049514248220176, + 1.342167788047052, + 0.0007221756914018174 + ], + "std": [ + 0.003715834401211816, + 0.31477121581981743, + 0.0033817797439947223, + 0.006425054000426722, + 0.6296766588467924, + 0.004352423792047084, + 0.009527855025679127, + 0.9446285455329261, + 0.00541197444055908, + 0.01384426844527193, + 1.2597877930169958, + 0.00652777815004148, + 0.019039948524574007, + 1.575096035523675, + 0.00766339992282911 + ], + "variance": [ + 2860.5747236552397, + 20527242.7269445, + 2369.360978064942, + 8552.52118961064, + 82143988.19092312, + 3924.6681955513727, + 18807.49154980076, + 184868418.17942598, + 6068.0907666401445, + 39708.23818515886, + 328802708.45969415, + 8828.183814404589, + 75105.56792230486, + 513990002.14526296, + 12167.001880504762 + ], + "count": 207176549 + }, + "2": { + "mean": [ + 6.605737401599643e-05, + 0.08890981750321883, + 0.00015065101744678278, + 0.0001637614541285082, + 0.17761178329257915, + 0.00038819900799972256, + 0.00022986280543272364, + 0.26613053592585556, + 0.0007149983459247902, + 0.0002911213860348422, + 0.3546160558429683, + 0.0008559739922100919, + 0.0003676181476921518, + 0.4430406499487639, + 0.0010568536534470813 + ], + "std": [ + 0.007753842123415143, + 0.04253014312279845, + 0.055171130822452986, + 0.008410925149344573, + 0.08440795293039188, + 0.06064542570631366, + 0.009410215854386303, + 0.12657437571524094, + 0.06454712487252261, + 0.011029414922225273, + 0.16875474250661104, + 0.06878773886146691, + 0.012661988088702432, + 0.2108811573370102, + 0.06904798191339484 + ], + "variance": [ + 876.1786122637436, + 26360.426218756573, + 44359.07800767968, + 1030.9705883216645, + 103830.62669577164, + 53598.771736656636, + 1290.4997081454626, + 233480.34573240427, + 60717.31301869645, + 1772.8161102050992, + 415021.61196300405, + 68957.38874154376, + 2336.4825448986726, + 648088.4659439061, + 69480.14551000056 + ], + "count": 14573328 + }, + "3": { + "mean": [ + 0.0008559199544655345, + 0.3795221111132833, + 0.0007421707918483278, + 0.001524082515419561, + 0.759621667019723, + 0.001620632893061715, + 0.0018925932942431786, + 1.140493156599199, + 0.0024769377138832947, + 0.0020451382862025943, + 1.5214457746175274, + 0.0028766065799815804, + 0.0017514382821425072, + 1.9027782396626325, + 0.0038912939244633276 + ], + "std": [ + 0.07886546180662052, + 0.2139357704520633, + 0.07789648124700384, + 0.08109374508633674, + 0.4120566371456612, + 0.08291366519748732, + 0.08550935940022134, + 0.6132930546984772, + 0.08829202360772906, + 0.08754665670801602, + 0.816212101984284, + 0.09140064058626646, + 0.0960183070540082, + 1.0180237476581013, + 0.09218218023538403 + ], + "variance": [ + 9336.091491182693, + 68700.23276726932, + 9108.08505867415, + 9871.11275291698, + 254862.08125077438, + 10319.14285360544, + 10975.358206569534, + 564582.6015327502, + 11701.306063367438, + 11504.573651660288, + 999994.1448150611, + 12539.778827322052, + 13838.833571631929, + 1555633.2443215435, + 12755.143493287247 + ], + "count": 1501037 + } +} \ No newline at end of file diff --git a/scenestreamer/tokenization/motion_tokenizers.py b/scenestreamer/tokenization/motion_tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d9b0d3a6cf2ef1bf2a124519f82679ebeddf20 --- /dev/null +++ b/scenestreamer/tokenization/motion_tokenizers.py @@ -0,0 +1,1871 @@ +""" +This file implements how we translate a trajectory to a sequence of discretized actions +and the reverse process. +""" +import logging +import pathlib + +import numpy as np +import torch +import torch.nn.functional as F + +from scenestreamer.utils import rotate, wrap_to_pi +from scenestreamer.utils import utils + +from scipy.interpolate import CubicSpline +FOLDER = pathlib.Path(__file__).resolve().parent + +logger = logging.getLogger(__file__) + +STEPS_PER_SECOND = 10 + +# TODO: +beta = 1.0 + +# Define a special action +START_ACTION = 1_000_000 +END_ACTION = 7_777_777 + + +def nucleus_sampling(logits, p=None, epsilon=1e-8): + # TODO: duplicate code. + p = p or 0.9 + + # logits = logits.clamp(-20, 20) + + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits).fill_(-1e9), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits).fill_(-1e9), logits) + + # Adding a small epsilon to logits to avoid log(0) + # logits = logits + epsilon + + # Convert logits to probabilities + probs = torch.softmax(logits, dim=-1) + + # Sort the probabilities to identify the top-p cutoff + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Remove tokens with cumulative probability above the threshold p + cutoff_index = cumulative_probs > p + # Shift the mask to the right to keep the first token above the threshold + cutoff_index[..., 1:] = cutoff_index[..., :-1].clone() + cutoff_index[..., 0] = False + + # Zero out the probabilities for tokens not in the top-p set + sorted_probs.masked_fill_(cutoff_index, 0) + + # Recover the original order of the probabilities + original_probs = torch.zeros_like(probs) + original_probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) + + # original_probs += epsilon + + # Sample from the adjusted probability distribution + # try: + sampled_token_index = torch.distributions.Categorical(probs=original_probs).sample() + # except ValueError: + # import ipdb; ipdb.set_trace() + # print(1111111) + + return sampled_token_index + + +def get_bin_centers(x_min, x_max, y_min, y_max, x_num_bins, y_num_bins): + # Create linearly spaced bins for x and y + x_bins = np.linspace(x_min, x_max, x_num_bins) + y_bins = np.linspace(y_min, y_max, y_num_bins) + + # Create a meshgrid of x and y bin coordinates + x_grid, y_grid = np.meshgrid(x_bins, y_bins, indexing='ij') + + # Stack the grid coordinates to create the 2D bins + xy_bins = np.stack((x_grid, y_grid), axis=-1).reshape(-1, 2) + assert xy_bins.shape == (x_num_bins * y_num_bins, 2) + + return xy_bins.astype(np.float32) + + +def infer_heading(*, current_pos, last_pos, last_heading, min_displacement=-777, flip_heading=False, **kwargs): + assert min_displacement != -777 + + if flip_heading: + current_pos, last_pos = last_pos, current_pos + + # Ensure the input shapes are correct + assert current_pos.shape == last_pos.shape + B, T, N, D = current_pos.shape + last_heading = last_heading.reshape(B, T, N) + # Calculate displacement + displacement = current_pos - last_pos + if isinstance(current_pos, np.ndarray): + heading = np.arctan2(displacement[..., 1], displacement[..., 0]) + else: + heading = torch.arctan2(displacement[..., 1], displacement[..., 0]) + + # if flip_heading: + # # Note that you can't flip heading after masking. Should do it before masking. + # heading = heading + np.pi + + # Apply the previous heading for static or minimally moving objects + if min_displacement is not None: + movement_mask = displacement.norm(dim=-1) >= min_displacement + heading[~movement_mask] = last_heading[~movement_mask] + heading = utils.wrap_to_pi(heading) + + # mask = utils.wrap_to_pi(heading - last_heading).abs() > np.deg2rad(90) + # heading[mask] = last_heading[mask] + + return heading + + +def rotate_bin_to_absolute_heading(bin_center, heading): + B, num_actions, N, _ = bin_center.shape + y_axis_in_relative_coord = heading + x_axis_in_relative_coord = y_axis_in_relative_coord - np.pi / 2 + abs_pos = rotate(bin_center[..., 0], bin_center[..., 1], x_axis_in_relative_coord.expand(B, num_actions, N)) + return abs_pos + + +def _reconstruct_delta_pos_from_abs_vel(vel, heading, dt): + # TODO: WHAT"S WRONG HERE??????????? + # TODO: WHAT"S WRONG HERE??????????? + # TODO: WHAT"S WRONG HERE??????????? + # TODO: WHAT"S WRONG HERE??????????? + vel = utils.rotate(vel[..., 0], vel[..., 1], angle=-heading) + pos = vel * dt + return pos + + +def get_relative_velocity(vel, heading): + # TODO: WHAT"S WRONG HERE??????????? + # TODO: WHAT"S WRONG HERE??????????? + # TODO: WHAT"S WRONG HERE??????????? + # TODO: WHAT"S WRONG HERE??????????? + return utils.rotate(vel[..., 0], vel[..., 1], angle=-heading) + + +def interpolate_reconstructed_valid_mask(input_valid_mask, fine_factor=5): + """ + It's quite tricky, for input mask: 0=T, 5=T, 10=T, 15=F + We need to interpolate it to: 0=T, ..., 10=T, 11=T, ..., 14=T, 15=T, 16=F, ... + This is because our model predicts future 5 steps so it's actually cover the last+1 macro step. + """ + valid = input_valid_mask + # Offset 1 step in the time dimension. + + B, T, N = input_valid_mask.shape + + valid = valid.reshape(B, -1, 1, N).expand(-1, -1, fine_factor, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_valid_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + def find_last_valid(mask): + B, T, N = mask.shape + indices = mask * torch.arange(T, device=mask.device).reshape(1, T, 1).expand(*mask.shape) + indices = indices.argmax(1, keepdims=True) + return indices + + last_valid = find_last_valid(input_valid_mask) + last_valid_plus_one = (last_valid + 1) * fine_factor + last_valid_plus_one = torch.minimum( + last_valid_plus_one, torch.tensor(reconstructed_valid_mask.shape[1] - 1, device=valid.device) + ) + + # Set last_valid to 1 in valid. + reconstructed_valid_mask.scatter_(1, last_valid_plus_one, 1) + reconstructed_valid_mask[~input_valid_mask.any(1, keepdims=True).expand_as(reconstructed_valid_mask)] = 0 + + return reconstructed_valid_mask + + +def interpolate_trajectory_spline(pos, heading, vel, mask, fine_factor=5): + """ + Interpolate pos, heading, vel from coarse steps (e.g., every 5 frames) to finer steps (every frame). + + Args: + pos: (B, T, N, 2) positions at macro steps. + heading: (B, T, N) headings at macro steps. + vel: (B, T, N, 2) velocities at macro steps. + original_times: array/list of shape (T,) with the original macro step times. For example, [0, 5, 10, 15, ...] + fine_factor: How many subdivisions per macro step. If original spacing is 5 frames, fine_factor=5 means + you'll get one sample per frame. + + Returns: + fine_times: (T_fine,) new time array with fine steps. + pos_fine: (B, T_fine, N, 2) + heading_fine: (B, T_fine, N) + vel_fine: (B, T_fine, N, 2) + """ + + # Convert to numpy + pos_np = pos.cpu().numpy() # (B, T, N, 2) + heading_np = heading.cpu().numpy() # (B, T, N) + vel_np = vel.cpu().numpy() # (B, T, N, 2) + mask_np = mask.cpu().numpy() # (B, T, N) + + B, T, N, _ = pos.shape + + # The original interval might be something like every 5 frames + # We know original_times: e.g. [0, 5, 10, ...] + # Let's construct new times with finer resolution + # last_time = T * fine_factor + dt_coarse = fine_factor + dt_fine = 1 + last_time = (T - 1) * fine_factor + T_fine = (T - 1) * fine_factor + 1 + fine_times = np.linspace(0, last_time, T_fine) + + original_times = np.linspace(0, last_time, T) + + # Flatten B and N into one dimension: BN = B*N + BN = B * N + + # Reorder axes so T is first for spline fitting: + # pos: (T, B, N, 2) -> (T, BN, 2) + # pos_reshaped = pos_np.transpose(1, 0, 2, 3).reshape(T, BN, 2) + # vel_reshaped = vel_np.transpose(1, 0, 2, 3).reshape(T, BN, 2) + + # heading: (B, T, N) -> (T, B, N) -> (T, BN) + # heading_reshaped = heading_np.transpose(1, 0, 2).reshape(T, BN) + + # Unwrap heading along time axis to avoid discontinuities + # np.unwrap operates along axis=0 (time), this works vectorized for all BN trajectories + # heading_unwrapped = np.unwrap(heading_reshaped, axis=0) + + reconstructed_valid_mask = interpolate_reconstructed_valid_mask(mask[:, :-1]) + reconstructed_valid_mask_np = reconstructed_valid_mask.cpu().numpy() + + pos_fine = np.zeros((B, T_fine, N, 2), dtype=pos_np.dtype) + heading_fine = np.zeros((B, T_fine, N), dtype=heading_np.dtype) + vel_fine = np.zeros((B, T_fine, N, 2), dtype=vel_np.dtype) + + # Loop over batch and agents + for b in range(B): + for n in range(N): + # Extract the series for this agent + pos_series = pos_np[b, :, n, :] # Shape: (T, 2) + heading_series = heading_np[b, :, n] # Shape: (T,) + vel_series = vel_np[b, :, n, :] # Shape: (T, 2) + m = reconstructed_valid_mask_np[b, :, n][::5] + + if not m.any(): + # If all positions are masked, skip this agent + continue + + if m.sum() == 1: + ind = m.nonzero()[0].item() + pos_fine[b, ind * fine_factor, n] = pos_series[ind] + heading_fine[b, ind * fine_factor, n] = heading_series[ind] + vel_fine[b, ind * fine_factor, n] = vel_series[ind] + continue + + # Unwrap heading to avoid discontinuities + unwrapped_heading = np.unwrap(heading_series) + + # Interpolate each dimension of position and velocity with a spline + for dim in range(2): + # Position + cs_pos = CubicSpline(original_times[m], pos_series[:, dim][m]) + pos_fine[b, :, n, dim] = cs_pos(fine_times) + + # Velocity + cs_vel = CubicSpline(original_times[m], vel_series[:, dim][m]) + vel_fine[b, :, n, dim] = cs_vel(fine_times) + + # Interpolate heading (after unwrap) + cs_heading = CubicSpline(original_times[m], unwrapped_heading[m], extrapolate=False) + heading_fine_unwrapped = cs_heading(fine_times) + # Wrap back to [-pi, pi] + heading_fine[b, :, n] = utils.wrap_to_pi(heading_fine_unwrapped) + + pos_fine[~reconstructed_valid_mask_np] = 0 + vel_fine[~reconstructed_valid_mask_np] = 0 + heading_fine[~reconstructed_valid_mask_np] = 0 + + # Convert back to PyTorch tensors + pos_fine = torch.from_numpy(pos_fine).to(pos.device) + heading_fine = torch.from_numpy(heading_fine).to(heading.device) + vel_fine = torch.from_numpy(vel_fine).to(vel.device) + # valid_mask = torch.from_numpy(reconstructed_valid_mask).to(mask.device) + + return {"position": pos_fine, "velocity": vel_fine, "heading": heading_fine, "valid_mask": reconstructed_valid_mask} + + +def interpolate(input_tensor, num_skipped_steps, remove_first_step=True): + """ + TODO: This is linear interpolation on position, which might be incorrect as we need to consider heading. + """ + is_4d = False + if input_tensor.ndim == 4: + is_4d = True + _, _, N, D = input_tensor.shape + tensor = input_tensor.flatten(2, 3) + else: + tensor = input_tensor + B, T_before_plus_1, _ = tensor.shape + T_before = T_before_plus_1 - 1 + tensor = tensor.permute(0, 2, 1) # Reshape tensor to put the time dimension last + T_after = num_skipped_steps * T_before + interpolated = F.interpolate(tensor, size=T_after + 1, mode="linear", align_corners=True) + if remove_first_step: + interpolated = interpolated[:, :, 1:] + else: + T_after = T_after + 1 + interpolated = interpolated.permute(0, 2, 1) # Reshape back + if is_4d: + interpolated = interpolated.reshape(B, T_after, N, D) + assert interpolated.shape[:2] == (B, T_after) + # assert interpolated[:, 4::5] == input_tensor[:, 1:] + return interpolated + + +def interpolate_heading(input_tensor, num_skipped_steps, remove_first_step=True): + is_4d = False + if input_tensor.ndim == 4: + is_4d = True + _, _, N, D = input_tensor.shape + tensor = input_tensor.flatten(2, 3) + else: + tensor = input_tensor + B, T_before_plus_1, _ = tensor.shape + T_before = T_before_plus_1 - 1 + tensor = tensor.permute(0, 2, 1) # Reshape tensor to put the time dimension last + T_after = num_skipped_steps * T_before + + # Circular interpolation for headings + headings_cos = torch.cos(tensor) + headings_sin = torch.sin(tensor) + headings_cos_interp = F.interpolate(headings_cos, size=T_after + 1, mode="linear", align_corners=True) + headings_sin_interp = F.interpolate(headings_sin, size=T_after + 1, mode="linear", align_corners=True) + + # Recompose interpolated headings + interpolated = torch.atan2(headings_sin_interp, headings_cos_interp) + + # + # interpolated = F.interpolate(tensor, size=T_after + 1, mode="linear", align_corners=True) + if remove_first_step: + interpolated = interpolated[:, :, 1:] + else: + T_after = T_after + 1 + interpolated = interpolated.permute(0, 2, 1) # Reshape back + if is_4d: + interpolated = interpolated.reshape(B, T_after, N, D) + assert interpolated.shape[:2] == (B, T_after) + # assert interpolated[:, 4::5] == input_tensor[:, 1:] + + interpolated = utils.wrap_to_pi(interpolated) + return interpolated + + +class BaseTokenizer: + get_relative_velocity = get_relative_velocity + + def __init__(self, config): + self.num_skipped_steps = config.TOKENIZATION.NUM_SKIPPED_STEPS + self.predict_all_agents = config.TRAINING.PREDICT_ALL_AGENTS + self.dt = (1 / STEPS_PER_SECOND) * self.num_skipped_steps + + def detokenize_numpy_array( + self, data_dict, interpolation=True, detokenizing_gt=False, backward_prediction=False, **kwargs + ): + with torch.no_grad(): + new_data_dict = self.detokenize( + self._numpy_to_tensor(data_dict), + interpolation=interpolation, + detokenizing_gt=detokenizing_gt, + backward_prediction=backward_prediction, + **kwargs + ) + data_dict = self._tensor_to_numpy(new_data_dict) + return data_dict + + def _numpy_to_tensor(self, data_dict): + # Translate to tensors + new_data_dict = {"in_evaluation": torch.from_numpy(np.array([data_dict["in_evaluation"]]))} + + for k, v in data_dict.items(): + if k.startswith("decoder/") and isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.number) or v.dtype == bool: + new_data_dict[k] = torch.from_numpy(v).unsqueeze(dim=0) + else: + pass + + # TODO: The device is default to CPU for now. Might be set from config. + return new_data_dict + + def _tensor_to_numpy(self, data_dict): + for k in data_dict: + if isinstance(data_dict[k], torch.Tensor): + d = data_dict[k].cpu().numpy() + elif isinstance(data_dict[k], bool): + d = [data_dict[k]] + else: + raise ValueError("Unknown type: {}".format(type(data_dict[k]))) + assert len(d) == 1 + data_dict[k] = d[0] + data_dict["in_evaluation"] = data_dict["in_evaluation"].all().item() + return data_dict + + def tokenize_numpy_array(self, data_dict, **kwargs): + with torch.no_grad(): + new_data_dict, stat = self.tokenize(self._numpy_to_tensor(data_dict), **kwargs) + data_dict = self._tensor_to_numpy(new_data_dict) + return data_dict, stat + + def tokenize(self, data_dict, **kwargs): + raise NotImplementedError + + def detokenize(self, data_dict, interpolation=True, detokenizing_gt=False, backward_prediction=False, **kwargs): + raise NotImplementedError + + def detokenize_step(self, *args, **kwargs): + ret = self._detokenize_a_step(*args, **kwargs) + # assert "delta_pos" in ret + return ret + + def get_motion_feature(self): + m = torch.from_numpy(self.bin_centers_flat) + dist = m.norm(p=2, dim=-1).unsqueeze(-1) + heading = torch.atan2(m[..., 1], m[..., 0]).unsqueeze(-1) + return torch.cat([m, dist, heading], dim=-1) + + def get_bin_centers(self, agent_type): + B, N = agent_type.shape + if self.bin_centers is not None: + if self.use_type_specific_bins: + bin_centers = self.bin_centers.to(agent_type.device).expand(B, self.num_actions, N, 3, 2) + agent_type = agent_type - 1 # Veh: 0, Ped: 1, Cyc: 2 + agent_type[agent_type < 0] = 0 + agent_type = agent_type.reshape(B, 1, N, 1, 1).expand(B, bin_centers.shape[1], N, 1, 2) + bin_centers = torch.gather(bin_centers, dim=-2, index=agent_type).squeeze(-2) + else: + bin_centers = self.bin_centers.to(agent_type.device).expand(B, self.num_actions, N, 2) + else: + bin_centers = None + return bin_centers + + def hole_filling(self, data_dict): + # ===== Get initial data ===== + # If we don't clone here, the following hole-filling code will overwrite raw data. + agent_pos = data_dict["decoder/agent_position"] + agent_heading = data_dict["decoder/agent_heading"] + agent_valid_mask = data_dict["decoder/agent_valid_mask"] + agent_velocity = data_dict["decoder/agent_velocity"] + B, T_full, N, _ = agent_pos.shape + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + + # ===== Hole filling ===== + for i in range(T_chunks): + current_pos = agent_pos[:, i:i + 1] + current_heading = agent_heading[:, i:i + 1] + current_vel = agent_velocity[:, i:i + 1] + current_valid_mask = agent_valid_mask[:, i:i + 1] + # === There exists a very rare case, that the agent validity is True, False, True === + # When it happens in the beginning first 3 steps, the case become very complex. + # A solution is to assume a default action for the newly added agents. + if 0 < i < T_chunks - 1: + step0_valid_mask = agent_valid_mask[:, i - 1:i] + step1_valid_mask = agent_valid_mask[:, i:i + 1] + step2_valid_mask = agent_valid_mask[:, i + 1:i + 2] + is_rare_case = step2_valid_mask & step0_valid_mask & ~step1_valid_mask + if is_rare_case.any(): + # Interpolate position, heading and velocity + int_pos = (agent_pos[:, i - 1:i] + agent_pos[:, i + 1:i + 2]) / 2 + int_vel = (agent_velocity[:, i - 1:i] + agent_velocity[:, i + 1:i + 2]) / 2 + + # Circular interpolation for headings + head_s = agent_heading[:, i - 1:i] + head_e = agent_heading[:, i + 1:i + 2] + tensor = torch.atan2(torch.sin(head_s) + torch.sin(head_e), torch.cos(head_s) + torch.cos(head_e)) + int_heading = tensor + + agent_pos[:, i:i + + 1] = torch.where(is_rare_case[..., None].expand(-1, -1, -1, 3), int_pos, current_pos) + agent_heading[:, i:i + 1] = torch.where(is_rare_case, int_heading, current_heading) + agent_velocity[:, i:i + 1] = torch.where( + is_rare_case[..., None].expand(-1, -1, -1, 2), int_vel, current_vel + ) + agent_valid_mask[:, i:i + 1] = torch.logical_or(current_valid_mask, is_rare_case) + # Write back: + data_dict["decoder/agent_position"][:, ::self.num_skipped_steps] = agent_pos + data_dict["decoder/agent_heading"][:, ::self.num_skipped_steps] = agent_heading + data_dict["decoder/agent_velocity"][:, ::self.num_skipped_steps] = agent_velocity + data_dict["decoder/agent_valid_mask"][:, ::self.num_skipped_steps] = agent_valid_mask + return data_dict + + +class DeltaDeltaTokenizer(BaseTokenizer): + def __init__(self, config): + super().__init__(config) + + # We reuse x_max and y_max to refer to the maximal acceleration in 1s in the x and y dimensions respectively. + # Note that this isn't the maximal change in velocity between two consecutive timesteps. + assert "X_LIMIT" not in config.TOKENIZATION, "Please use X_MAX/MIN, Y_MAX/MIN instead!" + assert "Y_LIMIT" not in config.TOKENIZATION, "Please use X_MAX/MIN, Y_MAX/MIN instead!" + # x_max = config.TOKENIZATION.X_MAX / STEPS_PER_SECOND * self.num_skipped_steps + # x_min = config.TOKENIZATION.X_MIN / STEPS_PER_SECOND * self.num_skipped_steps + # y_max = config.TOKENIZATION.Y_MAX / STEPS_PER_SECOND * self.num_skipped_steps + # y_min = config.TOKENIZATION.Y_MIN / STEPS_PER_SECOND * self.num_skipped_steps + assert config.TOKENIZATION.X_MAX == 3.5, "X_MAX is deprecated!" + + x_num_bins = y_num_bins = config.TOKENIZATION.NUM_BINS + self.num_bins = config.TOKENIZATION.NUM_BINS + self.num_actions = x_num_bins * y_num_bins + self.config = config + + x_limit_veh = config.TOKENIZATION.VEH_LIMIT / STEPS_PER_SECOND * self.num_skipped_steps + x_limit_ped = config.TOKENIZATION.PED_LIMIT / STEPS_PER_SECOND * self.num_skipped_steps + x_limit_cyc = config.TOKENIZATION.CYC_LIMIT / STEPS_PER_SECOND * self.num_skipped_steps + # assert x_num_bins == y_num_bins == 33 + + # Precompute the bin positions. In the future, we can load them from dataset. + bin_veh = get_bin_centers( + x_min=-x_limit_veh, + x_max=x_limit_veh, + y_min=-x_limit_veh, + y_max=x_limit_veh, + x_num_bins=x_num_bins, + y_num_bins=y_num_bins + ) + if x_limit_veh == x_limit_ped and x_limit_veh == x_limit_cyc: + self.bin_centers_flat = bin_veh + # Assert if (dx=0, dy=0) are in the bin centers. + assert self.bin_centers_flat.shape == (self.num_actions, 2) + self.default_action = int(np.argmin(np.linalg.norm(self.bin_centers_flat, axis=-1))) + self.bin_centers = torch.from_numpy(self.bin_centers_flat.reshape(1, self.num_actions, 1, 2)) + self.use_type_specific_bins = False + else: + bin_ped = get_bin_centers( + x_min=-x_limit_ped, + x_max=x_limit_ped, + y_min=-x_limit_ped, + y_max=x_limit_ped, + x_num_bins=x_num_bins, + y_num_bins=y_num_bins + ) + bin_cyc = get_bin_centers( + x_min=-x_limit_cyc, + x_max=x_limit_cyc, + y_min=-x_limit_cyc, + y_max=x_limit_cyc, + x_num_bins=x_num_bins, + y_num_bins=y_num_bins + ) + self.bin_centers_flat = np.stack([bin_veh, bin_ped, bin_cyc], axis=1) + # Assert if (dx=0, dy=0) are in the bin centers. + assert self.bin_centers_flat.shape == (self.num_actions, 3, 2) + self.default_action = int(np.argmin(np.linalg.norm(self.bin_centers_flat, axis=-1).mean(1))) + self.bin_centers = torch.from_numpy(self.bin_centers_flat.reshape(1, self.num_actions, 1, 3, 2)) + self.use_type_specific_bins = True + + self.add_noise = config.TOKENIZATION.ADD_NOISE + # assert self.add_noise is False + + num_bins = self.num_bins + # Create coordinate grid centered at (0,0) + y, x = np.ogrid[-(num_bins // 2):(num_bins + 1) // 2, -(num_bins // 2):(num_bins + 1) // 2] + # Calculate the distance from the center + dist_from_center = np.sqrt(x**2 + y**2) + # Normalize distances so that the center is -1 and edges are 0 + max_distance = dist_from_center.max() + min_val = 1e-5 + normalized_dist = ((dist_from_center / max_distance) - 1) * min_val + + # Flatten to get a (num_bins^2,) vector + self.noise = torch.from_numpy(normalized_dist.ravel()).reshape(1, num_bins * num_bins, 1) + + def get_bin_centers(self, agent_type): + agent_type = agent_type.clone() + B, N = agent_type.shape + if self.bin_centers is not None: + if self.use_type_specific_bins: + bin_centers = self.bin_centers.to(agent_type.device).expand(B, self.num_actions, N, 3, 2) + agent_type = agent_type - 1 # Veh: 0, Ped: 1, Cyc: 2 + agent_type[agent_type < 0] = 0 + agent_type = agent_type.reshape(B, 1, N, 1, 1).expand(B, bin_centers.shape[1], N, 1, 2) + bin_centers = torch.gather(bin_centers, dim=-2, index=agent_type).squeeze(-2) + else: + bin_centers = self.bin_centers.to(agent_type.device).expand(B, self.num_actions, N, 2) + else: + bin_centers = None + return bin_centers + + def tokenize(self, data_dict, backward_prediction=False, **kwargs): + """ + + Args: + data_dict: Input data + + Returns: + Discretized action in an int array with shape (num time steps for actions, num agents). + """ + + if backward_prediction: + return self._tokenize_backward_prediction(data_dict, **kwargs) + + # TODO: Hardcoded here... + if self.config.GPT_STYLE: + start_step = 0 + else: + start_step = 2 + + # ===== Hole Filling ===== + data_dict = self.hole_filling(data_dict) + + # ===== Get initial data ===== + # If we don't clone here, the following hole-filling code will overwrite raw data. + agent_pos = data_dict["decoder/agent_position"] # .clone() + agent_heading = data_dict["decoder/agent_heading"] # .clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] # .clone() + agent_velocity = data_dict["decoder/agent_velocity"] # .clone() + agent_shape = data_dict["decoder/current_agent_shape"] # .clone() + agent_type = data_dict["decoder/agent_type"] # .clone() + B, T_full, N, _ = agent_pos.shape + # assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos_full = agent_pos.clone() + agent_heading_full = agent_heading.clone() + agent_velocity_full = agent_velocity.clone() + agent_valid_mask_full = agent_valid_mask.clone() + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + # assert T_chunks == 19 + + # ===== Build up some variables ===== + current_pos = agent_pos[:, start_step:start_step + 1, ..., :2] + current_heading = agent_heading[:, start_step:start_step + 1] + current_vel = agent_velocity[:, start_step:start_step + 1, ..., :2] + current_valid_mask = agent_valid_mask[:, start_step:start_step + 1] + + init_pos = current_pos.clone() + init_heading = current_heading.clone() + init_vel = current_vel.clone() + init_valid_mask = current_valid_mask.clone() + init_delta = _reconstruct_delta_pos_from_abs_vel(current_vel, current_heading, dt=self.dt) + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + target_action = [] + target_action_valid_mask = [] + reconstruction_list = [] + relative_delta_pos_list = [] + pos = [] + heading = [] + vel = [] + + # ===== Loop to reconstruct the scenario ===== + tokenization_state = None + for next_step in range(start_step + 1, T_chunks): + res = self._tokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_vel=current_vel, + current_valid_mask=current_valid_mask, + next_pos=agent_pos[:, next_step:next_step + 1, ..., :2], # (B, 1, N, 2) + next_heading=agent_heading[:, next_step:next_step + 1], # (B, 1, N) + next_valid_mask=agent_valid_mask[:, next_step:next_step + 1], # (B, 1, N) + next_velocity=agent_velocity[:, next_step:next_step + 1, ..., :2], # (B, 1, N, 2) + bin_centers=bin_centers, + add_noise=self.add_noise, + topk=self.config.TOKENIZATION.NOISE_TOPK, + agent_shape=agent_shape, + agent_type=agent_type, + dt=self.dt, + tokenization_state=tokenization_state, + agent_pos_full=agent_pos_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + 1], + agent_heading_full=agent_heading_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + 1], + agent_velocity_full=agent_velocity_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + 1], + agent_valid_mask_full=agent_valid_mask_full[:, (next_step - 1) * + self.num_skipped_steps:next_step * self.num_skipped_steps + + 1], + ) + tokenization_state = res + + best_action = res["action"] + recon_next_pos = res["pos"] + recon_next_heading = res["heading"] + recon_next_vel = res["vel"] + recon_next_valid_mask = res["mask"] + recon_next_delta_pos = res["delta_pos"] # The input delta for next step. + + best_action = best_action.reshape(B, 1, N) + + # ===== Process the target action/valid mask ===== + target_action_valid_mask.append(recon_next_valid_mask.clone()) + target_action.append(best_action) + + # Some debug asserts + assert (best_action[recon_next_valid_mask] >= 0).all() + assert (best_action[~recon_next_valid_mask] == -1).all() + + # ===== Process the "current_xxx" for next step ===== + if self.config.GPT_STYLE: + assert self.config.TOKENIZATION.ALLOW_SKIP_STEP + if self.config.TOKENIZATION.ALLOW_SKIP_STEP: + # Use the next valid mask as the valid mask for next step. + # In contrast, if this flag is False, then we will use "next valid mask & if it's not removed" for next + # step. + next_valid_mask = agent_valid_mask[:, next_step:next_step + 1] + newly_added = torch.logical_and(~recon_next_valid_mask, next_valid_mask) + if newly_added.any(): + recon_next_pos[newly_added] = agent_pos[:, next_step:next_step + 1, ..., :2][newly_added] + recon_next_heading[newly_added] = agent_heading[:, next_step:next_step + 1][newly_added] + recon_next_vel[newly_added] = agent_velocity[:, next_step:next_step + 1, ..., :2][newly_added] + recon_next_delta_pos[newly_added] = _reconstruct_delta_pos_from_abs_vel( + vel=agent_velocity[:, next_step:next_step + 1, ..., :2][newly_added], + heading=agent_heading[:, next_step:next_step + 1][newly_added], + dt=self.dt + ) + recon_next_valid_mask[newly_added] = next_valid_mask[newly_added] + + relative_delta_pos_list.append(recon_next_delta_pos) + current_vel = recon_next_vel + current_heading = recon_next_heading + current_pos = recon_next_pos + current_valid_mask = recon_next_valid_mask + pos.append(current_pos.clone()) + heading.append(current_heading.clone()) + vel.append(current_vel.clone()) + + # ===== Postprocess and prepare the "start action" ===== + # In GPT style, some agents will be added in the middle of the scene. + # So we need to find out when they are in and add a start action before that step. + # In non-GPT style, we only need to prepare the start action for the first step. + target_actions = torch.cat(target_action, dim=1) # (B, T_skipped, N) + target_action_valid_mask = torch.cat(target_action_valid_mask, dim=1) # (B, T_skipped, N) + relative_delta_pos_list = torch.cat(relative_delta_pos_list, dim=1) # (B, T_skipped, N) + pos = torch.cat(pos, dim=1) + heading = torch.cat(heading, dim=1) + vel = torch.cat(vel, dim=1) + + pos = torch.cat([init_pos, pos], dim=1) + heading = torch.cat([init_heading, heading], dim=1) + vel = torch.cat([init_vel, vel], dim=1) + relative_delta_pos_list = torch.cat([init_delta, relative_delta_pos_list], dim=1) + + # If not in back prediction, what will be: + # 1. The first tokens in input_actions? START_ACTION + # 2. The last tokens in input_actions? Just the tokens at t=18 (t=85~90) + # 3. The first tokens in target_actions? The tokens at t=0 (t=0~5) for GPT and t=2 otherwise. + # 4. The last tokens in target_actions? All -1 because there is no GT for t=19 (t=90~95) + if self.config.GPT_STYLE: + # Search for the first step that has newly added agents + assert start_step == 0 + already_tokenized = init_valid_mask.clone() + start_action = torch.full_like(target_actions[:, :1], -1) + start_action[init_valid_mask] = START_ACTION + assert target_actions.shape[1] == T_chunks - 1 + input_action = torch.cat([start_action, target_actions], dim=1) + input_action_valid_mask = torch.cat([init_valid_mask, target_action_valid_mask], dim=1) + for next_step in range(start_step + 1, T_chunks): + next_valid_mask = agent_valid_mask[:, next_step:next_step + 1] + is_newly_added = torch.logical_and(~already_tokenized, next_valid_mask) + if is_newly_added.any(): + input_action[:, next_step:next_step + 1][is_newly_added] = START_ACTION + input_action_valid_mask[:, next_step:next_step + 1][is_newly_added] = \ + next_valid_mask[is_newly_added] + already_tokenized = torch.logical_or(already_tokenized, is_newly_added) + + else: + start_action = torch.full_like(target_actions[:, :1], -1) + start_action[init_valid_mask] = START_ACTION + input_action = torch.cat([start_action, target_actions], dim=1) + input_action_valid_mask = torch.cat([init_valid_mask.reshape(B, 1, N), target_action_valid_mask], dim=1) + + target_actions = torch.cat([target_actions, target_actions.new_full((B, 1, N), -1)], dim=1) + target_action_valid_mask = torch.cat( + [target_action_valid_mask, target_action_valid_mask.new_zeros((B, 1, N))], dim=1 + ) + data_dict["in_backward_prediction"] = False + assert (agent_valid_mask[:, start_step:] >= target_action_valid_mask).all() + assert (agent_valid_mask[:, start_step + 1:] >= target_action_valid_mask[:, :-1]).all() + assert (agent_valid_mask[:, start_step:] >= input_action_valid_mask).all() + + # # Some debug asserts for backward prediction: + # assert (target_actions[:, :-1] == flipped_target_actions[:, :-1].flip(dims=[1])).all() + # minp = (input_action * (input_action != START_ACTION)) + # minp = minp * (input_action != -1) + # mfinp = (flipped_input_action * (flipped_input_action != END_ACTION)) + # mfinp = mfinp * (flipped_input_action != -1) + # assert (minp[:, 1:] == mfinp[:, 1:].flip(dims=[1])).all() + # assert (pos == flipped_pos.flip(dims=[1])).all() + # assert (heading == flipped_heading.flip(dims=[1])).all() + # assert (vel == flipped_vel.flip(dims=[1])).all() + + data_dict["decoder/target_action"] = target_actions + data_dict["decoder/target_action_valid_mask"] = target_action_valid_mask + data_dict["decoder/input_action"] = input_action + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + data_dict["decoder/modeled_agent_delta"] = relative_delta_pos_list + data_dict["decoder/modeled_agent_position"] = pos + data_dict["decoder/modeled_agent_heading"] = heading + data_dict["decoder/modeled_agent_velocity"] = vel + + # Debug: + # pos_diff = (pos - agent_pos[..., :2]).norm(dim=-1).numpy() + # heading_diff = utils.wrap_to_pi(heading - agent_heading).abs().numpy() + # vel_diff = (vel - agent_velocity[..., :2]).norm(dim=-1).numpy() + + # All input actions should be >0 + assert (input_action[input_action_valid_mask] >= 0).all() + assert (target_actions[target_action_valid_mask] >= 0).all() + assert (input_action[~input_action_valid_mask] == -1).all() + assert (target_actions[~target_action_valid_mask] == -1).all() + + return data_dict, {"reconstruction_list": reconstruction_list} + + def _tokenize_backward_prediction(self, data_dict, **kwargs): + # TODO: Hardcoded here... + if self.config.GPT_STYLE: + start_step = 0 + else: + raise ValueError() + start_step = 2 + + # ===== Hole Filling ===== + data_dict = self.hole_filling(data_dict) + + # ===== Get initial data ===== + # If we don't clone here, the following hole-filling code will overwrite raw data. + agent_pos = data_dict["decoder/agent_position"] # .clone() + agent_heading = data_dict["decoder/agent_heading"] # .clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"] # .clone() + agent_velocity = data_dict["decoder/agent_velocity"] # .clone() + agent_shape = data_dict["decoder/current_agent_shape"] # .clone() + agent_type = data_dict["decoder/agent_type"] # .clone() + B, T_full, N, _ = agent_pos.shape + # assert T_full == 91 + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + # assert T_chunks == 19 + + # ===== Build up some variables ===== + current_pos = agent_pos[:, -1:, ..., :2] + current_heading = agent_heading[:, -1:] + current_vel = agent_velocity[:, -1:, ..., :2] + current_valid_mask = agent_valid_mask[:, -1:] + + init_pos = current_pos.clone() + init_heading = current_heading.clone() + init_vel = current_vel.clone() + init_valid_mask = current_valid_mask.clone() + + # NOTE: +180deg here. + init_delta = _reconstruct_delta_pos_from_abs_vel(current_vel, current_heading + np.pi, dt=self.dt) + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + target_action = [] + target_action_valid_mask = [] + reconstruction_list = [] + relative_delta_pos_list = [] + pos = [] + heading = [] + vel = [] + + # ===== Loop to reconstruct the scenario ===== + for backward_next_step in range(1, T_chunks): + # backward_next_step = 1, ..., 18 + + forward_next_step = T_chunks - backward_next_step - 1 + # forward_next_step = 17, ..., 0 + + res = self._tokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_vel=current_vel, + current_valid_mask=current_valid_mask, + next_pos=agent_pos[:, forward_next_step:forward_next_step + 1, ..., :2], # (B, 1, N, 2) + next_heading=agent_heading[:, forward_next_step:forward_next_step + 1], # (B, 1, N) + next_valid_mask=agent_valid_mask[:, forward_next_step:forward_next_step + 1], # (B, 1, N) + next_velocity=agent_velocity[:, forward_next_step:forward_next_step + 1, ..., :2], # (B, 1, N, 2) + bin_centers=bin_centers, + add_noise=self.add_noise, + topk=self.config.TOKENIZATION.NOISE_TOPK, + agent_shape=agent_shape, + agent_type=agent_type, + dt=-self.dt, + **kwargs + ) + + best_action = res["action"] + recon_next_pos = res["pos"] + recon_next_heading = res["heading"] + recon_next_vel = res["vel"] + recon_next_valid_mask = res["mask"] + recon_next_delta_pos = res["delta_pos"] # The input delta for next step. + + best_action = best_action.reshape(B, 1, N) + + # ===== Process the target action/valid mask ===== + target_action_valid_mask.append(recon_next_valid_mask.clone()) + target_action.append(best_action) + + # Some debug asserts + assert (best_action[recon_next_valid_mask] >= 0).all() + assert (best_action[~recon_next_valid_mask] == -1).all() + + # ===== Process the "current_xxx" for next step ===== + # Use the next valid mask as the valid mask for next step. + # In contrast, if this flag is False, then we will use "next valid mask & if it's not removed" for next + # step. + next_valid_mask = agent_valid_mask[:, forward_next_step:forward_next_step + 1] + newly_added = torch.logical_and(~recon_next_valid_mask, next_valid_mask) + if newly_added.any(): + recon_next_pos[newly_added] = agent_pos[:, forward_next_step:forward_next_step + 1, + ..., :2][newly_added] + recon_next_heading[newly_added] = agent_heading[:, forward_next_step:forward_next_step + 1][newly_added] + recon_next_vel[newly_added] = agent_velocity[:, forward_next_step:forward_next_step + 1, + ..., :2][newly_added] + recon_next_delta_pos[newly_added] = _reconstruct_delta_pos_from_abs_vel( + vel=agent_velocity[:, forward_next_step:forward_next_step + 1, ..., :2][newly_added], + # heading=agent_heading[:, forward_next_step:forward_next_step + 1][newly_added], + heading=agent_heading[:, forward_next_step:forward_next_step + 1][newly_added] + np.pi, + dt=self.dt + ) + recon_next_valid_mask[newly_added] = next_valid_mask[newly_added] + + relative_delta_pos_list.append(recon_next_delta_pos) + current_vel = recon_next_vel + current_heading = recon_next_heading + current_pos = recon_next_pos + current_valid_mask = recon_next_valid_mask + pos.append(current_pos.clone()) + heading.append(current_heading.clone()) + vel.append(current_vel.clone()) + + # ===== Postprocess and prepare the "start action" ===== + # In GPT style, some agents will be added in the middle of the scene. + # So we need to find out when they are in and add a start action before that step. + # In non-GPT style, we only need to prepare the start action for the first step. + target_actions = torch.cat(target_action, dim=1) # (B, T_skipped, N) + target_action_valid_mask = torch.cat(target_action_valid_mask, dim=1) # (B, T_skipped, N) + relative_delta_pos_list = torch.cat(relative_delta_pos_list, dim=1) # (B, T_skipped, N) + pos = torch.cat(pos, dim=1) + heading = torch.cat(heading, dim=1) + vel = torch.cat(vel, dim=1) + + pos = torch.cat([init_pos, pos], dim=1) + heading = torch.cat([init_heading, heading], dim=1) + vel = torch.cat([init_vel, vel], dim=1) + relative_delta_pos_list = torch.cat([init_delta, relative_delta_pos_list], dim=1) + + # Search for the first step that has newly added agents + assert start_step == 0 + already_tokenized = init_valid_mask.clone() + start_action = torch.full_like(target_actions[:, :1], -1) + start_action[init_valid_mask] = END_ACTION + assert target_actions.shape[1] == T_chunks - 1 + input_action = torch.cat([start_action, target_actions], dim=1) + input_action_valid_mask = torch.cat([init_valid_mask, target_action_valid_mask], dim=1) + for backward_next_step in range(1, T_chunks): + forward_next_step = T_chunks - backward_next_step - 1 + next_valid_mask = agent_valid_mask[:, forward_next_step:forward_next_step + 1] + is_newly_added = torch.logical_and(~already_tokenized, next_valid_mask) + if is_newly_added.any(): + input_action[:, backward_next_step:backward_next_step + 1][is_newly_added] = END_ACTION + input_action_valid_mask[:, backward_next_step:backward_next_step + 1][is_newly_added] = \ + next_valid_mask[is_newly_added] + already_tokenized = torch.logical_or(already_tokenized, is_newly_added) + + target_actions = torch.cat([target_actions, target_actions.new_full((B, 1, N), -1)], dim=1) + target_action_valid_mask = torch.cat( + [target_action_valid_mask, target_action_valid_mask.new_zeros((B, 1, N))], dim=1 + ) + data_dict["in_backward_prediction"] = True + + flipped_agent_valid_mask = agent_valid_mask.flip(dims=[1]) + assert (flipped_agent_valid_mask[:, start_step:] >= target_action_valid_mask).all() + assert (flipped_agent_valid_mask[:, start_step + 1:] >= target_action_valid_mask[:, :-1]).all() + assert (flipped_agent_valid_mask[:, start_step:] >= input_action_valid_mask).all() + + data_dict["decoder/target_action"] = target_actions + data_dict["decoder/target_action_valid_mask"] = target_action_valid_mask + data_dict["decoder/input_action"] = input_action + data_dict["decoder/input_action_valid_mask"] = input_action_valid_mask + data_dict["decoder/modeled_agent_delta"] = relative_delta_pos_list + data_dict["decoder/modeled_agent_position"] = pos + data_dict["decoder/modeled_agent_heading"] = heading + data_dict["decoder/modeled_agent_velocity"] = vel + + # All input actions should be >0 + assert (input_action[input_action_valid_mask] >= 0).all() + assert (target_actions[target_action_valid_mask] >= 0).all() + assert (input_action[~input_action_valid_mask] == -1).all() + assert (target_actions[~target_action_valid_mask] == -1).all() + + return data_dict, {"reconstruction_list": reconstruction_list} + + def detokenize( + self, + data_dict, + interpolation=True, + detokenizing_gt=False, + backward_prediction=False, + flip_wrong_heading=False, + autoregressive_start_step=2, + **kwargs, + ): # actions, current_pos, current_vel, current_heading): + """ + Compared to the non-gpt style, this function dynamically adds new agents into the scene. + A very interesting point here is we can't start with 'current position' in the data. + Because the model is predicting according to the first few tokens, which already have some errors. + """ + + if backward_prediction: + return self._detokenize_backward_prediction( + data_dict, interpolation=interpolation, detokenizing_gt=detokenizing_gt, **kwargs + ) + + # TODO: Hardcoded here... + if self.config.GPT_STYLE: + start_step = 0 + else: + start_step = 2 + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"].clone() + agent_heading = data_dict["decoder/agent_heading"].clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() + agent_velocity = data_dict["decoder/agent_velocity"].clone() + agent_shape = data_dict["decoder/current_agent_shape"].clone() + agent_type = data_dict["decoder/agent_type"].clone() + if detokenizing_gt: + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + input_mask = data_dict["decoder/input_action_valid_mask"] + B, T_full, N, _ = agent_pos.shape + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps].clone() + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + # T_chunks = agent_pos.shape[1] + + # ===== Prepare some variables ===== + action = data_dict["decoder/output_action"] + T_actions = action.shape[1] + # if T_actions + start_step != T_chunks: + # print( + # "WARNING: The number of actions is not consistent with the number of raw data chunks! You have {} actions, start step is {} and the number of chunks is {}." + # .format(T_actions, start_step, T_chunks) + # ) + T_generated_chunks = T_actions + start_step + + current_pos = agent_pos[:, start_step:start_step + 1, ..., :2].clone() + current_heading = agent_heading[:, start_step:start_step + 1].clone() + current_vel = agent_velocity[:, start_step:start_step + 1, ..., :2].clone() + current_valid_mask = agent_valid_mask[:, start_step:start_step + 1].clone() + + if detokenizing_gt: + # Merge input mask with target mask + input_mask = input_mask & target_action_valid_mask + + reconstructed_pos_list = [current_pos.clone()] + reconstructed_heading_list = [current_heading.clone()] + reconstructed_vel_list = [current_vel.clone()] + + already_interpolated = False + reconstructed_pos_full_list = [current_pos.clone()] + reconstructed_heading_full_list = [current_heading.clone()] + reconstructed_vel_full_list = [current_vel.clone()] + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + kwargs["detokenization_state"] = None + + for curr_step in range(T_generated_chunks): + if curr_step < start_step: + next_pos = agent_pos[:, curr_step + 1:curr_step + 2, ..., :2] + next_heading = agent_heading[:, curr_step + 1:curr_step + 2] + next_vel = agent_velocity[:, curr_step + 1:curr_step + 2, ..., :2] + next_valid_mask = agent_valid_mask[:, curr_step + 1:curr_step + 2] + + else: + # We assume that starting from start_step, the agent valid mask will not change. + action_step = curr_step - start_step + action_valid_mask_step = input_mask[:, action_step:action_step + 1] + + act = action[:, action_step:action_step + 1] + assert (act[action_valid_mask_step] != -1).all() + res = self._detokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=action_valid_mask_step, + current_vel=current_vel, + action=act, + agent_shape=agent_shape, + agent_type=agent_type, + bin_centers=bin_centers, + dt=self.dt, + flip_wrong_heading=flip_wrong_heading, + **kwargs + ) + kwargs["detokenization_state"] = res + + next_pos, next_heading, next_vel = res["pos"], res["heading"], res["vel"] + assert "delta_pos" in res + next_pos = next_pos.reshape(B, 1, N, 2) + next_heading = next_heading.reshape(B, 1, N) + next_vel = next_vel.reshape(B, 1, N, 2) + next_valid_mask = current_valid_mask + + # ===== A special case: fill in the info for the agents added in next step ===== + # ===== Another special case: if you are detokenizing the raw tokenized data, you need to fill in + # the info for the agents added in the next step. ===== + if (curr_step < autoregressive_start_step) or (detokenizing_gt and curr_step < T_generated_chunks - 1): + # Fill in the initial states of newly added agents + action_valid_mask_next_step = input_mask[:, action_step + 1:action_step + 2] + newly_added = torch.logical_and(~action_valid_mask_step, action_valid_mask_next_step) + next_pos[newly_added] = agent_pos[:, curr_step + 1:curr_step + 2, ..., :2][newly_added] + next_heading[newly_added] = agent_heading[:, curr_step + 1:curr_step + 2][newly_added] + next_vel[newly_added] = agent_velocity[:, curr_step + 1:curr_step + 2, ..., :2][newly_added] + next_valid_mask[newly_added] = action_valid_mask_next_step[newly_added] + if "reconstructed_position" in res: + # If some agents are added in the next step, the "last step" in reconstructed chunk + # aka the 5-th step in the chunk should be replaced by the GT states. + res["reconstructed_position"][-1][newly_added] = agent_pos[:, curr_step + 1:curr_step + 2, + ..., :2][newly_added] + res["reconstructed_heading"][-1][newly_added] = agent_heading[:, curr_step + 1:curr_step + + 2][newly_added] + res["reconstructed_velocity"][-1][newly_added] = agent_velocity[:, curr_step + 1:curr_step + 2, + ..., :2][newly_added] + + if "reconstructed_position" in res: + already_interpolated = True + reconstructed_pos_full_list.extend(res["reconstructed_position"]) + reconstructed_heading_full_list.extend(res["reconstructed_heading"]) + reconstructed_vel_full_list.extend(res["reconstructed_velocity"]) + + current_pos = next_pos + current_heading = next_heading + current_vel = next_vel + current_valid_mask = next_valid_mask + + reconstructed_pos_list.append(current_pos.clone()) + reconstructed_heading_list.append(current_heading.clone()) + reconstructed_vel_list.append(current_vel.clone()) + + reconstructed_pos = torch.cat(reconstructed_pos_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_list, dim=1) + + # Every input token has it's own position (before the action). + # As we have 19 tokens, and the last one token will lead us to a new place, + # So it's totally 20 positions. + assert reconstructed_pos.shape[1] == T_generated_chunks + 1 + assert input_mask.shape[1] == T_generated_chunks - start_step + + # Interpolation + if interpolation: + + if already_interpolated: + reconstructed_pos = torch.cat(reconstructed_pos_full_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_full_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_full_list, dim=1) + + else: + + # spline_res = interpolate_trajectory_spline( + # pos=reconstructed_pos, + # heading=reconstructed_heading, + # vel=reconstructed_vel, + # mask=torch.cat([input_mask, input_mask[:, -1:]], dim=1), + # ) + # reconstructed_pos = spline_res["position"] + # reconstructed_heading = spline_res["heading"] + # reconstructed_vel = spline_res["velocity"] + # reconstructed_valid_mask = spline_res["valid_mask"] + + new_reconstructed_pos = interpolate(reconstructed_pos, self.num_skipped_steps, remove_first_step=False) + assert (new_reconstructed_pos[:, ::5] == reconstructed_pos).all() + reconstructed_pos = new_reconstructed_pos + + reconstructed_heading = interpolate_heading( + reconstructed_heading, self.num_skipped_steps, remove_first_step=False + ) + reconstructed_vel = interpolate(reconstructed_vel, self.num_skipped_steps, remove_first_step=False) + + input_mask_augmented = torch.cat([agent_valid_mask[:, :start_step], input_mask], dim=1) + assert input_mask_augmented.shape[1] == T_generated_chunks + valid = input_mask_augmented + valid = valid.reshape(B, -1, 1, N).expand(-1, -1, self.num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + else: + reconstructed_valid_mask = input_mask + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + + return data_dict + + def _detokenize_backward_prediction( + self, + data_dict, + interpolation=True, + detokenizing_gt=False, + flip_wrong_heading=False, + ): # actions, current_pos, current_vel, current_heading): + """ + Compared to the non-gpt style, this function dynamically adds new agents into the scene. + A very interesting point here is we can't start with 'current position' in the data. + Because the model is predicting according to the first few tokens, which already have some errors. + """ + # TODO: Hardcoded here... + assert self.config.GPT_STYLE + start_step = 0 + # autoregressive_start_step = 2 + + # ===== Get initial data ===== + agent_pos = data_dict["decoder/agent_position"].clone() + agent_heading = data_dict["decoder/agent_heading"].clone() + agent_valid_mask = data_dict["decoder/agent_valid_mask"].clone() + agent_velocity = data_dict["decoder/agent_velocity"].clone() + agent_shape = data_dict["decoder/current_agent_shape"].clone() + agent_type = data_dict["decoder/agent_type"].clone() + target_action_valid_mask = data_dict["decoder/target_action_valid_mask"] + input_mask = data_dict["decoder/input_action_valid_mask"] + B, T_full, N, _ = agent_pos.shape + assert T_full == 91 # TODO: hardcoded + assert agent_pos.ndim == 4 + + # ===== Skip some steps ===== + agent_pos = agent_pos[:, ::self.num_skipped_steps] + agent_heading = agent_heading[:, ::self.num_skipped_steps] + agent_valid_mask = agent_valid_mask[:, ::self.num_skipped_steps] + agent_velocity = agent_velocity[:, ::self.num_skipped_steps] + T_chunks = agent_pos.shape[1] + assert T_chunks == 19 # TODO: hardcoded + + # ===== Prepare some variables ===== + action = data_dict["decoder/output_action"] + T_actions = action.shape[1] + if T_actions + start_step != T_chunks: + print( + "WARNING: The number of actions is not consistent with the number of raw data chunks! You have {} actions, start step is {} and the number of chunks is {}." + .format(T_actions, start_step, T_chunks) + ) + T_generated_chunks = T_actions + start_step + + current_pos = agent_pos[:, -1:, ..., :2] + current_heading = agent_heading[:, -1:] + current_vel = agent_velocity[:, -1:, ..., :2] + current_valid_mask = agent_valid_mask[:, -1:] + + if detokenizing_gt: + # Merge input mask with target mask + input_mask = input_mask & target_action_valid_mask + + reconstructed_pos_list = [current_pos.clone()] + reconstructed_heading_list = [current_heading.clone()] + reconstructed_vel_list = [current_vel.clone()] + + # Select correct bins: + bin_centers = self.get_bin_centers(agent_type) + + for curr_backward_step in range(T_generated_chunks): + # curr_backward_step = 0, 1, ..., 18 + + curr_forward_step = T_chunks - curr_backward_step - 1 + # curr_forward_step = 18, 17, ..., 0 + + next_forward_step = curr_forward_step - 1 + # next_forward_step = 17, 16, ..., -1 + + action_valid_mask_step = input_mask[:, curr_backward_step:curr_backward_step + 1] + act = action[:, curr_backward_step:curr_backward_step + 1] + assert (act[action_valid_mask_step] != -1).all() + res = self._detokenize_a_step( + current_pos=current_pos, + current_heading=current_heading, + current_valid_mask=action_valid_mask_step, + current_vel=current_vel, + action=act, + agent_shape=agent_shape, + agent_type=agent_type, + bin_centers=bin_centers, + dt=-self.dt, + flip_wrong_heading=flip_wrong_heading, + ) + next_pos, next_heading, next_vel = res["pos"], res["heading"], res["vel"] + next_pos = next_pos.reshape(B, 1, N, 2) + next_heading = next_heading.reshape(B, 1, N) + next_vel = next_vel.reshape(B, 1, N, 2) + next_valid_mask = current_valid_mask + + # if detokenizing_gt and curr_backward_step < T_generated_chunks - 1: + # TODO: Here the detokenizing_gt is ignored and we always add new agents in. + if curr_backward_step < T_generated_chunks - 1: + # Fill in the initial states of newly added agents + action_valid_mask_next_step = input_mask[:, curr_backward_step + 1:curr_backward_step + 2] + newly_added = torch.logical_and(~action_valid_mask_step, action_valid_mask_next_step) + next_pos[newly_added] = agent_pos[:, next_forward_step:next_forward_step + 1, ..., :2][newly_added] + next_heading[newly_added] = agent_heading[:, next_forward_step:next_forward_step + 1][newly_added] + next_vel[newly_added] = agent_velocity[:, next_forward_step:next_forward_step + 1, ..., :2][newly_added] + next_valid_mask[newly_added] = action_valid_mask_next_step[newly_added] + + current_pos = next_pos + current_heading = next_heading + current_vel = next_vel + current_valid_mask = next_valid_mask + + reconstructed_pos_list.append(current_pos.clone()) + reconstructed_heading_list.append(current_heading.clone()) + reconstructed_vel_list.append(current_vel.clone()) + + reconstructed_pos = torch.cat(reconstructed_pos_list, dim=1) + reconstructed_heading = torch.cat(reconstructed_heading_list, dim=1) + reconstructed_vel = torch.cat(reconstructed_vel_list, dim=1) + + # Every input token has it's own position (before the action). + # As we have 19 tokens, and the last one token will lead us to a new place, + # So it's totally 20 positions. + assert reconstructed_pos.shape[1] == T_generated_chunks + 1 + assert input_mask.shape[1] == T_generated_chunks - start_step + + # TODO: Not sure if we should return flipped data or not. + reconstructed_pos = reconstructed_pos.flip(dims=[1]) + reconstructed_heading = reconstructed_heading.flip(dims=[1]) + reconstructed_vel = reconstructed_vel.flip(dims=[1]) + input_mask = input_mask.flip(dims=[1]) + + # Interpolation + if interpolation: + new_reconstructed_pos = interpolate(reconstructed_pos, self.num_skipped_steps, remove_first_step=False) + assert (new_reconstructed_pos[:, ::5] == reconstructed_pos).all() + reconstructed_pos = new_reconstructed_pos + + reconstructed_heading = interpolate_heading( + reconstructed_heading, self.num_skipped_steps, remove_first_step=False + ) + reconstructed_vel = interpolate(reconstructed_vel, self.num_skipped_steps, remove_first_step=False) + + # input_mask_augmented = torch.cat([agent_valid_mask[:, :start_step], input_mask], dim=1) + input_mask_augmented = input_mask + assert input_mask_augmented.shape[1] == T_generated_chunks + valid = input_mask_augmented + valid = valid.reshape(B, -1, 1, N).expand(-1, -1, self.num_skipped_steps, -1).reshape(B, -1, N) + valid = torch.cat([valid, input_mask[:, -1:]], dim=1) + reconstructed_valid_mask = valid + + # Mask out: + reconstructed_pos = reconstructed_pos * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_vel = reconstructed_vel * reconstructed_valid_mask.unsqueeze(-1) + reconstructed_heading = reconstructed_heading * reconstructed_valid_mask + + # We ensure that the output must be 5*T_chunks+1 + assert reconstructed_pos.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_valid_mask.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_vel.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + assert reconstructed_heading.shape[1] == self.num_skipped_steps * T_generated_chunks + 1 + else: + reconstructed_valid_mask = input_mask + + data_dict["decoder/reconstructed_position"] = reconstructed_pos + data_dict["decoder/reconstructed_heading"] = reconstructed_heading + data_dict["decoder/reconstructed_velocity"] = reconstructed_vel + data_dict["decoder/reconstructed_valid_mask"] = reconstructed_valid_mask + + return data_dict + + def _tokenize_a_step( + self, *, current_pos, current_heading, current_valid_mask, current_vel, next_pos, next_heading, next_valid_mask, + next_velocity, bin_centers, add_noise, topk, agent_shape, agent_type, dt, **kwargs + ): + B, _, N, _ = current_pos.shape + + valid_mask = torch.logical_and(current_valid_mask, next_valid_mask) + + delta_vel = rotate_bin_to_absolute_heading(bin_centers, current_heading) + + candidate_vel = delta_vel + current_vel + + candidate_pos = candidate_vel * dt + current_pos + + flip_heading_accordingly = kwargs.get("flip_heading_accordingly", True) + + candidate_heading = infer_heading( + current_pos=candidate_pos, + last_pos=current_pos.expand(-1, self.num_actions, -1, -1), + last_heading=current_heading.expand(-1, self.num_actions, -1), + min_displacement=self.config.TOKENIZATION.MIN_DISPLACEMENT, + flip_heading=(flip_heading_accordingly and dt < 0) + ) + + contour = utils.cal_polygon_contour_torch( + x=candidate_pos[..., 0], + y=candidate_pos[..., 1], + theta=candidate_heading, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + + gt_contour = utils.cal_polygon_contour_torch( + x=next_pos[..., 0], + y=next_pos[..., 1], + theta=next_heading, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + + error_ade = torch.norm(candidate_pos - next_pos, dim=-1) + error_ade = error_ade * valid_mask + + error_pos = torch.norm(contour - gt_contour, dim=-1).mean(-1) + error_pos = error_pos * valid_mask # masking + assert error_pos.ndim == 3 + + # error_heading = utils.wrap_to_pi(candidate_vel_heading - next_heading.expand(-1, self.num_actions, -1)) + # error_heading = error_heading.abs() * valid_mask + + if self.config.TOKENIZATION.USE_CONTOUR_ERROR: + error = error_pos # + error_heading + + else: + error = error_ade # + error_heading + + if add_noise: + # raise ValueError() + print("Noise is not supported in the current version.") + # sampled_action = nucleus_sampling(logits=1 / (error.permute(0, 2, 1) + 1e-6), p=0.95) + # sampled_error = torch.gather(error, 1, sampled_action.unsqueeze(1)).squeeze(1) + # best_action = sampled_action + min_result = error.min(dim=1) + best_action = min_result.indices + + else: + # Pick the best bin with the least error: + min_result = error.min(dim=1) + best_action = min_result.indices + + best_action[~valid_mask.squeeze(1)] = -1 + + # Update reconstructed position and velocity according to the best action: + ind = best_action.reshape(B, 1, N, 1).expand(B, 1, N, 2).clone() + mask = ind == -1 + ind[mask] = self.default_action # Workaround the gather can't handle -1 + reconstructed_pos = torch.gather(candidate_pos, index=ind, dim=1) + reconstructed_vel = torch.gather(candidate_vel, index=ind, dim=1) + reconstructed_heading = torch.gather(candidate_heading, index=ind[..., 0], dim=1) + + reconstructed_vel[mask] = 0 + reconstructed_pos[mask] = 0 + reconstructed_heading[~valid_mask] = 0 + assert current_pos.shape == reconstructed_pos.shape + + # FIXME: This is actually wrong in backward prediction. It's the flipped version of "current velocity". + # But that's OK if Tokenization/Autoregressive share the same code. + relative_delta_pos = reconstructed_pos - current_pos + relative_delta_pos = utils.rotate( + relative_delta_pos[..., 0], relative_delta_pos[..., 1], angle=-reconstructed_heading + ) + relative_delta_pos[mask] = 0 + + # AID = 0 + # print("CUR {}, Recon Pos: {}, GT Pos {}, Cur Vel: {}, Vel: {}, CUR Head: {}, RECON Head: {}".format( + # current_pos[0,0,AID], + # reconstructed_pos[0,0,AID], + # next_pos[0,0,AID], + # current_vel[0,0,AID], + # reconstructed_vel[0,0,AID], + # current_heading[0,0,AID], + # reconstructed_heading[0,0,AID] + # )) + + return dict( + action=best_action, + pos=reconstructed_pos, + heading=reconstructed_heading, + vel=reconstructed_vel, + mask=valid_mask, + delta_pos=relative_delta_pos, + ) + + # def detokenize_for_step(self, data_dict, action): + # # get reconstructed heading: + # return self._detokenize_a_step( + # current_pos=data_dict["decoder/current_agent_position"], + # current_heading=data_dict["decoder/current_agent_heading"], + # current_valid_mask=data_dict["decoder/current_agent_valid_mask"], + # current_vel=data_dict["decoder/current_agent_velocity"], + # action=action + # ) + + def _detokenize_a_step( + self, *, current_pos, current_heading, current_valid_mask, current_vel, action, bin_centers, dt, + flip_wrong_heading, **kwargs + ): + assert action.ndim == 3 + B, T_action, N = action.shape + + assert T_action == 1 + + # TODO: delta_pos computing is updated. + if self.config.DELTA_POS_IS_VELOCITY: + raise ValueError + + # if self.bin_centers.device != action.device: + # self.bin_centers = self.bin_centers.to(action.device) + # bin_centers = self.bin_centers + # bin_centers = bin_centers.reshape(1, 1, 1, self.num_actions, 2).expand(B, T_action, N, self.num_actions, 2) + + action_expanded = action.reshape(B, T_action, N, 1).expand(B, T_action, N, 2).clone() + mask = (action_expanded == -1) | (action_expanded == START_ACTION) | (action_expanded == END_ACTION) + action_expanded[mask] = 0 + delta_vel_candidates = torch.gather(bin_centers, index=action_expanded, axis=1) # .squeeze(1) + # delta_vel_candidates[mask.squeeze(-2)] = 0 + # assert (current_valid_mask == (action!=-1)).all() + + unrotated_delta_vel = delta_vel_candidates + + reconstructed_pos = torch.clone(current_pos[..., :2]).reshape(B, 1, N, 2) + reconstructed_heading = torch.clone(current_heading).reshape(B, 1, N) + reconstructed_vel = torch.clone(current_vel).reshape(B, 1, N, 2) + + # Reconstruct position and heading: + delta_vel = rotate_bin_to_absolute_heading(unrotated_delta_vel, reconstructed_heading) + new_reconstructed_vel = delta_vel + reconstructed_vel + new_reconstructed_pos = new_reconstructed_vel * dt + reconstructed_pos + + flip_heading_accordingly = kwargs.get("flip_heading_accordingly", True) + new_reconstructed_heading = infer_heading( + current_pos=new_reconstructed_pos, + last_pos=reconstructed_pos, + last_heading=reconstructed_heading, + current_velocity=reconstructed_vel, + # init_pos=init_pos, + min_displacement=self.config.TOKENIZATION.MIN_DISPLACEMENT, + min_displacement_init=self.config.TOKENIZATION.MIN_DISPLACEMENT_INIT, + min_speed=self.config.TOKENIZATION.MIN_SPEED, + smooth_factor=self.config.TOKENIZATION.SMOOTH_FACTOR, + max_heading_diff=self.config.TOKENIZATION.MAX_HEADING_DIFF, + flip_heading=flip_heading_accordingly and dt < 0, + # ema_heading=ema_heading + ) + + # PZH: This is a dirty workaround! + if flip_wrong_heading: + wrong_heading_mask = utils.wrap_to_pi(new_reconstructed_heading - + reconstructed_heading).abs() > np.deg2rad(90) + wrong_heading_mask = wrong_heading_mask & current_valid_mask + # Flipped?? + new_reconstructed_heading[wrong_heading_mask] = utils.wrap_to_pi( + new_reconstructed_heading[wrong_heading_mask] + np.pi + ) + + new_reconstructed_heading = new_reconstructed_heading.reshape(B, 1, N) + + # Update reconstructed pos and vel + reconstructed_pos = new_reconstructed_pos + assert reconstructed_pos.shape == (B, 1, N, 2) + reconstructed_vel = new_reconstructed_vel + + reconstructed_pos = reconstructed_pos.reshape(B, N, 2) + new_reconstructed_heading = new_reconstructed_heading.reshape(B, N) + reconstructed_vel = reconstructed_vel.reshape(B, N, 2) + + # Masking + reconstructed_pos = (current_valid_mask.reshape(B, N, 1).expand(B, N, 2) * reconstructed_pos) + new_reconstructed_heading = (current_valid_mask.reshape(B, N) * new_reconstructed_heading) + + reconstructed_pos = reconstructed_pos.reshape(B, 1, N, 2) + new_reconstructed_heading = new_reconstructed_heading.reshape(B, 1, N) + reconstructed_vel = reconstructed_vel.reshape(B, 1, N, 2) + + relative_delta_pos = reconstructed_pos.reshape(B, 1, N, 2) - current_pos + relative_delta_pos = utils.rotate( + relative_delta_pos[..., 0], + relative_delta_pos[..., 1], + angle=-new_reconstructed_heading.reshape(B, 1, N) + np.pi + ) + # AID = 14 + # print( + # "POS: {}, HEAD: {}, VEL: {}, SPEED: {}, unrotated_delta_vel: {}, cur vel {}".format( + # reconstructed_pos[0, 0, AID].cpu().numpy(), + # reconstructed_heading[0, 0, AID], + # reconstructed_vel[0, 0, AID].norm(dim=-1).cpu().numpy(), + # reconstructed_vel.norm(dim=-1)[0, 0, AID], + # unrotated_delta_vel[0, 0, AID].cpu().numpy(), + # current_vel[0, 0, AID].norm(dim=-1) + # ) + # ) + + return dict( + pos=reconstructed_pos, + heading=new_reconstructed_heading, + vel=reconstructed_vel, + delta_pos=relative_delta_pos, + # trajectory_pos=rotated_selected_trajs_pos, + # trajectory_heading=rotated_selected_trajs_head + ) + + +class DeltaTokenizer(DeltaDeltaTokenizer): + def __init__(self, config): + BaseTokenizer.__init__(self, config) + + from scenestreamer.utils import REPO_ROOT + # import numpy as np + import pickle + + self.use_type_specific_bins = False + + with open(REPO_ROOT / config.DELTA_TOKENIZER_FILE_NAME, 'rb') as f: + veh = pickle.load(f) + all_trajs = veh["trajs"] + all_head = veh["heading"] + + self.num_actions = len(all_trajs) + + self.all_trajs = torch.from_numpy(all_trajs).float() + self.bin_centers = self.all_trajs[:, -1].reshape(1, self.num_actions, 1, 2) + + self.config = config + self.all_heading = torch.from_numpy(all_head).float() + + self.default_action = 0 # We set action 0 to be all zeros. + self.add_noise = config.TOKENIZATION.ADD_NOISE + + def get_motion_feature(self): + # m = torch.from_numpy(self.bin_centers_flat) + m = self.all_trajs[:, -1] # (1025, 2) + dist = m.norm(p=2, dim=-1).unsqueeze(-1) + heading = self.all_heading[:, -1] + return torch.cat([m, dist, heading], dim=-1) + + def _tokenize_a_step( + self, *, current_pos, current_heading, current_valid_mask, current_vel, next_pos, next_heading, next_valid_mask, + next_velocity, bin_centers, add_noise, agent_shape, **kwargs + ): + B, _, N, _ = current_pos.shape + valid_mask = torch.logical_and(current_valid_mask, next_valid_mask) + delta_pos = rotate_bin_to_absolute_heading(bin_centers, current_heading) + candidate_pos = delta_pos + current_pos + head = self.all_heading[:, -1].reshape(1, -1, 1).expand(B, -1, N) + candidate_heading = current_heading.reshape(B, 1, N) + head + candidate_pos = candidate_pos.reshape(B, -1, N, 2) + contour = utils.cal_polygon_contour_torch( + x=candidate_pos[..., 0], + y=candidate_pos[..., 1], + theta=candidate_heading, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + gt_contour = utils.cal_polygon_contour_torch( + x=next_pos[..., 0], + y=next_pos[..., 1], + theta=next_heading, + width=agent_shape[..., 1].reshape(B, 1, N), + length=agent_shape[..., 0].reshape(B, 1, N) + ) + error_pos = torch.norm(contour - gt_contour, dim=-1).mean(-1) + error = error_pos * valid_mask + + if add_noise: + raise ValueError() + sampled_action = nucleus_sampling(logits=1 / (error.permute(0, 2, 1) + 1e-6), p=0.95) + sampled_error = torch.gather(error, 1, sampled_action.unsqueeze(1)).squeeze(1) + best_action = sampled_action + + else: + # Pick the best bin with the least error: + min_result = error.min(dim=1) + best_action = min_result.indices + + best_action[~valid_mask.squeeze(1)] = -1 + + # Update reconstructed position and velocity according to the best action: + ind = best_action.reshape(B, 1, N, 1).expand(B, 1, N, 2).clone() + mask = ind == -1 + ind[mask] = self.default_action # Workaround the gather can't handle -1 + reconstructed_pos = torch.gather(candidate_pos, index=ind, dim=1) + reconstructed_pos[mask] = 0 + assert current_pos.shape == reconstructed_pos.shape + + if self.all_heading.device != reconstructed_pos.device: + self.all_heading = self.all_heading.to(reconstructed_pos.device) + all_heading = self.all_heading[:, -1].reshape(1, self.num_actions, 1).expand(B, -1, N) + ind = best_action.reshape(B, 1, N).clone() + ind[ind == -1] = self.default_action + reconstructed_heading = torch.gather(all_heading, index=ind, dim=1) + reconstructed_heading = reconstructed_heading + current_heading + reconstructed_heading[~valid_mask] = 0 + + reconstructed_vel = (reconstructed_pos - current_pos) / self.dt + reconstructed_vel[~valid_mask] = 0 + + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + relative_delta_pos[~valid_mask] = 0 + + best_action[~valid_mask.squeeze(1)] = -1 + assert (best_action[valid_mask.squeeze(1)] >= 0).all() + assert (best_action[~valid_mask.squeeze(1)] == -1).all() + # AID = 26 + # print("CUR {}, Recon Pos: {}, GT Pos {}, Cur Vel: {}, Vel: {}, CUR Head: {}, RECON Head: {}".format( + # current_pos[0,0,AID], + # reconstructed_pos[0,0,AID], + # next_pos[0,0,AID], + # current_vel[0,0,AID], + # reconstructed_vel[0,0,AID], + # current_heading[0,0,AID], + # reconstructed_heading[0,0,AID] + # )) + return dict( + action=best_action, + pos=reconstructed_pos, + heading=reconstructed_heading, + vel=reconstructed_vel, + mask=valid_mask, + delta_pos=relative_delta_pos, + ) + + def _detokenize_a_step(self, *, current_pos, current_heading, current_valid_mask, current_vel, action, **kwargs): + assert action.ndim == 3 + B, T_action, N = action.shape + + assert T_action == 1 + + bin_centers = self.bin_centers.to(action.device) + bin_centers = bin_centers.reshape(1, 1, 1, self.num_actions, 2).expand(B, T_action, N, self.num_actions, 2) + action_expanded = action.reshape(B, T_action, N, 1, 1).expand(B, T_action, N, 1, 2).clone() + mask = (action_expanded == -1) | (action_expanded == START_ACTION) + action_expanded[mask] = 0 + delta_pos_candidates = torch.gather(bin_centers, index=action_expanded, axis=3).squeeze(-2) + + reconstructed_pos = torch.clone(current_pos[..., :2]).reshape(B, 1, N, 2) + reconstructed_heading = torch.clone(current_heading).reshape(B, 1, N) + + # Reconstruct position and heading: + delta_pos = rotate_bin_to_absolute_heading(delta_pos_candidates, current_heading.reshape(B, 1, N)) + new_reconstructed_pos = delta_pos + reconstructed_pos + + if self.all_trajs.device != new_reconstructed_pos.device: + self.all_trajs = self.all_trajs.to(new_reconstructed_pos.device) + if self.all_heading.device != reconstructed_heading.device: + self.all_heading = self.all_heading.to(reconstructed_heading.device) + + all_trajs = self.all_trajs.reshape(1, self.num_actions, 1, 5, 2).expand(B, -1, N, -1, -1) + action_expanded_for_traj = action.reshape(B, T_action, N, 1, 1).expand(B, T_action, N, 5, 2).clone() + mask = (action_expanded_for_traj == -1) | (action_expanded_for_traj == START_ACTION) + action_expanded_for_traj[mask] = 0 + selected_trajs = torch.gather(all_trajs, index=action_expanded_for_traj, axis=1).squeeze(1) + + all_heading = self.all_heading.reshape(1, self.num_actions, 1, 5).expand(B, -1, N, -1) + action_expanded_for_traj = action.reshape(B, T_action, N, 1).expand(B, T_action, N, 5).clone() + mask = (action_expanded_for_traj == -1) | (action_expanded_for_traj == START_ACTION) + action_expanded_for_traj[mask] = 0 + selected_heading = torch.gather(all_heading, index=action_expanded_for_traj, axis=1).squeeze(1) + + reconstructed_heading = selected_heading[..., -1] + current_heading.reshape(B, N) + reconstructed_heading = reconstructed_heading.reshape(B, 1, N) + + new_reconstructed_vel = (new_reconstructed_pos - reconstructed_pos) / self.dt + reconstructed_vel = new_reconstructed_vel + + # Update reconstructed pos and vel + reconstructed_pos = new_reconstructed_pos + assert reconstructed_pos.shape == (B, 1, N, 2) + + reconstructed_pos = reconstructed_pos.reshape(B, N, 2) + reconstructed_heading = reconstructed_heading.reshape(B, N) + reconstructed_vel = reconstructed_vel.reshape(B, N, 2) + + # Masking + reconstructed_pos = (current_valid_mask.reshape(B, N, 1).expand(B, N, 2) * reconstructed_pos) + reconstructed_heading = (current_valid_mask.reshape(B, N) * reconstructed_heading) + + rotated_selected_trajs_pos = rotate_bin_to_absolute_heading( + selected_trajs[..., :2], + current_heading.reshape(B, N, 1).expand(B, N, 5) + ) + rotated_selected_trajs_pos = rotated_selected_trajs_pos + current_pos.reshape(B, N, 1, 2).expand(B, N, 5, 2) + rotated_selected_trajs_head = selected_heading + current_heading.reshape(B, N, 1).expand(B, N, 5) + + full_pos = torch.cat([current_pos.swapaxes(1, 2), rotated_selected_trajs_pos], dim=2) + rotated_selected_trajs_vel = (full_pos[:, :, 1:] - full_pos[:, :, :-1]) / (self.dt / self.num_skipped_steps) + + # Masking + rotated_selected_trajs_pos = ( + current_valid_mask.reshape(B, N, 1, 1).expand(B, N, 5, 2) * rotated_selected_trajs_pos + ) + rotated_selected_trajs_head = ( + current_valid_mask.reshape(B, N, 1).expand(B, N, 5) * rotated_selected_trajs_head + ) + rotated_selected_trajs_vel = ( + current_valid_mask.reshape(B, N, 1, 1).expand(B, N, 5, 2) * rotated_selected_trajs_vel + ) + # AID = 26 + # print( + # "ACTION: {}, Cur Pos: {}, POS: {}, HEAD: {}, VEL: {}, SPEED: {}, cur vel {}, cur mask {}".format( + # action[0, 0, AID].cpu().numpy(), + # current_pos[0, 0, AID].cpu().numpy(), + # reconstructed_pos[0, AID].cpu().numpy(), + # reconstructed_heading[0, AID], + # reconstructed_vel[0, AID].norm(dim=-1).cpu().numpy(), + # reconstructed_vel.norm(dim=-1)[0, AID], + # # unrotated_delta_vel[0, 0, AID].cpu().numpy(), + # current_vel[0, 0, AID].norm(dim=-1), + # current_valid_mask[0, 0, AID] + # ) + # ) + + relative_delta_pos = get_relative_velocity(reconstructed_vel, reconstructed_heading) + relative_delta_pos[~current_valid_mask.reshape(B, N)] = 0 + + return dict( + pos=reconstructed_pos, + heading=reconstructed_heading, + vel=reconstructed_vel, + delta_pos=relative_delta_pos, + # trajectory_pos=rotated_selected_trajs_pos, + reconstructed_position=[ + rotated_selected_trajs_pos[:, :, t].unsqueeze(1) for t in range(self.num_skipped_steps) + ], + reconstructed_heading=[ + rotated_selected_trajs_head[:, :, t].unsqueeze(1) for t in range(self.num_skipped_steps) + ], + reconstructed_velocity=[ + rotated_selected_trajs_vel[:, :, t].unsqueeze(1) for t in range(self.num_skipped_steps) + ], + ) diff --git a/scenestreamer/tokenization/precomputed_delta_delta_0309sol1.json b/scenestreamer/tokenization/precomputed_delta_delta_0309sol1.json new file mode 100644 index 0000000000000000000000000000000000000000..732f995876fb0d5254fbff8ad07a42ec3759375f --- /dev/null +++ b/scenestreamer/tokenization/precomputed_delta_delta_0309sol1.json @@ -0,0 +1 @@ +[[0.0, 0.0], [-0.20100729167461395, 0.042658522725105286], [-0.10357233136892319, 0.3789677619934082], [-0.03251465782523155, 0.047527577728033066], [0.04116864874958992, 1.2327746152877808], [-0.032443735748529434, 0.048020292073488235], [-0.032479096204042435, 0.04755720496177673], [-0.09009642153978348, 0.1263166218996048], [-0.03248092532157898, 0.04755433648824692], [-2.72074556350708, 0.3129623532295227], [-0.04366391524672508, 0.05442005768418312], [-0.018815165385603905, -0.6968593597412109], [0.012525119818747044, -0.9254288673400879], [-0.031056443229317665, 0.050685085356235504], [-0.2824927866458893, -0.03523877635598183], [-0.025265831500291824, 0.021728189662098885], [0.5176170468330383, 3.9028046131134033], [-0.6928201913833618, 0.2704077959060669], [0.0757674127817154, -0.2867741882801056], [-0.032369036227464676, 0.046994566917419434], [-0.032476115971803665, 0.04750480502843857], [-0.03246999531984329, 0.04749852046370506], [-0.03247496485710144, 0.047506727278232574], [-0.032474540174007416, 0.04750504344701767], [-0.032474640756845474, 0.04750223830342293], [-0.07899434119462967, 0.2333170771598816], [-0.032474011182785034, 0.047507889568805695], [-0.2586832046508789, -1.2880054712295532], [-0.09805482625961304, -1.8860783576965332], [-0.15976740419864655, 1.475889801979065], [-0.032473400235176086, 0.04751427471637726], [-0.9238372445106506, 1.2366536855697632], [-0.03406164050102234, 0.05306003615260124], [-0.10951928049325943, 0.07464853674173355], [7.680830478668213, -0.7040635347366333], [-0.1522618681192398, 0.7807302474975586], [-0.03247406333684921, 0.047503430396318436], [-0.041728466749191284, 0.09425577521324158], [-0.16747406125068665, 1.8190141916275024], [-0.04666057601571083, 0.11335548758506775], [-0.07174667716026306, 1.0457067489624023], [-0.14856556057929993, -0.11613085120916367], [0.003343712305650115, -0.3750915229320526], [-0.032472360879182816, 0.04751567542552948], [0.12460913509130478, 0.8626418709754944], [-0.032471463084220886, 0.047525230795145035], [-0.032470520585775375, 0.04753248021006584], [-0.03246995061635971, 0.04753703624010086], [-0.5059569478034973, -0.23372747004032135], [-0.0324707068502903, 0.04752534627914429], [-0.03246961161494255, 0.04753819853067398], [-0.03246978670358658, 0.04753333330154419], [-0.7948559522628784, -0.5127915143966675], [1.3322755098342896, 1.059065341949463], [-0.3675495386123657, 0.3209506571292877], [-0.03630969300866127, 0.049348413944244385], [5.583672523498535, -0.2747773826122284], [2.1563963890075684, 0.04438663274049759], [0.019743584096431732, 0.16897693276405334], [-0.04815419018268585, 0.3195549547672272], [-0.031539883464574814, 0.1349296271800995], [-0.15258513391017914, -0.5913354754447937], [-4.021631240844727, -0.6159193515777588], [-0.116386279463768, -1.0715410709381104], [0.1332291066646576, 0.26388630270957947], [-0.03246922791004181, 0.04753951355814934], [0.5159450769424438, -0.008527029305696487], [0.9686608910560608, -0.9318268299102783], [-1.7038300037384033, 2.1474685668945312], [-0.06062861904501915, -0.8612353801727295], [-0.032469410449266434, 0.04753553867340088], [-0.23973315954208374, -0.9838830232620239], [-0.03246929123997688, 0.04753703996539116], [-0.002825119998306036, -0.1925535649061203], [-1.0905084609985352, -1.044264793395996], [-0.032468825578689575, 0.04753953590989113], [-0.032468780875205994, 0.04753968119621277], [-2.1329143047332764, -1.4166032075881958], [5.851795196533203, 0.821886420249939], [-0.08804911375045776, -0.24739529192447662], [0.14984603226184845, -1.0016599893569946], [0.0011452793842181563, 0.24939745664596558], [-0.02904891036450863, -1.3754806518554688], [-0.032468944787979126, 0.04753631353378296], [-0.44368743896484375, -0.7775915265083313], [0.04404611513018608, -0.08605606108903885], [0.27502113580703735, 0.17970246076583862], [-0.03246787190437317, 0.04754769802093506], [3.5471396446228027, -0.2175915688276291], [4.872360706329346, -0.8988540172576904], [-0.03246750682592392, 0.0475505031645298], [-0.5730357766151428, -1.1196258068084717], [-0.4879508316516876, 0.018804769963026047], [0.5970646739006042, 0.739807665348053], [-0.03246859461069107, 0.04753582179546356], [-0.10738460719585419, -0.4411945641040802], [-0.03246768191456795, 0.047544464468955994], [-1.2472903728485107, 0.3417221009731293], [0.05304281413555145, 0.13606542348861694], [0.08667632937431335, 1.3969056606292725], [-0.11430513858795166, -0.024819355458021164], [-0.03246745467185974, 0.0475471131503582], [-6.162646293640137, -0.49631527066230774], [-0.2675101161003113, 1.2785576581954956], [-0.032467663288116455, 0.047544073313474655], [-0.7488318681716919, -1.6717889308929443], [-0.03426986560225487, -0.5714247226715088], [-0.014826134778559208, 2.4642765522003174], [-0.013251347467303276, 0.6020581722259521], [-0.03246716037392616, 0.04754907265305519], [0.20180058479309082, -0.37372949719429016], [0.2522398829460144, 0.029735036194324493], [2.964193820953369, 1.1256965398788452], [-0.016015049070119858, -0.0320463590323925], [-0.271144837141037, 0.5612770915031433], [-0.03246711567044258, 0.04754813387989998], [0.06684739142656326, -0.616408109664917], [-0.03246590867638588, 0.047563452273607254], [0.00153100467287004, 0.07023830711841583], [-0.3357997536659241, -0.2184734344482422], [-0.03246595710515976, 0.0475616455078125], [-0.029145628213882446, 0.1799589991569519], [-0.03246613219380379, 0.04755804315209389], [-0.7368745803833008, 1.9910902976989746], [-0.05263886973261833, -1.2401652336120605], [-0.3170037269592285, 0.1471141129732132], [0.2504652738571167, -1.0537148714065552], [-0.02756655029952526, 0.07044760882854462], [-0.010538161732256413, 0.4094603657722473], [-0.032465532422065735, 0.04756470397114754], [-0.03246523439884186, 0.047567810863256454], [6.585717678070068, 2.9219000339508057], [-0.42692703008651733, 1.0166434049606323], [-39.58283233642578, 0.4639221131801605], [0.39297083020210266, 0.5405029058456421], [-0.03246573358774185, 0.047561921179294586], [0.07182163000106812, 0.5589873194694519], [0.47078368067741394, -0.7922490835189819], [0.27705028653144836, 1.8382720947265625], [4.495794296264648, 0.30848559737205505], [-0.06578400731086731, 1.3110488653182983], [3.2361834049224854, -0.9126150608062744], [-0.16824449598789215, 0.0008915771613828838], [-0.025684362277388573, 1.430021047592163], [0.17217093706130981, -1.742404580116272], [-0.03246529772877693, 0.047566723078489304], [-0.03246545046567917, 0.04756418243050575], [-0.03246455267071724, 0.0475754588842392], [3.473351240158081, 2.373555898666382], [-0.09114215523004532, 0.5255845189094543], [0.07596435397863388, -0.8240886330604553], [-0.22803476452827454, -2.1909186840057373], [0.12634555995464325, 0.06238046661019325], [0.36864370107650757, -1.2992624044418335], [-0.7427727580070496, -0.13358095288276672], [0.13100215792655945, -0.02707989141345024], [-0.012279110960662365, 1.5988585948944092], [0.8070716857910156, -1.7627655267715454], [0.319932222366333, -0.2534741759300232], [-0.2033846080303192, -0.24070513248443604], [0.611739456653595, -0.33292993903160095], [-0.060580309480428696, 0.46196600794792175], [0.2935931980609894, 1.0393402576446533], [1.225479006767273, -0.14826539158821106], [0.29431724548339844, -0.47038769721984863], [-0.11001357436180115, -1.643731713294983], [-0.03246476128697395, 0.04757201299071312], [-0.1588619202375412, -5.449110507965088], [-0.03457324206829071, 0.04869924858212471], [0.24736911058425903, -0.8449870347976685], [0.0024610282853245735, 0.10476410388946533], [-0.03457324206829071, 0.04869924858212471], [1.7988827228546143, -0.9625368118286133], [-0.03457324206829071, 0.04869924858212471], [-0.24044057726860046, 0.39968711137771606], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.1775430142879486, -0.7644380331039429], [-0.1514424830675125, 0.12461462616920471], [0.21454034745693207, -2.3073508739471436], [-0.03457324206829071, 0.04869924858212471], [0.01732325181365013, -1.0291292667388916], [6.3890557289123535, -0.12087410688400269], [-0.08859125524759293, 0.17459750175476074], [0.0901530459523201, 0.4255724251270294], [0.0322876013815403, 0.7286819219589233], [0.060046855360269547, 1.0907224416732788], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.6045233607292175, 0.6308645009994507], [0.11795935779809952, -0.5127959847450256], [-0.0978274941444397, 0.020011678338050842], [0.16658911108970642, -0.1597650945186615], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [0.47626015543937683, -0.6018805503845215], [0.10419448465108871, 0.1760389357805252], [-0.06827446073293686, 0.04632747173309326], [-0.4321858882904053, -0.5257142186164856], [-0.06675803661346436, 0.08172103762626648], [-0.050145167857408524, -0.10444431006908417], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [0.056639574468135834, 0.32279276847839355], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [0.3795260488986969, 0.3290144205093384], [-0.6496068239212036, -0.6628745794296265], [-0.29441529512405396, -1.4926807880401611], [-0.2198297083377838, 0.23265188932418823], [-0.03457324206829071, 0.04869924858212471], [0.20859821140766144, 0.43978577852249146], [-0.08401114493608475, 0.6412590146064758], [0.4859842360019684, -1.0296193361282349], [-0.03457324206829071, 0.04869924858212471], [2.6716673374176025, -1.7013599872589111], [-1.156517505645752, -0.38881316781044006], [-1.7699925899505615, -0.058728862553834915], [0.4163183867931366, 1.4228426218032837], [-0.03457324206829071, 0.04869924858212471], [0.04336346313357353, 0.011278250254690647], [-0.03457324206829071, 0.04869924858212471], [-0.14804908633232117, 0.28097864985466003], [-0.03457324206829071, 0.04869924858212471], [-0.45895466208457947, 0.20197166502475739], [0.8043020963668823, 0.3152182102203369], [0.1902092546224594, -0.7035972476005554], [-0.03457324206829071, 0.04869924858212471], [-0.049262918531894684, 0.0676737129688263], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.22960926592350006, 0.10629866272211075], [-0.1768663376569748, 9.026327133178711], [-0.27732327580451965, -0.41121906042099], [-0.04312817007303238, -3.0872185230255127], [-0.0862421914935112, 1.1698386669158936], [-0.03457324206829071, 0.04869924858212471], [-0.035073112696409225, -1.5059555768966675], [-0.28570789098739624, -0.6776381134986877], [-0.5391536355018616, -0.4054335057735443], [-0.03457324206829071, 0.04869924858212471], [0.05825640261173248, -1.1526819467544556], [-0.03457324206829071, 0.04869924858212471], [0.06421855837106705, 0.07584027200937271], [0.14679105579853058, -1.460405707359314], [-0.05782574415206909, 0.9177041053771973], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.03457324206829071, 0.04869924858212471], [-0.38127878308296204, 0.061903826892375946], [3.3697235584259033, 7.92602014541626], [0.7607005834579468, 0.24012523889541626], [-43.059993743896484, -13.406118392944336], [46.00641632080078, -10.580183982849121], [-0.1794726550579071, -0.8819662928581238], [0.0339733250439167, 0.9806593060493469], [-29.785377502441406, 12.521777153015137], [-0.18114283680915833, -23.23348617553711], [0.19215773046016693, 2.4312076568603516], [19.205677032470703, -4.436672687530518], [0.4614512026309967, -0.005844674073159695], [-9.867729187011719, 0.5327489972114563], [5.755468368530273, 13.866148948669434], [3.338019609451294, 1.0778652429580688], [16.62271499633789, 13.956756591796875], [-0.013572365045547485, 7.363010883331299], [-1.0980985164642334, 0.33860766887664795], [0.9505332708358765, -0.6947017908096313], [6.102337837219238, -3.6540467739105225], [-29.60430335998535, -24.984054565429688], [-2.4665331840515137, -7.534351348876953], [38.141136169433594, 13.490243911743164], [1.2444486618041992, -0.06270994991064072], [-1.4514052867889404, -0.9631136655807495], [18.199981689453125, -30.14586067199707], [-0.2954864501953125, -4.156425952911377], [-21.202800750732422, 22.68663787841797], [-0.17634055018424988, 0.3171546757221222], [14.944695472717285, 27.136547088623047], [0.11301840096712112, -0.5768436193466187], [-5.2295708656311035, 2.778974771499634], [-46.13640213012695, 14.125810623168945], [-26.4930419921875, 1.2081525325775146], [-19.466840744018555, -17.379236221313477], [0.4834428131580353, 0.6852759122848511], [-0.19156503677368164, -0.1647273749113083], [-0.015268741175532341, -1.596873164176941], [5.483791351318359, 3.52213978767395], [31.893966674804688, -2.5440289974212646], [-1.6269553899765015, 0.9746137261390686], [-2.3411922454833984, 3.7954421043395996], [-0.655098021030426, -0.5895565152168274], [0.17256774008274078, 12.377542495727539], [37.160701751708984, -33.179283142089844], [-0.03165079653263092, 5.46018648147583], [0.7204809188842773, -0.22051844000816345], [-14.561258316040039, 13.265552520751953], [-57.958221435546875, -0.13014210760593414], [3.2210190296173096, -1.0914171934127808], [12.069998741149902, -17.847551345825195], [-20.903343200683594, 39.96489715576172], [-5.611117839813232, -0.7425530552864075], [0.9887650609016418, -1.5799787044525146], [27.06534194946289, 42.0356559753418], [-9.969817161560059, -49.274169921875], [71.74842834472656, 0.7838032841682434], [10.085864067077637, 0.1604057252407074], [16.96276092529297, 3.4473965167999268], [1.516270399093628, -13.976123809814453], [-0.7259449362754822, -0.03244495764374733], [-57.846778869628906, 33.32236099243164], [-1.0805821418762207, 2.1229443550109863], [1.2002519369125366, 0.6836123466491699], [-0.6835423111915588, 0.9950168132781982], [0.3347822427749634, 0.2756745517253876], [2.041701316833496, -3.0156655311584473], [1.7024462223052979, 1.8217071294784546], [26.04505157470703, 8.308212280273438], [-8.118915557861328, 5.966681003570557], [0.06848184764385223, 0.18277591466903687], [-0.12702995538711548, 0.07366414368152618], [61.40410614013672, -27.93779945373535], [0.20192646980285645, -0.028560495004057884], [0.31291496753692627, -0.2743363678455353], [0.3514486253261566, -1.0223441123962402], [-0.15868711471557617, 1.570615291595459], [-0.2162129282951355, -0.4704780578613281], [-2.8180463314056396, -0.8841894865036011], [-69.3528060913086, -24.045650482177734], [-3.037479877471924, 0.9345249533653259], [-4.809382915496826, 12.698028564453125], [1.768396019935608, -0.7398825287818909], [8.628571510314941, 6.709376811981201], [23.440414428710938, -14.657163619995117], [-39.34572982788086, 0.21909211575984955], [0.47317978739738464, -0.5522514581680298], [-0.252657949924469, 0.6587550044059753], [-0.0032274879049509764, 15.579906463623047], [1.176527500152588, 31.579349517822266], [5.804084777832031, 0.006870333105325699], [-0.047830577939748764, 3.738269329071045], [-1.8760923147201538, 0.030339431017637253], [-7.958590030670166, -4.968172550201416], [10.630573272705078, -8.216493606567383], [0.5866058468818665, 1.367844581604004], [49.504493713378906, 5.1417646408081055], [28.486900329589844, 23.534847259521484], [2.6219332218170166, -6.9159932136535645], [-12.289016723632812, -30.020212173461914], [-0.6547898650169373, -1.2186588048934937], [0.008567409589886665, -0.053413793444633484], [2.075314998626709, 0.2586975395679474], [0.0862123966217041, 0.504858672618866], [-9.253722190856934, 26.1356143951416], [-0.5637832283973694, 0.37319669127464294], [-29.386123657226562, -40.56601333618164], [34.24925994873047, -17.715116500854492], [-17.068954467773438, -5.602968692779541], [0.07920746505260468, 47.26346206665039], [-0.12064173817634583, 9.70920181274414], [0.24159158766269684, -2.5233798027038574], [-9.507444381713867, -13.868274688720703], [-3.3085031509399414, -3.3652963638305664], [-0.45821768045425415, -0.2501350939273834], [14.577071189880371, -48.624813079833984], [-3.539350748062134, 7.219913482666016], [2.209853410720825, 4.120955467224121], [-1.0627635717391968, -2.2018277645111084], [53.40488052368164, 24.91563606262207], [-44.857872009277344, -29.482202529907227], [3.3452959060668945, -34.45347213745117], [-29.309476852416992, -10.144152641296387], [-34.17702865600586, 27.601341247558594], [0.1896757185459137, 19.895605087280273], [-1.1341708898544312, -0.2760099768638611], [-17.481443405151367, 5.003342151641846], [0.03948484733700752, -0.28262242674827576], [0.2994389832019806, -0.13576234877109528], [-5.121236324310303, 242.8358612060547], [-1.2052065134048462, 0.1220259964466095], [143.41397094726562, 54.326263427734375], [-2.279202461242676, 58.25896453857422], [-156.9510498046875, -20.015668869018555], [-0.014055541716516018, 0.20731759071350098], [1.374221682548523, 0.32440829277038574], [-158.80003356933594, 103.67757415771484], [231.13705444335938, 246.80514526367188], [3.3009347915649414, 0.3031063675880432], [11.545523643493652, 485.6884460449219], [-0.21960796415805817, -0.25955870747566223], [-0.29010072350502014, 0.44607579708099365], [110.25254821777344, -125.1540298461914], [0.15082783997058868, -0.854674756526947], [-3.612292766571045, -0.08442164212465286], [-0.03933897614479065, 119.9073486328125], [51.067264556884766, -29.74639129638672], [1.021894097328186, -0.23216229677200317], [-221.98971557617188, 219.03306579589844], [151.25474548339844, 18.918107986450195], [-0.6009636521339417, 0.3186725974082947], [0.15239161252975464, 0.5861682891845703], [-1.0521501302719116, -0.6039580702781677], [60.218780517578125, 226.00115966796875], [-2.592090129852295, -0.4216771423816681], [57.644866943359375, 89.176513671875], [2.1533901691436768, -0.16497808694839478], [0.12487722933292389, 2.6399991512298584], [0.7277470231056213, -0.4369269013404846], [218.32110595703125, -2.9663777351379395], [0.17153489589691162, 0.34051862359046936], [0.029759306460618973, 81.8800277709961], [-0.677707850933075, -0.13221202790737152], [-1.3753873109817505, 0.6350052356719971], [0.6520480513572693, 0.18742793798446655], [-2.1982460021972656, 0.20207619667053223], [10.713054656982422, -60.05792236328125], [-0.011438802815973759, 0.05536484345793724], [-7.087655544281006, 35.39435958862305], [-0.14037960767745972, 0.06426693499088287], [-0.41726285219192505, -0.42341873049736023], [-0.008601522073149681, -0.30840763449668884], [0.7162342667579651, 0.951939046382904], [-0.10852634906768799, -2.990787982940674], [1.0739291906356812, -1.1622544527053833], [76.726318359375, -12.038759231567383], [-0.4220333397388458, 1.7053868770599365], [-0.3260860741138458, 0.05285017564892769], [2.681631326675415, 1.0795985460281372], [0.39646169543266296, -0.5450063943862915], [0.7176943421363831, -0.11270690709352493], [-18.498249053955078, 74.17709350585938], [-4.752356052398682, 0.3969094753265381], [-0.23242498934268951, -1.1457115411758423], [-0.923009991645813, -1.2249879837036133], [-1.694688320159912, -0.8771937489509583], [1.5060322284698486, -0.3287917375564575], [3.917891025543213, -0.2009863257408142], [-0.21781297028064728, 1.0726827383041382], [-0.16262681782245636, -0.4761689305305481], [-0.04999591037631035, 0.40887799859046936], [1.1025432348251343, -0.6070537567138672], [2.819150686264038, -0.36071816086769104], [0.2909708023071289, 0.04532775282859802], [-2.200861930847168, 0.8971433639526367], [-2.686286211013794, 11.987119674682617], [-0.3789004385471344, 0.23194603621959686], [0.23331332206726074, -0.3428000509738922], [4.541215896606445, 0.6560301780700684], [-0.4912729859352112, 0.659020185470581], [-0.9360176920890808, -13.462583541870117], [0.48837023973464966, -0.011247940361499786], [1.4806723594665527, 0.8927144408226013], [12.967082023620605, -11.308036804199219], [2.038822650909424, 0.40921106934547424], [-1.4036434888839722, -0.24899663031101227], [-0.7053748965263367, -0.40656623244285583], [-22.698444366455078, 255.7769775390625], [-0.039389874786138535, 219.82040405273438], [0.910959780216217, 0.4519810676574707], [1.8803032636642456, -0.6991255283355713], [0.6653987169265747, -1.9281164407730103], [-0.8305219411849976, 0.10313118994235992], [0.08756177872419357, -0.5522916913032532], [-0.5869425535202026, -0.7659205794334412], [-0.06009300798177719, -0.10724862664937973], [-25.87381935119629, 113.78138732910156], [-3.4191091060638428, 0.7751022577285767], [14.404867172241211, 35.077327728271484], [-0.4336215555667877, -0.15804176032543182], [-1.919376015663147, -0.22628116607666016], [-0.990425169467926, -0.17608724534511566], [1.1910101175308228, 1.8618247509002686], [-0.23502139747142792, -0.08100711554288864], [-0.03936442732810974, 142.44595336914062], [-2.307331085205078, -2.135190010070801], [2.4060773849487305, -1.807517170906067], [0.2084456831216812, 1.4578547477722168], [-0.7630151510238647, 0.9796803593635559], [-26.243202209472656, 11.536364555358887], [-0.12914270162582397, 0.6921336054801941], [1.7368009090423584, 0.021521244198083878], [2.599273920059204, 0.1724216639995575], [-1.288973331451416, 1.470885157585144], [0.3093835115432739, -1.2690668106079102], [-0.5429993867874146, 0.05969151109457016], [0.474513977766037, 0.5240129828453064], [0.11390524357557297, -0.17197294533252716], [-1.6558159589767456, 0.18133489787578583], [-0.20862145721912384, -0.7397524118423462], [0.9606093168258667, 0.082832932472229], [-0.039389874786138535, 102.56448364257812], [0.6044726371765137, -0.8324936628341675], [0.7135443091392517, 5.387773513793945], [-0.18325303494930267, 0.23097573220729828], [1.30592679977417, -0.029710853472352028], [0.1445894092321396, 0.14808151125907898], [-0.9142131209373474, 0.4645828604698181], [0.4867987632751465, -0.2587350904941559], [0.12037110328674316, -0.015903029590845108], [-2.824378490447998, 0.2295389175415039], [0.38369816541671753, 0.23286356031894684], [-0.12303656339645386, -6.188194751739502], [-1.8092983961105347, 2.901294469833374], [0.19091881811618805, 0.9187283515930176], [-0.31483814120292664, -1.7853933572769165]] \ No newline at end of file diff --git a/scenestreamer/tokenization/precomputed_delta_delta_0309sol6c.json b/scenestreamer/tokenization/precomputed_delta_delta_0309sol6c.json new file mode 100644 index 0000000000000000000000000000000000000000..709cbca5e9c99057fa9f1daf202fe3b17813fa5f --- /dev/null +++ b/scenestreamer/tokenization/precomputed_delta_delta_0309sol6c.json @@ -0,0 +1 @@ +[[0.00026434984058142097, -0.000998910516500473], [0.7088356018066406, 30.805255889892578], [-5.777011871337891, -12.431294441223145], [17.909698486328125, -13.271759033203125], [0.025514010339975357, 11.096918106079102], [-10.464719772338867, 3.839756727218628], [17.860837936401367, 10.371227264404297], [10.969579696655273, 0.049477338790893555], [-1.2718284130096436, -16.058727264404297], [4.841934680938721, -8.027181625366211], [-1.3089356422424316, -5.448477268218994], [1.768057942390442, 2.684217691421509], [-0.8315563201904297, 46.590267181396484], [2.819387435913086, -0.11693938076496124], [19.22503662109375, -29.132617950439453], [-18.495933532714844, 3.1892361640930176], [-2.7526257038116455, 0.5103709697723389], [9.290264129638672, 9.476231575012207], [0.6113970279693604, -23.84191131591797], [-10.667922973632812, -6.320938587188721], [0.3580303490161896, 20.675373077392578], [-0.8873229622840881, -2.7437920570373535], [-0.12489937618374825, 7.0042500495910645], [1.0458880066871643, -1.5643610954284668], [-0.032846637070178986, 2.9840924739837646], [17.264467239379883, -6.670576572418213], [-19.484058380126953, 12.266918182373047], [-13.248194694519043, -10.719786643981934], [-0.23360124230384827, -9.55393123626709], [-6.789799690246582, -18.09844398498535], [-14.392882347106934, -2.5093016624450684], [15.844526290893555, 5.997669219970703], [5.749706268310547, -3.9098668098449707], [0.07039797306060791, -19.27470588684082], [-3.600513458251953, -28.58978843688965], [-0.6709483861923218, 26.015283584594727], [9.831847190856934, -7.0401105880737305], [-0.8947106401125591, -0.3049529492855072], [11.532182693481445, 4.861447334289551], [6.153838157653809, 0.7267402410507202], [-2.0571296215057373, 4.97422456741333], [0.03135548532009125, 1.4810140132904053], [2.5953657627105713, -3.102954864501953], [17.7473201751709, 15.268266677856445], [-4.002257823944092, -5.40598726272583], [-14.1517972946167, 1.3648643493652344], [-9.587149620056152, -0.5145817995071411], [-7.814518928527832, 5.397807598114014], [2.8520805835723877, 7.944614410400391], [0.03472340106964111, 13.625947952270508], [0.025864727795124054, -2.5670385360717773], [15.013154983520508, -9.706660270690918], [21.258480072021484, -6.71762752532959], [-6.148342132568359, -8.51479721069336], [-2.8976309299468994, -1.8691421747207642], [1.42465078830719, -0.1784999668598175], [0.9317193031311035, -5.693777084350586], [0.1503884419798851, -0.9450680315494537], [3.671100616455078, -11.632243156433105], [3.6645703315734863, -18.25322914123535], [-11.532023429870605, -14.580163955688477], [11.113574028015137, -3.134721279144287], [-0.038326919078826904, 34.34180450439453], [8.197751998901367, 5.3842549324035645], [0.3465169370174408, 9.029275894165039], [-21.835805892944336, 4.559760093688965], [-0.11084352433681488, 4.518217325210571], [-4.60203742980957, 1.724625825881958], [8.89199447631836, -10.270989418029785], [2.1572821140289307, 0.7659869194030762], [6.40976095199585, 9.303670883178711], [9.393778800964355, 2.757859230041504], [-6.939807891845703, -3.9466092586517334], [3.007993698120117, -1.3977612257003784], [-1.0034838914871216, 0.8983618021011353], [-10.551980972290039, 6.664337158203125], [-1.4964711666107178, -7.655055522918701], [-1.8380585312843323, 2.089560031890869], [0.3693857192993164, 23.10028839111328], [3.0855185985565186, 5.762034893035889], [-0.08186604827642441, -3.953968048095703], [-0.3696967661380768, -13.576131820678711], [0.1636839658021927, 0.7840569019317627], [2.064678907394409, -7.3919854164123535], [8.211236953735352, -2.9622578620910645], [2.045382261276245, 11.976278305053711], [3.475128173828125, 1.2488161325454712], [5.055067539215088, 3.007617950439453], [-8.769730567932129, -7.0500807762146], [-0.2934892177581787, -1.6964106559753418], [-5.369734764099121, 3.970533609390259], [-0.2046348750591278, -11.621964454650879], [0.06908255815505981, 18.45418357849121], [-4.108065128326416, -1.171015977859497], [1.255661964416504, 0.7281436721483866], [-1.4148463010787964, -0.7815871834754944], [-3.922402858734131, -11.969165802001953], [3.4497861862182617, -4.734394550323486], [4.47063684463501, -0.33235982060432434], [0.46180427074432373, 5.477315902709961], [-15.58374309539795, 2.250784397125244], [3.7186248302459717, -2.250965118408203], [-1.817298412322998, -3.7843410968780518], [2.8357276916503906, 4.172008991241455], [1.0581730604171753, -2.6844804286956787], [0.5579033295313518, -0.20213976005713147], [-0.9036059379577637, 14.849746704101562], [-0.9420291781425476, 3.073174476623535], [0.5325174927711487, -7.939874172210693], [-2.3401260375976562, 0.36932800710201263], [-0.29967691004276276, 0.375987246632576], [4.796463966369629, 0.9992530345916748], [2.432006359100342, -1.1541025638580322], [-2.097280502319336, -1.2326164841651917], [1.4243375062942505, 4.7577595710754395], [-0.36873580515384674, -0.33795298635959625], [0.02718379472692807, 1.9641128381093342], [0.20917761623859404, -0.2631525844335556], [-0.9786620438098907, -1.8235825300216675], [-1.590072214603424, 0.5225403606891632], [-0.4389476776123047, -7.212259769439697], [0.7351102530956268, 0.4191250056028366], [0.9563969671726227, 1.6177560091018677], [-0.09544921666383743, 12.446911811828613], [-3.6595959663391113, 1.7627038955688477], [-1.0172446966171265, 2.0921194553375244], [0.013215870906909302, 0.3507876048485438], [-1.9570096731185913, -2.2502872943878174], [0.08784899255260825, -2.112941265106201], [-2.512773036956787, -0.41214390099048615], [2.123035192489624, -5.082457065582275], [-0.39609718322753906, -18.007122039794922], [0.07971013337373734, 3.919536590576172], [-7.805483341217041, 4.260200500488281], [1.9679993391036987, -0.03095754235982895], [0.05228213407099247, 2.353350520133972], [-3.1014840602874756, -0.14624424278736115], [-1.9640202522277832, 1.0646508038043976], [2.145549774169922, 1.7508795261383057], [1.351671576499939, -3.609724998474121], [-1.1323267221450806, 0.2526739686727524], [2.0865530967712402, 4.003533840179443], [-1.079329013824463, -1.0985249280929565], [0.5484986901283264, 6.211134433746338], [-0.035287268459796906, -0.35089063218661715], [-0.24326792359352112, 8.543231964111328], [2.648028612136841, -1.994925618171692], [-0.5965510308742523, -1.2531514167785645], [1.0900849103927612, -0.4649277925491333], [0.11765208033223946, -1.4991998672485352], [-3.192969560623169, 1.1137669086456299], [4.38342809677124, -1.6560075283050537], [1.777410348256429, -0.39215681950251263], [3.652460813522339, 0.07307739555835724], [3.716505765914917, -0.8460269570350647], [0.22655007243156433, -5.582303047180176], [-0.7021942734718323, 10.987932205200195], [-1.618802785873413, 0.02314339578151703], [0.5231624096632004, -1.843578815460205], [0.3007906973361969, 9.708359718322754], [2.985959053039551, 0.4782446622848511], [-0.12509731948375702, -4.593304634094238], [-0.10131272152066231, 0.0954496081918478], [-3.3905415534973145, -5.509091377258301], [-0.27076513320207596, 1.1984248757362366], [-0.3710556924343109, 1.7473304271697998], [-1.060715913772583, 1.4463436603546143], [1.967591404914856, -1.4343783855438232], [-3.6140782833099365, -0.6701768040657043], [-1.655011534690857, -3.169455051422119], [-3.7421107292175293, 0.14742609858512878], [-0.06340217962861061, -0.8491809666156769], [1.336456537246704, -1.023759365081787], [1.313909649848938, -2.311392903327942], [-2.8513143062591553, -1.1023764610290527], [-4.463316440582275, 1.062060832977295], [3.2368760108947754, -0.4539453983306885], [0.36593571305274963, -1.1403228640556335], [0.6489569544792175, 2.446542978286743], [-0.34212973713874817, 3.7679853439331055], [0.010558314621448517, 3.452025890350342], [0.15346696972846985, 5.068790435791016], [-0.6854808330535889, 2.539550542831421], [0.7393325567245483, -0.5700272023677826], [-0.03709305077791214, -3.3814358711242676], [3.108727216720581, 1.197811484336853], [0.16780863516032696, 0.07637290237471461], [0.5150232911109924, 1.3177307844161987], [2.668258547782898, -0.49722930788993835], [0.5159062743186951, 4.1279401779174805], [-0.5216069519519806, 0.05771670117974281], [-0.35782304406166077, -2.0899733304977417], [0.023735759779810905, 7.56247091293335], [-1.6139593124389648, -1.7015652656555176], [0.4052180349826813, 3.099813938140869], [2.4495365619659424, 0.25175392627716064], [-0.3105985124905904, -0.9099085529645283], [-1.5158457159996033, -0.49229639768600464], [2.844167709350586, -0.8807631731033325], [-0.48166853189468384, 0.5930574735005696], [0.7584388852119446, -2.9007811546325684], [-0.2664032514606203, -0.031035952270030975], [0.08538249731063843, -0.5944119155406952], [-0.05616738125681878, 0.7938180685043335], [1.7412750720977783, 0.9102497696876526], [0.7408931255340576, -0.8843732476234436], [0.32595177491505944, -0.546072373787562], [-0.8358883261680603, -0.8234978020191193], [0.05100584030151367, 9.401427268981934], [0.5495907664299011, -2.52315616607666], [0.43536247313022614, 0.8562805652618408], [0.07540751248598099, -2.997840404510498], [1.9963009357452393, -0.7548007965087891], [-0.2974889874458313, 2.378384590148926], [0.1583878993988037, 5.990407943725586], [0.051415344607084995, -0.11701767034828664], [1.159417986869812, 0.07943920604884624], [-3.752032995223999, -1.3550511598587036], [0.02988374757114798, 0.17700461404664175], [0.27247194200754166, 0.5200352668762207], [3.3890960216522217, -1.414064645767212], [-1.851188063621521, 4.659206867218018], [1.4570242166519165, 2.7975308895111084], [-0.5394153594970703, -2.8558785915374756], [0.2841554284095764, 0.28442053496837616], [-0.43083344399929047, -0.554456815123558], [-0.12671937296787897, -0.18722273409366608], [0.6074647903442383, 0.14373890310525894], [2.25974440574646, -0.17834892868995667], [-0.7725498676300049, 0.1015239879488945], [-2.9393248558044434, 0.9692080020904541], [-0.03217330668121576, 2.6167261600494385], [0.23448888957500458, -4.109840393066406], [-1.643755316734314, -0.9965870380401611], [0.05555347508440415, 1.1017545859018962], [-1.0568946599960327, 1.1365869045257568], [1.0798102617263794, 0.9997133612632751], [-0.13185426220297813, -0.6221873164176941], [-0.008557954430580134, 0.5211360156536102], [0.008306674659252167, -1.7881015539169312], [-0.5188526511192322, 0.9365230202674866], [1.5034211874008179, 0.1979008913040161], [0.7683070302009583, -1.184828758239746], [-0.200625941157341, -1.359721839427948], [-1.1721680164337158, 3.1513113975524902], [-0.002296289739509419, 0.009542482553256876], [0.8593295216560364, -0.17529422044754028], [-1.8782658576965332, -0.08231405913829803], [-0.1809538513422012, 0.2373181939125061], [0.0006525677939255986, -1.165860891342163], [0.28442839682102206, -0.05784455277025699], [1.3488988876342773, -0.575585663318634], [-1.5432809591293335, -3.9041175842285156], [-1.1695648431777954, -1.3512797355651855], [1.8939968347549438, 0.24449698626995087], [2.725404977798462, 0.2826811969280243]] \ No newline at end of file diff --git a/scenestreamer/tokenization/reltok.py b/scenestreamer/tokenization/reltok.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bf4f98b5cc71d7c8789d389e44be41d4ed941f --- /dev/null +++ b/scenestreamer/tokenization/reltok.py @@ -0,0 +1,729 @@ +# from scenestreamer.tokenization.motion_tokenizers import BaseTokenizer + +from typing import * + +import lightning.pytorch as pl +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.cluster.vq import kmeans2 +from torch.nn.modules.transformer import TransformerEncoderLayer as NativeTransformerEncoderLayer + +from scenestreamer.models.layers import common_layers +from scenestreamer.utils import lr_schedule, wrap_to_pi + +# from scenestreamer.tokenization.tokenizers import DeltaTokenizer, DeltaDeltaTokenizer + +RELATION_DIM = 2 + + +def masked_average(tensor, mask, dim): + """ + Compute the average of tensor along the specified dimension, ignoring masked elements. + """ + assert tensor.shape == mask.shape + count = mask.sum(dim=dim) + count = torch.max(count, torch.ones_like(count)) + return (tensor * mask).sum(dim=dim) / count + + +def get_mask(mask): + """ + input mask is in shape (B, N), we need to prepare a pairwise mask in shape (B, N, N). + It's not correct to naively expand the mask. We need to maintain the symmetry of the mask. + """ + B, N = mask.shape + mask = mask.unsqueeze(1).expand(B, N, N) + mask = mask & mask.transpose(1, 2) + return mask + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1, ) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +def pairwise_relative_diff(positions): + """ + Compute pairwise relative diffs for a batch of objects. + For the ouput [b, i, j, :], it means the relative differences of [b, j] - [b, i], + which is the pos of j in i's coordinate system. + + Parameters: + - positions: A PyTorch tensor of shape (B, N, 2) + + Returns: + - A PyTorch tensor of shape (B, N, N, 2) containing pairwise relative positions. + """ + + # Expand dimensions to get tensors of shapes (B, N, 1, ...) and (B, 1, N, ...) + positions_expanded_a = positions.unsqueeze(2) # Shape: (B, N, 1, ...) + positions_expanded_b = positions.unsqueeze(1) # Shape: (B, 1, N, ...) + + # Compute the pairwise relative positions by subtraction + relative_positions = positions_expanded_b - positions_expanded_a # Shape: (B, N, N, ...) + + return relative_positions + + +# +# def prepro(data, num_samples=1024): +# B, T, N, _ = data["encoder/agent_position"].shape +# +# # pos = data["encoder/agent_position"][..., :2].reshape(B, T * N, -1) +# # vel = data["encoder/agent_velocity"].reshape(B, T * N, -1) +# # head = data["encoder/agent_heading"].reshape(B, T * N, -1) +# # time = torch.arange(T).reshape(1, T, 1).expand(B, T, N).to(pos.device).reshape(B, T * N, -1) +# # +# # mask = data["encoder/agent_valid_mask"].reshape(B, T * N) +# +# T = 1 +# pos = data["encoder/agent_position"][..., :2][:, 10].reshape(B, T * N, -1) +# vel = data["encoder/agent_velocity"][:, 10].reshape(B, T * N, -1) +# head = data["encoder/agent_heading"][:, 10].reshape(B, T * N, -1) +# time = torch.arange(1).reshape(1, 1, 1).expand(B, T, N).to(pos.device).reshape(B, T * N, -1) +# mask = data["encoder/agent_valid_mask"][:, 10].reshape(B, T * N) +# +# if num_samples is not None: +# indices = torch.randint(high=T * N, size=(B, num_samples, 1)).to(pos.device) # (B, 1024, 1) +# pos = torch.gather(pos, index=indices.expand(B, num_samples, 2), dim=1) +# vel = torch.gather(vel, index=indices.expand(B, num_samples, 2), dim=1) +# head = torch.gather(head, index=indices.expand(B, num_samples, 1), dim=1) +# time = torch.gather(time, index=indices.expand(B, num_samples, 1), dim=1) +# mask = torch.gather(mask, index=indices.reshape(B, num_samples), dim=1) +# +# # compute pairwise relative position: (B, N, N, D) +# rel_pos = pairwise_relative_diff(pos) +# rel_vel = pairwise_relative_diff(vel) +# rel_head = wrap_to_pi(pairwise_relative_diff(head)) +# rel_time = pairwise_relative_diff(time) +# +# # rotated to local coordinate +# +# num_selected = head.shape[1] +# +# # i's local coordinate's y-axis (the heading) in the global coordinate +# i_local_y_wrt_global = head.reshape(B, -1, 1).expand(B, num_selected, num_selected) +# +# i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 +# +# # rotated_pos = rel_pos +# rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) +# +# rotated_vel = rotate(rel_vel[..., 0], rel_vel[..., 1], angle=-i_local_x_wrt_global) +# +# relation_matrix = torch.concatenate([ +# rotated_pos, +# rotated_vel, +# rel_head, +# rel_time +# ], dim=-1) +# +# relation_matrix[..., 0] /= 400 +# relation_matrix[..., 1] /= 400 +# relation_matrix[..., 2] /= 25 +# relation_matrix[..., 3] /= 25 +# relation_matrix[..., 4] /= 3.1415 +# relation_matrix[..., 5] /= 90 +# +# return relation_matrix, mask + + +def prepro(data, num_samples=1024): + B, T, N, _ = data["encoder/agent_position"].shape + + # pos = data["encoder/agent_position"][..., :2].reshape(B, T * N, -1) + # vel = data["encoder/agent_velocity"].reshape(B, T * N, -1) + # head = data["encoder/agent_heading"].reshape(B, T * N, -1) + # time = torch.arange(T).reshape(1, T, 1).expand(B, T, N).to(pos.device).reshape(B, T * N, -1) + # mask = data["encoder/agent_valid_mask"].reshape(B, T * N) + + T = 1 + pos = data["encoder/agent_position"][..., :2][:, 10].reshape(B, T * N, -1) + vel = data["encoder/agent_velocity"][:, 10].reshape(B, T * N, -1) + head = data["encoder/agent_heading"][:, 10].reshape(B, T * N, -1) + time = torch.arange(1).reshape(1, 1, 1).expand(B, T, N).to(pos.device).reshape(B, T * N, -1) + mask = data["encoder/agent_valid_mask"][:, 10].reshape(B, T * N) + + if num_samples is not None: + indices = torch.randint(high=T * N, size=(B, num_samples, 1)).to(pos.device) # (B, 1024, 1) + pos = torch.gather(pos, index=indices.expand(B, num_samples, 2), dim=1) + vel = torch.gather(vel, index=indices.expand(B, num_samples, 2), dim=1) + head = torch.gather(head, index=indices.expand(B, num_samples, 1), dim=1) + time = torch.gather(time, index=indices.expand(B, num_samples, 1), dim=1) + mask = torch.gather(mask, index=indices.reshape(B, num_samples), dim=1) + + # compute pairwise relative position: (B, N, N, D) + rel_pos = pairwise_relative_diff(pos) + rel_vel = pairwise_relative_diff(vel) + rel_head = wrap_to_pi(pairwise_relative_diff(head)) + rel_time = pairwise_relative_diff(time) + + # rotated to local coordinate + + num_selected = head.shape[1] + + # i's local coordinate's y-axis (the heading) in the global coordinate + i_local_y_wrt_global = head.reshape(B, -1, 1).expand(B, num_selected, num_selected) + + i_local_x_wrt_global = i_local_y_wrt_global - np.pi / 2 + + rotated_pos = rel_pos + # rotated_pos = rel_pos.norm(dim=-1, keepdim=True) + # rotated_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], angle=-i_local_x_wrt_global) + + rotated_vel = rel_vel.norm(dim=-1, keepdim=True) + # rotated_vel = rotate(rel_vel[..., 0], rel_vel[..., 1], angle=-i_local_x_wrt_global) + + rel_dir = wrap_to_pi(torch.arctan2(rel_pos[..., 1], rel_pos[..., 0]) - i_local_x_wrt_global) + + relation_matrix = torch.concatenate( + [ + rotated_pos, + # rotated_vel, + # rel_head, + # rel_dir[..., None], + # rel_time + ], + dim=-1 + ) + + relation_matrix[..., 0] /= 400 + relation_matrix[..., 1] /= 400 + # relation_matrix[..., 1] /= 25 + # # relation_matrix[..., 3] /= 25 + # relation_matrix[..., 2] /= 3.1415 + # relation_matrix[..., 3] /= 3.1415 + # relation_matrix[..., 4] /= 90 + + return relation_matrix, mask + + +class VectorQuantizer(nn.Module): + """ + PZH: From huggingface + + + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float = 0.25, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + # legacy: bool = True, + legacy: bool = False, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + self.register_buffer('data_initialized', torch.zeros(1)) + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.FloatTensor, disable=False) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]: + # reshape z -> (batch, height, width, channel) and flatten + # z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # PZH: https://github.com/karpathy/deep-vector-quantization/blob/c3c026a1ccea369bc892ad6dde5e6d6cd5a508a4/dvq/model/quantize.py + # DeepMind def does not do this but I find I have to... ;\ + if self.training and self.data_initialized.item() == 0: + print('running kmeans!!') # data driven initialization for the embeddings + rp = torch.randperm(z_flattened.size(0)) + kd = kmeans2(z_flattened[rp[:20000]].data.cpu().numpy(), self.n_e, minit='points') + self.embedding.weight.data.copy_(torch.from_numpy(kd[0])) + self.data_initialized.fill_(1) + # TODO: this won't work in multi-GPU setups + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean((z_q - z.detach())**2) + else: + loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) + + # preserve gradients + z_q: torch.FloatTensor = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + if disable: + return z, loss, (perplexity, min_encodings, min_encoding_indices) + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + # return z, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.FloatTensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class RelationEncoder(nn.Module): + def __init__(self, d_model=128, num_layers=2): #, num_heads=4): + super().__init__() + self.num_layers = 3 + nhead = 1 + self.d_model = d_model + self_attn_layers = [] + for _ in range(self.num_layers): + self_attn_layers.append( + NativeTransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=self.d_model * 4, + # dropout=dropout, + batch_first=True + ) + ) + self.self_attn_layers = nn.ModuleList(self_attn_layers) + + # # TODO: Add config + self.agent_pe = nn.Embedding(128, self.d_model) + self.pre_proj = common_layers.build_mlps( + c_in=RELATION_DIM, + mlp_channels=[d_model], # * (num_layers - 1) + [d_model], + ret_before_act=True, + ) + self.proj = common_layers.build_mlps( + c_in=d_model * 2, + mlp_channels=[d_model], # * (num_layers - 1) + [d_model], + ret_before_act=True, + ) + self.out = common_layers.build_mlps( + c_in=d_model, + mlp_channels=[d_model, d_model], # * (num_layers - 1) + [d_model], + ret_before_act=True, + ) + + def forward(self, rel_matrix, mask, batch_dict): + B, N, _, D = rel_matrix.shape + x = self.pre_proj(rel_matrix.reshape(-1, D)).reshape(B, N, N, -1) + # pooled = x.max(dim=-2)[0] + pooled = masked_average(x, mask=get_mask(mask).reshape(B, N, N, 1).expand(B, N, N, self.d_model), dim=-2) + x = torch.cat([x, pooled[:, :, None].repeat(1, 1, N, 1)], dim=-1) + x = self.proj(x) + # x = x.max(dim=-2)[0] + x = masked_average(x, mask=get_mask(mask).reshape(B, N, N, 1).expand(B, N, N, self.d_model), dim=-2) + x = self.out(x) + + # x = batch_dict["encoder/agent_position"][:, 10][..., :2] + # B, N, D = x.shape + # x = self.pre_proj(x.reshape(-1, RELATION_DIM)).reshape(B, -1, self.d_model) + # x = self.out(x.reshape(-1, self.d_model)).reshape(B, N, -1) + + agent_pe = self.agent_pe(batch_dict["encoder/agent_id"]) + x += agent_pe + for k in range(len(self.self_attn_layers)): + x = self.self_attn_layers[k](src=x, src_key_padding_mask=~mask) + return x, mask + + +class RelationDecoder(nn.Module): + def __init__(self, d_model=128, num_layers=2): + super(RelationDecoder, self).__init__() + self.num_layers = 3 + nhead = 1 + self.d_model = d_model + self_attn_layers = [] + for _ in range(self.num_layers): + self_attn_layers.append( + NativeTransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=self.d_model * 4, + # dropout=dropout, + batch_first=True + ) + ) + self.self_attn_layers = nn.ModuleList(self_attn_layers) + self.prediction_head = common_layers.build_mlps( + c_in=d_model, mlp_channels=[d_model, d_model, RELATION_DIM], ret_before_act=True + ) + self.agent_pe = nn.Embedding(128, self.d_model) + + def forward(self, latent, mask, batch_dict): + + B, N, D = latent.shape + + # FIXME: TODO: + x = self.prediction_head(latent.reshape(-1, self.d_model)).reshape(B, N, -1) + return x + + x = latent + agent_pe = self.agent_pe(batch_dict["encoder/agent_id"]) + x += agent_pe + for k in range(len(self.self_attn_layers)): + x = self.self_attn_layers[k](src=x, src_key_padding_mask=~mask) + x = self.prediction_head(x.reshape(-1, self.d_model)).reshape(B, N, -1) + return x + + +class RelationDecoderDEPRECATED(nn.Module): + def __init__(self, d_model=128, num_layers=2, num_heads=4): + super().__init__() + self.d_model = d_model + nhead = 1 + self_attn_layers = [] + self.num_layers = 3 + for _ in range(self.num_layers): + self_attn_layers.append( + NativeTransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=self.d_model * 4, + # dropout=dropout, + batch_first=True + ) + ) + self.self_attn_layers = nn.ModuleList(self_attn_layers) + self.proj1 = common_layers.build_mlps( + c_in=d_model, # TODO: or 6? + # mlp_channels=[d_out] * (num_layers - 1) + [6], + mlp_channels=[d_model], + ret_before_act=True, + ) + self.proj2 = common_layers.build_mlps( + c_in=d_model, # TODO: or 6? + # mlp_channels=[d_out] * (num_layers - 1) + [6], + mlp_channels=[d_model], + ret_before_act=True, + ) + self.proj3 = common_layers.build_mlps( + c_in=d_model, # TODO: or 6? + mlp_channels=[d_model] * (num_layers - 1) + [RELATION_DIM], + # mlp_channels=[RELATION_DIM], + ret_before_act=True, + ) + self.norm1 = torch.nn.LayerNorm(d_model, eps=1e-5, bias=True) + self.norm2 = torch.nn.LayerNorm(d_model, eps=1e-5, bias=True) + + def forward(self, latent, mask=None): + B, N1, D = latent.shape + x = latent + for k in range(len(self.self_attn_layers)): + x = self.self_attn_layers[k](src=x, src_key_padding_mask=~mask) + + q = self.norm1(self.proj1(x.reshape(-1, D)).reshape(B, N1, -1) + x) + k = self.norm2(self.proj2(x.reshape(-1, D)).reshape(B, N1, -1) + x) + x = torch.einsum("bnd,bmd->bnmd", q, k) + x = self.proj3(x.reshape(-1, D)).reshape(B, N1, N1, -1) + + # B, N1, _, D = latent.shape + # x = latent + # for k in range(len(self.self_attn_layers)): + # x = self.self_attn_layers[k](src=x, src_key_padding_mask=~mask) + + # x = latent + # x = self.norm1(self.proj1(x.reshape(-1, D)).reshape(B, N1, N1, -1) + x) + # k = self.norm2(self.proj2(x.reshape(-1, D)).reshape(B, N1, N1, -1) + x) + # x = torch.einsum("bnmd,bnmd->bnd", q, k) + # x = self.proj3(x.reshape(-1, D)).reshape(B, N1, N1, -1) + + return x + + +class Reltok(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + d_model = 128 + self.enc = RelationEncoder(num_layers=3, d_model=d_model) + + # from scenestreamer.models.scene_encoder import SceneEncoder + # self.config = config + # self.scene_encoder = SceneEncoder(config=self.config) + + self.dec = RelationDecoder(num_layers=3, d_model=d_model) + + self.quantizer = VectorQuantizer(1024, d_model) + + def forward(self, batch_dict): + data, mask = prepro(batch_dict, num_samples=None) + # latent, agent_pe, new_mask = self.enc(data, mask, batch_dict) + latent, mask = self.enc(data, mask, batch_dict) + z, quant_loss, (perplexity, min_encodings, min_encoding_indices) = self.quantizer(latent, disable=False) + + emask = get_mask(mask) + count = emask.sum(-1, keepdims=True) + count = torch.masked_fill(count, count == 0, 1) + target = (data * emask[..., None]).sum(-2) / count + + return { + "output": self.dec(z, mask=mask, batch_dict=batch_dict), + "target": target, + "rel_matrix": data, + "quant_loss": quant_loss, + # "dist": posterior, + "data": batch_dict, + "valid_mask": mask, + "quant_idxs": min_encoding_indices, + } + + +class ReltokLightning(pl.LightningModule): + def __init__(self, config): + if "SEED" in config: + pl.seed_everything(config.SEED) + print("Everything is seeded to: ", config.SEED) + super().__init__() + self.config = config + + # self.enc = RelationEncoder() + # self.dec = RelationDecoder() + + self.reltok = Reltok(config) + + self.save_hyperparameters() + self.validation_outputs = [] + self.validation_ground_truth = [] + + def forward(self, batch_dict): + # data = prepro(batch_dict,num_samples=256) + return self.reltok(batch_dict) + + def get_loss(self, data_dict): + output_logit = data_dict["output"] + + target_action = data_dict["target"] + mask = data_dict["valid_mask"] # (B, N) + + # Masking + output_logit = output_logit[mask] + target_action = target_action[mask] + + mse = nn.functional.mse_loss(input=output_logit, target=target_action) + loss = (mse * 1 + data_dict["quant_loss"] * 0.1) + + output_logit_scaled = output_logit.clone() + output_logit_scaled[..., 0] *= 400 + output_logit_scaled[..., 1] *= 400 + # output_logit_scaled[..., 1] *= 25 + # # output_logit_scaled[..., 3] *= 25 + # output_logit_scaled[..., 2] *= 3.1415 + # output_logit_scaled[..., 3] *= 3.1415 + # output_logit_scaled[..., 4] *= 90 + + target_action_scaled = target_action.clone() + target_action_scaled[..., 0] *= 400 + target_action_scaled[..., 1] *= 400 + # target_action_scaled[..., 1] *= 25 + # # target_action_scaled[..., 3] *= 25 + # target_action_scaled[..., 2] *= 3.1415 + # target_action_scaled[..., 3] *= 3.1415 + # target_action_scaled[..., 4] *= 90 + + recon_rel_matrix = pairwise_relative_diff(data_dict["output"]) + rel_matrix = data_dict["rel_matrix"] + emask = get_mask(mask) + # recon_loss1 = nn.functional.l1_loss(input=recon_rel_matrix[emask], target=rel_matrix[emask]) + recon_loss2 = nn.functional.l1_loss(input=-recon_rel_matrix[emask], target=rel_matrix[emask]) + + # debugging: cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + # encodings = F.one_hot(data_dict["quant_idxs"][data_dict["valid_mask"].flatten()], self.reltok.quantizer.n_e).float().reshape(-1, self.reltok.quantizer.n_e) + # flat_mask = get_mask(data_dict["valid_mask"]).flatten() + flat_mask = data_dict["valid_mask"].flatten() + encodings = F.one_hot(data_dict["quant_idxs"][flat_mask], + self.reltok.quantizer.n_e).float().reshape(-1, self.reltok.quantizer.n_e) + + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + # self.log('val_perplexity', perplexity, prog_bar=True) + # self.log('val_cluster_use', cluster_use, prog_bar=True) + + scaled_mse = nn.functional.mse_loss(input=output_logit_scaled, target=target_action_scaled) + scaled_norm = (output_logit_scaled[..., :1] - target_action_scaled[..., :1]).norm(dim=-1).mean() + + loss_stat = { + # "recon/loss1": recon_loss1, + "recon/loss2": recon_loss2, + "loss/total_loss": loss, + "loss/mse": mse, + "mse": mse, + "perplexity": perplexity, + "cluster_use": cluster_use, + "scaled_mse": scaled_mse, + "scaled_norm": scaled_norm, + "loss/quant_loss": data_dict["quant_loss"], # ["codebook_loss"], + # "loss/commitment_loss": data_dict["quant_loss"]["commitment_loss"], + "output/output_mean": output_logit.mean(), + "output/output_max": output_logit.max(), + "output/output_min": output_logit.min(), + "output/target_mean": target_action.mean(), + "output/target_max": target_action.max(), + "output/target_min": target_action.min(), + "quant/quant_idxs_mean": data_dict["quant_idxs"][flat_mask].float().mean(), + "quant/quant_idxs_max": data_dict["quant_idxs"][flat_mask].float().max(), + "quant/quant_idxs_min": data_dict["quant_idxs"][flat_mask].float().min(), + } + try: + loss_stat["lr"] = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] + except RuntimeError: + # When debugging, the model might not be attached to a trainer. + pass + return loss, loss_stat + + def training_step(self, data_dict, batch_idx): + data_dict = self(data_dict) + loss, loss_stat = self.get_loss(data_dict) + self.log_dict( + {f"train/{k}": float(v) + for k, v in loss_stat.items()}, + batch_size=data_dict["data"]["encoder/agent_feature"].shape[0], + # on_epoch=True, + prog_bar=True, + ) + self.log('monitoring_step', float(self.global_step)) + return loss + + def configure_optimizers(self): + """Required by Lightning.""" + opt_cfg = self.config.OPTIMIZATION + optimizer = torch.optim.AdamW( + self.parameters(), lr=opt_cfg.LR, weight_decay=opt_cfg.get('WEIGHT_DECAY', 0), betas=(0.9, 0.95), eps=1e-5 + ) + scheduler = lr_schedule.get_cosine_schedule_with_warmup( + optimizer=optimizer, + # num_warmup_steps=opt_cfg.WARMUP_STEPS, + # num_training_steps=opt_cfg.TRAINING_STEPS, + num_warmup_steps=200, # TODO + num_training_steps=opt_cfg.TRAINING_STEPS, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step" + }, + } + + +if __name__ == '__main__': + pass diff --git a/scenestreamer/tokenization/test_tokenization.py b/scenestreamer/tokenization/test_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..9f66f55e38603e3cde2ecaf5bb6eb560a40491fa --- /dev/null +++ b/scenestreamer/tokenization/test_tokenization.py @@ -0,0 +1,824 @@ +import hydra +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.tokenization.motion_tokenizers import DeltaDeltaTokenizer +from scenestreamer.utils import REPO_ROOT +from scenestreamer.utils import debug_tools, get_time_str +from scenestreamer.utils import utils + + +def _unbatch_to_numpy(tensor_dict, index=0): + ret = {} + for k, v in tensor_dict.items(): + if isinstance(v[index], (np.ndarray, str, np.str_)): + ret[k] = v[index] + else: + ret[k] = v[index].numpy() + return ret + + +def _batch_to_tensor(array_list): + return torch.from_numpy(np.array(array_list)) + + +def evaluate_tokenizer(tokenizer, dataloader, num_scenarios_limit=100, test_tokenizer=False): + # error_list = [] + ade_list = [] + fde_list = [] + num_scenarios = 0 + num_objects = 0 + + pbar = tqdm(total=num_scenarios_limit) + + data_dict_list = [] + + for input_dict in dataloader: + + B = input_dict["in_evaluation"].shape[0] + num_scenarios += B + if test_tokenizer: + # Do the tokenization (already done by dataloader and filled decoder/target_action. + output_dict, stats = tokenizer.tokenize(input_dict) + # error_list.append(stats['reconstruction_error']) + + output_dict["decoder/output_action"] = output_dict["decoder/target_action"] + fill_zero = ~output_dict["decoder/target_action_valid_mask"] + output_dict["decoder/input_action_valid_mask"][fill_zero] = False + + else: + raise ValueError + input_dict["decoder/output_action"] = input_dict["decoder/target_action"] + + with torch.no_grad(): + input_dict = tokenizer.detokenize(output_dict, detokenizing_gt=True) + + # recon_mask = input_dict['decoder/interpolated_target_action_valid_mask'][:, 4::5] # 80 -> 16 + # input_act_mask = input_dict['decoder/input_action_valid_mask'][:, 2:] # 18 -> 16 + # assert (recon_mask == input_act_mask).all() + # assert (recon_mask >= input_dict['decoder/agent_valid_mask'][:, 14::5]).all() + + pred = input_dict["decoder/reconstructed_position"] + pred_head = input_dict["decoder/reconstructed_heading"] + if pred.shape[1] == 96: + pred = pred[:, 11:] + pred_head = pred_head[:, 11:] + elif pred.shape[1] == 91: + pred = pred[:, 11:] + pred_head = pred_head[:, 11:] + else: + raise ValueError + gt = input_dict["decoder/agent_position"][:, 11:, :, :2] + gt_head = input_dict["decoder/agent_heading"][:, 11:] + + target_action_valid_mask = input_dict["decoder/target_action_valid_mask"] + gt_action = input_dict["decoder/target_action"].clone() + gt_action[~target_action_valid_mask] = 0 + # encodings = torch.nn.functional.one_hot(gt_action, num_classes=tokenizer.num_actions) + + future_valid_mask = input_dict["decoder/agent_valid_mask"][:, 11:] + current_valid_mask = input_dict["decoder/current_agent_valid_mask"] + raw_valid_mask = torch.logical_and(future_valid_mask, current_valid_mask[:, None]) + + T_pred = pred.shape[1] + T_gt = gt.shape[1] + T_compare = min(T_pred, T_gt) + pred = pred[:, :T_compare] + gt = gt[:, :T_compare] + pred_head = pred_head[:, :T_compare] + error = (pred - gt).norm(dim=-1) + error_head = utils.wrap_to_pi(pred_head[:, :T_compare] - gt_head[:, :T_compare]).abs() + + contours = utils.cal_polygon_contour_torch( + x=pred[..., 0], + y=pred[..., 1], + theta=pred_head, + length=input_dict["decoder/current_agent_shape"][..., 0][:, None], + width=input_dict["decoder/current_agent_shape"][..., 1][:, None] + ) + gt_contours = utils.cal_polygon_contour_torch( + x=gt[..., 0], + y=gt[..., 1], + theta=gt_head, + length=input_dict["decoder/current_agent_shape"][..., 0][:, None], + width=input_dict["decoder/current_agent_shape"][..., 1][:, None] + ) + + contour_error = (contours - gt_contours).norm(dim=-1).mean(dim=-1) + + error = error[:, 4::5] + error_head = error_head[:, 4::5] + contour_error = contour_error[:, 4::5] + + stat = {} + + at = input_dict["decoder/agent_type"].unsqueeze(1).expand(-1, 80, -1) + + # === all === + if input_dict["decoder/reconstructed_valid_mask"].shape[1] == 96: + input_mask = input_dict["decoder/reconstructed_valid_mask"][:, :-5] + elif input_dict["decoder/reconstructed_valid_mask"].shape[1] == 91: + input_mask = input_dict["decoder/reconstructed_valid_mask"][:, :] + else: + raise ValueError + formal_valid_mask = input_dict["decoder/agent_valid_mask"] + assert input_mask.shape == formal_valid_mask.shape + formal_valid_mask = torch.logical_and(formal_valid_mask, input_mask) + formal_valid_mask = formal_valid_mask[:, ::5] + + num_all_objects = formal_valid_mask.sum().item() + if error.shape[1] == 17: + error = error[:, :-1] + if error_head.shape[1] == 17: + error_head = error_head[:, :-1] + if contour_error.shape[1] == 17: + contour_error = contour_error[:, :-1] + if formal_valid_mask.shape[1] == 17: + formal_valid_mask = formal_valid_mask[:, :-1] + if formal_valid_mask.shape[1] == 19: + formal_valid_mask = formal_valid_mask[:, 3:] + + # tmp = (error * formal_valid_mask)[0].sum(0) / formal_valid_mask[0].sum(0) + error_masked = error * formal_valid_mask + ade = (error_masked).sum() # / valid_mask.sum() + ade_head = (error_head * formal_valid_mask).sum() # / valid_mask.sum() + ade_contour = (contour_error * formal_valid_mask).sum() # / valid_mask.sum() + # ade_list.append(ade) + + valid_mask_any_step = formal_valid_mask.any(dim=1) + fde = torch.masked_fill(error, ~formal_valid_mask, float("-inf")).max(dim=1)[0] + fde = torch.masked_fill(fde, ~valid_mask_any_step, 0).sum() # / valid_mask_any_step.sum() + + # fde_list.append(fde) + + stat["all/ade_contour_sum"] = ade_contour.item() + stat["all/ade_head_sum"] = ade_head.item() + stat["all/ade_sum"] = ade.item() + stat["all/ade_count"] = formal_valid_mask.sum().item() + stat["all/fde_sum"] = fde.item() + stat["all/fde_count"] = valid_mask_any_step.sum().item() + stat["all/num_objects"] = num_all_objects + stat["all/num_scenarios"] = B + + # stat["all/cluster_use"] = torch.sum(encodings[target_action_valid_mask].float().mean(0) > 0).item() + + for ot in at.unique(): + if ot == -1: + continue + + is_type = at == ot + valid_mask = torch.logical_and(formal_valid_mask, is_type[:, ::5]) + N = int(valid_mask.sum()) + if N == 0: + stat["obj{}/ade".format(ot)] = -1 + stat["obj{}/fde".format(ot)] = -1 + stat["obj{}/num_objects".format(ot)] = 0 + stat["obj{}/num_scenarios".format(ot)] = 0 + continue + + ade_contour = (contour_error * valid_mask).sum() # / valid_mask.sum() + ade_head = (error_head * valid_mask).sum() # / valid_mask.sum() + ade = (error * valid_mask).sum() # / valid_mask.sum() + + real_ade = ade / valid_mask.sum() + + valid_mask_any_step = valid_mask.any(dim=1) + fde = torch.masked_fill(error, ~valid_mask, float("-inf")).max(dim=1)[0] + fde = torch.masked_fill(fde, ~valid_mask_any_step, 0).sum() # / valid_mask_any_step.sum() + stat["obj{}/ade_sum".format(ot)] = ade.item() + stat["obj{}/ade_contour_sum".format(ot)] = ade_contour.item() + stat["obj{}/ade_head_sum".format(ot)] = ade_head.item() + stat["obj{}/ade_count".format(ot)] = valid_mask.sum().item() + stat["obj{}/fde_sum".format(ot)] = fde.item() + stat["obj{}/fde_count".format(ot)] = valid_mask_any_step.sum().item() + stat["obj{}/num_objects".format(ot)] = N + stat["obj{}/num_scenarios".format(ot)] = B + # stat["obj{}/cluster_use".format(ot)] = torch.sum(encodings[target_action_valid_mask].float().mean(0) > 0 + # ).item() + + data_dict_list.append(stat) + num_objects += num_all_objects + pbar.update(B) + if num_scenarios_limit is not None and num_scenarios > num_scenarios_limit: + break + + pbar.close() + + return data_dict_list, num_objects, num_scenarios + + +def evaluate_tokenizer_gpt(tokenizer, dataloader, num_scenarios_limit=100, test_tokenizer=False): + # error_list = [] + ade_list = [] + fde_list = [] + num_scenarios = 0 + num_objects = 0 + + pbar = tqdm(total=num_scenarios_limit) + + data_dict_list = [] + + for input_dict in dataloader: + + B = input_dict["in_evaluation"].shape[0] + num_scenarios += B + if test_tokenizer: + # Do the tokenization (already done by dataloader and filled decoder/target_action. + output_dict, stats = tokenizer.tokenize_gpt_style(input_dict) + # error_list.append(stats['reconstruction_error']) + input_dict["decoder/output_action"] = output_dict["decoder/target_action"] + + fill_zero = ((input_dict["decoder/output_action"] == -1) & input_dict["decoder/input_action_valid_mask"]) + input_dict["decoder/output_action"][fill_zero] = tokenizer.default_action + + # else: + # input_dict["decoder/output_action"] = input_dict["decoder/target_action"] + + with torch.no_grad(): + input_dict = tokenizer.detokenize_gpt_style(input_dict) + + recon_mask = input_dict['decoder/interpolated_target_action_valid_mask'][:, ::5][:, :-1] # 91 -> 18 + input_act_mask = input_dict['decoder/input_action_valid_mask'][:, :] + assert (recon_mask == input_act_mask).all() + + pred = input_dict["decoder/reconstructed_position"] + gt = input_dict["decoder/agent_position"][..., :2] + + target_action_valid_mask = input_dict["decoder/target_action_valid_mask"] + gt_action = input_dict["decoder/target_action"].clone() + gt_action[~target_action_valid_mask] = 0 + encodings = torch.nn.functional.one_hot(gt_action, num_classes=tokenizer.num_actions) + + # future_valid_mask = input_dict["decoder/future_agent_valid_mask"] + # current_valid_mask = input_dict["decoder/current_agent_valid_mask"] + # raw_valid_mask = torch.logical_and(future_valid_mask, current_valid_mask[:, None]) + + # T_pred = pred.shape[1] + # T_gt = gt.shape[1] + # T_compare = min(T_pred, T_gt) + # pred = pred[:, :T_compare] + # gt = gt[:, :T_compare] + error = (pred - gt).norm(dim=-1) + + error = error[:, ::5][:, 1:] + + stat = {} + + at = input_dict["decoder/agent_type"].unsqueeze(1).expand(-1, 91, -1) + + action_valid_mask = input_dict["decoder/input_action_valid_mask"] + next_action_valid_mask = input_dict["decoder/target_action_valid_mask"] + + # === all === + formal_valid_mask = action_valid_mask & next_action_valid_mask + num_all_objects = formal_valid_mask.sum().item() + + tmp = (error * formal_valid_mask)[0].sum(0) / formal_valid_mask[0].sum(0) + + ade = (error * formal_valid_mask).sum() # / valid_mask.sum() + # ade_list.append(ade) + + valid_mask_any_step = formal_valid_mask.any(dim=1) + fde = torch.masked_fill(error, ~formal_valid_mask, float("-inf")).max(dim=1)[0] + fde = torch.masked_fill(fde, ~valid_mask_any_step, 0).sum() # / valid_mask_any_step.sum() + # fde_list.append(fde) + + stat["all/ade_sum"] = ade.item() + stat["all/ade_count"] = formal_valid_mask.sum().item() + stat["all/fde_sum"] = fde.item() + stat["all/fde_count"] = valid_mask_any_step.sum().item() + stat["all/num_objects"] = num_all_objects + stat["all/num_scenarios"] = B + + stat["all/cluster_use"] = torch.sum(encodings[target_action_valid_mask].float().mean(0) > 0).item() + + for ot in at.unique(): + if ot == -1: + continue + + is_type = at == ot + valid_mask = torch.logical_and(formal_valid_mask, is_type[:, ::5][:, 1:]) + N = int(valid_mask.sum()) + if N == 0: + stat["obj{}/ade".format(ot)] = -1 + stat["obj{}/fde".format(ot)] = -1 + stat["obj{}/num_objects".format(ot)] = 0 + stat["obj{}/num_scenarios".format(ot)] = 0 + continue + + ade = (error * valid_mask).sum() # / valid_mask.sum() + valid_mask_any_step = valid_mask.any(dim=1) + fde = torch.masked_fill(error, ~valid_mask, float("-inf")).max(dim=1)[0] + fde = torch.masked_fill(fde, ~valid_mask_any_step, 0).sum() # / valid_mask_any_step.sum() + stat["obj{}/ade_sum".format(ot)] = ade.item() + stat["obj{}/ade_count".format(ot)] = valid_mask.sum().item() + stat["obj{}/fde_sum".format(ot)] = fde.item() + stat["obj{}/fde_count".format(ot)] = valid_mask_any_step.sum().item() + stat["obj{}/num_objects".format(ot)] = N + stat["obj{}/num_scenarios".format(ot)] = B + stat["obj{}/cluster_use".format(ot)] = torch.sum(encodings[target_action_valid_mask].float().mean(0) > 0 + ).item() + + data_dict_list.append(stat) + num_objects += num_all_objects + pbar.update(B) + if num_scenarios_limit is not None and num_scenarios > num_scenarios_limit: + break + + pbar.close() + + return data_dict_list, num_objects, num_scenarios + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="0220_midgpt.yaml") +def test_delta_delta(config): + from omegaconf import OmegaConf + + OmegaConf.set_struct(config, False) + OmegaConf.set_struct(config, True) + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=1, + train_num_workers=config.num_workers, + train_prefetch_factor=config.prefetch_factor, + val_batch_size=1, + val_num_workers=config.val_num_workers, + val_prefetch_factor=config.prefetch_factor, + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + # config.TOKENIZATION.TOKENIZATION_METHOD = "delta_delta" + + file_name = str(config.TOKENIZATION.TOKENIZATION_METHOD) + + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config=config) + + num_scenarios_limit = min(500, len(datamodule.val_dataset)) + + stat, num_objects, num_scenarios = evaluate_tokenizer( + tokenizer, dataloader, test_tokenizer=True, num_scenarios_limit=num_scenarios_limit + ) + exp_state = { + "total_num_scenarios": num_scenarios, + "total_num_objects": num_objects, + "file_name": file_name, + } + stat = pd.DataFrame(stat) + for name in ["all", "obj1", "obj2", "obj3"]: + + # if name in ['obj2', 'obj3']: + # continue + + if f"{name}/ade_sum" not in stat: + continue + + ade = stat[f"{name}/ade_sum"].sum() / stat[f"{name}/ade_count"].sum() + ade_head = stat[f"{name}/ade_head_sum"].sum() / stat[f"{name}/ade_count"].sum() + ade_contour = stat[f"{name}/ade_contour_sum"].sum() / stat[f"{name}/ade_count"].sum() + fde = stat[f"{name}/fde_sum"].sum() / stat[f"{name}/fde_count"].sum() + obj_count = stat[f"{name}/num_objects"].sum() + scenario_count = stat[f"{name}/num_scenarios"].sum() + exp_state.update( + { + f"{name}/ade": ade, + f"{name}/ade_head": ade_head, + f"{name}/ade_contour": ade_contour, + f"{name}/fde": fde, + f"{name}/num_objects": obj_count, + f"{name}/num_scenarios": scenario_count, + # f"{name}/cluster_use": stat[f"{name}/cluster_use"].mean() + } + ) + print(f"{num_scenarios=}, {num_objects=}.\n" f"{exp_state}" f"\n==========\n") + print({k: round(v, 4) for k, v in exp_state.items() if "obj1" in k}) + print({k: round(v, 4) for k, v in exp_state.items() if "obj2" in k}) + print({k: round(v, 4) for k, v in exp_state.items() if "obj3" in k}) + print(f"\n==========\n") + print({k: round(v, 4) for k, v in exp_state.items() if "all" in k}) + print(f"\n==========\n") + # Print average for obj1/2/3 + keys = [k.split("obj3/")[-1] for k, v in exp_state.items() if "obj3" in k] + res = {} + for k in keys: + res[k] = np.mean([v for kee, v in exp_state.items() if kee.endswith(k) and (not kee.startswith("all"))]) + print({"avg/{}".format(k): round(v, 4) for k, v in res.items()}) + print(f"\n==========\n") + df = pd.DataFrame([exp_state]) + + def applytab(row): + s = "" + for v in row.values: + if isinstance(v, float): + s += f"{v:.3f}" + '\t' + else: + s += str(v) + '\t' + print(s) + + # print('\t'.join(map(str,df.columns))) # to print the column names if required + df.apply(applytab, axis=1) + + df.to_csv(f"{file_name}_EVAL.csv") + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +def test_delta(config): + from omegaconf import OmegaConf + + OmegaConf.set_struct(config, False) + config.DATA.TRAINING_DATA_DIR = 'data/20scenarios' + config.DATA.TEST_DATA_DIR = 'data/20scenarios' + OmegaConf.set_struct(config, True) + + num_scenarios_limit = 100 + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=config.batch_size, + train_num_workers=config.num_workers, + train_prefetch_factor=config.prefetch_factor, + val_batch_size=config.val_batch_size, + val_num_workers=config.val_num_workers, + val_prefetch_factor=config.prefetch_factor, + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + config.TOKENIZATION.TOKENIZATION_METHOD = "delta" + + file_name = "delta" + + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config=config) + + stat, num_objects, num_scenarios = evaluate_tokenizer( + tokenizer, dataloader, test_tokenizer=True, num_scenarios_limit=num_scenarios_limit + ) + exp_state = { + "total_num_scenarios": num_scenarios, + "total_num_objects": num_objects, + "file_name": file_name, + } + stat = pd.DataFrame(stat) + for name in ["all", "obj1", "obj2", "obj3"]: + + # if name in ['obj2', 'obj3']: + # continue + + ade = stat[f"{name}/ade_sum"].sum() / stat[f"{name}/ade_count"].sum() + fde = stat[f"{name}/fde_sum"].sum() / stat[f"{name}/fde_count"].sum() + obj_count = stat[f"{name}/num_objects"].sum() + scenario_count = stat[f"{name}/num_scenarios"].sum() + exp_state.update( + { + f"{name}/ade": ade, + f"{name}/fde": fde, + f"{name}/num_objects": obj_count, + f"{name}/num_scenarios": scenario_count, + f"{name}/cluster_use": stat[f"{name}/cluster_use"].mean() + } + ) + print(f"{num_scenarios=}, {num_objects=}.\n" f"{exp_state}" f"\n==========\n") + print({k: v for k, v in exp_state.items() if "obj1" in k}) + print({k: v for k, v in exp_state.items() if "obj2" in k}) + print({k: v for k, v in exp_state.items() if "obj3" in k}) + print(f"\n==========\n") + df = pd.DataFrame([exp_state]) + + def applytab(row): + s = "" + for v in row.values: + if isinstance(v, float): + s += f"{v:.3f}" + '\t' + else: + s += str(v) + '\t' + print(s) + + # print('\t'.join(map(str,df.columns))) # to print the column names if required + df.apply(applytab, axis=1) + + df.to_csv(f"{file_name}_EVAL.csv") + + +def test_precomputed(num_scenarios_limit=100): + cfg_file = "cfgs/motion_debug.yaml" + config = debug_tools.get_debug_config(cfg_file=cfg_file) + + large_dataset_file = 'data/metadrive_processed_waymo/validation' + config.DATA["TRAINING_DATA_DIR"] = large_dataset_file + config.DATA["TEST_DATA_DIR"] = large_dataset_file + + dataloader = debug_tools.get_debug_dataloader( + cfg_file=cfg_file, in_evaluation=False, train_num_workers=8, train_batch_size=16 + ) + + # config.TOKENIZATION.FILE_NAME = "test_0308-2125.json" + # config.TOKENIZATION.FILE_NAME = "test_0308-2210.json" + # config.TOKENIZATION.FILE_NAME = "test_0308-2221.json" + file_name = "precomputed_delta_delta_0309sol1.json" + + config.TOKENIZATION.FILE_NAME = file_name + + config.TOKENIZATION.TOKENIZATION_METHOD = None + + stat, num_objects, num_scenarios = evaluate_tokenizer( + PrecomputedDeltaDeltaTokenizer(config), + dataloader, + test_tokenizer=True, + num_scenarios_limit=num_scenarios_limit + ) + exp_state = { + "total_num_scenarios": num_scenarios, + "total_num_objects": num_objects, + "file_name": file_name, + } + stat = pd.DataFrame(stat) + for name in ["all", "obj1", "obj2", "obj3"]: + ade = stat[f"{name}/ade_sum"].sum() / stat[f"{name}/ade_count"].sum() + fde = stat[f"{name}/fde_sum"].sum() / stat[f"{name}/fde_count"].sum() + obj_count = stat[f"{name}/num_objects"].sum() + scenario_count = stat[f"{name}/num_scenarios"].sum() + exp_state.update( + { + f"{name}/ade": ade, + f"{name}/fde": fde, + f"{name}/num_objects": obj_count, + f"{name}/num_scenarios": scenario_count, + f"{name}/cluster_use": stat[f"{name}/cluster_use"].mean() + } + ) + print(f"{num_scenarios=}, {num_objects=}.\n" f"{exp_state}" f"\n==========\n") + print({k: v for k, v in exp_state.items() if "obj1" in k}) + print({k: v for k, v in exp_state.items() if "obj2" in k}) + print({k: v for k, v in exp_state.items() if "obj3" in k}) + print(f"\n==========\n") + df = pd.DataFrame([exp_state]) + + def applytab(row): + s = "" + for v in row.values: + if isinstance(v, float): + s += f"{v:.3f}" + '\t' + else: + s += str(v) + '\t' + print(s) + + # print('\t'.join(map(str,df.columns))) # to print the column names if required + df.apply(applytab, axis=1) + + df.to_csv(f"{file_name}_EVAL.csv") + + +def grid_search_delta_tokenizer(num_scenarios_limit=5000): + # large_dataset_file = 'data/waymo_8s_debug' + # large_dataset_file = '/data1/datasets/metadrive_processed_waymo/validation' + # large_dataset_file = '/home/zhenghao/Datasets/metadrive_processed_waymo/validation' + large_dataset_file = '/data1/datasets/metadrive_processed_waymo/validation' + tokenizer_class = DeltaTokenizer + config = debug_tools.get_debug_config() + config.DATA["TRAINING_DATA_DIR"] = large_dataset_file + config.DATA["TEST_DATA_DIR"] = large_dataset_file + datamodule = SceneStreamerDataModule( + config, + train_batch_size=1, + train_num_workers=0, + val_batch_size=4, + val_num_workers=4, + train_prefetch_factor=2, + val_prefetch_factor=2 + ) + datamodule.setup("fit") + # dataloader = datamodule.val_dataloader() + dataloader = datamodule.train_dataloader() + file_name = "tokenizer_test_{}.csv".format(get_time_str()) + result = [] + + for nbins in [13]: # 17, 21, 25]: + config.TOKENIZATION["NUM_BINS"] = nbins + ymax = 35 + ymin = -3 + xmax = 2 + xmin = -2 + config.TOKENIZATION["X_MAX"] = xmax + config.TOKENIZATION["X_MIN"] = xmin + config.TOKENIZATION["Y_MAX"] = ymax + config.TOKENIZATION["Y_MIN"] = ymin + config.TOKENIZATION["NUM_BINS"] = nbins + # config.TOKENIZATION["NUM_SKIPPED_STEPS"] = 1 + + _, ade_list, fde_list, num_scenarios, num_objects = evaluate_tokenizer( + tokenizer_class(config), dataloader, num_scenarios_limit=num_scenarios_limit + ) + + print( + f"{xmax=}, {xmin=}, {ymax=}, {ymin=}, {nbins=}. " + f"Reconstruction ADE: {np.mean(ade_list)}, FDE: {np.mean(fde_list)}. " + f"Num scenarios: {num_scenarios}, Num objects: {num_objects}." + ) + + result.append( + dict( + X_MAX=xmax, + X_MIN=xmin, + Y_MAX=ymax, + Y_MIN=ymin, + NUM_BINS=nbins, + # error=np.mean(error_list), + ade=np.mean(ade_list), + fde=np.mean(fde_list), + num_scenarios=num_scenarios + ) + ) + # pd.DataFrame(result).to_csv("tmp_" + file_name) + pd.DataFrame(result).to_csv(file_name) + + +def grid_search_delta_delta_tokenizer(num_scenarios_limit=5000, batch_size=32): + # large_dataset_file = 'data/waymo_8s_debug' + # large_dataset_file = '/data/datasets/scenarionet/waymo/training' + large_dataset_file = 'data/metadrive_processed_waymo/validation' + # large_dataset_file = '/data1/datasets/metadrive_processed_waymo/validation' + # large_dataset_file = '/home/zhenghao/Datasets/metadrive_processed_waymo/validation' + + tokenizer_class = DeltaDeltaTokenizer + + config = debug_tools.get_debug_config() + config.DATA["TRAINING_DATA_DIR"] = large_dataset_file + config.DATA["TEST_DATA_DIR"] = large_dataset_file + config.TOKENIZATION.TOKENIZATION_METHOD = "delta_delta" + # config.TRAINING["PREDICT_ALL_AGENTS"] = True + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=batch_size, + train_num_workers=0, + val_batch_size=4, + val_num_workers=4, + train_prefetch_factor=2, + val_prefetch_factor=2 + ) + datamodule.setup("fit") + # dataloader = datamodule.val_dataloader() + dataloader = datamodule.train_dataloader() + file_name = "tokenizer_test_{}.csv".format(get_time_str()) + result = [] + + for nbins in [21]: # 17, 21, 25]: + for xlimit in [3.5]: + for ymax in [3.5]: + for ymin in [-3.5]: + + config.TOKENIZATION["NUM_BINS"] = nbins + + xmax = xlimit + xmin = -xlimit + + config.TOKENIZATION["X_MAX"] = xmax + config.TOKENIZATION["X_MIN"] = xmin + config.TOKENIZATION["Y_MAX"] = ymax + config.TOKENIZATION["Y_MIN"] = ymin + config.TOKENIZATION["NUM_BINS"] = nbins + + stat, num_objects, num_scenarios = evaluate_tokenizer( + tokenizer_class(config), + dataloader, + num_scenarios_limit=num_scenarios_limit, + # object_type=object_type, + ) + exp_state = { + "X_MAX": xmax, + "X_MIN": xmin, + "Y_MAX": ymax, + "Y_MIN": ymin, + "NUM_BINS": nbins, + "total_num_scenarios": num_scenarios, + "total_num_objects": num_objects + } + stat = pd.DataFrame(stat) + for name in ["all", "obj1", "obj2", "obj3"]: + ade = stat[f"{name}/ade_sum"].sum() / stat[f"{name}/ade_count"].sum() + fde = stat[f"{name}/fde_sum"].sum() / stat[f"{name}/fde_count"].sum() + obj_count = stat[f"{name}/num_objects"].sum() + scenario_count = stat[f"{name}/num_scenarios"].sum() + exp_state.update( + { + f"{name}/ade": ade, + f"{name}/fde": fde, + f"{name}/num_objects": obj_count, + f"{name}/num_scenarios": scenario_count, + f"{name}/cluster_use": stat[f"{name}/cluster_use"].mean() + } + ) + cu = {k: v for k, v in exp_state.items() if "obj3" in k} + print( + f"\n\n==========\n{xmax=}, {xmin=}, {ymax=}, {ymin=}, {nbins=}.\n" + f"{num_scenarios=}, {num_objects=}.\n" + f"{exp_state}" + f"\n==========\n" + f"{cu}" + f"\n==========\n" + ) + result.append(exp_state) + pd.DataFrame(result).to_csv("tmp_" + file_name) + df = pd.DataFrame(result) + df.to_csv(file_name) + print("Data saved to", file_name) + print(df) + return df + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +def test_bicycle_model(config): + + from omegaconf import OmegaConf + + OmegaConf.set_struct(config, False) + # config.DATA.TRAINING_DATA_DIR = 'data/20scenarios' + # config.DATA.TEST_DATA_DIR = 'data/20scenarios' + OmegaConf.set_struct(config, True) + + num_scenarios_limit = 100 + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=config.batch_size, + train_num_workers=config.num_workers, + train_prefetch_factor=config.prefetch_factor, + val_batch_size=config.val_batch_size, + val_num_workers=config.val_num_workers, + val_prefetch_factor=config.prefetch_factor, + ) + datamodule.setup("fit") + dataloader = datamodule.val_dataloader() + + config.TOKENIZATION.TOKENIZATION_METHOD = "bicycle" + + file_name = "bicycle" + + from scenestreamer.tokenization import get_tokenizer + tokenizer = get_tokenizer(config=config) + + stat, num_objects, num_scenarios = evaluate_tokenizer( + tokenizer, dataloader, test_tokenizer=True, num_scenarios_limit=num_scenarios_limit + ) + exp_state = { + "total_num_scenarios": num_scenarios, + "total_num_objects": num_objects, + "file_name": file_name, + } + stat = pd.DataFrame(stat) + for name in ["all", "obj1", "obj2", "obj3"]: + # if name in ['obj2', 'obj3']: + # continue + + ade = stat[f"{name}/ade_sum"].sum() / stat[f"{name}/ade_count"].sum() + ade_head = stat[f"{name}/ade_head_sum"].sum() / stat[f"{name}/ade_count"].sum() + ade_contour = stat[f"{name}/ade_contour_sum"].sum() / stat[f"{name}/ade_count"].sum() + fde = stat[f"{name}/fde_sum"].sum() / stat[f"{name}/fde_count"].sum() + obj_count = stat[f"{name}/num_objects"].sum() + scenario_count = stat[f"{name}/num_scenarios"].sum() + exp_state.update( + { + f"{name}/ade": ade, + f"{name}/ade_head": ade_head, + f"{name}/ade_contour": ade_contour, + f"{name}/fde": fde, + f"{name}/num_objects": obj_count, + f"{name}/num_scenarios": scenario_count, + f"{name}/cluster_use": stat[f"{name}/cluster_use"].mean() + } + ) + print(f"{num_scenarios=}, {num_objects=}.\n" f"{exp_state}" f"\n==========\n") + print({k: v for k, v in exp_state.items() if "obj1" in k}) + print({k: v for k, v in exp_state.items() if "obj2" in k}) + print({k: v for k, v in exp_state.items() if "obj3" in k}) + print(f"\n==========\n") + df = pd.DataFrame([exp_state]) + + def applytab(row): + s = "" + for v in row.values: + if isinstance(v, float): + s += f"{v:.3f}" + '\t' + else: + s += str(v) + '\t' + print(s) + + # print('\t'.join(map(str,df.columns))) # to print the column names if required + df.apply(applytab, axis=1) + + df.to_csv(f"{file_name}_EVAL.csv") + + +if __name__ == '__main__': + # grid_search_delta_tokenizer(num_scenarios_limit=1000) + # grid_search_delta_delta_tokenizer(num_scenarios_limit=200, batch_size=32) + # test_precomputed(num_scenarios_limit=2000) + test_delta_delta() + # test_delta() + # test_bicycle_model() diff --git a/scenestreamer/tokenization/trafficgen_tokenizers.py b/scenestreamer/tokenization/trafficgen_tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..20db27be52ea62dd43e7066dfd759f8180f92859 --- /dev/null +++ b/scenestreamer/tokenization/trafficgen_tokenizers.py @@ -0,0 +1,166 @@ +import logging +import pathlib + +import numpy as np +import torch +import torch.nn.functional as F + +from scenestreamer.utils import rotate, wrap_to_pi +from scenestreamer.utils import utils + + +class TrafficGenTokenizerBaseVer: + limit = { + "position_x": (-30, 30), + "position_y": (-20, 20), + "velocity_x": (0, 30), + "velocity_y": (-10, 10), + "heading": (-np.pi / 2, np.pi / 2), + "length": (0.5, 10), + "width": (0.5, 3), + "height": (0.5, 4), + "agent_type": (None, None), + } + num_bins = { + "position_x": 121, + "position_y": 81, + "velocity_x": 61, + "velocity_y": 41, + "heading": 21, + "length": 21, + "width": 11, + "height": 11, + "agent_type": 3, + } + + def __init__(self, config): + self.config = config + self.start_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + self.end_action_id = config.PREPROCESSING.MAX_MAP_FEATURES + 1 + self.INIT_START_ACTION = self.start_action_id + self.INIT_END_ACTION = self.end_action_id + + @classmethod + def bucketize(cls, value, key): + is_torch = isinstance(value, torch.Tensor) + if not is_torch: + value = torch.tensor(value) + limit_min, limit_max = cls.limit[key] + num_bins = cls.num_bins[key] + if limit_min is None: + return value + value = torch.clamp(value, limit_min, limit_max) + ret = torch.round((value - limit_min) / (limit_max - limit_min) * (num_bins - 1)) + if not is_torch: + ret = ret.numpy() + return ret + + @classmethod + def de_bucketize(cls, value, key): + is_torch = isinstance(value, torch.Tensor) + if not is_torch: + value = torch.tensor(value) + limit_min, limit_max = cls.limit[key] + num_bins = cls.num_bins[key] + if limit_min is None: + return value + ret = value / (num_bins - 1) * (limit_max - limit_min) + limit_min + if not is_torch: + ret = ret.numpy() + return ret + + def detokenize(self, data_dict, action, agent_type, offset_action): + B, M, _ = data_dict["encoder/map_position"].shape + action = action.clone().unsqueeze(-1) + assert action.ndim == 3 # B, T, 1 + + is_valid_action = action < self.start_action_id + action[~is_valid_action] = 0 + map_pos = torch.gather(data_dict["encoder/map_position"], dim=1, index=action.expand(-1, -1, 3))[..., :2] + map_head = torch.gather(data_dict["encoder/map_heading"][:, :M], dim=1, index=action.reshape(B, -1)) + + offset_values = {} + for k, a in offset_action.items(): + offset_values[k] = self.de_bucketize(a, k) + + pos = utils.rotate(x=offset_values["position_x"], y=offset_values["position_y"], angle=map_head) + pos = pos + map_pos + + head = wrap_to_pi(offset_values["heading"] + map_head) + + vel = utils.rotate(x=offset_values["velocity_x"], y=offset_values["velocity_y"], angle=map_head) + + shape = torch.stack([offset_values["length"], offset_values["width"], offset_values["height"]], dim=-1) + + # feature is the relative pos/head/vel + feature = torch.stack( + [ + offset_values["position_x"], offset_values["position_y"], offset_values["heading"], + offset_values["velocity_x"], offset_values["velocity_y"] + ], + dim=-1 + ) + + # I am sorry that we change the key names here... + return { + "position": pos, + "velocity": vel, + "heading": head, + "shape": shape, + "agent_type": agent_type, + "feature": feature, + "offset_values": offset_values, + } + + +class TrafficGenTokenizerSpecialVer(TrafficGenTokenizerBaseVer): + limit = { + "position_x": (-10, 10), + "position_y": (-10, 10), + "velocity_x": (0, 30), + "velocity_y": (-10, 10), + "heading": (-np.pi / 4, np.pi / 4), + "length": (0.5, 10), + "width": (0.5, 3), + "height": (0.5, 4), + "agent_type": (None, None), + } + num_bins = { + "position_x": 81, + "position_y": 81, + "velocity_x": 61, + "velocity_y": 41, + "heading": 41, + "length": 41, + "width": 41, + "height": 41, + "agent_type": 3, + } + + +# TODO: Hardcoded... +TrafficGenTokenizer = TrafficGenTokenizerSpecialVer + +class TrafficGenTokenizerAutoregressive(TrafficGenTokenizerBaseVer): + limit = { + "position_x": (-10, 10), + "position_y": (-10, 10), + "velocity_x": (0, 30), + "velocity_y": (-10, 10), + "heading": (-np.pi / 2, np.pi / 2), + "length": (0.5, 10), + "width": (0.5, 3), + "height": (0.5, 4), + "agent_type": (None, None), + } + num_bins = { + "position_x": 81, + "position_y": 81, + "velocity_x": 81, + "velocity_y": 81, + "heading": 81, + "length": 81, + "width": 81, + "height": 81, + "agent_type": 3, + } diff --git a/scenestreamer/train_init.py b/scenestreamer/train_init.py new file mode 100644 index 0000000000000000000000000000000000000000..464068a37257468d91a52c48fb19692dc5f51424 --- /dev/null +++ b/scenestreamer/train_init.py @@ -0,0 +1,199 @@ +import argparse +import datetime +import os +import pathlib +from pathlib import Path + +import lightning.pytorch as pl +import torch +import wandb +from scenestreamer.models.initializer_pl import SceneStreamerInitializer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.utils import global_config, cfg_from_list, cfg_from_yaml_file + +# torch.backends.cudnn.benchmark = True +torch.set_float32_matmul_precision("high") # Enable TF32 matrix multiplication + +REPO_ROOT = pathlib.Path(os.path.dirname(__file__)).parent + + +def get_time_str(): + return datetime.datetime.now().strftime("%Y-%m-%d_%H%M") + + +def parse_config(): + parser = argparse.ArgumentParser(description='arg parser') + parser.add_argument( + '--cfg_file', type=str, default=None, help='The config for training. Related path to the repo root.' + ) + + parser.add_argument('--batch_size', type=int, default=None, required=False, help='batch size for training') + parser.add_argument('--val_batch_size', type=int, default=2, required=False, help='batch size for training') + parser.add_argument('--num_sanity_val_steps', type=int, default=0, required=False, help='batch size for training') + parser.add_argument('--val_num_workers', type=int, default=4, required=False, help='batch size for training') + parser.add_argument('--limit_val_batches', type=int, default=-1, required=False, help='batch size for training') + parser.add_argument('--limit_train_batches', type=int, default=-1, required=False, help='batch size for training') + parser.add_argument('--train_prefetch_factor', type=int, default=2, required=False, help='batch size for training') + parser.add_argument('--epochs', type=int, default=100, required=False, help='number of epochs to train for') + parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataloader') + parser.add_argument('--exp_name', type=str, default='default', help='extra tag for this experiment') + parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from') + parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model') + parser.add_argument('--without_sync_bn', action='store_true', default=False, help='whether to use sync bn') + parser.add_argument('--debug', action='store_true', default=False, help='') + parser.add_argument('--eval', action='store_true', default=False, help='') + parser.add_argument('--precision', type=str, default=None, help='precision') + parser.add_argument('--ckpt_save_interval', type=int, default=2, help='number of training epochs') + parser.add_argument('--local_rank', type=int, default=None, help='local rank for distributed training') + parser.add_argument('--max_ckpt_save_num', type=int, default=5, help='max number of saved checkpoint') + parser.add_argument( + '--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER, help='set extra config keys if needed' + ) + + parser.add_argument('--max_waiting_mins', type=int, default=0, help='max waiting minutes') + parser.add_argument('--start_epoch', type=int, default=0, help='') + parser.add_argument('--save_to_file', action='store_true', default=False, help='') + parser.add_argument('--not_eval_with_train', action='store_true', default=False, help='') + + # PZH: added + parser.add_argument('--wandb', action='store_true', default=False, help='') + + parser.add_argument('--logger_iter_interval', type=int, default=50, help='') + parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds') + + # parser.add_argument('--add_worker_init_fn', action='store_true', default=False, help='') + args = parser.parse_args() + + cfg_file = REPO_ROOT / args.cfg_file + cfg_from_yaml_file(cfg_file, global_config) + + global_config.TAG = Path(args.cfg_file).stem + global_config.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml' + + if args.set_cfgs is not None: + cfg_from_list(args.set_cfgs, global_config) + + return args, global_config + + +def main(): + args, cfg = parse_config() + + exp_name = args.exp_name + max_epochs = args.epochs #or cfg.OPTIMIZATION.NUM_EPOCHS + batch_size = args.batch_size or cfg.OPTIMIZATION.BATCH_SIZE_PER_GPU + val_batch_size = args.val_batch_size or 2 + num_workers = args.num_workers + precision = args.precision + if precision in ["16", 16, "bf16", "bf16-mixed"]: + print("Setting torch.set_float32_matmul_precision('medium') because you are using half precision.") + torch.set_float32_matmul_precision("medium") + else: + print("Do not set torch.set_float32_matmul_precision since you are using full precision.", precision) + + model = SceneStreamerInitializer(cfg=cfg) + + # Setup wandb logger + trial_id = get_time_str() + name = "{}_{}".format(exp_name, trial_id) + if args.wandb: + with open(os.path.abspath(os.path.expanduser("~/wandb_api_key_file.txt")), "rt") as fp: + api_key = fp.readline().strip() + wandb.login(key=api_key) + save_dir = os.path.join(REPO_ROOT, "lightning_logs") + logger = WandbLogger( + name=name, + save_dir=save_dir, + id=name, + project="scenestreamer", + log_model=True, + group=exp_name, + ) + else: + save_dir = os.path.join(REPO_ROOT, "lightning_logs") + logger = TensorBoardLogger(save_dir=save_dir, name=exp_name) + + callbacks = [ + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + monitor="monitoring_step", + every_n_epochs=1, + save_last=True, + auto_insert_metric_name=True, + mode="max", + save_top_k=3, + save_on_train_epoch_end=True, + ), + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + train_time_interval=datetime.timedelta(minutes=15), + auto_insert_metric_name=True, + save_on_train_epoch_end=True, + ) + ] + + # from lightning.pytorch.profilers import PyTorchProfiler + # profiler = PyTorchProfiler(filename="profile") + + trainer_kwargs = dict( + + # Debug only: + # num_sanity_val_steps=2, + # max_epochs=max_epochs, + # profiler=profiler, + # detect_anomaly=True, + num_sanity_val_steps=args.num_sanity_val_steps, + limit_val_batches=args.limit_val_batches if args.limit_val_batches > 0 else None, + limit_train_batches=args.limit_train_batches if args.limit_train_batches > 0 else None, + gradient_clip_val=cfg.OPTIMIZATION.GRAD_NORM_CLIP or None, + max_epochs=max_epochs, + callbacks=callbacks, + logger=logger, + accelerator="auto", + devices="auto", + log_every_n_steps=2, + + # strategy='ddp_find_unused_parameters_true' + ) + + datamodule = SceneStreamerDataModule( + cfg.DATA_CONFIG, + train_batch_size=batch_size, + train_num_workers=num_workers, + train_prefetch_factor=args.train_prefetch_factor, + val_batch_size=val_batch_size, + val_num_workers=args.val_num_workers, + val_prefetch_factor=1, + ) + + # if args.debug: + # trainer_kwargs["limit_train_batches"] = 100 + # trainer_kwargs["limit_val_batches"] = 100 + # trainer_kwargs["max_epochs"] = 2 + + if precision is not None: + trainer_kwargs["precision"] = precision + + if torch.cuda.device_count() > 1: + # trainer_kwargs["strategy"] ='ddp_find_unused_parameters_true' + trainer_kwargs["strategy"] = 'ddp' + + trainer = pl.Trainer(**trainer_kwargs) + + ckpt_path = args.ckpt + if ckpt_path is not None: + ckpt_path = os.path.join(REPO_ROOT, ckpt_path) + assert os.path.isfile(ckpt_path) + assert ckpt_path.endswith(".ckpt") + print("==============================") + print("Loading checkpoint: ", ckpt_path) + print("==============================") + + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + +if __name__ == '__main__': + main() diff --git a/scenestreamer/train_motion.py b/scenestreamer/train_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad56c460a3d95cb6d1d415c3af48d0e775a6f19 --- /dev/null +++ b/scenestreamer/train_motion.py @@ -0,0 +1,240 @@ +import datetime +import os +import pathlib + +import hydra +import lightning.pytorch as pl +import torch +import wandb +from lightning.pytorch.callbacks import LearningRateMonitor +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from lightning.pytorch.utilities.model_summary import summarize +from omegaconf import OmegaConf + +import scenestreamer.utils as utils +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.models.motionlm_lightning import MotionLMLightning +from scenestreamer.utils import REPO_ROOT, get_time_str + +torch.set_float32_matmul_precision('high') + + +@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml") +def main(config): + # Unfreeze the config to allow modification + OmegaConf.set_struct(config, False) + config.ROOT_DIR = REPO_ROOT + OmegaConf.set_struct(config, True) + + from scenestreamer.utils.config import global_config, cfg_from_yaml_file + default_config = cfg_from_yaml_file(REPO_ROOT / "cfgs/motion_default.yaml", global_config) + + pl.seed_everything(config.seed) + print("Everything is seeded to: ", config.seed) + + # Set up config + # cfg_file = REPO_ROOT / config.cfg_file + # config = cfg_from_yaml_file(cfg_file, global_config) + exp_name = config.exp_name + max_epochs = config.epochs #or config.OPTIMIZATION.NUM_EPOCHS + batch_size = config.batch_size + val_batch_size = config.val_batch_size + num_workers = config.num_workers + val_num_workers = config.val_num_workers + log_dir = config.log_dir or None + if log_dir is not None: + log_dir = pathlib.Path(log_dir) + + # Setup wandb logger + trial_id = get_time_str(no_time=True) + name = "{}_{}".format(exp_name, trial_id) + if log_dir: + save_dir = pathlib.Path(log_dir / "lightning_logs") + else: + save_dir = pathlib.Path(os.path.join(REPO_ROOT, "lightning_logs")) + if config.wandb and not config.eval: + with open(os.path.abspath(os.path.expanduser("~/wandb_api_key_file.txt")), "rt") as fp: + api_key = fp.readline().strip() + wandb.login(key=api_key) + logger = WandbLogger( + name=name, + save_dir=save_dir, + id=name, + project="scenestreamer", + log_model=False, + group=exp_name, + ) + else: + logger = TensorBoardLogger(save_dir=save_dir / "scenestreamer", name=name) + + ckpt_save_dir = pathlib.Path(save_dir).absolute() / "scenestreamer" / name + + # Set up trainer arguments + callbacks = [ + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + monitor="monitoring_step", + every_n_epochs=1, + save_last=True, + auto_insert_metric_name=True, + mode="max", + save_top_k=-1, + save_on_train_epoch_end=True, + ), + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + train_time_interval=datetime.timedelta(minutes=30), + auto_insert_metric_name=True, + save_on_train_epoch_end=True, + every_n_train_steps=None, + every_n_epochs=None, + save_top_k=-1, + ), + LearningRateMonitor(logging_interval='step') + ] + device = "auto" if torch.cuda.is_available() else "cpu" + trainer_kwargs = dict( + num_sanity_val_steps=config.num_sanity_val_steps, + limit_val_batches=config.limit_val_batches if config.limit_val_batches >= 0 else None, + limit_train_batches=config.limit_train_batches if config.limit_train_batches >= 0 else None, + gradient_clip_val=config.OPTIMIZATION.GRAD_NORM_CLIP, + max_epochs=max_epochs, + callbacks=callbacks, + logger=logger, + accelerator=device, + devices="auto", + log_every_n_steps=2, + deterministic=config.deterministic, + detect_anomaly=config.detect_anomaly, + check_val_every_n_epoch=config.get("check_val_every_n_epoch", None), + val_check_interval=config.val_check_interval, + # strategy='ddp_find_unused_parameters_true' + # strategy_settings={"timeout": datetime.timedelta(seconds=7200)} + ) + + # If multi GPUs are found: + if torch.cuda.device_count() > 1: + from lightning.pytorch.strategies.ddp import DDPStrategy + trainer_kwargs["strategy"] = DDPStrategy(timeout=datetime.timedelta(seconds=7200)) + + # from lightning.pytorch.profilers import PyTorchProfiler + # profiler = PyTorchProfiler(filename="profile") + # trainer_kwargs.update( + # profiler=profiler, + # ) + + # if config.debug: + # # from lightning.pytorch.profilers import PyTorchProfiler + # # profiler = PyTorchProfiler(filename="profile") + # trainer_kwconfig.update( + # num_sanity_val_steps=0, + # # profiler=profiler, + # detect_anomaly=True, + # limit_val_batches=2, + # limit_train_batches=2, + # log_every_n_steps=1, + # ) + # num_workers = 0 + # val_num_workers = 0 + # if bf16: + # trainer_kwargs["precision"] = "bf16-mixed" + + datamodule = SceneStreamerDataModule( + config, + train_batch_size=batch_size, + train_num_workers=num_workers, + train_prefetch_factor=config.prefetch_factor, + val_batch_size=val_batch_size, + val_num_workers=val_num_workers, + val_prefetch_factor=config.prefetch_factor, + ) + # if torch.cuda.device_count() > 1: + # trainer_kwargs["strategy"] = 'ddp' + # trainer_kwargs["strategy"] = 'ddp_find_unused_parameters_true' + if log_dir: + trainer_kwargs["default_root_dir"] = log_dir + + # Set up trainer + trainer = pl.Trainer(**trainer_kwargs) + + # Set up model + ckpt_path = config.ckpt + if ckpt_path is not None: + ckpt_path = REPO_ROOT / pathlib.Path(ckpt_path).expanduser() + if ckpt_path.is_dir(): + ckpt_path = ckpt_path / "last.ckpt" + ckpt_path = str(ckpt_path.resolve().absolute()) + assert os.path.isfile(ckpt_path), ckpt_path + assert ckpt_path.endswith(".ckpt"), ckpt_path + print("==============================") + print("Loading checkpoint: ", ckpt_path) + print("==============================") + + pretrained_path = config.pretrain + if pretrained_path: + pretrained_path = pathlib.Path(pretrained_path).expanduser() + pretrained_path = REPO_ROOT / pretrained_path + if pretrained_path.is_dir(): + if (pretrained_path / "last.ckpt").exists(): + pretrained_path = pretrained_path / "last.ckpt" + # If only one file: + elif len(list(pretrained_path.glob("*.ckpt"))) == 1: + pretrained_path = list(pretrained_path.glob("*.ckpt"))[0] + else: + raise ValueError( + "Please provide a checkpoint file or a directory with only one checkpoint file or contains the " + "last.ckpt." + ) + pretrained_path = str(pretrained_path.absolute().resolve()) + assert os.path.isfile(pretrained_path), pretrained_path + assert pretrained_path.endswith(".ckpt"), pretrained_path + print("==============================") + print("Loading pretrained model: ", pretrained_path) + print("==============================") + + map_location = None + if not torch.cuda.is_available(): + print("CUDA is not available. Loading model on CPU!") + print("CUDA is not available. Loading model on CPU!") + print("CUDA is not available. Loading model on CPU!") + map_location = "cpu" + + model = utils.load_from_checkpoint( + checkpoint_path=pretrained_path, + cls=MotionLMLightning, + config=config, + default_config=default_config, + strict=True, + checkpoint_surgery_func=utils.checkpoint_surgery_func, + map_location=map_location + ) + # model = MotionLMLightning.load_from_checkpoint(checkpoint_path=pretrained_path, strict=strict, **config) + else: + model = MotionLMLightning(config=config) + model.exp_name = name + + assert model.config == config, "The config system is not working properly! Original:\n{}\n\nNew:\n{}".format( + model.config, config + ) + config_save_path = ckpt_save_dir / "config.yaml" + config_save_path.parent.mkdir(parents=True, exist_ok=True) + utils.rank_zero_print(summarize(model, max_depth=3)) + utils.rank_zero_print("==============================") + utils.rank_zero_print("Root Directory: ", save_dir / "scenestreamer") + utils.rank_zero_print("Checkpoint Log Directory: ", ckpt_save_dir) + utils.rank_zero_print("Config Save Path: ", config_save_path) + utils.rank_zero_print("Exp Group: ", name) + utils.rank_zero_print("Exp Full Name: ", name) + utils.rank_zero_print("==============================") + print("Rank {} is done setting up the model.".format(trainer.global_rank)) + OmegaConf.save(config, config_save_path) + + if config.eval: + trainer.validate(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + else: + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + +if __name__ == '__main__': + main() diff --git a/scenestreamer/train_reltok.py b/scenestreamer/train_reltok.py new file mode 100644 index 0000000000000000000000000000000000000000..74b5ebcea753f761443480eb4d83c1883d0b08fe --- /dev/null +++ b/scenestreamer/train_reltok.py @@ -0,0 +1,195 @@ +import argparse +import datetime +import os +import pathlib + +import lightning.pytorch as pl +import torch +import wandb +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +# from scenestreamer.models.motionlm_lightning import MotionLMLightning +from scenestreamer.tokenization.reltok import ReltokLightning +from scenestreamer.utils import global_config, cfg_from_yaml_file, REPO_ROOT, get_time_str + + +def main(): + parser = argparse.ArgumentParser(description='arg parser') + + # Experiment + parser.add_argument( + '--cfg_file', + type=str, + default="cfgs/motion_debug.yaml", + help='The config file path, relative to the repo root.' + ) + parser.add_argument('--exp_name', type=str, default='train_reltok', help='Experiment name.') + parser.add_argument('--ckpt', type=str, default=None, help='Path to pretrained checkpoint.') + parser.add_argument('--log_dir', type=str, default=None, help='Path to store all logs/ckpts/files.') + parser.add_argument('--debug', action='store_true', default=False, help='Whether to quickly set debug config.') + parser.add_argument('--eval', action='store_true', default=False, help='Whether to evaluate the model.') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--wandb', action='store_true', default=False, help='Whether to use wandb logging.') + + # Training + parser.add_argument('--batch_size', type=int, default=25, required=False, help='Batch size for training.') + parser.add_argument( + '--prefetch_factor', type=int, default=2, required=False, help='Datamodule prefetch factor for training.' + ) + parser.add_argument( + '--limit_train_batches', + type=int, + default=-1, + required=False, + help='Number of validation steps in each iteration.' + ) + parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for dataloader.') + parser.add_argument('--epochs', type=int, default=None, required=False, help='Number of epochs for training.') + + # Validation + parser.add_argument('--val_batch_size', type=int, default=6, required=False, help='Batch size for validation.') + parser.add_argument( + '--val_num_workers', type=int, default=4, help='Number of workers for dataloader in validation.' + ) + parser.add_argument( + '--num_sanity_val_steps', + type=int, + default=20, + required=False, + help='Number of validation steps before first training epoch.' + ) + parser.add_argument( + '--limit_val_batches', + type=int, + default=-1, + required=False, + help='Number of validation steps in each iteration. Default to whole validation dataset.' + ) + + args = parser.parse_args() + + pl.seed_everything(args.seed) + print("Everything is seeded to: ", args.seed) + + # Set up config + cfg_file = REPO_ROOT / args.cfg_file + config = cfg_from_yaml_file(cfg_file, global_config) + exp_name = args.exp_name + max_epochs = args.epochs #or config.OPTIMIZATION.NUM_EPOCHS + batch_size = args.batch_size + val_batch_size = args.val_batch_size + num_workers = args.num_workers + val_num_workers = args.val_num_workers + log_dir = args.log_dir or None + if log_dir is not None: + log_dir = pathlib.Path(log_dir) + + # Setup wandb logger + trial_id = get_time_str() + name = "{}_{}".format(exp_name, trial_id) + if log_dir: + save_dir = log_dir / "lightning_logs" + else: + save_dir = os.path.join(REPO_ROOT, "lightning_logs") + if args.wandb and not args.eval: + with open(os.path.abspath(os.path.expanduser("~/wandb_api_key_file.txt")), "rt") as fp: + api_key = fp.readline().strip() + wandb.login(key=api_key) + logger = WandbLogger( + name=name, + save_dir=save_dir, + id=name, + project="scenestreamer", + log_model=True, + group=exp_name, + ) + else: + logger = TensorBoardLogger(save_dir=save_dir, name=exp_name) + + # Set up trainer arguments + callbacks = [ + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + monitor="monitoring_step", + every_n_epochs=1, + save_last=True, + auto_insert_metric_name=True, + mode="max", + save_top_k=-1, + save_on_train_epoch_end=True, + ), + ModelCheckpoint( + filename=str(name) + "_{epoch}-{step}", + train_time_interval=datetime.timedelta(minutes=30), + auto_insert_metric_name=True, + save_on_train_epoch_end=True, + every_n_train_steps=None, + every_n_epochs=None, + ) + ] + trainer_kwargs = dict( + num_sanity_val_steps=args.num_sanity_val_steps, + limit_val_batches=args.limit_val_batches if args.limit_val_batches > 0 else None, + limit_train_batches=args.limit_train_batches if args.limit_train_batches > 0 else None, + gradient_clip_val=config.OPTIMIZATION.GRAD_NORM_CLIP, + max_epochs=max_epochs, + callbacks=callbacks, + logger=logger, + accelerator="auto", + devices="auto", + log_every_n_steps=2, + # strategy='ddp_find_unused_parameters_true' + ) + if args.debug: + # from lightning.pytorch.profilers import PyTorchProfiler + # profiler = PyTorchProfiler(filename="profile") + trainer_kwargs.update( + num_sanity_val_steps=0, + # profiler=profiler, + detect_anomaly=True, + limit_val_batches=2, + limit_train_batches=2, + log_every_n_steps=1, + ) + num_workers = 0 + val_num_workers = 0 + datamodule = SceneStreamerDataModule( + config, + train_batch_size=batch_size, + train_num_workers=num_workers, + train_prefetch_factor=args.prefetch_factor, + val_batch_size=val_batch_size, + val_num_workers=val_num_workers, + val_prefetch_factor=args.prefetch_factor, + ) + if torch.cuda.device_count() > 1: + trainer_kwargs["strategy"] = 'ddp' + # trainer_kwargs["strategy"] = 'ddp_find_unused_parameters_true' + if log_dir: + trainer_kwargs["default_root_dir"] = log_dir + + # Set up trainer + trainer = pl.Trainer(**trainer_kwargs) + + # Set up model + ckpt_path = args.ckpt + if ckpt_path is not None: + ckpt_path = os.path.join(REPO_ROOT, ckpt_path) + assert os.path.isfile(ckpt_path), ckpt_path + assert ckpt_path.endswith(".ckpt"), ckpt_path + print("==============================") + print("Loading checkpoint: ", ckpt_path) + print("==============================") + + model = ReltokLightning(config=config) + + if args.eval: + trainer.validate(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + else: + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + +if __name__ == '__main__': + main() diff --git a/scenestreamer/utils/__init__.py b/scenestreamer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe58582dfeb6b0116f07476548ecbaca264e4a2e --- /dev/null +++ b/scenestreamer/utils/__init__.py @@ -0,0 +1,2 @@ +from scenestreamer.utils.config import * # noqa: F403 +from scenestreamer.utils.utils import * # noqa: F403 diff --git a/scenestreamer/utils/autoregressive_rollout.py b/scenestreamer/utils/autoregressive_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..d09bf5502508080e0f856935d098fabdfb9d1b71 --- /dev/null +++ b/scenestreamer/utils/autoregressive_rollout.py @@ -0,0 +1,310 @@ +import torch + +from scenestreamer.models.motionlm import nucleus_sampling +from scenestreamer.tokenization.gen_tokenizers import Tokens, SceneStreamerTokenizer + + +def sample_action(logits, config): + sampling_method = config.SAMPLING.SAMPLING_METHOD + temperature = config.SAMPLING.TEMPERATURE + topp = config.SAMPLING.TOPP + + if sampling_method == "argmax": + selected_action = logits.argmax(-1) + elif sampling_method == "softmax": + selected_action = torch.distributions.Categorical(logits=logits / temperature).sample() + elif sampling_method == "topp": + selected_action = nucleus_sampling(logits=logits / temperature, p=topp) + else: + raise ValueError("Unknown sampling method: {}".format(sampling_method)) + + return selected_action + + +def mask_out_invalid_actions(logits, valid_min=None, valid_max=None, valids=None): + mask = logits.new_ones(logits.shape).bool() # 1: to be filled, 0: good to go + + if valids is not None: + for v in valids: + mask[..., v].fill_(0) + + if valid_min is not None: + assert valid_max is not None + mask[..., valid_min:valid_max].fill_(0) + + logits = logits.masked_fill(mask, -1e9) + return logits + + +class StateMachine: + def __init__( + self, *, state, init_tokens, init_valid_mask, causal_mask_offset, map_ids, agent_ids, config, batch_id, step=0 + ): + self.state = state + + end_index = int(init_valid_mask.sum(-1)) + assert init_valid_mask[:end_index].all().item() is True + assert init_valid_mask[end_index:].any().item() is False + self.tokens = Tokens.create( + ids=init_tokens[:end_index], + mask=init_valid_mask[:end_index], + causal_mask_offset=causal_mask_offset[:end_index], + length=end_index, + use_numpy=False, + ) + + self.start_index = 0 + self.end_index = end_index + self.config = config + self.map_ids = set(map_ids) + self.agent_ids = set(agent_ids) + self.batch_id = batch_id + self.step = step + self.intra_step_start = 0 + self.intra_step_end = end_index + + def update(self, model_output): + """ + According to current state, read the model's output (do some indexing / slicing etc.), and update the state. + Return the parsed new tokens. + """ + batch_id = self.batch_id + assert model_output.ndim == 2 + + if self.state == SceneStreamerTokenizer.UPDATE_START: + # Read N actions, set state to UPDATE_END + return self.process_UPDATE_START(model_output) + + if self.state == SceneStreamerTokenizer.UPDATE_END: + # Read 1 action, set state to REMOVE_START or STEP_END according to the input + return self.process_UPDATE_END(model_output) + + elif self.state == SceneStreamerTokenizer.REMOVE_START: + # Read 1 agent_id or REMOVE_END + return self.process_REMOVE_START(model_output) + + elif self.state == SceneStreamerTokenizer.STEP_START: + # Read 1 action, set state to UPDATE_START or ADD_START according to the input + return self.process_STEP_START(model_output) + + else: + raise ValueError(f"Invalid state: {self.state}") + + def process_UPDATE_START(self, model_output): + """Read N actions, set state to UPDATE_END.""" + out = model_output[self.start_index:self.end_index] + + action_id_min, action_id_max = SceneStreamerTokenizer.get_action_id_range(self.config) + out = mask_out_invalid_actions(out, valid_min=action_id_min, valid_max=action_id_max) + + action = sample_action(out, self.config) + + length = action.shape[0] + action_tokens = Tokens.create( + ids=action, + mask=self.tokens.mask[self.start_index:self.end_index], + causal_mask_offset=self.tokens.causal_mask_offset.new_ones(length) * length, + length=length, + use_numpy=False + ) + new_tokens = Tokens.concatenate( + [action_tokens, SceneStreamerTokenizer.get_update_end_tokens(use_numpy=False, device=out.device)] + ) + self.tokens = Tokens.concatenate([self.tokens, new_tokens]) + + self.state = SceneStreamerTokenizer.UPDATE_END + print(f"Scenario {self.batch_id} change state from UPDATE_START to UPDATE_END") + + self.start_index += new_tokens.length + self.end_index = self.start_index + 1 # Looking for REMOVE_START or STEP_END + + self.intra_step_start = self.intra_step_end + self.intra_step_end += new_tokens.length + + return { + "tokens": new_tokens, + "step": self.tokens.causal_mask_offset.new_ones(new_tokens.length) * self.step, + "intra_step": torch.arange(self.intra_step_start, self.intra_step_end).to(out.device) + } + + def process_UPDATE_END(self, model_output): + """Read 1 action, set state to REMOVE_START or STEP_END according to the input.""" + out = model_output[self.start_index:self.end_index] + + out = mask_out_invalid_actions(out, valids=[SceneStreamerTokenizer.STEP_END, SceneStreamerTokenizer.REMOVE_START]) + + action = sample_action(out, self.config) + + a = action.item() + + # TODO ===== Fix in future + if a == SceneStreamerTokenizer.REMOVE_START: + a = SceneStreamerTokenizer.STEP_END + # TODO ===== Fix in future + + if a == SceneStreamerTokenizer.REMOVE_START: + # TODO double check + self.state = SceneStreamerTokenizer.REMOVE_START + print(f"Scenario {self.batch_id} change state from UPDATE_END to REMOVE_START") + self.start_index += 1 + self.end_index = self.start_index + 1 # Looking for AGENT_ID to be removed + out = action + + elif a == SceneStreamerTokenizer.STEP_END: + out = torch.cat([action, torch.as_tensor([SceneStreamerTokenizer.STEP_START], dtype=out.dtype, device=out.device)]) + + self.state = SceneStreamerTokenizer.STEP_START + print(f"Scenario {self.batch_id} change state from UPDATE_END to STEP_START") + + self.start_index += out.shape[0] + self.end_index = self.start_index + 1 # Looking for REMOVE_START or STEP_END + out = action + + else: + raise ValueError(f"Invalid action: {a}") + + # TODO: Make tokens here + return out + + def process_REMOVE_START(self, model_output): + # Read 1 agent_id or REMOVE_END + out = model_output[self.start_index:self.end_index] + + # Options: REMOVE_END, AGENT_ID (that are valid now) + out = mask_out_invalid_actions(out, valids=[SceneStreamerTokenizer.REMOVE_END] + list(self.agent_ids)) + + action = sample_action(out, self.config) + + if action == SceneStreamerTokenizer.REMOVE_END: + # TODO + # Automatically add REMOVE_END, STEP_END, STEP_START + self.start_index += 1 + self.end_index = self.start_index + 3 + pass + else: + assert action in self.agent_ids + self.agent_ids.remove(action) + + # no need to change state. + self.start_index += 1 + self.end_index = self.start_index + 1 + + # TODO: Make tokens + return None + + def process_STEP_START(self, model_output): + out = model_output[self.start_index:self.end_index] + + out = mask_out_invalid_actions(out, valids=[SceneStreamerTokenizer.UPDATE_START, SceneStreamerTokenizer.ADD_START]) + + action = sample_action(out, self.config) + + a = action.item() + if a == SceneStreamerTokenizer.UPDATE_START: + # Add all existing agent_id to the tokens. + + self.state = SceneStreamerTokenizer.UPDATE_START + print(f"Scenario {self.batch_id} change state from STEP_START to UPDATE_START") + + out = torch.cat([action, self.agent_ids]) + + self.start_index += 1 + self.end_index = self.start_index + 1 # Looking for AGENT_ID to be removed + out = action + + elif a == SceneStreamerTokenizer.STEP_END: + out = torch.cat([action, torch.as_tensor([SceneStreamerTokenizer.STEP_START], out.dtype, out.device)]) + + self.state = SceneStreamerTokenizer.STEP_START + print(f"Scenario {self.batch_id} change state from STEP_START to STEP_START") + self.start_index += out.shape[0] + self.end_index = self.start_index + 1 # Looking for REMOVE_START or STEP_END + out = action + + +class ARRollout: + """ + This class helps organize the rollout of the autoregressive model. + """ + def __init__(self, init_tokens, init_valid_mask, causal_mask_offset, map_ids, agent_ids, config): + self.B = init_tokens.shape[0] + self.config = config + self.states = [ + StateMachine( + state=SceneStreamerTokenizer.UPDATE_START, + init_tokens=init_tokens[i], + init_valid_mask=init_valid_mask[i], + causal_mask_offset=causal_mask_offset[i], + map_ids=map_ids[i], + agent_ids=agent_ids[i], + config=config, + batch_id=i + ) for i in range(self.B) + ] + + def get_tokens(self): + """ + We truncate the tokens for each scenario i to the range start_indices[i] to end_indices[i], + stack them to for a batched tokens, and apply padding. + """ + tokens = [s.tokens for s in self.states] + max_len = max([t.length for t in tokens]) + + # padding all masks + for t in tokens: + t.mask = torch.nn.functional.pad(t.mask, (0, max_len - t.length), value=0) + + # padding all causal_mask_offset + for t in tokens: + t.causal_mask_offset = torch.nn.functional.pad(t.causal_mask_offset, (0, max_len - t.length), value=-1) + + # padding all ids + for t in tokens: + t.ids = torch.nn.functional.pad(t.ids, (0, max_len - t.length), value=-1) + + # Stack + out = Tokens.create( + ids=torch.stack([t.ids for t in tokens], dim=0), + mask=torch.stack([t.mask for t in tokens], dim=0), + causal_mask_offset=torch.stack([t.causal_mask_offset for t in tokens], dim=0), + length=max_len, + use_numpy=False + ) + + return out + + def update(self, logits): + """ + Let's say in the get_tokens function, you get a batch of tokens each scenario has valid tokens: + [ + end_indices[0] - start_indices[0] + ... + end_indices[i] - start_indices[i] + ... + end_indices[B-1] - start_indices[B-1] + ]. The maximum number of tokens is L. + You call the model, which will return you something in shape (B, L, D) - before sampling or + (B, L) -after sampling. + Now, how we append the new tokens to the old tokens? + There is a significant challenge that the number of new tokens to be added might vary in different scenarios. + Scenario A might be updating the states so there are N new tokens. + Scenario B might be adding new object so there should only has 1 new token. + We will handover the right to determine how many tokens should be appended to the external function. + And faithfully append the tokens to the old tokens. + """ + assert logits.shape[0] == self.B + + output = [] + step = [] + intra_step = [] + for b, logits_per_scenario in enumerate(logits): + + out = self.states[b].update(logits_per_scenario) + output.append(out["tokens"]) + step.append(out["step"]) + intra_step.append(out["intra_step"]) + + max_len = max([t.length for t in output]) + + # print(1111) + return output diff --git a/scenestreamer/utils/config.py b/scenestreamer/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb152a2a6b41decc9dece67a6f851d283564151 --- /dev/null +++ b/scenestreamer/utils/config.py @@ -0,0 +1,97 @@ +from pathlib import Path + +import yaml +from easydict import EasyDict + +REPO_ROOT = (Path(__file__).resolve().parent / '../../').resolve() + + +def log_config_to_file(cfg, pre='config', logger=None): + for key, val in cfg.items(): + if isinstance(cfg[key], EasyDict): + logger.info('\n%s.%s = edict()' % (pre, key)) + log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) + continue + logger.info('%s.%s: %s' % (pre, key, val)) + + +def cfg_from_list(cfg_list, config): + """Set config keys via list (e.g., from command line).""" + from ast import literal_eval + assert len(cfg_list) % 2 == 0 + for k, v in zip(cfg_list[0::2], cfg_list[1::2]): + key_list = k.split('.') + d = config + for subkey in key_list[:-1]: + assert subkey in d, 'NotFoundKey: %s' % subkey + d = d[subkey] + subkey = key_list[-1] + assert subkey in d, 'NotFoundKey: %s' % subkey + try: + value = literal_eval(v) + except: + value = v + + if type(value) != type(d[subkey]) and isinstance(d[subkey], EasyDict): + key_val_list = value.split(',') + for src in key_val_list: + cur_key, cur_val = src.split(':') + val_type = type(d[subkey][cur_key]) + cur_val = val_type(cur_val) + d[subkey][cur_key] = cur_val + elif type(value) != type(d[subkey]) and isinstance(d[subkey], list): + val_list = value.split(',') + for k, x in enumerate(val_list): + val_list[k] = type(d[subkey][0])(x) + d[subkey] = val_list + else: + assert type(value) == type(d[subkey]), \ + 'type {} does not match original type {}'.format(type(value), type(d[subkey])) + d[subkey] = value + + +def merge_new_config(config, new_config): + if '_BASE_CONFIG_' in new_config: + with open(new_config['_BASE_CONFIG_'], 'r') as f: + try: + yaml_config = yaml.load(f, Loader=yaml.FullLoader) + except: + yaml_config = yaml.load(f) + config.update(EasyDict(yaml_config)) + + for key, val in new_config.items(): + if not isinstance(val, dict): + config[key] = val + continue + if key not in config: + config[key] = EasyDict() + merge_new_config(config[key], val) + + return config + + +def cfg_from_yaml_file(cfg_file, config): + with open(cfg_file, 'r') as f: + try: + new_config = yaml.load(f, Loader=yaml.FullLoader) + except: + new_config = yaml.load(f) + + merge_new_config(config=config, new_config=new_config) + + return config + + +# Set the precision globally. +# torch.backends.cudnn.benchmark = True +# torch.set_float32_matmul_precision("high") # Enable TF32 matrix multiplication + +DEFAULT_CONFIG_PATH = REPO_ROOT / "cfgs/motion_default.yaml" +with open(DEFAULT_CONFIG_PATH, 'r') as f: + try: + global_config = yaml.load(f, Loader=yaml.FullLoader) + except: + global_config = yaml.load(f) +global_config = EasyDict(global_config) +global_config.ROOT_DIR = REPO_ROOT +global_config.LOCAL_RANK = 0 diff --git a/scenestreamer/utils/debug_tools.py b/scenestreamer/utils/debug_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d275611ddd57078e6e81f15b6a20d6381277d539 --- /dev/null +++ b/scenestreamer/utils/debug_tools.py @@ -0,0 +1,57 @@ +import resource + +from scenestreamer.dataset.datamodule import SceneStreamerDataModule +from scenestreamer.utils.config import global_config, cfg_from_yaml_file +from scenestreamer.utils.utils import REPO_ROOT + +DEBUG_CONFIG_FILE = "cfgs/motion_debug.yaml" + + +def get_debug_config(cfg_file=DEBUG_CONFIG_FILE): + config = cfg_from_yaml_file(REPO_ROOT / cfg_file, global_config) + return config + + +def using(point=""): + usage = resource.getrusage(resource.RUSAGE_SELF) + return '''%s: usertime=%s systime=%s mem=%s mb + ''' % (point, usage[0], usage[1], usage[2] / 1024.0) + + +def get_debug_dataloader( + cfg_file=DEBUG_CONFIG_FILE, + in_evaluation=True, + config=None, + train_batch_size=10, + train_num_workers=0, + val_batch_size=1, + val_num_workers=0, +): + if config is None: + config = get_debug_config(cfg_file=cfg_file) + datamodule = SceneStreamerDataModule( + config, + train_batch_size=train_batch_size, + train_num_workers=train_num_workers, + val_batch_size=val_batch_size, + val_num_workers=val_num_workers, + train_prefetch_factor=2, + val_prefetch_factor=2 + ) + datamodule.setup("fit") + if in_evaluation: + dataloader = datamodule.val_dataloader() + else: + dataloader = datamodule.train_dataloader() + return dataloader + + +def get_debug_data(cfg_file=DEBUG_CONFIG_FILE, in_evaluation=True): + dataloader = get_debug_dataloader(cfg_file, in_evaluation) + for data in dataloader: + return data + + +if __name__ == '__main__': + data = get_debug_data() + print(1) diff --git a/scenestreamer/utils/ema.py b/scenestreamer/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..0f741325adad86d8ff245e5feba57692cb01b8bf --- /dev/null +++ b/scenestreamer/utils/ema.py @@ -0,0 +1,258 @@ +""" +PZH NOTE: Code from https://github.com/nkicsl/Resfusion/blob/c84d5d790b8bc397f17e38945b2f105ea55da476/callback/ema.py +""" +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): +# ------------------------------------------------------------------------------------------------------------------------------------- + +import os +import os.path +import warnings +from typing import Any, Dict, List, Optional + +import lightning as pl +import torch +from lightning import Callback +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.utilities import rank_zero_warn, rank_zero_info +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor + +try: + import amp_C + + apex_available = True +except Exception: + apex_available = False + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + apply_ema_every_n_steps: Apply EMA every n global steps. + start_step: Start applying EMA from ``start_step`` global step onwards. + save_ema_weights_in_callback_state: Enable saving EMA weights in callback state. + evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights. + Note this means that when saving the model, the validation metrics are calculated with the EMA weights. + + Adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py + """ + def __init__( + self, + decay: float = 0.999, + apply_ema_every_n_steps: int = 1, + start_step: int = 0, + # else .ckpt will save a model weights copy in key 'callback' + save_ema_weights_in_callback_state: bool = False, + evaluate_ema_weights_instead: bool = True, + ): + if not apex_available: + rank_zero_warn( + "EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation." + ) + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self._ema_model_weights: Optional[List[torch.Tensor]] = None + self._overflow_buf: Optional[torch.Tensor] = None + self._cur_step: Optional[int] = None + self._weights_buffer: Optional[List[torch.Tensor]] = None + self.apply_ema_every_n_steps = apply_ema_every_n_steps + self.start_step = start_step + self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state + self.evaluate_ema_weights_instead = evaluate_ema_weights_instead + self.decay = decay + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + rank_zero_info("Creating EMA weights copy.") + if self._ema_model_weights is None: + self._ema_model_weights = [p.detach().clone() for p in pl_module.denoising_module.state_dict().values()] + # ensure that all the weights are on the correct device + self._ema_model_weights = [p.to(pl_module.device) for p in self._ema_model_weights] + self._overflow_buf = torch.IntTensor([0]).to(pl_module.device) + + def ema(self, pl_module: "pl.LightningModule") -> None: + if apex_available and pl_module.device.type == "cuda": + return self.apply_multi_tensor_ema(pl_module) + return self.apply_ema(pl_module) + + def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None: + model_weights = list(pl_module.denoising_module.state_dict().values()) + amp_C.multi_tensor_axpby( + 65536, + self._overflow_buf, + [self._ema_model_weights, model_weights, self._ema_model_weights], + self.decay, + 1 - self.decay, + -1, + ) + + def apply_ema(self, pl_module: "pl.LightningModule") -> None: + for orig_weight, ema_weight in zip(list(pl_module.denoising_module.state_dict().values()), + self._ema_model_weights): + if ema_weight.data.dtype != torch.long and orig_weight.data.dtype != torch.long: + # ensure that non-trainable parameters (e.g., feature distributions) are not included in EMA weight averaging + diff = ema_weight.data - orig_weight.data + diff.mul_(1.0 - self.decay) + ema_weight.sub_(diff) + + def should_apply_ema(self, step: int) -> bool: + return step != self._cur_step and step >= self.start_step and step % self.apply_ema_every_n_steps == 0 + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if self.should_apply_ema(trainer.global_step): + self._cur_step = trainer.global_step + self.ema(pl_module) + + def state_dict(self) -> Dict[str, Any]: + if self.save_ema_weights_in_callback_state: + return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights) + return dict(cur_step=self._cur_step) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._cur_step = state_dict["cur_step"] + # when loading within apps such as NeMo, EMA weights will be loaded by the experiment manager separately + if self._ema_model_weights is None: + self._ema_model_weights = state_dict.get("ema_weights") + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + if trainer.ckpt_path and checkpoint_callback is not None: + ext = checkpoint_callback.FILE_EXTENSION + if trainer.ckpt_path.endswith(f"-EMA{ext}"): + rank_zero_info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = trainer.ckpt_path.replace(ext, f"-EMA{ext}") + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu")) + self._ema_model_weights = ema_state_dict["state_dict"].values() + del ema_state_dict + rank_zero_info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") + else: + warnings.warn( + "we were unable to find the associated EMA weights when re-loading, " + "training will start with new EMA weights.", + UserWarning, + ) + + def replace_model_weights(self, pl_module: "pl.LightningModule") -> None: + self._weights_buffer = [p.detach().clone().to("cpu") for p in pl_module.denoising_module.state_dict().values()] + new_state_dict = {k: v for k, v in zip(pl_module.denoising_module.state_dict().keys(), self._ema_model_weights)} + pl_module.denoising_module.load_state_dict(new_state_dict) + + def restore_original_weights(self, pl_module: "pl.LightningModule") -> None: + state_dict = pl_module.denoising_module.state_dict() + new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)} + pl_module.denoising_module.load_state_dict(new_state_dict) + del self._weights_buffer + + @property + def ema_initialized(self) -> bool: + return self._ema_model_weights is not None + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.replace_model_weights(pl_module) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.restore_original_weights(pl_module) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.replace_model_weights(pl_module) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema_initialized and self.evaluate_ema_weights_instead: + self.restore_original_weights(pl_module) + + +class EMAModelCheckpoint(ModelCheckpoint): + """ + Light wrapper around Lightning's `ModelCheckpoint` to, upon request, save an EMA copy of the model as well. + + Adapted from: https://github.com/NVIDIA/NeMo/blob/be0804f61e82dd0f63da7f9fe8a4d8388e330b18/nemo/utils/exp_manager.py#L744 + """ + def __init__(self, **kwargs): + # call the parent class constructor with the provided kwargs + super().__init__(**kwargs) + + def _get_ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]: + ema_callback = None + for callback in trainer.callbacks: + if isinstance(callback, EMA): + ema_callback = callback + return ema_callback + + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + ema_callback = self._get_ema_callback(trainer) + if ema_callback is not None: + # save EMA copy of the model as well + ema_callback.replace_model_weights(trainer.lightning_module) + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) + ema_callback.restore_original_weights(trainer.lightning_module) + + def _ema_format_filepath(self, filepath: str) -> str: + return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}") + + # only change the last line + def _update_best_and_save( + self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] + ) -> None: + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k + + del_filepath = None + if len(self.best_k_models) == k and k > 0: + del_filepath = self.kth_best_model_path + self.best_k_models.pop(del_filepath) + + # do not save nan, replace with +/- inf + if isinstance(current, Tensor) and torch.isnan(current): + current = torch.tensor(float("inf" if self.mode == "min" else "-inf"), device=current.device) + + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath) + + # save the current score + self.current_score = current + self.best_k_models[filepath] = current + + if len(self.best_k_models) == k: + # monitor dict has reached k elements + _op = max if self.mode == "min" else min + self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] + self.kth_value = self.best_k_models[self.kth_best_model_path] + + _op = min if self.mode == "min" else max + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] + self.best_model_score = self.best_k_models[self.best_model_path] + + if self.verbose: + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] + rank_zero_info( + f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" + f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" + ) + self._save_checkpoint(trainer, filepath) + + if del_filepath is not None and filepath != del_filepath: + self._remove_checkpoint(trainer, del_filepath) + self._remove_checkpoint(trainer, del_filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}")) diff --git a/scenestreamer/utils/lr_schedule.py b/scenestreamer/utils/lr_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcbc04b07471b488edf3a0df3f9fb84e761ab36 --- /dev/null +++ b/scenestreamer/utils/lr_schedule.py @@ -0,0 +1,127 @@ +import logging +import math +from functools import partial + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +logger = logging.getLogger(__file__) + +# From: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/optimization.py#L296 + + +def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + shift = timescale - num_warmup_steps + decay = 1.0 / math.sqrt((current_step + shift) / timescale) + return decay + + +def get_inverse_sqrt_schedule(optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1): + """ + Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a + warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + timescale (`int`, *optional*, defaults to `num_warmup_steps`): + Time scale. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + # Note: this implementation is adapted from + # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930 + + if timescale is None: + timescale = num_warmup_steps + + lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + progress = min(progress, 1.0) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/scenestreamer/utils/safety_critical_generation_utils.py b/scenestreamer/utils/safety_critical_generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d03316c163db7f194573ea87263a88a851c8e0fa --- /dev/null +++ b/scenestreamer/utils/safety_critical_generation_utils.py @@ -0,0 +1,451 @@ +import PIL +import hydra +import matplotlib.pyplot as plt +import numpy as np +import omegaconf +from omegaconf import DictConfig +from omegaconf import OmegaConf +import seaborn as sns +from matplotlib.animation import FFMpegWriter +from matplotlib.patches import Polygon, Circle, Rectangle + +from scenestreamer.dataset.dataset import SceneStreamerDataset +from scenestreamer.utils import REPO_ROOT +import torch +import copy +import pdb +import pathlib + + +def _overwrite_data_given_agents_not_ooi(original_data_dict, data_dict, ooi): + new_data_dict = copy.deepcopy(original_data_dict) + + B, T, N, _ = data_dict["decoder/reconstructed_position"].shape + + assert B == 1 + for b in range(B): + for aid in range(N): + if aid in ooi: + continue + traj = data_dict["decoder/reconstructed_position"][b, :91, aid, ] + traj_mask = data_dict["decoder/reconstructed_valid_mask"][b, :91, aid] + vel = data_dict['decoder/reconstructed_velocity'][b, :91, aid] + theta = data_dict['decoder/reconstructed_heading'][b, :91, aid] + + new_data_dict["decoder/agent_position"][b, :, aid, :2] = traj + new_data_dict["decoder/agent_position"][b, :, aid, 2] = 0.0 + new_data_dict["decoder/agent_valid_mask"][b, :, aid] = traj_mask + new_data_dict["decoder/agent_heading"][b, :, aid] = theta + new_data_dict["decoder/agent_velocity"][b, :, aid] = vel + new_data_dict["decoder/agent_shape"][b, :, aid] = new_data_dict["decoder/current_agent_shape"][b, aid] + + return new_data_dict + + +def get_ego_edge_points(x, y, theta, width, length): + # Calculate each corner of the rectangle + left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_front = np.array([left_front_x, left_front_y]) + + right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_front = np.array([right_front_x, right_front_y]) + + right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_back = np.array([right_back_x, right_back_y]) + + left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_back = np.array([left_back_x, left_back_y]) + + # Function to calculate intermediate points on an edge + def sample_edge_points(start, end, num_points=2): + return [start + (end - start) * (i / (num_points + 1)) for i in range(1, num_points + 1)] + + # Sample points on each edge + front_edge_points = sample_edge_points(left_front, right_front) + right_edge_points = sample_edge_points(right_front, right_back) + back_edge_points = sample_edge_points(right_back, left_back) + left_edge_points = sample_edge_points(left_back, left_front) + + # Combine all points: corners and sampled edge points + polygon_contour = np.array( + [ + left_front, *front_edge_points, right_front, *right_edge_points, right_back, *back_edge_points, left_back, + *left_edge_points + ] + ) + + return polygon_contour + + +def get_ego_edge_points_old(agent_position, length, width): + """ + Returns 8 evenly spaced points on the edge of an agent's rectangular contour. + """ + # Calculate half dimensions + half_length = length / 2 + half_width = width / 2 + + # Define the contour as a rectangle centered at the agent's position + x, y = agent_position + contour_points = [ + (x - half_length, y - half_width), # Bottom-left + (x + half_length, y - half_width), # Bottom-right + (x + half_length, y + half_width), # Top-right + (x - half_length, y + half_width) # Top-left + ] + contour = Polygon(contour_points) + + # Calculate 2 points per side (8 points total on the edge) + edge_points = [] + for i in range(len(contour_points)): + start = contour_points[i] + end = contour_points[(i + 1) % len(contour_points)] + + points = [(start[0] + (end[0] - start[0]) * t / 3, start[1] + (end[1] - start[1]) * t / 3) for t in range(1, 3)] + edge_points.extend(points) + + return edge_points + + +def post_process_adv_traj(data_dict, adv_id, sdc_id=0): + """ + + """ + from scenestreamer.dataset.preprocess_action_label import get_safety_action_from_sdc_adv, cal_polygon_contour + # import pdb; pdb.set_trace() + collision_label = np.array(get_safety_action_from_sdc_adv(data_dict, adv_id, sdc_id)) + + if not np.any(collision_label): + return data_dict, False + else: + first_collision_step = np.argmax(collision_label) + if first_collision_step < 5: + return data_dict, False + + print("first_collision_step", first_collision_step) + + adv_mask = data_dict["decoder/agent_valid_mask"][:, adv_id] + adv_mask[first_collision_step:] = False + data_dict["decoder/agent_valid_mask"][:, adv_id] = adv_mask + + return data_dict, True + + +def _overwrite_data_given_agents_ooi(original_data_dict, data_dict, ooi): + new_data_dict = copy.deepcopy(original_data_dict) + + B, T, N, _ = data_dict["decoder/reconstructed_position"].shape + + assert B == 1 + for b in range(B): + for aid in ooi: + traj = data_dict["decoder/reconstructed_position"][b, :91, aid, ] + traj_mask = data_dict["decoder/reconstructed_valid_mask"][b, :91, aid] + vel = data_dict['decoder/reconstructed_velocity'][b, :91, aid] + theta = data_dict['decoder/reconstructed_heading'][b, :91, aid] + + new_data_dict["decoder/agent_position"][b, :, aid, :2] = traj + new_data_dict["decoder/agent_position"][b, :, aid, 2] = 0.0 + new_data_dict["decoder/agent_valid_mask"][b, :, aid] = traj_mask + new_data_dict["decoder/agent_heading"][b, :, aid] = theta + new_data_dict["decoder/agent_velocity"][b, :, aid] = vel + new_data_dict["decoder/agent_shape"][b, :, aid] = new_data_dict["decoder/current_agent_shape"][b, aid] + + return new_data_dict + + +def _overwrite_data_given_agents(original_data_dict, data_dict, sdc_id, adv_id): + new_data_dict = copy.deepcopy(original_data_dict) + + T, N, _ = data_dict["decoder/reconstructed_position"].shape + + # for id in ooi_arr: # overwrite all agents + traj = data_dict["decoder/reconstructed_position"][:91, sdc_id, ] + traj_mask = data_dict["decoder/reconstructed_valid_mask"][:91, sdc_id] + theta = data_dict['decoder/reconstructed_heading'][:91, sdc_id] + + new_data_dict["decoder/agent_position"][:, sdc_id, :2] = traj + new_data_dict["decoder/agent_position"][:, sdc_id, 2] = 0.0 + new_data_dict["decoder/agent_valid_mask"][:, sdc_id] = traj_mask + new_data_dict["decoder/agent_heading"][:, sdc_id] = theta + + adv_traj = data_dict["decoder/reconstructed_position"][:91, adv_id][:, None] + new_dim = np.zeros((adv_traj.shape[0], adv_traj.shape[1], 1)) + adv_traj = np.concatenate([adv_traj, new_dim], axis=-1) + adv_traj_mask = data_dict["decoder/reconstructed_valid_mask"][:91, adv_id][:, None] + adv_theta = data_dict['decoder/reconstructed_heading'][:91, adv_id][:, None] + + new_data_dict["decoder/agent_position"] = np.concatenate( + [new_data_dict["decoder/agent_position"], adv_traj], axis=1 + ) + new_data_dict["decoder/agent_valid_mask"] = np.concatenate( + [new_data_dict["decoder/agent_valid_mask"], adv_traj_mask], axis=1 + ) + new_data_dict["decoder/agent_heading"] = np.concatenate([new_data_dict["decoder/agent_heading"], adv_theta], axis=1) + + return new_data_dict + + +def check_last_step_sdc_adv_collision(data_dict, sdc_id, sdc_pos, sdc_heading, adv_id, adv_pos, adv_heading): + + from scenestreamer.dataset.preprocess_action_label import cal_polygon_contour, detect_collision + contours = [] + + sdc_contour = cal_polygon_contour( + sdc_pos[0], sdc_pos[1], sdc_heading, data_dict["decoder/agent_shape"][10, sdc_id, 1], + data_dict["decoder/agent_shape"][10, sdc_id, 0] + ) + adv_contour = cal_polygon_contour( + adv_pos[0], adv_pos[1], adv_heading, data_dict["decoder/agent_shape"][10, adv_id, 1], + data_dict["decoder/agent_shape"][10, adv_id, 0] + ) + + collision_tags = detect_collision(adv_contour, [True], sdc_contour, [True]) + collision_detected = np.array(collision_tags) + + if np.any(collision_detected): + print("collision") + return True + + return False + + +def set_adv(data_dict): + """ + here is the current design: from existing agents, choose the one with its lastest step having nearest distance among all + """ + ego_id = data_dict["decoder/sdc_index"] + ego_traj = data_dict["decoder/agent_position"][:, ego_id] + ego_heading = data_dict["decoder/agent_heading"][:, ego_id] + ego_velocity = data_dict["decoder/agent_velocity"][:, ego_id] + ego_shape = data_dict["decoder/agent_shape"][:, ego_id] + ego_mask = data_dict["decoder/agent_valid_mask"][:, ego_id] + + adv_id, adv_pos, adv_heading, adv_vel, last_valid_step = choose_nearest_adv(data_dict) + last_valid_step = np.where(ego_mask)[0][-1] # force setting the last valid step + + ego_last_pos = data_dict["decoder/agent_position"][last_valid_step, ego_id, :2] + ego_last_heading = data_dict["decoder/agent_heading"][last_valid_step, ego_id] + + # begin to search + alphas = np.arange(0, 1.02, 0.05) + collision_point = ego_last_pos #- np.random.normal(loc=0.0, scale=1, size=ego_last_pos.shape[0]) + + for alpha in alphas: + cand_adv_pos = (1 - alpha) * adv_pos + alpha * ego_last_pos + + if check_last_step_sdc_adv_collision(data_dict, ego_id, ego_last_pos, ego_last_heading, adv_id, cand_adv_pos, + adv_heading): + collision_point = cand_adv_pos + break + + # collision_points = np.array(get_ego_edge_points(ego_last_pos[0], ego_last_pos[1], ego_heading[last_valid_step].item(), ego_shape[10,1], ego_shape[10,0])) + # distances = np.linalg.norm(points_array - adv_pos) + # closest_index = np.argmin(distances) + # collision_point = collision_points[int(closest_index)] + + adv_mask = np.zeros_like(ego_mask) + adv_mask[:last_valid_step + 1] = 1 + data_dict["decoder/agent_valid_mask"][:, adv_id] = adv_mask + + # ===== Position ===== + # import random + # collision_point = random.choice(collision_points) # choose the nearest edge point + data_dict["decoder/agent_position"][ + last_valid_step, + adv_id, :2] = collision_point # ego_traj[last_valid_step] - np.random.normal(loc=0.0, scale=2, size=3) + # ==================== + + # ===== Heading ===== + data_dict["decoder/agent_heading"][last_valid_step, adv_id] = adv_heading + np.random.normal( + loc=0.0, scale=0.1, size=1 + ) + print("Ego heading: ", ego_heading[last_valid_step]) + print("Adv heading: ", adv_heading) + + # =================== + + # ===== Velocity ===== + # adv_velocity[last_valid_step] = ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=2) + # adv_vel = 0.5 * (adv_vel + np.random.normal(loc=0.0, scale=0.1, size=2)) + ego_vel = 0.5 * (ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=2)) + adv_vel = ego_vel + print("Ego velocity: ", ego_vel, ego_velocity[last_valid_step]) + print("Adv velocity: ", adv_vel) + data_dict["decoder/agent_velocity"][last_valid_step, ego_id] = ego_vel + data_dict["decoder/agent_velocity"][last_valid_step, adv_id] = adv_vel + # ==================== + + return data_dict, adv_id + + +def choose_nearest_adv(data_dict): + # find nearest adv for ego's ending position + sdc_id = data_dict["decoder/sdc_index"] + all_ooi = data_dict["decoder/agent_id"] + sdc_mask = data_dict["decoder/agent_valid_mask"][:91, sdc_id] + last_valid_step = np.where(sdc_mask)[0][-1] + + min_dist = float('inf') + adv_id = None + adv_closes_step = None + + for id in all_ooi: + if id == sdc_id: + continue + agent_mask = data_dict["decoder/agent_valid_mask"][:91, id] + + mask = sdc_mask & agent_mask + valid_steps = np.where(mask)[0] # get the original indices where valid_step is True + sdc_pos = data_dict["decoder/agent_position"][:91, id, :2] + agent_pos = data_dict["decoder/agent_position"][:91, id, :2] + distances = np.linalg.norm(sdc_pos[mask] - agent_pos[mask]) + dist = np.min(distances) + closest_index = np.argmin(distances) + closest_step = valid_steps[closest_index] + + if dist < min_dist: + adv_id = id + min_dist = dist + adv_closes_step = closest_step + + # now get adv last valid step's information + adv_pos = data_dict["decoder/agent_position"][adv_closes_step, adv_id, :2] + adv_heading = data_dict["decoder/agent_heading"][adv_closes_step, adv_id] + adv_vel = data_dict["decoder/agent_velocity"][adv_closes_step, adv_id] + + return adv_id, adv_pos, adv_heading, adv_vel, adv_closes_step + + +def create_new_adv(data_dict): + ego_id = data_dict["decoder/sdc_index"] + + ego_traj = data_dict["decoder/agent_position"][:, ego_id] + ego_heading = data_dict["decoder/agent_heading"][:, ego_id] + ego_velocity = data_dict["decoder/agent_velocity"][:, ego_id] + ego_shape = data_dict["decoder/agent_shape"][:, ego_id] + ego_mask = data_dict["decoder/agent_valid_mask"][:, ego_id] + + last_valid_step = np.where(ego_mask)[0][-1] + + # Create a new ADV at the final step. + + adv_mask = np.zeros_like(ego_mask) + adv_mask[:last_valid_step + 1] = True + + adv_traj = np.zeros_like(ego_traj) + adv_heading = np.zeros_like(ego_heading) + adv_velocity = np.zeros_like(ego_velocity) + adv_shape = np.zeros_like(ego_shape) + + # Copy the final pos/head/vel/shape of ego + # ===== Position ===== + adv_traj[last_valid_step] = ego_traj[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=3) + print("Ego position: ", ego_traj[last_valid_step]) + print("Adv position: ", adv_traj[last_valid_step]) + # ==================== + + # ===== Heading ===== + adv_heading[last_valid_step] = ego_heading[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=1) + print("Ego heading: ", ego_heading[last_valid_step]) + print("Adv heading: ", adv_heading[last_valid_step]) + # =================== + + # ===== Velocity ===== + # adv_velocity[last_valid_step] = ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=2) + adv_vel = 0.5 * (ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=2)) + adv_velocity[last_valid_step] = adv_vel + ego_vel = 0.5 * (ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=2)) + print("Ego velocity: ", ego_vel, ego_velocity[last_valid_step]) + print("Adv velocity: ", adv_velocity[last_valid_step]) + data_dict["decoder/agent_velocity"][last_valid_step, ego_id] = ego_vel + # ==================== + + # ===== Shape ===== + for i in range(data_dict["decoder/agent_shape"].shape[0]): + adv_shape[i] = ego_shape[last_valid_step] + # ================= + + # Insert data back: + data_dict["decoder/agent_position"] = np.concatenate( + [data_dict["decoder/agent_position"], adv_traj[:, None]], axis=1 + ) + data_dict["decoder/agent_heading"] = np.concatenate( + [data_dict["decoder/agent_heading"], adv_heading[:, None]], axis=1 + ) + data_dict["decoder/agent_velocity"] = np.concatenate( + [data_dict["decoder/agent_velocity"], adv_velocity[:, None]], axis=1 + ) + # data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1) + + data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1) + + data_dict["decoder/agent_valid_mask"] = np.concatenate( + [data_dict["decoder/agent_valid_mask"], adv_mask[:, None]], axis=1 + ) + + data_dict["decoder/current_agent_shape"] = np.concatenate( + [data_dict["decoder/current_agent_shape"], data_dict["decoder/current_agent_shape"][ego_id:ego_id + 1]], axis=0 + ) + data_dict["decoder/agent_type"] = np.concatenate( + [data_dict["decoder/agent_type"], data_dict["decoder/agent_type"][ego_id:ego_id + 1]], axis=0 + ) + data_dict["decoder/agent_id"] = np.concatenate( + [data_dict["decoder/agent_id"], [len(data_dict["decoder/agent_id"])]], axis=0 + ) + + # Add ADV into OOI: + data_dict["decoder/object_of_interest_id"] = np.concatenate( + [data_dict["decoder/object_of_interest_id"], [len(data_dict["decoder/agent_id"]) - 1]], axis=0 + ) + + # Deal with some thing for forward prediction: + data_dict["decoder/current_agent_valid_mask"] = np.concatenate( + [data_dict["decoder/current_agent_valid_mask"], [1]], axis=0 + ) + + print("====================================") + print( + "The new ADV is created at the final step {}, it's ID is: {}".format( + last_valid_step, + len(data_dict["decoder/agent_id"]) - 1 + ) + ) + print("====================================") + + return data_dict + + +def run_backward_prediction_with_teacher_forcing( + model, config, backward_input_dict, tokenizer, not_teacher_forcing_ids +): + # pdb.set_trace() + device = backward_input_dict["decoder/agent_position"].device + + # Force to run backward prediction first to make sure the data is tokenized correctly. + tok_data_dict, _ = tokenizer.tokenize(backward_input_dict, backward_prediction=True) + backward_input_dict.update(tok_data_dict) + + backward_input_dict["in_evaluation"] = torch.tensor([1], dtype=bool).to(device) + backward_input_dict["in_backward_prediction"] = torch.tensor([1], dtype=bool).to(device) + with torch.no_grad(): + ar_func = model.model.autoregressive_rollout_backward_prediction_with_replay + # ar_func = model.model.autoregressive_rollout_backward_prediction + backward_output_dict = ar_func( + backward_input_dict, + num_decode_steps=None, + sampling_method=config.SAMPLING.SAMPLING_METHOD, + temperature=config.SAMPLING.TEMPERATURE, + not_teacher_forcing_ids=not_teacher_forcing_ids, + ) + backward_output_dict = tokenizer.detokenize( + backward_output_dict, + detokenizing_gt=False, + backward_prediction=True, + flip_wrong_heading=True, + ) + return backward_output_dict diff --git a/scenestreamer/utils/utils.py b/scenestreamer/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05a080953d775f10269cc304a85e8188d051b586 --- /dev/null +++ b/scenestreamer/utils/utils.py @@ -0,0 +1,1454 @@ +import copy +import datetime +import json +import logging +import numbers +import os +import pathlib +import pickle +import shutil +import subprocess +from pathlib import Path +from typing import IO, Optional, Type, Union + +import easydict +import lightning.pytorch as pl +import numpy as np +import omegaconf +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import yaml +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.migration import pl_legacy_patch +from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint +from omegaconf import OmegaConf + +REPO_ROOT = pathlib.Path(__file__).parent.parent.parent.resolve() # Root to the repo + + +def average_heading(heading1, heading2): + if isinstance(heading1, np.ndarray): + # Convert headings to unit vectors + x1, y1 = np.cos(heading1), np.sin(heading1) + x2, y2 = np.cos(heading2), np.sin(heading2) + + # Compute average vector + avg_x = (x1 + x2) / 2 + avg_y = (y1 + y2) / 2 + + # Compute the angle of the average vector + return np.arctan2(avg_y, avg_x) + elif isinstance(heading1, torch.Tensor): + # Convert headings to unit vectors + x1, y1 = torch.cos(heading1), torch.sin(heading1) + x2, y2 = torch.cos(heading2), torch.sin(heading2) + + # Compute average vector + avg_x = (x1 + x2) / 2 + avg_y = (y1 + y2) / 2 + + # Compute the angle of the average vector + return torch.atan2(avg_y, avg_x) + else: + raise ValueError("Input must be a NumPy array or PyTorch tensor") + + +def average_angles(angles): + # Convert angles to Cartesian coordinates + sum_sin = np.mean(np.sin(angles)) + sum_cos = np.mean(np.cos(angles)) + + # Convert the average coordinates back to angles (in radians) + avg_angle_rad = np.arctan2(sum_sin, sum_cos) + return avg_angle_rad + + +def get_time_str(no_time=False): + if no_time: + return datetime.datetime.now().strftime("%Y-%m-%d") + else: + return datetime.datetime.now().strftime("%Y-%m-%d_%H%M") + + +def assert_shape(array, shape): + assert array.ndim == len(shape) + for i in range(array.ndim): + if shape[i] is not None: + assert array.shape[i] == shape[i] + + +def padding_2nd_dim(tensor_list): + maxt_feat1 = max([x.shape[1] for x in tensor_list]) + ret_tensor_list = [] + for cur_tensor in tensor_list: + new_tensor = cur_tensor.new_zeros(cur_tensor.shape[0], maxt_feat1, *cur_tensor.shape[2:]) + new_tensor[:, :cur_tensor.shape[1]] = cur_tensor + ret_tensor_list.append(new_tensor) + return torch.stack(ret_tensor_list, dim=0) # (num_stacked_samples, num_feat0_maxt, num_feat1, num_feat2) + + +def padding_1st_dim(tensor_list, fill=None, max_1st_dim=None): + if max_1st_dim is None: + max_feat0 = max([x.shape[0] for x in tensor_list]) + else: + max_feat0 = max_1st_dim + ret_tensor_list = [] + for cur_tensor in tensor_list: + new_tensor = cur_tensor.new_zeros(max_feat0, *cur_tensor.shape[1:]) + if fill is not None: + new_tensor.fill_(fill) + new_tensor[:cur_tensor.shape[0]] = cur_tensor + ret_tensor_list.append(new_tensor) + return torch.stack(ret_tensor_list, dim=0) # (num_stacked_samples, num_feat0_maxt, num_feat1, num_feat2) + + +def padding_1st_and_2nd_dim(tensor_list, max_1st_dim=None, max_2nd_dim=None, fill=None): + maxt_feat1 = max([x.shape[1] for x in tensor_list]) if max_2nd_dim is None else max_2nd_dim + maxt_feat0 = max([x.shape[0] for x in tensor_list]) if max_1st_dim is None else max_1st_dim + ret_tensor_list = [] + for cur_tensor in tensor_list: + new_tensor = cur_tensor.new_zeros(maxt_feat0, maxt_feat1, *cur_tensor.shape[2:]) + if fill is not None: + new_tensor.fill_(fill) + new_tensor[:cur_tensor.shape[0], :cur_tensor.shape[1]] = cur_tensor + ret_tensor_list.append(new_tensor) + return torch.stack(ret_tensor_list, dim=0) + + +def padding_all_dims(tensor_list, max_dims=None, fill=None): + """ + Pad each tensor in tensor_list to have the same shape in all dimensions. + + Args: + tensor_list (list of torch.Tensor): List of tensors to be padded. + max_dims (sequence of int, optional): Desired maximum sizes for each dimension. + If not provided, the function computes the maximum size along each dimension across tensor_list. + fill (scalar, optional): Value to fill the padded areas with. + If None, the padded areas are filled with zeros. + + Returns: + torch.Tensor: A single tensor of shape (N, *max_dims) where N is the number of tensors, + with each tensor padded to the same shape. + """ + # Determine the number of dimensions from the first tensor + ndim = tensor_list[0].dim() + + if tensor_list[0].dtype == torch.bool: + fill = False + + # Compute maximum dimensions if not provided + if max_dims is None: + max_dims = [max(t.shape[d] for t in tensor_list) for d in range(ndim)] + else: + max_dims = list(max_dims) + + padded_list = [] + for t in tensor_list: + # Create a new tensor of the computed target shape with the same type and device as t + new_tensor = t.new_zeros(*max_dims) + if fill is not None: + new_tensor.fill_(fill) + # Create a slice object for each dimension to copy over the data from t + slices = tuple(slice(0, t.shape[d]) for d in range(ndim)) + new_tensor[slices] = t + padded_list.append(new_tensor) + + # Stack along a new first dimension + return torch.stack(padded_list, dim=0) + + +def get_distribution(dist_parameter): + weight = dist_parameter[..., 0].clamp(-100, 100) + + log_std_range = (-1.609, 5.0) + para = dist_parameter[..., 1:].clamp(-100, 100) + + if para.shape[-1] == 5: + loc, tril, diag = para[..., :2], para[..., 2], para[..., 3:] + sigma_1 = torch.exp(diag[..., 0].clamp(log_std_range[0], log_std_range[1])) + sigma_2 = torch.exp(diag[..., 1].clamp(log_std_range[0], log_std_range[1])) + rho = torch.tanh(tril).clamp(-0.5, 0.5) + cov = torch.stack([sigma_1**2, rho * sigma_1 * sigma_2, rho * sigma_1 * sigma_2, sigma_2**2], + dim=-1).view(*loc.shape[:-1], 2, 2) + dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=loc, covariance_matrix=cov) + pos_weight = torch.distributions.Categorical(logits=weight) + dist = torch.distributions.mixture_same_family.MixtureSameFamily(pos_weight, dist) + + elif para.shape[-1] == 6: + loc, log_scale = para[..., :3], para[..., 3:] + scale = torch.exp(log_scale.clamp(log_std_range[0], log_std_range[1])) + dist = torch.distributions.independent.Independent( + torch.distributions.normal.Normal(loc=loc, scale=scale), reinterpreted_batch_ndims=1 + ) + pos_weight = torch.distributions.Categorical(logits=weight) + dist = torch.distributions.mixture_same_family.MixtureSameFamily(pos_weight, dist) + + elif para.shape[-1] == 2: + loc, scale = para[..., 0], para[..., 1] + scale = torch.exp(scale.clamp(log_std_range[0], log_std_range[1])) + dist = torch.distributions.Normal(loc, scale) + pos_weight = torch.distributions.Categorical(logits=weight) + dist = torch.distributions.mixture_same_family.MixtureSameFamily(pos_weight, dist) + + else: + raise ValueError(para.shape) + + return dist + + +def unwrap(flatten_array, valid_mask, existing=None, fill=None): + assert valid_mask.sum() == flatten_array.shape[0] + if existing is None: + ret = flatten_array.new_zeros(valid_mask.shape + (flatten_array.shape[-1], )) + else: + ret = existing + if fill is not None: + ret.fill_(fill) + ret[valid_mask] = flatten_array + return ret + + +def pack_sequences(seqs) -> (np.ndarray, np.ndarray): + values = np.concatenate(seqs, axis=0) + offsets = np.cumsum([len(s) for s in seqs]) + return values, offsets + + +def wrap_to_pi(radians_array): + """ + Wrap all input radians to range [-pi, pi] + """ + if isinstance(radians_array, np.ndarray): + wrapped_radians_array = np.mod(radians_array, 2 * np.pi) + wrapped_radians_array[wrapped_radians_array > np.pi] -= 2 * np.pi + elif isinstance(radians_array, torch.Tensor): + wrapped_radians_array = radians_array % (2 * torch.tensor(np.pi)) + wrapped_radians_array[wrapped_radians_array > torch.tensor(np.pi)] -= 2 * np.pi + elif isinstance(radians_array, (float, np.float32)): + wrapped_radians_array = radians_array % (2 * np.pi) + if wrapped_radians_array > np.pi: + wrapped_radians_array -= 2 * np.pi + else: + raise ValueError("Input must be a NumPy array or PyTorch tensor") + + return wrapped_radians_array + + +def unpack_sequence(values: np.ndarray, offsets: np.ndarray, index: int) -> np.ndarray: + off1 = offsets[index] + if index > 0: + off0 = offsets[index - 1] + elif index == 0: + off0 = 0 + else: + raise ValueError(index) + return values[off0:off1] + + +def string_to_sequence(s: str, dtype=np.int32) -> np.ndarray: + return np.array([ord(c) for c in s], dtype=dtype) + + +def sequence_to_string(seq: np.ndarray) -> str: + return ''.join([chr(c) for c in seq]) + + +def check_numpy_to_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x), True # .float(), True + return x, False + + +def rotate_points_along_z(points, angle): + """ + Args: + points: (B, N, 3 + C) + angle: (B), angle along z-axis, angle increases x ==> y + Returns: + + """ + points, is_numpy = check_numpy_to_torch(points) + angle, _ = check_numpy_to_torch(angle) + + cosa = torch.cos(angle) + sina = torch.sin(angle) + zeros = angle.new_zeros(points.shape[0]) + if points.shape[-1] == 2: + rot_matrix = torch.stack((cosa, sina, -sina, cosa), dim=1).view(-1, 2, 2) # .float() + points_rot = torch.matmul(points, rot_matrix) + else: + ones = angle.new_ones(points.shape[0]) + rot_matrix = torch.stack((cosa, sina, zeros, -sina, cosa, zeros, zeros, zeros, ones), + dim=1).view(-1, 3, 3) # .float() + points_rot = torch.matmul(points[:, :, 0:3], rot_matrix) + points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1) + return points_rot.numpy() if is_numpy else points_rot + + +def absolute_to_relative(abs_pos, map_head): + relative_y = map_head + relative_x = relative_y - np.pi / 2 + object_heading_to_rotate = relative_x + if abs_pos.shape[-1] == 3: + z = abs_pos[..., 2] + else: + z = None + rel_pos = rotate(abs_pos[..., 0], abs_pos[..., 1], -object_heading_to_rotate.squeeze(-1), z=z) + return rel_pos + + +def relative_to_absolute(rel_pos, map_head): + relative_y = map_head + relative_x = relative_y - np.pi / 2 + object_heading_to_rotate = relative_x + if rel_pos.shape[-1] == 3: + z = rel_pos[..., 2] + else: + z = None + abs_pos = rotate(rel_pos[..., 0], rel_pos[..., 1], object_heading_to_rotate.squeeze(-1), z=z) + return abs_pos + + +def rotate(x, y, angle, z=None, assert_shape=True): + # TODO(pzh): Repeat function, remove one. + if assert_shape: + assert angle.shape == x.shape == y.shape, (angle.shape, x.shape, y.shape) + if z is not None: + assert x.shape == z.shape + if isinstance(x, torch.Tensor): + other_x_trans = torch.cos(angle) * x - torch.sin(angle) * y + other_y_trans = torch.cos(angle) * y + torch.sin(angle) * x + if z is None: + output_coords = torch.stack((other_x_trans, other_y_trans), dim=-1) + else: + output_coords = torch.stack((other_x_trans, other_y_trans, z), dim=-1) + else: + other_x_trans = np.cos(angle) * x - np.sin(angle) * y + other_y_trans = np.cos(angle) * y + np.sin(angle) * x + if z is None: + output_coords = np.stack((other_x_trans, other_y_trans), axis=-1) + else: + output_coords = np.stack((other_x_trans, other_y_trans, z), axis=-1) + return output_coords + + +# def translate_pos_to_ego_centric(xyz, center, heading): +# assert center.shape[-1] == 3 +# assert heading.shape[-1] == 1 +# # assert xyz.shape[0] == center.shape[0] == heading.shape[0] +# assert xyz.ndim == 3 +# assert center.ndim == 2 +# assert heading.ndim == 2 +# xyz = xyz - center[:, None] +# xyz = rotate_points_along_z(xyz, -heading) +# return xyz + + +def merge_batch_by_padding_2nd_dim(tensor_list, return_pad_mask=False): + assert len(tensor_list[0].shape) in [3, 4] + only_3d_tensor = False + if len(tensor_list[0].shape) == 3: + tensor_list = [x.unsqueeze(dim=-1) for x in tensor_list] + only_3d_tensor = True + maxt_feat0 = max([x.shape[1] for x in tensor_list]) + + _, _, num_feat1, num_feat2 = tensor_list[0].shape + + ret_tensor_list = [] + ret_mask_list = [] + for k in range(len(tensor_list)): + cur_tensor = tensor_list[k] + assert cur_tensor.shape[2] == num_feat1 and cur_tensor.shape[3] == num_feat2 + + new_tensor = cur_tensor.new_zeros(cur_tensor.shape[0], maxt_feat0, num_feat1, num_feat2) + new_tensor[:, :cur_tensor.shape[1], :, :] = cur_tensor + ret_tensor_list.append(new_tensor) + + new_mask_tensor = cur_tensor.new_zeros(cur_tensor.shape[0], maxt_feat0) + new_mask_tensor[:, :cur_tensor.shape[1]] = 1 + ret_mask_list.append(new_mask_tensor.bool()) + + ret_tensor = torch.cat(ret_tensor_list, dim=0) # (num_stacked_samples, num_feat0_maxt, num_feat1, num_feat2) + ret_mask = torch.cat(ret_mask_list, dim=0) + + if only_3d_tensor: + ret_tensor = ret_tensor.squeeze(dim=-1) + + if return_pad_mask: + return ret_tensor, ret_mask + return ret_tensor + + +def create_logger(log_file=None, rank=0, log_level=logging.INFO): + logger = logging.getLogger(__name__) + logger.setLevel(log_level if rank == 0 else 'ERROR') + formatter = logging.Formatter('%(asctime)s %(levelname)5s %(message)s') + console = logging.StreamHandler() + console.setLevel(log_level if rank == 0 else 'ERROR') + console.setFormatter(formatter) + logger.addHandler(console) + if log_file is not None: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setLevel(log_level if rank == 0 else 'ERROR') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logger.propagate = False + return logger + + +def get_dist_info(return_gpu_per_machine=False): + if torch.__version__ < '1.0': + initialized = dist._initialized + else: + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + if return_gpu_per_machine: + gpu_per_machine = torch.cuda.device_count() + return rank, world_size, gpu_per_machine + + return rank, world_size + + +def get_batch_offsets(batch_idxs, bs, device): + ''' + :param batch_idxs: (N), int + :param bs: int + :return: batch_offsets: (bs + 1) + ''' + batch_offsets = torch.zeros([ + bs + 1, + ], device=device).int() + for i in range(bs): + batch_offsets[i + 1] = batch_offsets[i] + (batch_idxs == i).sum() + assert batch_offsets[-1] == batch_idxs.shape[0] + return batch_offsets + + +# def set_random_seed(seed): +# random.seed(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# torch.backends.cudnn.deterministic = True +# torch.backends.cudnn.benchmark = False + + +def init_dist_slurm(tcp_port, local_rank, backend='nccl'): + """ + modified from https://github.com/open-mmlab/mmdetection + Args: + tcp_port: + backend: + + Returns: + + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput('scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = str(tcp_port) + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + total_gpus = dist.get_world_size() + rank = dist.get_rank() + return total_gpus, rank + + +def init_dist_pytorch(tcp_port, local_rank, backend='nccl'): + # if mp.get_start_method(allow_none=True) is None: + # mp.set_start_method('spawn') + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(local_rank % num_gpus) + + dist.init_process_group( + backend=backend, + # init_method='tcp://127.0.0.1:%d' % tcp_port, + # rank=local_rank, + # world_size=num_gpus + ) + rank = dist.get_rank() + return num_gpus, rank + + +def merge_results_dist(result_part, size, tmpdir): + rank, world_size = get_dist_info() + os.makedirs(tmpdir, exist_ok=True) + + dist.barrier() + pickle.dump(result_part, open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(rank)), 'wb')) + dist.barrier() + + if rank != 0: + return None + + part_list = [] + for i in range(world_size): + part_file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i)) + part_list.append(pickle.load(open(part_file, 'rb'))) + + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + ordered_results = ordered_results[:size] + shutil.rmtree(tmpdir) + return ordered_results + + +def calculate_trajectory_probabilities(logits, sampled_actions, mask): + # Apply softmax to convert logits to probabilities + probs = F.softmax(logits, dim=-1) + + # Remove invalid actions: + invalid_action_mask = torch.logical_or(sampled_actions < 0, sampled_actions >= probs.shape[-1]) + assert (logits[invalid_action_mask] == 0).all() + sampled_actions = sampled_actions.masked_fill(invalid_action_mask, 0) + + # Gather the probabilities of the sampled actions + gathered_probs = torch.gather(probs, -1, sampled_actions.unsqueeze(-1)).squeeze(-1) + + gathered_probs = gathered_probs.masked_fill(invalid_action_mask, 0) + + # Multiply probabilities across the time dimension for each trajectory + # Use log probabilities for numerical stability + log_probs = torch.log(gathered_probs) + trajectory_log_probs = torch.sum(log_probs, dim=1) + + # Convert back from log probabilities if needed + trajectory_probs = torch.exp(trajectory_log_probs) + + # Aggregate to get final shape (B, N) + mask = mask.reshape(trajectory_probs.shape) + trajectory_probs = trajectory_probs.masked_fill(~mask, float("-inf")) + return trajectory_probs + + +def calculate_trajectory_probabilities_new(logits, sampled_actions, mask): + """ + Compared to the old version, we allow the mask to have temporal dimension. + """ + # Apply softmax to convert logits to probabilities + probs = F.softmax(logits, dim=-1) + + # Remove invalid actions: + invalid_action_mask = torch.logical_or(sampled_actions < 0, sampled_actions >= probs.shape[-1]) + assert (logits[invalid_action_mask] == 0).all() + sampled_actions = sampled_actions.masked_fill(invalid_action_mask, 0) + + # Gather the probabilities of the sampled actions + gathered_probs = torch.gather(probs, -1, sampled_actions.unsqueeze(-1)).squeeze(-1) + + gathered_probs = gathered_probs.masked_fill(invalid_action_mask, 0) + + # Multiply probabilities across the time dimension for each trajectory + # Use log probabilities for numerical stability + log_probs = torch.log(gathered_probs) + + log_probs = log_probs.masked_fill(~mask, 0) + + trajectory_log_probs = torch.sum(log_probs, dim=1) / mask.sum(dim=1) + + # Convert back from log probabilities if needed + trajectory_probs = torch.exp(trajectory_log_probs) + trajectory_probs = trajectory_probs.masked_fill(~mask.any(dim=1), 0) + + trajectory_log_probs = trajectory_log_probs.masked_fill(~mask.any(dim=1), float("-inf")) + return trajectory_log_probs, trajectory_probs + + +def masked_average(tensor, mask, dim): + """ + Compute the average of tensor along the specified dimension, ignoring masked elements. + """ + assert tensor.shape == mask.shape + count = mask.sum(dim=dim) + count = torch.max(count, torch.ones_like(count)) + return (tensor * mask).sum(dim=dim) / count + + +def weight_init(m: nn.Module) -> None: + raise ValueError + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + fan_in = m.in_channels / m.groups + fan_out = m.out_channels / m.groups + bound = (6.0 / (fan_in + fan_out))**0.5 + nn.init.uniform_(m.weight, -bound, bound) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + if m.elementwise_affine: + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.MultiheadAttention): + if m.in_proj_weight is not None: + fan_in = m.embed_dim + fan_out = m.embed_dim + bound = (6.0 / (fan_in + fan_out))**0.5 + nn.init.uniform_(m.in_proj_weight, -bound, bound) + else: + nn.init.xavier_uniform_(m.q_proj_weight) + nn.init.xavier_uniform_(m.k_proj_weight) + nn.init.xavier_uniform_(m.v_proj_weight) + if m.in_proj_bias is not None: + nn.init.zeros_(m.in_proj_bias) + nn.init.xavier_uniform_(m.out_proj.weight) + if m.out_proj.bias is not None: + nn.init.zeros_(m.out_proj.bias) + if m.bias_k is not None: + nn.init.normal_(m.bias_k, mean=0.0, std=0.02) + if m.bias_v is not None: + nn.init.normal_(m.bias_v, mean=0.0, std=0.02) + elif isinstance(m, (nn.LSTM, nn.LSTMCell)): + for name, param in m.named_parameters(): + if 'weight_ih' in name: + for ih in param.chunk(4, 0): + nn.init.xavier_uniform_(ih) + elif 'weight_hh' in name: + for hh in param.chunk(4, 0): + nn.init.orthogonal_(hh) + elif 'weight_hr' in name: + nn.init.xavier_uniform_(param) + elif 'bias_ih' in name: + nn.init.zeros_(param) + elif 'bias_hh' in name: + nn.init.zeros_(param) + nn.init.ones_(param.chunk(4, 0)[1]) + elif isinstance(m, (nn.GRU, nn.GRUCell)): + for name, param in m.named_parameters(): + if 'weight_ih' in name: + for ih in param.chunk(3, 0): + nn.init.xavier_uniform_(ih) + elif 'weight_hh' in name: + for hh in param.chunk(3, 0): + nn.init.orthogonal_(hh) + elif 'bias_ih' in name: + nn.init.zeros_(param) + elif 'bias_hh' in name: + nn.init.zeros_(param) + + +def masked_average_numpy(tensor, mask, dim): + """ + Compute the average of tensor along the specified dimension, ignoring masked elements. + """ + assert tensor.shape == mask.shape + count = mask.sum(axis=dim) + count = np.maximum(count, np.ones_like(count)) + return (tensor * mask).sum(axis=dim) / count + + +def extract_data_by_agent_indices(data, agent_indices, agent_dim, fill=None): + agent_indices = np.asarray(agent_indices, dtype=int) + new_shape = [ + 1, + ] * data.ndim + new_shape[agent_dim] = agent_indices.shape[0] + agent_indices = agent_indices.reshape(*new_shape) + data = np.take_along_axis(data, agent_indices, axis=agent_dim) + data = np.where(agent_indices != -1, data, np.zeros_like(data)) + if fill is not None: + data[agent_indices == -1] = fill + return data + + +def average_angles(angles): + # Convert angles to Cartesian coordinates + sum_sin = np.mean(np.sin(angles)) + sum_cos = np.mean(np.cos(angles)) + + # Convert the average coordinates back to angles (in radians) + avg_angle_rad = np.arctan2(sum_sin, sum_cos) + return avg_angle_rad + + +def masked_average_angles(angles, mask, axis): + assert angles.shape == mask.shape + + # Convert masked angles to Cartesian coordinates + sum_sin = np.sum(np.sin(angles) * mask, axis=axis) + sum_cos = np.sum(np.cos(angles) * mask, axis=axis) + + # Calculate the number of valid entries along the specified axis + count_valid = np.sum(mask, axis=axis) + + # Compute the mean of the sine and cosine, avoiding division by zero + mean_sin = np.divide(sum_sin, count_valid, where=count_valid != 0) + mean_cos = np.divide(sum_cos, count_valid, where=count_valid != 0) + + # Convert the average coordinates back to angles (in radians) + avg_angle_rad = np.arctan2(mean_sin, mean_cos) + return avg_angle_rad + + +def modulate(x, shift, scale): + assert x.shape == shift.shape == scale.shape + return x * (1 + scale) + shift + + +def _to_dict(d): + if isinstance(d, easydict.EasyDict): + return {k: _to_dict(v) for k, v in d.items()} + return d + + +def load_from_checkpoint( + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], + checkpoint_path: Union[_PATH, IO], + config, + default_config=None, + map_location: _MAP_LOCATION_TYPE = None, + hparams_file: Optional[_PATH] = None, + strict: Optional[bool] = None, + checkpoint_surgery_func=None, +) -> Union["pl.LightningModule", "pl.LightningDataModule"]: + if checkpoint_path is None: + if default_config is not None: + # Merge config and default config + if isinstance(default_config, easydict.EasyDict): + default_config = _to_dict(default_config) + default_config = OmegaConf.create(default_config) + config = OmegaConf.merge(default_config, config) + return cls(config) + + with pl_legacy_patch(): + # PyTorch >= 2.6 defaults torch.load(weights_only=True), which can fail for Lightning checkpoints + # that include non-tensor objects (e.g. pathlib.PosixPath). We explicitly request weights_only=False. + try: + checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=False) + except TypeError: + # Older lightning versions may not support weights_only in pl_load. + checkpoint = pl_load(checkpoint_path, map_location=map_location) + except RuntimeError as e: + # Common when loading CUDA-saved checkpoints on CPU-only machines (e.g. macOS): + # retry with CPU mapping if caller didn't request a specific map_location. + msg = str(e) + if map_location is None and ( + "Attempting to deserialize object on a CUDA device" in msg or "torch.cuda.is_available() is False" in msg + ): + try: + checkpoint = pl_load(checkpoint_path, map_location=torch.device("cpu"), weights_only=False) + except TypeError: + checkpoint = pl_load(checkpoint_path, map_location=torch.device("cpu")) + else: + raise + + # convert legacy checkpoints to the new format + checkpoint = _pl_migrate_checkpoint( + checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None) + ) + + from lightning.pytorch.core.saving import load_hparams_from_yaml, load_hparams_from_tags_csv + + if hparams_file is not None: + extension = str(hparams_file).split(".")[-1] + if extension.lower() == "csv": + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ("yml", "yaml"): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # TODO: make this a migration: + # for past checkpoint need to add the new key + checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}) + # override the hparams with values that were passed in + + # PZH: Change + # checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config) + if isinstance(default_config, omegaconf.DictConfig): + default_config = OmegaConf.to_container(default_config) + if default_config: + default_config = copy.deepcopy(default_config) + if "LOCAL_RANK" in default_config: + default_config.pop("LOCAL_RANK") + + if "defaults" in default_config: + default_config.pop("defaults") + + if "config" in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]: + + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["config"] + if isinstance(checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], easydict.EasyDict): + + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = _to_dict(checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) + + if "ROOT_DIR" in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["ROOT_DIR"] = str(REPO_ROOT) + if "SAMPLE_INTERVAL" in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["DATA"]: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["DATA"].pop("SAMPLE_INTERVAL") + if "LOCAL_RANK" in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].pop("LOCAL_RANK") + + if isinstance(default_config, easydict.EasyDict): + default_config = _to_dict(default_config) + if "ROOT_DIR" in default_config: + default_config["ROOT_DIR"] = str(REPO_ROOT) + default_config = OmegaConf.merge(default_config, checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = default_config + + if checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = OmegaConf.create(checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) + + if config: + # if "SAMPLE_INTERVAL" in config.DATA: + # config.DATA.pop("SAMPLE_INTERVAL") + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY + ] = OmegaConf.merge(checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], config) + + config = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + from lightning.pytorch.core.saving import _load_state + + if checkpoint_surgery_func is not None: + checkpoint = checkpoint_surgery_func(checkpoint, cls, config) + + if issubclass(cls, pl.LightningDataModule): + return _load_state(cls, checkpoint, **config) + if issubclass(cls, pl.LightningModule): + storage = _load_state(cls, checkpoint, strict=strict, config=config) + state_dict = checkpoint["state_dict"] + if not state_dict: + raise ValueError(f"The state dict in {checkpoint_path!r} contains no parameters.") + map_location = list(state_dict.values())[0].device + assert isinstance(storage, pl.LightningModule) + return storage.to(map_location) + + raise NotImplementedError(f"Unsupported {cls}") + + +def cal_polygon_contour(x, y, theta, width, length): + left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_front = np.column_stack((left_front_x, left_front_y)) + + right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_front = np.column_stack((right_front_x, right_front_y)) + + right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) + right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) + right_back = np.column_stack((right_back_x, right_back_y)) + + left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) + left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) + left_back = np.column_stack((left_back_x, left_back_y)) + + polygon_contour = np.concatenate( + (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1 + ) + + return polygon_contour + + +def cal_polygon_contour_torch(x, y, theta, width, length): + # Calculate corner points using torch operations + left_front_x = x + 0.5 * length * torch.cos(theta) - 0.5 * width * torch.sin(theta) + left_front_y = y + 0.5 * length * torch.sin(theta) + 0.5 * width * torch.cos(theta) + left_front = torch.stack((left_front_x, left_front_y), dim=-1) + + right_front_x = x + 0.5 * length * torch.cos(theta) + 0.5 * width * torch.sin(theta) + right_front_y = y + 0.5 * length * torch.sin(theta) - 0.5 * width * torch.cos(theta) + right_front = torch.stack((right_front_x, right_front_y), dim=-1) + + right_back_x = x - 0.5 * length * torch.cos(theta) + 0.5 * width * torch.sin(theta) + right_back_y = y - 0.5 * length * torch.sin(theta) - 0.5 * width * torch.cos(theta) + right_back = torch.stack((right_back_x, right_back_y), dim=-1) + + left_back_x = x - 0.5 * length * torch.cos(theta) - 0.5 * width * torch.sin(theta) + left_back_y = y - 0.5 * length * torch.sin(theta) + 0.5 * width * torch.cos(theta) + left_back = torch.stack((left_back_x, left_back_y), dim=-1) + + # Stack all corner points into the desired shape (N, 4, 2) + polygon_contour = torch.stack((left_front, right_front, right_back, left_back), dim=-2) + + return polygon_contour + + +def checkpoint_surgery_func(checkpoint, model_class, config): + m = None + + # Update 2025-05-05: + if config.MODEL.NAME == "scenestreamer" and config.get("SCENESTREAMER_NO_TG", False) is False: + if "model.trafficgen_head.offset_token_embedding.tokens.weight" in checkpoint["state_dict"]: + if checkpoint["state_dict"]["model.trafficgen_head.offset_token_embedding.tokens.weight"].shape[0] == 81: + if m is None: + m = model_class(config) + default_params = m.state_dict() + new_weight = default_params["model.trafficgen_head.offset_token_embedding.tokens.weight"] + # Remove the 81 (-1 id) + new_weight[:81] = checkpoint["state_dict"]["model.trafficgen_head.offset_token_embedding.tokens.weight"] + checkpoint["state_dict"]["model.trafficgen_head.offset_token_embedding.tokens.weight"] = new_weight + print("====================================") + print("[WARNING] The trafficgen_head.offset_token_embedding.tokens is not found in the checkpoint.") + print("Using initial parameters for trafficgen_head.offset_token_embedding.tokens.") + print("Writing: ", "model.trafficgen_head.offset_token_embedding.tokens.weight", new_weight.shape) + print("====================================") + + # Update 2025-04-28: + if config.MODEL.NAME == "scenestreamer" and config.MODEL.USE_MOTION_HEAD_PRENORM is True: + if "model.motion_prenorm.weight" not in checkpoint["state_dict"]: + print("====================================") + print("[WARNING] The motion_prenorm is not found in the checkpoint.") + print("Using initial parameters for motion_prenorm.") + if m is None: + m = model_class(config) + default_params = m.state_dict() + for k, v in default_params.items(): + if "motion_prenorm" in k: + assert k not in checkpoint["state_dict"], k + print("Writing: ", k, v.shape) + checkpoint["state_dict"][k] = v + print("====================================") + + if config.get("SCENESTREAMER_NO_TG", True) is False: + if "model.trafficgen_prenorm.weight" not in checkpoint["state_dict"]: + print("====================================") + print("[WARNING] The trafficgen_prenorm is not found in the checkpoint.") + print("Using initial parameters for trafficgen_prenorm.") + if m is None: + m = model_class(config) + default_params = m.state_dict() + for k, v in default_params.items(): + if "trafficgen_prenorm" in k: + assert k not in checkpoint["state_dict"], k + print("Writing: ", k, v.shape) + checkpoint["state_dict"][k] = v + print("====================================") + + # Update 2025-04-25: + # If the pretrained checkpoint is SCENESTREAMER_NO_TG=True, we need to add the trafficgen decoder. + if config.MODEL.NAME == "scenestreamer" and config.get("SCENESTREAMER_NO_TG", False) is False: + if "model.trafficgen_intra_step.tokens.weight" not in checkpoint["state_dict"]: + print("====================================") + print("[WARNING] The trafficgen decoder is not found in the checkpoint.") + print("Using initial parameters for trafficgen decoder.") + if m is None: + m = model_class(config) + default_params = m.state_dict() + for k, v in default_params.items(): + if "trafficgen_intra_step" in k or "model.trafficgen_head" in k or "model.trafficgen_feat_embed" in k: + assert k not in checkpoint["state_dict"], k + print("Writing: ", k, v.shape) + checkpoint["state_dict"][k] = v + print("====================================") + + # # 2025-04-22: Add model.relation_embed_3d + # if config.MODEL.NAME == "scenestreamer": + # if "model.relation_embed_3d.freqs.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The relation_embed_3d is not found in the checkpoint.") + # print("Using initial parameters for relation_embed_3d.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "relation_embed_3d" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # + # if config.MODEL.NAME == "scenestreamer" and config.get("SCENESTREAMER_NO_TG", False) is True: + # if "model.relation_embed_1d.freqs.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The relation_embed_1d is not found in the checkpoint.") + # print("Using initial parameters for relation_embed_1d.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "relation_embed_1d" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # if checkpoint["state_dict"]["model.map_id_embed.tokens.weight"].shape[0] == 3007: + # print("====================================") + # print("[WARNING] The map_id_embed is not found in the checkpoint.") + # print("Using initial parameters for map_id_embed.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # new_weight = default_params["model.map_id_embed.tokens.weight"] + # # Remove the 3007 (-1 id) + # new_weight[:3006] = checkpoint["state_dict"]["model.map_id_embed.tokens.weight"][:-1] + # checkpoint["state_dict"]["model.map_id_embed.tokens.weight"] = new_weight + # print("====================================") + # + # # Update 2025-04-16: + # # If the pretrained checkpoint is SCENESTREAMER_NO_TG=True, we need to add the relation_embed_1d. + # if config.MODEL.NAME == "scenestreamer" and config.get("SCENESTREAMER_NO_TG", False) is False: + # if "model.relation_embed_1d.freqs.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The relation_embed_1d is not found in the checkpoint.") + # print("Using initial parameters for relation_embed_1d.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "relation_embed_1d" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # + # # Update 2025-04-13: + # # If the pretrained checkpoint is SCENESTREAMER_NO_TG=True, we need to add the trafficgen decoder. + # if config.MODEL.NAME == "scenestreamer" and config.get("SCENESTREAMER_NO_TG", False) is False: + # if "model.trafficgen_intra_step.tokens.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The trafficgen decoder is not found in the checkpoint.") + # print("Using initial parameters for trafficgen decoder.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "trafficgen_intra_step" in k or "model.trafficgen_head" in k or "model.trafficgen_feat_embed" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # # Update 2025-04-14: + # # The map id vocabulary lacks a "trafficgen_action_sos_id". + # if checkpoint["state_dict"]["model.map_id_embed.tokens.weight"].shape[0] == 3007: + # print("====================================") + # print("[WARNING] The map_id_embed is not found in the checkpoint.") + # print("Using initial parameters for map_id_embed.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # new_weight = default_params["model.map_id_embed.tokens.weight"] + # # Remove the 3007 (-1 id) + # new_weight[:3006] = checkpoint["state_dict"]["model.map_id_embed.tokens.weight"][:-1] + # checkpoint["state_dict"]["model.map_id_embed.tokens.weight"] = new_weight + # print("====================================") + # + # # Update 2025-04-06: + # # We make a bug that the max agents is not set to 128. Fix it here. + # if "model.agent_id_embed.tokens.weight" in checkpoint["state_dict"]: + # if checkpoint["state_dict"]["model.agent_id_embed.tokens.weight"].shape[0] != 128 + 1: + # print("====================================") + # print("[WARNING] The agent_id_embed is not set to 128.") + # print("Changing it to 128.") + # old_weight = checkpoint["state_dict"]["model.agent_id_embed.tokens.weight"] + # new_weight = old_weight.new_zeros(128 + 1, old_weight.shape[1]) + # new_weight[:old_weight.shape[0], :] = old_weight + # checkpoint["state_dict"]["model.agent_id_embed.tokens.weight"] = new_weight + # print("====================================") + # if "model.trafficgen_intra_step.tokens.weight" in checkpoint["state_dict"]: + # if checkpoint["state_dict"]["model.trafficgen_intra_step.tokens.weight"].shape[0] != 512 + 2: + # print("====================================") + # print("[WARNING] The trafficgen_intra_step is not set to 514.") + # print("Changing it to 514.") + # old_weight = checkpoint["state_dict"]["model.trafficgen_intra_step.tokens.weight"] + # new_weight = old_weight.new_zeros(512 + 2, old_weight.shape[1]) + # new_weight[:old_weight.shape[0], :] = old_weight + # checkpoint["state_dict"]["model.trafficgen_intra_step.tokens.weight"] = new_weight + # print("====================================") + # + # # Update 2025-03-11: + # # Another update, we have a checkpoint trained with TG but now want to finetune with destination without TG. + # if not config["USE_TRAFFICGEN"]: + # if "model.trafficgen_decoder.decoder.layers.0.cross_a2a.to_v.bias" in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The trafficgen decoder is found in the checkpoint.") + # print("Removing trafficgen decoder from the checkpoint.") + # for k in list(checkpoint["state_dict"].keys()): + # if "trafficgen_decoder" in k: + # print("Removing: ", k) + # checkpoint["state_dict"].pop(k) + # print("====================================") + # + # # Update 2025-03-11: + # if config["USE_DESTINATION"]: + # if "model.motion_decoder.decoder.layers.0.a2s_norm.gamma_proj.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The AdaLN is not found in the checkpoint!!!") + # print("Using initial parameters for step_embed.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if ("gamma_proj" in k) or ("beta_proj" in k) or ("map_id_embed" in k): + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # + # # Delete some weights + # for k in list(checkpoint["state_dict"].keys()): + # if "motion_decoder" in k and (("a2s_norm.weight" in k) or ("a2s_norm.bias" in k) or ( + # "a2t_norm.weight" in k) or ("a2t_norm.bias" in k) or ("a2a_norm.weight" in k) or ( + # "a2a_norm.bias" in k) or ("mlp_prenorm.weight" in k) or ("mlp_prenorm.bias" in k) or ("prediction_prenorm.weight" in k) or ("prediction_prenorm.bias" in k)): + # assert k in checkpoint["state_dict"], k + # print("Removing: ", k) + # checkpoint["state_dict"].pop(k) + # print("====================================") + # + # # Update 2025-02-28: + # # If model is TG but there is no step_embed, use initial parameters for step_embed. + # if config["USE_TRAFFICGEN"]: + # if "model.trafficgen_decoder.step_embed.tokens.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The step_embed is not found in the checkpoint.") + # print("Using initial parameters for step_embed.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "trafficgen_decoder.step_embed" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # + # # Update 2025-02-26: + # # If finetune with trafficgen, use initial parameters for trafficgen related modules. + # if config["USE_TRAFFICGEN"]: + # if "model.trafficgen_decoder.action_embed.freqs.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The trafficgen decoder is not found in the checkpoint.") + # print("Using initial parameters for trafficgen decoder.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "trafficgen_decoder" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # + # # Update 2024-11-02: + # # Use initial parameters for backward prediction related modules. + # if config["BACKWARD_PREDICTION"]: + # if "model.motion_decoder.prediction_backward_head.0.weight" not in checkpoint["state_dict"]: + # print("====================================") + # print("[WARNING] The backward prediction head is not found in the checkpoint.") + # print("Using initial parameters for backward prediction head.") + # if m is None: + # m = model_class(config) + # default_params = m.state_dict() + # for k, v in default_params.items(): + # if "backward" in k: + # assert k not in checkpoint["state_dict"], k + # print("Writing: ", k, v.shape) + # checkpoint["state_dict"][k] = v + # print("====================================") + # + # # Update 2024-10-26: + # # This is used for 1026 model which use 1.75/1.75/1.75 delta-delta tokenizer but + # # the "motion_features" still using type-specific delta-delta. + # if (checkpoint["hyper_parameters"]["TOKENIZATION"]["VEH_LIMIT"] == 3.5 + # and checkpoint["hyper_parameters"]["TOKENIZATION"]["CYC_LIMIT"] == 3.5 + # and checkpoint["hyper_parameters"]["TOKENIZATION"]["PED_LIMIT"] == 3.5): + # if "model.motion_decoder.motion_features" in checkpoint["state_dict"]: + # motion_features = checkpoint["state_dict"]["model.motion_decoder.motion_features"] + # if motion_features.ndim == 3 and motion_features.shape[1] == 3: + # assert (motion_features[:, 0] == motion_features[:, 1]).all() + # assert (motion_features[:, 0] == motion_features[:, 2]).all() + # motion_features = motion_features[:, 0] + # checkpoint["state_dict"]["model.motion_decoder.motion_features"] = motion_features + + return checkpoint + + +def get_model( + *, + huggingface_repo=None, + huggingface_file=None, + config=None, + checkpoint_path=None, + device=None, + default_config="motion_default.yaml" +): + if device is None: + if torch.cuda.is_available(): + device = torch.device("cuda") + elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + if huggingface_repo is not None: + assert config is None, "config should not be provided when huggingface_path is provided." + assert checkpoint_path is None, "checkpoint_path should not be provided when huggingface_path is provided." + assert huggingface_file is not None, "huggingface_file must be provided when huggingface_repo is provided." + try: + import huggingface_hub + except ImportError: + raise ValueError("Please install huggingface_hub via: pip install --upgrade huggingface_hub") + print("Downloading checkpoint from Huggingface Hub: ", huggingface_repo, huggingface_file) + checkpoint_path = huggingface_hub.hf_hub_download(repo_id=huggingface_repo, filename=huggingface_file) + print("Downloaded checkpoint from Huggingface Hub: ", checkpoint_path) + + assert config is not None or checkpoint_path is not None, "Either config or checkpoint_path must be provided." + from scenestreamer.models.motionlm_lightning import MotionLMLightning + + from scenestreamer.utils.config import global_config, cfg_from_yaml_file + default_config = cfg_from_yaml_file(REPO_ROOT / "cfgs" / default_config, global_config) + + pretrained_path_from_config = config.pretrain if config is not None else None + pretrained_path_from_arg = checkpoint_path + assert pretrained_path_from_config is None or pretrained_path_from_arg is None, ( + "Both pretrained path from config and from argument are provided." + ) + pretrained_path = pretrained_path_from_config or pretrained_path_from_arg + if pretrained_path: + pretrained_path = pathlib.Path(pretrained_path).expanduser() + pretrained_path = REPO_ROOT / pretrained_path + if pretrained_path.is_dir(): + if (pretrained_path / "last.ckpt").exists(): + pretrained_path = pretrained_path / "last.ckpt" + # If only one file: + elif len(list(pretrained_path.glob("*.ckpt"))) == 1: + pretrained_path = list(pretrained_path.glob("*.ckpt"))[0] + else: + raise ValueError( + "Please provide a checkpoint file or a directory with only one checkpoint file or contains the " + "last.ckpt." + ) + pretrained_path = str(pretrained_path.absolute()) # Don't call resolve() here as huggingface_hub will fail. + assert os.path.isfile(pretrained_path), pretrained_path + assert pretrained_path.endswith(".ckpt"), pretrained_path + print("==============================") + print("Loading pretrained model: ", pretrained_path) + print("==============================") + + model = load_from_checkpoint( + checkpoint_path=pretrained_path, + cls=MotionLMLightning, + config=config, + default_config=default_config, + strict=True, + checkpoint_surgery_func=checkpoint_surgery_func, + # For non-CUDA targets (CPU/MPS), first load weights onto CPU and then move. + # This avoids torch.load failing when the checkpoint contains CUDA storages. + map_location=device if (isinstance(device, torch.device) and device.type == "cuda") or str(device).startswith("cuda") else torch.device("cpu"), + ) + + else: + model = MotionLMLightning(config=config) + + if device is not None: + model.to(device) + + model.eval() + + return model + + +def repeat_for_modes(v, num_modes): + if isinstance(v, list): + return v + d = v.ndim + if d > 1: + v = v[:, None] + if isinstance(v, np.ndarray): + shape = v.shape + v = v.repeat(num_modes, axis=1) + v = v.reshape(-1, *(shape[2:])) + else: + v = v.repeat(1, num_modes, *((1, ) * (d - 1))) + v = v.flatten(0, 1) + else: + v = v.reshape(-1, 1) + if isinstance(v, np.ndarray): + v = v.repeat(num_modes, axis=1) + else: + v = v.repeat(1, num_modes) + v = v.reshape(-1) + return v + + +def expand_for_modes(v, num_modes): + + if isinstance(v, dict): + ret = {} + for k, vv in v.items(): + if isinstance(vv, torch.Tensor): + ret[k] = expand_for_modes(vv, num_modes) + else: + ret[k] = vv + return ret + + assert isinstance(v, torch.Tensor), "Only torch.Tensor is supported. Found: {}".format(type(v)) + d = v.ndim + if d > 1: + v = v[:, None] + v = v.expand(-1, num_modes, *((-1, ) * (d - 1))) + v = v.flatten(0, 1) + else: + v = v.reshape(-1, 1) + v = v.expand(-1, num_modes) + v = v.reshape(-1) + return v + + +def numpy_to_torch(v, device=None): + if isinstance(v, dict): + return {k: numpy_to_torch(vv, device) for k, vv in v.items()} + + if isinstance(v, list) or isinstance(v, (float, int)): + v = np.array(v) + + # Skip conversion for strings + if isinstance(v, str): + return v + + # Convert numpy arrays to torch tensors + if isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.number) or v.dtype == bool: + v = torch.from_numpy(v) + + # Move tensor to the specified device if provided + if isinstance(v, torch.Tensor) and device is not None: + # MPS doesn't support float64; cast to float32 before moving. + if (isinstance(device, torch.device) and device.type == "mps") or str(device).startswith("mps"): + if v.dtype == torch.float64: + v = v.float() + v = v.to(device) + + return v + + +def batch_data(data_dict): + """Add one additional dimension to all values in the dictionary.""" + ret = {} + for k, v in data_dict.items(): + if isinstance(v, np.ndarray): + ret[k] = v[None] + elif isinstance(v, torch.Tensor): + ret[k] = v[None] + else: + ret[k] = v + return ret + + +def unbatch_data(data_dict): + """Remove the additional dimension from all values in the dictionary.""" + ret = {} + first_key = list(data_dict.keys())[0] + B = data_dict[first_key].shape[0] + for k, v in data_dict.items(): + if "track_name" in k and v.ndim == 1: + ret[k] = v + elif isinstance(v, np.ndarray): + assert v.shape[0] == B, f"Shape mismatch for {k}: {v.shape[0]} vs {B}" + ret[k] = v[0] + elif isinstance(v, torch.Tensor): + assert v.shape[0] == B, f"Shape mismatch for {k}: {v.shape[0]} vs {B}" + ret[k] = v[0] + else: + ret[k] = v + return ret + + +def torch_to_numpy(v): + if isinstance(v, dict): + return {k: torch_to_numpy(vv) for k, vv in v.items()} + + if isinstance(v, list): + v = np.array([torch_to_numpy(vv) for vv in v]) + + if isinstance(v, (float, int)): + v = np.array(v) + + # Convert torch tensors to numpy arrays + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + + return v + + +class SafeFallbackEncoder(json.JSONEncoder): + def __init__(self, nan_str="null", **kwargs): + super(SafeFallbackEncoder, self).__init__(**kwargs) + self.nan_str = nan_str + + def default(self, value): + try: + if np.isnan(value): + return self.nan_str + + if (type(value).__module__ == np.__name__ and isinstance(value, np.ndarray)): + return value.tolist() + + if issubclass(type(value), numbers.Integral): + return int(value) + if issubclass(type(value), numbers.Number): + return float(value) + + return super(SafeFallbackEncoder, self).default(value) + + except Exception: + return str(value) # give up, just stringify it (ok for logs) + + +def pretty_print(result, prefix=""): + """ + Should call print(pretty_print(result)) to print the result in a human-readable format. + """ + result = result.copy() + result = {prefix + k: v for k, v in result.items()} + cleaned = json.dumps(result, cls=SafeFallbackEncoder) + return yaml.safe_dump(json.loads(cleaned), default_flow_style=False) + + +rank_zero_print = rank_zero_only(print) + +def get_relative_velocity(vel, heading): + return rotate(vel[..., 0], vel[..., 1], angle=-heading) + + +def safe_entropy(logits, epsilon=1e-5): + """ + Computes the entropy of the given logits safely by replacing NaN and Inf values. + :param logits: Input logits tensor. + :param epsilon: A small value to add to the logits to avoid log(0) which results in NaN. + :return: Mean entropy of the logits. + """ + # Replace NaN and Inf values in logits to avoid errors in entropy computation + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits) + logits = torch.where(torch.isinf(logits), torch.zeros_like(logits), logits) + + # Adding a small epsilon to logits to avoid log(0) + logits = logits + epsilon + + # Compute softmax to get probabilities + probs = F.softmax(logits, dim=-1) + + # Compute entropy + entropy = -(probs * torch.log(probs)).sum(-1) + + # Return the mean entropy + return entropy.mean() \ No newline at end of file