diff --git a/third_party/PointFlowMatch/bash/collect_data.sh b/third_party/PointFlowMatch/bash/collect_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3c777c378b707c2325d8096fb569ad00369f9370
--- /dev/null
+++ b/third_party/PointFlowMatch/bash/collect_data.sh
@@ -0,0 +1,23 @@
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=unplug_charger
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=unplug_charger
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=close_door
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=close_door
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=open_box
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=open_box
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=open_fridge
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=open_fridge
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=take_frame_off_hanger
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=take_frame_off_hanger
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=open_oven
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=open_oven
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=put_books_on_bookshelf
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=put_books_on_bookshelf
+
+python scripts/collect_demos.py --config-name=collect_demos_train save_data=True env_config.vis=False env_config.task_name=take_shoes_out_of_box
+python scripts/collect_demos.py --config-name=collect_demos_valid save_data=True env_config.vis=False env_config.task_name=take_shoes_out_of_box
\ No newline at end of file
diff --git a/third_party/PointFlowMatch/bash/install_deps.sh b/third_party/PointFlowMatch/bash/install_deps.sh
new file mode 100644
index 0000000000000000000000000000000000000000..379863433fe9313a2e57fb233b04ce3e77e5d9c0
--- /dev/null
+++ b/third_party/PointFlowMatch/bash/install_deps.sh
@@ -0,0 +1,15 @@
+if ! [[ -n "${CONDA_PREFIX}" ]]; then
+ echo "You are not inside a conda environment. Please activate your environment first."
+ exit 1
+fi
+
+pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
+pip install fvcore iopath
+pip install --no-index --no-cache-dir pytorch3d==0.7.5 -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt212/download.html
+# Or if on cpu:
+# pip3 install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cpu
+pip install -e .
+rm -R *.egg-info
+
+# Pypose
+pip install --no-deps pypose
diff --git a/third_party/PointFlowMatch/bash/install_rlbench.sh b/third_party/PointFlowMatch/bash/install_rlbench.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c944e3c89af8045191ec458c4ca8d493d58b916a
--- /dev/null
+++ b/third_party/PointFlowMatch/bash/install_rlbench.sh
@@ -0,0 +1,21 @@
+if ! [[ -n "${CONDA_PREFIX}" ]]; then
+ echo "You are not inside a conda environment. Please activate your environment first."
+ exit 1
+fi
+
+if ! [[ -n "${COPPELIASIM_ROOT}" ]]; then
+ echo "COPPELIASIM_ROOT is not defined."
+ exit 1
+fi
+
+# Download Coppelia sim if not present
+if ! [[-e $COPPELIASIM_ROOT]]; then
+ wget https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz
+ mkdir -p $COPPELIASIM_ROOT && tar -xf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz -C $COPPELIASIM_ROOT --strip-components 1
+ rm -rf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz
+fi
+
+# Install PyRep and RLBench
+pip install -r https://raw.githubusercontent.com/stepjam/PyRep/master/requirements.txt
+pip install git+https://github.com/stepjam/PyRep.git
+pip install git+https://github.com/stepjam/RLBench.git
diff --git a/third_party/PointFlowMatch/bash/start_eval.sh b/third_party/PointFlowMatch/bash/start_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4e0b77713fc42d1343994b9742fda204a1616e19
--- /dev/null
+++ b/third_party/PointFlowMatch/bash/start_eval.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+gpu_id=$1
+ckpt_name=$2
+k_steps=${3:-50} # Default k_steps is 50
+seed=${4:-0} # Default seed is 0
+if [ $seed -ne 0 ]; then
+ tmux new-session -d -s "eval_${ckpt_name}_k${k_steps}_${seed}"
+ tmux send-keys -t "eval_${ckpt_name}_k${k_steps}_${seed}" "conda activate pfp_env && CUDA_VISIBLE_DEVICES=$gpu_id WANDB__SERVICE_WAIT=300 xvfb-run -a python scripts/evaluate.py log_wandb=True env_runner.env_config.vis=False policy.ckpt_name=$ckpt_name seed=$seed policy.num_k_infer=$k_steps" Enter
+else
+ for seed in 5678 2468 1357; do
+ tmux new-session -d -s "eval_${ckpt_name}_k${k_steps}_${seed}"
+ tmux send-keys -t "eval_${ckpt_name}_k${k_steps}_${seed}" "conda activate pfp_env && CUDA_VISIBLE_DEVICES=$gpu_id WANDB__SERVICE_WAIT=300 xvfb-run -a python scripts/evaluate.py log_wandb=True env_runner.env_config.vis=False policy.ckpt_name=$ckpt_name seed=$seed policy.num_k_infer=$k_steps" Enter
+ done
+fi
\ No newline at end of file
diff --git a/third_party/PointFlowMatch/bash/start_eval_ksteps.sh b/third_party/PointFlowMatch/bash/start_eval_ksteps.sh
new file mode 100644
index 0000000000000000000000000000000000000000..944a5bf920860f5c5b836ccae1261941dcd21fcb
--- /dev/null
+++ b/third_party/PointFlowMatch/bash/start_eval_ksteps.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+gpu_id=$1
+ckpt_name=$2
+seed=${4:-0} # Default seed is 0
+for k_steps in 1 2 4 8; do
+ if [ $seed -ne 0 ]; then
+ tmux new-session -d -s "eval_${ckpt_name}_k${k_steps}_${seed}"
+ tmux send-keys -t "eval_${ckpt_name}_k${k_steps}_${seed}" "conda activate pfp_env && CUDA_VISIBLE_DEVICES=$gpu_id WANDB__SERVICE_WAIT=300 xvfb-run -a python scripts/evaluate.py log_wandb=True env_runner.env_config.vis=False policy.ckpt_name=$ckpt_name seed=$seed policy.num_k_infer=$k_steps" Enter
+ else
+ for seed in 5678 2468 1357; do
+ tmux new-session -d -s "eval_${ckpt_name}_k${k_steps}_${seed}"
+ tmux send-keys -t "eval_${ckpt_name}_k${k_steps}_${seed}" "conda activate pfp_env && CUDA_VISIBLE_DEVICES=$gpu_id WANDB__SERVICE_WAIT=300 xvfb-run -a python scripts/evaluate.py log_wandb=True env_runner.env_config.vis=False policy.ckpt_name=$ckpt_name seed=$seed policy.num_k_infer=$k_steps" Enter
+ done
+ seed=0
+ fi
+done
\ No newline at end of file
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..f13aae1f5cc9ef247464383b124cad322af05c60
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/evaluate.log
@@ -0,0 +1 @@
+[2026-04-03 00:25:24,802][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/config.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b3727115127431584445aa3d29f0349a1d40819d
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/config.yaml
@@ -0,0 +1,14 @@
+seed: 5678
+log_wandb: false
+env_runner:
+ num_episodes: 1
+ max_episode_length: 200
+ verbose: true
+ env_config:
+ voxel_size: 0.01
+ headless: true
+ vis: false
+policy:
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
+ ckpt_episode: ep1500
+ num_k_infer: 50
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/hydra.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/hydra.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2eea6988110385a8f8719b2049b01b55baad898e
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/hydra.yaml
@@ -0,0 +1,159 @@
+hydra:
+ run:
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ sweep:
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+ launcher:
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
+ sweeper:
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
+ max_batch_size: null
+ params: null
+ help:
+ app_name: ${hydra.job.name}
+ header: '${hydra.help.app_name} is powered by Hydra.
+
+ '
+ footer: 'Powered by Hydra (https://hydra.cc)
+
+ Use --hydra-help to view Hydra specific help
+
+ '
+ template: '${hydra.help.header}
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (group=option)
+
+
+ $APP_CONFIG_GROUPS
+
+
+ == Config ==
+
+ Override anything in the config (foo.bar=value)
+
+
+ $CONFIG
+
+
+ ${hydra.help.footer}
+
+ '
+ hydra_help:
+ template: 'Hydra (${hydra.runtime.version})
+
+ See https://hydra.cc for more info.
+
+
+ == Flags ==
+
+ $FLAGS_HELP
+
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
+ to command line)
+
+
+ $HYDRA_CONFIG_GROUPS
+
+
+ Use ''--cfg hydra'' to Show the Hydra config.
+
+ '
+ hydra_help: ???
+ hydra_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][HYDRA] %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ root:
+ level: INFO
+ handlers:
+ - console
+ loggers:
+ logging_example:
+ level: DEBUG
+ disable_existing_loggers: false
+ job_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: simple
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
+ root:
+ level: INFO
+ handlers:
+ - console
+ - file
+ disable_existing_loggers: false
+ env: {}
+ mode: RUN
+ searchpath: []
+ callbacks: {}
+ output_subdir: .hydra
+ overrides:
+ hydra:
+ - hydra.mode=RUN
+ task:
+ - log_wandb=False
+ - env_runner.env_config.vis=False
+ - env_runner.num_episodes=1
+ - env_runner.max_episode_length=200
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ job:
+ name: evaluate
+ chdir: null
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ id: ???
+ num: ???
+ config_name: eval
+ env_set: {}
+ env_copy: []
+ config:
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: []
+ runtime:
+ version: 1.3.2
+ version_base: '1.3'
+ cwd: /workspace/third_party/PointFlowMatch
+ config_sources:
+ - path: hydra.conf
+ schema: pkg
+ provider: hydra
+ - path: /workspace/third_party/PointFlowMatch/conf
+ schema: file
+ provider: main
+ - path: ''
+ schema: structured
+ provider: schema
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06
+ choices:
+ hydra/env: default
+ hydra/callbacks: null
+ hydra/job_logging: default
+ hydra/hydra_logging: default
+ hydra/hydra_help: default
+ hydra/help: default
+ hydra/sweeper: basic
+ hydra/launcher: basic
+ hydra/output: default
+ verbose: false
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/overrides.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..159e6d765fd53783c3a7cb4293b974109693445c
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/.hydra/overrides.yaml
@@ -0,0 +1,5 @@
+- log_wandb=False
+- env_runner.env_config.vis=False
+- env_runner.num_episodes=1
+- env_runner.max_episode_length=200
+- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..a9dff43f1611e20869a470421563160136b83637
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-26-06/evaluate.log
@@ -0,0 +1 @@
+[2026-04-03 00:26:10,208][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/config.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b3727115127431584445aa3d29f0349a1d40819d
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/config.yaml
@@ -0,0 +1,14 @@
+seed: 5678
+log_wandb: false
+env_runner:
+ num_episodes: 1
+ max_episode_length: 200
+ verbose: true
+ env_config:
+ voxel_size: 0.01
+ headless: true
+ vis: false
+policy:
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
+ ckpt_episode: ep1500
+ num_k_infer: 50
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/hydra.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/hydra.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6fd2ae0afb3fd79562503230ebb229efcbb0aaa6
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/hydra.yaml
@@ -0,0 +1,159 @@
+hydra:
+ run:
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ sweep:
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+ launcher:
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
+ sweeper:
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
+ max_batch_size: null
+ params: null
+ help:
+ app_name: ${hydra.job.name}
+ header: '${hydra.help.app_name} is powered by Hydra.
+
+ '
+ footer: 'Powered by Hydra (https://hydra.cc)
+
+ Use --hydra-help to view Hydra specific help
+
+ '
+ template: '${hydra.help.header}
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (group=option)
+
+
+ $APP_CONFIG_GROUPS
+
+
+ == Config ==
+
+ Override anything in the config (foo.bar=value)
+
+
+ $CONFIG
+
+
+ ${hydra.help.footer}
+
+ '
+ hydra_help:
+ template: 'Hydra (${hydra.runtime.version})
+
+ See https://hydra.cc for more info.
+
+
+ == Flags ==
+
+ $FLAGS_HELP
+
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
+ to command line)
+
+
+ $HYDRA_CONFIG_GROUPS
+
+
+ Use ''--cfg hydra'' to Show the Hydra config.
+
+ '
+ hydra_help: ???
+ hydra_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][HYDRA] %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ root:
+ level: INFO
+ handlers:
+ - console
+ loggers:
+ logging_example:
+ level: DEBUG
+ disable_existing_loggers: false
+ job_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: simple
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
+ root:
+ level: INFO
+ handlers:
+ - console
+ - file
+ disable_existing_loggers: false
+ env: {}
+ mode: RUN
+ searchpath: []
+ callbacks: {}
+ output_subdir: .hydra
+ overrides:
+ hydra:
+ - hydra.mode=RUN
+ task:
+ - log_wandb=False
+ - env_runner.env_config.vis=False
+ - env_runner.num_episodes=1
+ - env_runner.max_episode_length=200
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ job:
+ name: evaluate
+ chdir: null
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ id: ???
+ num: ???
+ config_name: eval
+ env_set: {}
+ env_copy: []
+ config:
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: []
+ runtime:
+ version: 1.3.2
+ version_base: '1.3'
+ cwd: /workspace/third_party/PointFlowMatch
+ config_sources:
+ - path: hydra.conf
+ schema: pkg
+ provider: hydra
+ - path: /workspace/third_party/PointFlowMatch/conf
+ schema: file
+ provider: main
+ - path: ''
+ schema: structured
+ provider: schema
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35
+ choices:
+ hydra/env: default
+ hydra/callbacks: null
+ hydra/job_logging: default
+ hydra/hydra_logging: default
+ hydra/hydra_help: default
+ hydra/help: default
+ hydra/sweeper: basic
+ hydra/launcher: basic
+ hydra/output: default
+ verbose: false
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/overrides.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..159e6d765fd53783c3a7cb4293b974109693445c
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/.hydra/overrides.yaml
@@ -0,0 +1,5 @@
+- log_wandb=False
+- env_runner.env_config.vis=False
+- env_runner.num_episodes=1
+- env_runner.max_episode_length=200
+- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..ff0cc703cdf78ad46cfdf81ce0ae96678cbf9d12
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-28-35/evaluate.log
@@ -0,0 +1,2 @@
+[2026-04-03 00:28:38,193][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
+[2026-04-03 00:28:40,034][root][WARNING] - single robot
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/config.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b3727115127431584445aa3d29f0349a1d40819d
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/config.yaml
@@ -0,0 +1,14 @@
+seed: 5678
+log_wandb: false
+env_runner:
+ num_episodes: 1
+ max_episode_length: 200
+ verbose: true
+ env_config:
+ voxel_size: 0.01
+ headless: true
+ vis: false
+policy:
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
+ ckpt_episode: ep1500
+ num_k_infer: 50
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/hydra.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/hydra.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..30aea46b1ae8442620e9fe151d9f3c5452c12da3
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/hydra.yaml
@@ -0,0 +1,159 @@
+hydra:
+ run:
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ sweep:
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+ launcher:
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
+ sweeper:
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
+ max_batch_size: null
+ params: null
+ help:
+ app_name: ${hydra.job.name}
+ header: '${hydra.help.app_name} is powered by Hydra.
+
+ '
+ footer: 'Powered by Hydra (https://hydra.cc)
+
+ Use --hydra-help to view Hydra specific help
+
+ '
+ template: '${hydra.help.header}
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (group=option)
+
+
+ $APP_CONFIG_GROUPS
+
+
+ == Config ==
+
+ Override anything in the config (foo.bar=value)
+
+
+ $CONFIG
+
+
+ ${hydra.help.footer}
+
+ '
+ hydra_help:
+ template: 'Hydra (${hydra.runtime.version})
+
+ See https://hydra.cc for more info.
+
+
+ == Flags ==
+
+ $FLAGS_HELP
+
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
+ to command line)
+
+
+ $HYDRA_CONFIG_GROUPS
+
+
+ Use ''--cfg hydra'' to Show the Hydra config.
+
+ '
+ hydra_help: ???
+ hydra_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][HYDRA] %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ root:
+ level: INFO
+ handlers:
+ - console
+ loggers:
+ logging_example:
+ level: DEBUG
+ disable_existing_loggers: false
+ job_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: simple
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
+ root:
+ level: INFO
+ handlers:
+ - console
+ - file
+ disable_existing_loggers: false
+ env: {}
+ mode: RUN
+ searchpath: []
+ callbacks: {}
+ output_subdir: .hydra
+ overrides:
+ hydra:
+ - hydra.mode=RUN
+ task:
+ - log_wandb=False
+ - env_runner.env_config.vis=False
+ - env_runner.num_episodes=1
+ - env_runner.max_episode_length=200
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ job:
+ name: evaluate
+ chdir: null
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ id: ???
+ num: ???
+ config_name: eval
+ env_set: {}
+ env_copy: []
+ config:
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: []
+ runtime:
+ version: 1.3.2
+ version_base: '1.3'
+ cwd: /workspace/third_party/PointFlowMatch
+ config_sources:
+ - path: hydra.conf
+ schema: pkg
+ provider: hydra
+ - path: /workspace/third_party/PointFlowMatch/conf
+ schema: file
+ provider: main
+ - path: ''
+ schema: structured
+ provider: schema
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54
+ choices:
+ hydra/env: default
+ hydra/callbacks: null
+ hydra/job_logging: default
+ hydra/hydra_logging: default
+ hydra/hydra_help: default
+ hydra/help: default
+ hydra/sweeper: basic
+ hydra/launcher: basic
+ hydra/output: default
+ verbose: false
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/overrides.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..159e6d765fd53783c3a7cb4293b974109693445c
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/.hydra/overrides.yaml
@@ -0,0 +1,5 @@
+- log_wandb=False
+- env_runner.env_config.vis=False
+- env_runner.num_episodes=1
+- env_runner.max_episode_length=200
+- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..622591bce93d042ba7b533a9d7303e7f01f813b5
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-29-54/evaluate.log
@@ -0,0 +1,2 @@
+[2026-04-03 00:29:57,368][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
+[2026-04-03 00:29:59,155][root][WARNING] - single robot
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/config.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..63d457630bc542bd55c8e49a19df1e1bca0bf905
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/config.yaml
@@ -0,0 +1,14 @@
+seed: 5678
+log_wandb: false
+env_runner:
+ num_episodes: 1
+ max_episode_length: 200
+ verbose: true
+ env_config:
+ voxel_size: 0.01
+ headless: true
+ vis: false
+policy:
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
+ ckpt_episode: ep1500
+ num_k_infer: 10
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/hydra.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/hydra.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..efb6f61aa5adc1d13ca591b8a08a46d1c0895d35
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/hydra.yaml
@@ -0,0 +1,160 @@
+hydra:
+ run:
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ sweep:
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+ launcher:
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
+ sweeper:
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
+ max_batch_size: null
+ params: null
+ help:
+ app_name: ${hydra.job.name}
+ header: '${hydra.help.app_name} is powered by Hydra.
+
+ '
+ footer: 'Powered by Hydra (https://hydra.cc)
+
+ Use --hydra-help to view Hydra specific help
+
+ '
+ template: '${hydra.help.header}
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (group=option)
+
+
+ $APP_CONFIG_GROUPS
+
+
+ == Config ==
+
+ Override anything in the config (foo.bar=value)
+
+
+ $CONFIG
+
+
+ ${hydra.help.footer}
+
+ '
+ hydra_help:
+ template: 'Hydra (${hydra.runtime.version})
+
+ See https://hydra.cc for more info.
+
+
+ == Flags ==
+
+ $FLAGS_HELP
+
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
+ to command line)
+
+
+ $HYDRA_CONFIG_GROUPS
+
+
+ Use ''--cfg hydra'' to Show the Hydra config.
+
+ '
+ hydra_help: ???
+ hydra_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][HYDRA] %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ root:
+ level: INFO
+ handlers:
+ - console
+ loggers:
+ logging_example:
+ level: DEBUG
+ disable_existing_loggers: false
+ job_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: simple
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
+ root:
+ level: INFO
+ handlers:
+ - console
+ - file
+ disable_existing_loggers: false
+ env: {}
+ mode: RUN
+ searchpath: []
+ callbacks: {}
+ output_subdir: .hydra
+ overrides:
+ hydra:
+ - hydra.mode=RUN
+ task:
+ - log_wandb=False
+ - env_runner.env_config.vis=False
+ - env_runner.num_episodes=1
+ - env_runner.max_episode_length=200
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ - policy.num_k_infer=10
+ job:
+ name: evaluate
+ chdir: null
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka,policy.num_k_infer=10
+ id: ???
+ num: ???
+ config_name: eval
+ env_set: {}
+ env_copy: []
+ config:
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: []
+ runtime:
+ version: 1.3.2
+ version_base: '1.3'
+ cwd: /workspace/third_party/PointFlowMatch
+ config_sources:
+ - path: hydra.conf
+ schema: pkg
+ provider: hydra
+ - path: /workspace/third_party/PointFlowMatch/conf
+ schema: file
+ provider: main
+ - path: ''
+ schema: structured
+ provider: schema
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05
+ choices:
+ hydra/env: default
+ hydra/callbacks: null
+ hydra/job_logging: default
+ hydra/hydra_logging: default
+ hydra/hydra_help: default
+ hydra/help: default
+ hydra/sweeper: basic
+ hydra/launcher: basic
+ hydra/output: default
+ verbose: false
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/overrides.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5eeeac32ef56858ba8fe6c42a28cd74d9bed5698
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/.hydra/overrides.yaml
@@ -0,0 +1,6 @@
+- log_wandb=False
+- env_runner.env_config.vis=False
+- env_runner.num_episodes=1
+- env_runner.max_episode_length=200
+- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+- policy.num_k_infer=10
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..e79bfca314a12359961d7f5af5a8e31d0b1325b8
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-36-05/evaluate.log
@@ -0,0 +1,2 @@
+[2026-04-03 00:36:09,107][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
+[2026-04-03 00:36:11,021][root][WARNING] - single robot
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/config.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..63d457630bc542bd55c8e49a19df1e1bca0bf905
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/config.yaml
@@ -0,0 +1,14 @@
+seed: 5678
+log_wandb: false
+env_runner:
+ num_episodes: 1
+ max_episode_length: 200
+ verbose: true
+ env_config:
+ voxel_size: 0.01
+ headless: true
+ vis: false
+policy:
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
+ ckpt_episode: ep1500
+ num_k_infer: 10
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/hydra.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/hydra.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3134f942214aaeb6b1c0f037e7904007979d4494
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/hydra.yaml
@@ -0,0 +1,160 @@
+hydra:
+ run:
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ sweep:
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+ launcher:
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
+ sweeper:
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
+ max_batch_size: null
+ params: null
+ help:
+ app_name: ${hydra.job.name}
+ header: '${hydra.help.app_name} is powered by Hydra.
+
+ '
+ footer: 'Powered by Hydra (https://hydra.cc)
+
+ Use --hydra-help to view Hydra specific help
+
+ '
+ template: '${hydra.help.header}
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (group=option)
+
+
+ $APP_CONFIG_GROUPS
+
+
+ == Config ==
+
+ Override anything in the config (foo.bar=value)
+
+
+ $CONFIG
+
+
+ ${hydra.help.footer}
+
+ '
+ hydra_help:
+ template: 'Hydra (${hydra.runtime.version})
+
+ See https://hydra.cc for more info.
+
+
+ == Flags ==
+
+ $FLAGS_HELP
+
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
+ to command line)
+
+
+ $HYDRA_CONFIG_GROUPS
+
+
+ Use ''--cfg hydra'' to Show the Hydra config.
+
+ '
+ hydra_help: ???
+ hydra_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][HYDRA] %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ root:
+ level: INFO
+ handlers:
+ - console
+ loggers:
+ logging_example:
+ level: DEBUG
+ disable_existing_loggers: false
+ job_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: simple
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
+ root:
+ level: INFO
+ handlers:
+ - console
+ - file
+ disable_existing_loggers: false
+ env: {}
+ mode: RUN
+ searchpath: []
+ callbacks: {}
+ output_subdir: .hydra
+ overrides:
+ hydra:
+ - hydra.mode=RUN
+ task:
+ - log_wandb=False
+ - env_runner.env_config.vis=False
+ - env_runner.num_episodes=1
+ - env_runner.max_episode_length=200
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ - policy.num_k_infer=10
+ job:
+ name: evaluate
+ chdir: null
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka,policy.num_k_infer=10
+ id: ???
+ num: ???
+ config_name: eval
+ env_set: {}
+ env_copy: []
+ config:
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: []
+ runtime:
+ version: 1.3.2
+ version_base: '1.3'
+ cwd: /workspace/third_party/PointFlowMatch
+ config_sources:
+ - path: hydra.conf
+ schema: pkg
+ provider: hydra
+ - path: /workspace/third_party/PointFlowMatch/conf
+ schema: file
+ provider: main
+ - path: ''
+ schema: structured
+ provider: schema
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03
+ choices:
+ hydra/env: default
+ hydra/callbacks: null
+ hydra/job_logging: default
+ hydra/hydra_logging: default
+ hydra/hydra_help: default
+ hydra/help: default
+ hydra/sweeper: basic
+ hydra/launcher: basic
+ hydra/output: default
+ verbose: false
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/overrides.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5eeeac32ef56858ba8fe6c42a28cd74d9bed5698
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/.hydra/overrides.yaml
@@ -0,0 +1,6 @@
+- log_wandb=False
+- env_runner.env_config.vis=False
+- env_runner.num_episodes=1
+- env_runner.max_episode_length=200
+- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+- policy.num_k_infer=10
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..69a04f1ca8ad0d528681b6722f9b6a955174c24b
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-39-03/evaluate.log
@@ -0,0 +1,2 @@
+[2026-04-03 00:39:06,341][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
+[2026-04-03 00:39:08,232][root][WARNING] - single robot
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/config.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b3727115127431584445aa3d29f0349a1d40819d
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/config.yaml
@@ -0,0 +1,14 @@
+seed: 5678
+log_wandb: false
+env_runner:
+ num_episodes: 1
+ max_episode_length: 200
+ verbose: true
+ env_config:
+ voxel_size: 0.01
+ headless: true
+ vis: false
+policy:
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
+ ckpt_episode: ep1500
+ num_k_infer: 50
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/hydra.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/hydra.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae731ee1c1c926db649ccb612c919eb01541d373
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/hydra.yaml
@@ -0,0 +1,160 @@
+hydra:
+ run:
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ sweep:
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+ launcher:
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
+ sweeper:
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
+ max_batch_size: null
+ params: null
+ help:
+ app_name: ${hydra.job.name}
+ header: '${hydra.help.app_name} is powered by Hydra.
+
+ '
+ footer: 'Powered by Hydra (https://hydra.cc)
+
+ Use --hydra-help to view Hydra specific help
+
+ '
+ template: '${hydra.help.header}
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (group=option)
+
+
+ $APP_CONFIG_GROUPS
+
+
+ == Config ==
+
+ Override anything in the config (foo.bar=value)
+
+
+ $CONFIG
+
+
+ ${hydra.help.footer}
+
+ '
+ hydra_help:
+ template: 'Hydra (${hydra.runtime.version})
+
+ See https://hydra.cc for more info.
+
+
+ == Flags ==
+
+ $FLAGS_HELP
+
+
+ == Configuration groups ==
+
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
+ to command line)
+
+
+ $HYDRA_CONFIG_GROUPS
+
+
+ Use ''--cfg hydra'' to Show the Hydra config.
+
+ '
+ hydra_help: ???
+ hydra_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][HYDRA] %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ root:
+ level: INFO
+ handlers:
+ - console
+ loggers:
+ logging_example:
+ level: DEBUG
+ disable_existing_loggers: false
+ job_logging:
+ version: 1
+ formatters:
+ simple:
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
+ handlers:
+ console:
+ class: logging.StreamHandler
+ formatter: simple
+ stream: ext://sys.stdout
+ file:
+ class: logging.FileHandler
+ formatter: simple
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
+ root:
+ level: INFO
+ handlers:
+ - console
+ - file
+ disable_existing_loggers: false
+ env: {}
+ mode: RUN
+ searchpath: []
+ callbacks: {}
+ output_subdir: .hydra
+ overrides:
+ hydra:
+ - hydra.mode=RUN
+ task:
+ - log_wandb=False
+ - env_runner.env_config.vis=False
+ - env_runner.num_episodes=1
+ - env_runner.max_episode_length=200
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+ - policy.num_k_infer=50
+ job:
+ name: evaluate
+ chdir: null
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka,policy.num_k_infer=50
+ id: ???
+ num: ???
+ config_name: eval
+ env_set: {}
+ env_copy: []
+ config:
+ override_dirname:
+ kv_sep: '='
+ item_sep: ','
+ exclude_keys: []
+ runtime:
+ version: 1.3.2
+ version_base: '1.3'
+ cwd: /workspace/third_party/PointFlowMatch
+ config_sources:
+ - path: hydra.conf
+ schema: pkg
+ provider: hydra
+ - path: /workspace/third_party/PointFlowMatch/conf
+ schema: file
+ provider: main
+ - path: ''
+ schema: structured
+ provider: schema
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26
+ choices:
+ hydra/env: default
+ hydra/callbacks: null
+ hydra/job_logging: default
+ hydra/hydra_logging: default
+ hydra/hydra_help: default
+ hydra/help: default
+ hydra/sweeper: basic
+ hydra/launcher: basic
+ hydra/output: default
+ verbose: false
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/overrides.yaml b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8548d973d2404c1f21a0ef23479d4098b6dd4939
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/.hydra/overrides.yaml
@@ -0,0 +1,6 @@
+- log_wandb=False
+- env_runner.env_config.vis=False
+- env_runner.num_episodes=1
+- env_runner.max_episode_length=200
+- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
+- policy.num_k_infer=50
diff --git a/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/evaluate.log b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/evaluate.log
new file mode 100644
index 0000000000000000000000000000000000000000..a8078c93b68bac45549d66d80d87e03250fc3fe4
--- /dev/null
+++ b/third_party/PointFlowMatch/outputs/2026-04-03/00-48-26/evaluate.log
@@ -0,0 +1,2 @@
+[2026-04-03 00:48:33,178][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
+[2026-04-03 00:48:35,019][root][WARNING] - single robot
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/__init__.py b/third_party/diffusion_policy/diffusion_policy/env/kitchen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f553f35d19d476ef9f0bcd549f10dfa12606ed5b
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/__init__.py
@@ -0,0 +1,30 @@
+"""Environments using kitchen and Franka robot."""
+from gym.envs.registration import register
+
+register(
+ id="kitchen-microwave-kettle-light-slider-v0",
+ entry_point="diffusion_policy.env.kitchen.v0:KitchenMicrowaveKettleLightSliderV0",
+ max_episode_steps=280,
+ reward_threshold=1.0,
+)
+
+register(
+ id="kitchen-microwave-kettle-burner-light-v0",
+ entry_point="diffusion_policy.env.kitchen.v0:KitchenMicrowaveKettleBottomBurnerLightV0",
+ max_episode_steps=280,
+ reward_threshold=1.0,
+)
+
+register(
+ id="kitchen-kettle-microwave-light-slider-v0",
+ entry_point="diffusion_policy.env.kitchen.v0:KitchenKettleMicrowaveLightSliderV0",
+ max_episode_steps=280,
+ reward_threshold=1.0,
+)
+
+register(
+ id="kitchen-all-v0",
+ entry_point="diffusion_policy.env.kitchen.v0:KitchenAllV0",
+ max_episode_steps=280,
+ reward_threshold=1.0,
+)
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/base.py b/third_party/diffusion_policy/diffusion_policy/env/kitchen/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..61c9650bad3daeb456785f6493fbb4d154937400
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/base.py
@@ -0,0 +1,153 @@
+import sys
+import os
+# hack to import adept envs
+ADEPT_DIR = os.path.join(os.path.dirname(__file__), 'relay_policy_learning', 'adept_envs')
+sys.path.append(ADEPT_DIR)
+
+import logging
+import numpy as np
+import adept_envs
+from adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1
+
+OBS_ELEMENT_INDICES = {
+ "bottom burner": np.array([11, 12]),
+ "top burner": np.array([15, 16]),
+ "light switch": np.array([17, 18]),
+ "slide cabinet": np.array([19]),
+ "hinge cabinet": np.array([20, 21]),
+ "microwave": np.array([22]),
+ "kettle": np.array([23, 24, 25, 26, 27, 28, 29]),
+}
+OBS_ELEMENT_GOALS = {
+ "bottom burner": np.array([-0.88, -0.01]),
+ "top burner": np.array([-0.92, -0.01]),
+ "light switch": np.array([-0.69, -0.05]),
+ "slide cabinet": np.array([0.37]),
+ "hinge cabinet": np.array([0.0, 1.45]),
+ "microwave": np.array([-0.75]),
+ "kettle": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]),
+}
+BONUS_THRESH = 0.3
+logger = logging.getLogger()
+
+
+class KitchenBase(KitchenTaskRelaxV1):
+ # A string of element names. The robot's task is then to modify each of
+ # these elements appropriately.
+ TASK_ELEMENTS = []
+ ALL_TASKS = [
+ "bottom burner",
+ "top burner",
+ "light switch",
+ "slide cabinet",
+ "hinge cabinet",
+ "microwave",
+ "kettle",
+ ]
+ REMOVE_TASKS_WHEN_COMPLETE = True
+ TERMINATE_ON_TASK_COMPLETE = True
+ TERMINATE_ON_WRONG_COMPLETE = False
+ COMPLETE_IN_ANY_ORDER = (
+ True # This allows for the tasks to be completed in arbitrary order.
+ )
+
+ def __init__(
+ self, dataset_url=None, ref_max_score=None, ref_min_score=None,
+ use_abs_action=False,
+ **kwargs
+ ):
+ self.tasks_to_complete = list(self.TASK_ELEMENTS)
+ self.goal_masking = True
+ super(KitchenBase, self).__init__(use_abs_action=use_abs_action, **kwargs)
+
+ def set_goal_masking(self, goal_masking=True):
+ """Sets goal masking for goal-conditioned approaches (like RPL)."""
+ self.goal_masking = goal_masking
+
+ def _get_task_goal(self, task=None, actually_return_goal=False):
+ if task is None:
+ task = ["microwave", "kettle", "bottom burner", "light switch"]
+ new_goal = np.zeros_like(self.goal)
+ if self.goal_masking and not actually_return_goal:
+ return new_goal
+ for element in task:
+ element_idx = OBS_ELEMENT_INDICES[element]
+ element_goal = OBS_ELEMENT_GOALS[element]
+ new_goal[element_idx] = element_goal
+
+ return new_goal
+
+ def reset_model(self):
+ self.tasks_to_complete = list(self.TASK_ELEMENTS)
+ return super(KitchenBase, self).reset_model()
+
+ def _get_reward_n_score(self, obs_dict):
+ reward_dict, score = super(KitchenBase, self)._get_reward_n_score(obs_dict)
+ reward = 0.0
+ next_q_obs = obs_dict["qp"]
+ next_obj_obs = obs_dict["obj_qp"]
+ next_goal = self._get_task_goal(
+ task=self.TASK_ELEMENTS, actually_return_goal=True
+ ) # obs_dict['goal']
+ idx_offset = len(next_q_obs)
+ completions = []
+ all_completed_so_far = True
+ for element in self.tasks_to_complete:
+ element_idx = OBS_ELEMENT_INDICES[element]
+ distance = np.linalg.norm(
+ next_obj_obs[..., element_idx - idx_offset] - next_goal[element_idx]
+ )
+ complete = distance < BONUS_THRESH
+ condition = (
+ complete and all_completed_so_far
+ if not self.COMPLETE_IN_ANY_ORDER
+ else complete
+ )
+ if condition: # element == self.tasks_to_complete[0]:
+ print("Task {} completed!".format(element))
+ completions.append(element)
+ all_completed_so_far = all_completed_so_far and complete
+ if self.REMOVE_TASKS_WHEN_COMPLETE:
+ [self.tasks_to_complete.remove(element) for element in completions]
+ bonus = float(len(completions))
+ reward_dict["bonus"] = bonus
+ reward_dict["r_total"] = bonus
+ score = bonus
+ return reward_dict, score
+
+ def step(self, a, b=None):
+ obs, reward, done, env_info = super(KitchenBase, self).step(a, b=b)
+ if self.TERMINATE_ON_TASK_COMPLETE:
+ done = not self.tasks_to_complete
+ if self.TERMINATE_ON_WRONG_COMPLETE:
+ all_goal = self._get_task_goal(task=self.ALL_TASKS)
+ for wrong_task in list(set(self.ALL_TASKS) - set(self.TASK_ELEMENTS)):
+ element_idx = OBS_ELEMENT_INDICES[wrong_task]
+ distance = np.linalg.norm(obs[..., element_idx] - all_goal[element_idx])
+ complete = distance < BONUS_THRESH
+ if complete:
+ done = True
+ break
+ env_info["completed_tasks"] = set(self.TASK_ELEMENTS) - set(
+ self.tasks_to_complete
+ )
+ return obs, reward, done, env_info
+
+ def get_goal(self):
+ """Loads goal state from dataset for goal-conditioned approaches (like RPL)."""
+ raise NotImplementedError
+
+ def _split_data_into_seqs(self, data):
+ """Splits dataset object into list of sequence dicts."""
+ seq_end_idxs = np.where(data["terminals"])[0]
+ start = 0
+ seqs = []
+ for end_idx in seq_end_idxs:
+ seqs.append(
+ dict(
+ states=data["observations"][start : end_idx + 1],
+ actions=data["actions"][start : end_idx + 1],
+ )
+ )
+ start = end_idx + 1
+ return seqs
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/kitchen_lowdim_wrapper.py b/third_party/diffusion_policy/diffusion_policy/env/kitchen/kitchen_lowdim_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6908a425eda6cf82d2e4db87c6e3d864938faf1
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/kitchen_lowdim_wrapper.py
@@ -0,0 +1,49 @@
+from typing import List, Dict, Optional, Optional
+import numpy as np
+import gym
+from gym.spaces import Box
+from diffusion_policy.env.kitchen.base import KitchenBase
+
+class KitchenLowdimWrapper(gym.Env):
+ def __init__(self,
+ env: KitchenBase,
+ init_qpos: Optional[np.ndarray]=None,
+ init_qvel: Optional[np.ndarray]=None,
+ render_hw = (240,360)
+ ):
+ self.env = env
+ self.init_qpos = init_qpos
+ self.init_qvel = init_qvel
+ self.render_hw = render_hw
+
+ @property
+ def action_space(self):
+ return self.env.action_space
+
+ @property
+ def observation_space(self):
+ return self.env.observation_space
+
+ def seed(self, seed=None):
+ return self.env.seed(seed)
+
+ def reset(self):
+ if self.init_qpos is not None:
+ # reset anyway to be safe, not very expensive
+ _ = self.env.reset()
+ # start from known state
+ self.env.set_state(self.init_qpos, self.init_qvel)
+ obs = self.env._get_obs()
+ return obs
+ # obs, _, _, _ = self.env.step(np.zeros_like(
+ # self.action_space.sample()))
+ # return obs
+ else:
+ return self.env.reset()
+
+ def render(self, mode='rgb_array'):
+ h, w = self.render_hw
+ return self.env.render(mode=mode, width=w, height=h)
+
+ def step(self, a):
+ return self.env.step(a)
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/kitchen_util.py b/third_party/diffusion_policy/diffusion_policy/env/kitchen/kitchen_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..59852c1946c8cc1eea17dbd9bd7bca73df095ebe
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/kitchen_util.py
@@ -0,0 +1,51 @@
+import struct
+import numpy as np
+
+def parse_mjl_logs(read_filename, skipamount):
+ with open(read_filename, mode='rb') as file:
+ fileContent = file.read()
+ headers = struct.unpack('iiiiiii', fileContent[:28])
+ nq = headers[0]
+ nv = headers[1]
+ nu = headers[2]
+ nmocap = headers[3]
+ nsensordata = headers[4]
+ nuserdata = headers[5]
+ name_len = headers[6]
+ name = struct.unpack(str(name_len) + 's', fileContent[28:28+name_len])[0]
+ rem_size = len(fileContent[28 + name_len:])
+ num_floats = int(rem_size/4)
+ dat = np.asarray(struct.unpack(str(num_floats) + 'f', fileContent[28+name_len:]))
+ recsz = 1 + nq + nv + nu + 7*nmocap + nsensordata + nuserdata
+ if rem_size % recsz != 0:
+ print("ERROR")
+ else:
+ dat = np.reshape(dat, (int(len(dat)/recsz), recsz))
+ dat = dat.T
+
+ time = dat[0,:][::skipamount] - 0*dat[0, 0]
+ qpos = dat[1:nq + 1, :].T[::skipamount, :]
+ qvel = dat[nq+1:nq+nv+1,:].T[::skipamount, :]
+ ctrl = dat[nq+nv+1:nq+nv+nu+1,:].T[::skipamount,:]
+ mocap_pos = dat[nq+nv+nu+1:nq+nv+nu+3*nmocap+1,:].T[::skipamount, :]
+ mocap_quat = dat[nq+nv+nu+3*nmocap+1:nq+nv+nu+7*nmocap+1,:].T[::skipamount, :]
+ sensordata = dat[nq+nv+nu+7*nmocap+1:nq+nv+nu+7*nmocap+nsensordata+1,:].T[::skipamount,:]
+ userdata = dat[nq+nv+nu+7*nmocap+nsensordata+1:,:].T[::skipamount,:]
+
+ data = dict(nq=nq,
+ nv=nv,
+ nu=nu,
+ nmocap=nmocap,
+ nsensordata=nsensordata,
+ name=name,
+ time=time,
+ qpos=qpos,
+ qvel=qvel,
+ ctrl=ctrl,
+ mocap_pos=mocap_pos,
+ mocap_quat=mocap_quat,
+ sensordata=sensordata,
+ userdata=userdata,
+ logName = read_filename
+ )
+ return data
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/adept_models/__init__.py b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/adept_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/LICENSE b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/README.md b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..45d22577322991774d22b337ad504921fce55045
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/README.md
@@ -0,0 +1,9 @@
+# franka
+Franka panda mujoco models
+
+
+# Environment
+
+franka_panda.xml | coming soon
+:-------------------------:|:-------------------------:
+ | coming soon
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator0.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator0.xml
new file mode 100644
index 0000000000000000000000000000000000000000..86ee47c136768ead19ffa37e15cbab94e9067d1b
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator0.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator1.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator1.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a8eda4e44bb858f58d643b57a0c1ef837904af6b
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator1.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/assets.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/assets.xml
new file mode 100644
index 0000000000000000000000000000000000000000..4f2cdedb587f8d1ab48e4e163d767ded17ccbe81
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/assets.xml
@@ -0,0 +1,63 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/chain0_overlay.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/chain0_overlay.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e64f497e68de70c5781f89a2ba18ddb62af8bb32
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/chain0_overlay.xml
@@ -0,0 +1,62 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/teleop_actuator.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/teleop_actuator.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e5e46db5d33431e86a187b18a8a1c7986239fc67
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/assets/teleop_actuator.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/bi-franka_panda.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/bi-franka_panda.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c3072697a4411e7726a621bc47ae1bd03d478d42
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/bi-franka_panda.xml
@@ -0,0 +1,81 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ /
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.xml
new file mode 100644
index 0000000000000000000000000000000000000000..07c519380a8e0f9c44142a2f0edcc55318a0e980
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.xml
@@ -0,0 +1,38 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/franka_panda_teleop.xml b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/franka_panda_teleop.xml
new file mode 100644
index 0000000000000000000000000000000000000000..cdbf8cd45c5563b4876efbc73d3b21bb95511777
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/franka_panda_teleop.xml
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3b87289fea8128bcec3e0b4d174b169124e8e444
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl
new file mode 100644
index 0000000000000000000000000000000000000000..4e820902eba7b9d959a2e0cc8091f4b0f09ed77a
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..def070c7077c0ddb33bbe16cd6c75f19dd318734
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..426bcf2d7a04e067e01ab198d0ccfef63c6846e8
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b369f1599a3c1356611716621f998bd8b5a8863b
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..25162eeedf286d1e27fdd4ba38950ae90678bb0f
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl
new file mode 100644
index 0000000000000000000000000000000000000000..76c8c33c3e1e6c184f8c3693b390892c25b179e3
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3006a0b9a695f020e1887128d805d15aaa7fd342
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2e9594a873f97e572ec68a4e0ab6d65f41f5007e
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl
new file mode 100644
index 0000000000000000000000000000000000000000..0532d057507637533d25f4b3ed451f213685a61d
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2a5a2567dc89f46967556f6f2d0250b4cc955dd4
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl differ
diff --git a/third_party/diffusion_policy/diffusion_policy/env/kitchen/v0.py b/third_party/diffusion_policy/diffusion_policy/env/kitchen/v0.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a5b4ecf3da1292a873b93c5f1cd1dec21fbfb9
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/kitchen/v0.py
@@ -0,0 +1,20 @@
+from diffusion_policy.env.kitchen.base import KitchenBase
+
+
+class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase):
+ TASK_ELEMENTS = ["microwave", "kettle", "bottom burner", "light switch"]
+ COMPLETE_IN_ANY_ORDER = False
+
+
+class KitchenMicrowaveKettleLightSliderV0(KitchenBase):
+ TASK_ELEMENTS = ["microwave", "kettle", "light switch", "slide cabinet"]
+ COMPLETE_IN_ANY_ORDER = False
+
+
+class KitchenKettleMicrowaveLightSliderV0(KitchenBase):
+ TASK_ELEMENTS = ["kettle", "microwave", "light switch", "slide cabinet"]
+ COMPLETE_IN_ANY_ORDER = False
+
+
+class KitchenAllV0(KitchenBase):
+ TASK_ELEMENTS = KitchenBase.ALL_TASKS
diff --git a/third_party/diffusion_policy/diffusion_policy/env/pusht/__init__.py b/third_party/diffusion_policy/diffusion_policy/env/pusht/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..945de59a42234839697140f35e1681c81e9505c7
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/pusht/__init__.py
@@ -0,0 +1,9 @@
+from gym.envs.registration import register
+import diffusion_policy.env.pusht
+
+register(
+ id='pusht-keypoints-v0',
+ entry_point='envs.pusht.pusht_keypoints_env:PushTKeypointsEnv',
+ max_episode_steps=200,
+ reward_threshold=1.0
+)
\ No newline at end of file
diff --git a/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_env.py b/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a25cd663ddbe96c66dc3f29e1882570c234abe2
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_env.py
@@ -0,0 +1,367 @@
+import gym
+from gym import spaces
+
+import collections
+import numpy as np
+import pygame
+import pymunk
+import pymunk.pygame_util
+from pymunk.vec2d import Vec2d
+import shapely.geometry as sg
+import cv2
+import skimage.transform as st
+from diffusion_policy.env.pusht.pymunk_override import DrawOptions
+
+
+def pymunk_to_shapely(body, shapes):
+ geoms = list()
+ for shape in shapes:
+ if isinstance(shape, pymunk.shapes.Poly):
+ verts = [body.local_to_world(v) for v in shape.get_vertices()]
+ verts += [verts[0]]
+ geoms.append(sg.Polygon(verts))
+ else:
+ raise RuntimeError(f'Unsupported shape type {type(shape)}')
+ geom = sg.MultiPolygon(geoms)
+ return geom
+
+class PushTEnv(gym.Env):
+ metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
+ reward_range = (0., 1.)
+
+ def __init__(self,
+ legacy=False,
+ block_cog=None, damping=None,
+ render_action=True,
+ render_size=96,
+ reset_to_state=None
+ ):
+ self._seed = None
+ self.seed()
+ self.window_size = ws = 512 # The size of the PyGame window
+ self.render_size = render_size
+ self.sim_hz = 100
+ # Local controller params.
+ self.k_p, self.k_v = 100, 20 # PD control.z
+ self.control_hz = self.metadata['video.frames_per_second']
+ # legcay set_state for data compatibility
+ self.legacy = legacy
+
+ # agent_pos, block_pos, block_angle
+ self.observation_space = spaces.Box(
+ low=np.array([0,0,0,0,0], dtype=np.float64),
+ high=np.array([ws,ws,ws,ws,np.pi*2], dtype=np.float64),
+ shape=(5,),
+ dtype=np.float64
+ )
+
+ # positional goal for agent
+ self.action_space = spaces.Box(
+ low=np.array([0,0], dtype=np.float64),
+ high=np.array([ws,ws], dtype=np.float64),
+ shape=(2,),
+ dtype=np.float64
+ )
+
+ self.block_cog = block_cog
+ self.damping = damping
+ self.render_action = render_action
+
+ """
+ If human-rendering is used, `self.window` will be a reference
+ to the window that we draw to. `self.clock` will be a clock that is used
+ to ensure that the environment is rendered at the correct framerate in
+ human-mode. They will remain `None` until human-mode is used for the
+ first time.
+ """
+ self.window = None
+ self.clock = None
+ self.screen = None
+
+ self.space = None
+ self.teleop = None
+ self.render_buffer = None
+ self.latest_action = None
+ self.reset_to_state = reset_to_state
+
+ def reset(self):
+ seed = self._seed
+ self._setup()
+ if self.block_cog is not None:
+ self.block.center_of_gravity = self.block_cog
+ if self.damping is not None:
+ self.space.damping = self.damping
+
+ # use legacy RandomState for compatibility
+ state = self.reset_to_state
+ if state is None:
+ rs = np.random.RandomState(seed=seed)
+ state = np.array([
+ rs.randint(50, 450), rs.randint(50, 450),
+ rs.randint(100, 400), rs.randint(100, 400),
+ rs.randn() * 2 * np.pi - np.pi
+ ])
+ self._set_state(state)
+
+ observation = self._get_obs()
+ return observation
+
+ def step(self, action):
+ dt = 1.0 / self.sim_hz
+ self.n_contact_points = 0
+ n_steps = self.sim_hz // self.control_hz
+ if action is not None:
+ self.latest_action = action
+ for i in range(n_steps):
+ # Step PD control.
+ # self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
+ acceleration = self.k_p * (action - self.agent.position) + self.k_v * (Vec2d(0, 0) - self.agent.velocity)
+ self.agent.velocity += acceleration * dt
+
+ # Step physics.
+ self.space.step(dt)
+
+ # compute reward
+ goal_body = self._get_goal_pose_body(self.goal_pose)
+ goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
+ block_geom = pymunk_to_shapely(self.block, self.block.shapes)
+
+ intersection_area = goal_geom.intersection(block_geom).area
+ goal_area = goal_geom.area
+ coverage = intersection_area / goal_area
+ reward = np.clip(coverage / self.success_threshold, 0, 1)
+ done = coverage > self.success_threshold
+
+ observation = self._get_obs()
+ info = self._get_info()
+
+ return observation, reward, done, info
+
+ def render(self, mode):
+ return self._render_frame(mode)
+
+ def teleop_agent(self):
+ TeleopAgent = collections.namedtuple('TeleopAgent', ['act'])
+ def act(obs):
+ act = None
+ mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
+ if self.teleop or (mouse_position - self.agent.position).length < 30:
+ self.teleop = True
+ act = mouse_position
+ return act
+ return TeleopAgent(act)
+
+ def _get_obs(self):
+ obs = np.array(
+ tuple(self.agent.position) \
+ + tuple(self.block.position) \
+ + (self.block.angle % (2 * np.pi),))
+ return obs
+
+ def _get_goal_pose_body(self, pose):
+ mass = 1
+ inertia = pymunk.moment_for_box(mass, (50, 100))
+ body = pymunk.Body(mass, inertia)
+ # preserving the legacy assignment order for compatibility
+ # the order here doesn't matter somehow, maybe because CoM is aligned with body origin
+ body.position = pose[:2].tolist()
+ body.angle = pose[2]
+ return body
+
+ def _get_info(self):
+ n_steps = self.sim_hz // self.control_hz
+ n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
+ info = {
+ 'pos_agent': np.array(self.agent.position),
+ 'vel_agent': np.array(self.agent.velocity),
+ 'block_pose': np.array(list(self.block.position) + [self.block.angle]),
+ 'goal_pose': self.goal_pose,
+ 'n_contacts': n_contact_points_per_step}
+ return info
+
+ def _render_frame(self, mode):
+
+ if self.window is None and mode == "human":
+ pygame.init()
+ pygame.display.init()
+ self.window = pygame.display.set_mode((self.window_size, self.window_size))
+ if self.clock is None and mode == "human":
+ self.clock = pygame.time.Clock()
+
+ canvas = pygame.Surface((self.window_size, self.window_size))
+ canvas.fill((255, 255, 255))
+ self.screen = canvas
+
+ draw_options = DrawOptions(canvas)
+
+ # Draw goal pose.
+ goal_body = self._get_goal_pose_body(self.goal_pose)
+ for shape in self.block.shapes:
+ goal_points = [pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) for v in shape.get_vertices()]
+ goal_points += [goal_points[0]]
+ pygame.draw.polygon(canvas, self.goal_color, goal_points)
+
+ # Draw agent and block.
+ self.space.debug_draw(draw_options)
+
+ if mode == "human":
+ # The following line copies our drawings from `canvas` to the visible window
+ self.window.blit(canvas, canvas.get_rect())
+ pygame.event.pump()
+ pygame.display.update()
+
+ # the clock is already ticked during in step for "human"
+
+
+ img = np.transpose(
+ np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
+ )
+ img = cv2.resize(img, (self.render_size, self.render_size))
+ if self.render_action:
+ if self.render_action and (self.latest_action is not None):
+ action = np.array(self.latest_action)
+ coord = (action / 512 * 96).astype(np.int32)
+ marker_size = int(8/96*self.render_size)
+ thickness = int(1/96*self.render_size)
+ cv2.drawMarker(img, coord,
+ color=(255,0,0), markerType=cv2.MARKER_CROSS,
+ markerSize=marker_size, thickness=thickness)
+ return img
+
+
+ def close(self):
+ if self.window is not None:
+ pygame.display.quit()
+ pygame.quit()
+
+ def seed(self, seed=None):
+ if seed is None:
+ seed = np.random.randint(0,25536)
+ self._seed = seed
+ self.np_random = np.random.default_rng(seed)
+
+ def _handle_collision(self, arbiter, space, data):
+ self.n_contact_points += len(arbiter.contact_point_set.points)
+
+ def _set_state(self, state):
+ if isinstance(state, np.ndarray):
+ state = state.tolist()
+ pos_agent = state[:2]
+ pos_block = state[2:4]
+ rot_block = state[4]
+ self.agent.position = pos_agent
+ # setting angle rotates with respect to center of mass
+ # therefore will modify the geometric position
+ # if not the same as CoM
+ # therefore should be modified first.
+ if self.legacy:
+ # for compatibility with legacy data
+ self.block.position = pos_block
+ self.block.angle = rot_block
+ else:
+ self.block.angle = rot_block
+ self.block.position = pos_block
+
+ # Run physics to take effect
+ self.space.step(1.0 / self.sim_hz)
+
+ def _set_state_local(self, state_local):
+ agent_pos_local = state_local[:2]
+ block_pose_local = state_local[2:]
+ tf_img_obj = st.AffineTransform(
+ translation=self.goal_pose[:2],
+ rotation=self.goal_pose[2])
+ tf_obj_new = st.AffineTransform(
+ translation=block_pose_local[:2],
+ rotation=block_pose_local[2]
+ )
+ tf_img_new = st.AffineTransform(
+ matrix=tf_img_obj.params @ tf_obj_new.params
+ )
+ agent_pos_new = tf_img_new(agent_pos_local)
+ new_state = np.array(
+ list(agent_pos_new[0]) + list(tf_img_new.translation) \
+ + [tf_img_new.rotation])
+ self._set_state(new_state)
+ return new_state
+
+ def _setup(self):
+ self.space = pymunk.Space()
+ self.space.gravity = 0, 0
+ self.space.damping = 0
+ self.teleop = False
+ self.render_buffer = list()
+
+ # Add walls.
+ walls = [
+ self._add_segment((5, 506), (5, 5), 2),
+ self._add_segment((5, 5), (506, 5), 2),
+ self._add_segment((506, 5), (506, 506), 2),
+ self._add_segment((5, 506), (506, 506), 2)
+ ]
+ self.space.add(*walls)
+
+ # Add agent, block, and goal zone.
+ self.agent = self.add_circle((256, 400), 15)
+ self.block = self.add_tee((256, 300), 0)
+ self.goal_color = pygame.Color('LightGreen')
+ self.goal_pose = np.array([256,256,np.pi/4]) # x, y, theta (in radians)
+
+ # Add collision handling
+ self.collision_handeler = self.space.add_collision_handler(0, 0)
+ self.collision_handeler.post_solve = self._handle_collision
+ self.n_contact_points = 0
+
+ self.max_score = 50 * 100
+ self.success_threshold = 0.95 # 95% coverage.
+
+ def _add_segment(self, a, b, radius):
+ shape = pymunk.Segment(self.space.static_body, a, b, radius)
+ shape.color = pygame.Color('LightGray') # https://htmlcolorcodes.com/color-names
+ return shape
+
+ def add_circle(self, position, radius):
+ body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
+ body.position = position
+ body.friction = 1
+ shape = pymunk.Circle(body, radius)
+ shape.color = pygame.Color('RoyalBlue')
+ self.space.add(body, shape)
+ return body
+
+ def add_box(self, position, height, width):
+ mass = 1
+ inertia = pymunk.moment_for_box(mass, (height, width))
+ body = pymunk.Body(mass, inertia)
+ body.position = position
+ shape = pymunk.Poly.create_box(body, (height, width))
+ shape.color = pygame.Color('LightSlateGray')
+ self.space.add(body, shape)
+ return body
+
+ def add_tee(self, position, angle, scale=30, color='LightSlateGray', mask=pymunk.ShapeFilter.ALL_MASKS()):
+ mass = 1
+ length = 4
+ vertices1 = [(-length*scale/2, scale),
+ ( length*scale/2, scale),
+ ( length*scale/2, 0),
+ (-length*scale/2, 0)]
+ inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
+ vertices2 = [(-scale/2, scale),
+ (-scale/2, length*scale),
+ ( scale/2, length*scale),
+ ( scale/2, scale)]
+ inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
+ body = pymunk.Body(mass, inertia1 + inertia2)
+ shape1 = pymunk.Poly(body, vertices1)
+ shape2 = pymunk.Poly(body, vertices2)
+ shape1.color = pygame.Color(color)
+ shape2.color = pygame.Color(color)
+ shape1.filter = pymunk.ShapeFilter(mask=mask)
+ shape2.filter = pymunk.ShapeFilter(mask=mask)
+ body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
+ body.position = position
+ body.angle = angle
+ body.friction = 1
+ self.space.add(body, shape1, shape2)
+ return body
diff --git a/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_image_env.py b/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_image_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..db6bcfbf78b0d20db1780c04834857079fb7ada4
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_image_env.py
@@ -0,0 +1,66 @@
+from gym import spaces
+from diffusion_policy.env.pusht.pusht_env import PushTEnv
+import numpy as np
+import cv2
+
+class PushTImageEnv(PushTEnv):
+ metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
+
+ def __init__(self,
+ legacy=False,
+ block_cog=None,
+ damping=None,
+ render_size=96):
+ super().__init__(
+ legacy=legacy,
+ block_cog=block_cog,
+ damping=damping,
+ render_size=render_size,
+ render_action=False)
+ ws = self.window_size
+ self.observation_space = spaces.Dict({
+ 'image': spaces.Box(
+ low=0,
+ high=1,
+ shape=(3,render_size,render_size),
+ dtype=np.float32
+ ),
+ 'agent_pos': spaces.Box(
+ low=0,
+ high=ws,
+ shape=(2,),
+ dtype=np.float32
+ )
+ })
+ self.render_cache = None
+
+ def _get_obs(self):
+ img = super()._render_frame(mode='rgb_array')
+
+ agent_pos = np.array(self.agent.position)
+ img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
+ obs = {
+ 'image': img_obs,
+ 'agent_pos': agent_pos
+ }
+
+ # draw action
+ if self.latest_action is not None:
+ action = np.array(self.latest_action)
+ coord = (action / 512 * 96).astype(np.int32)
+ marker_size = int(8/96*self.render_size)
+ thickness = int(1/96*self.render_size)
+ cv2.drawMarker(img, coord,
+ color=(255,0,0), markerType=cv2.MARKER_CROSS,
+ markerSize=marker_size, thickness=thickness)
+ self.render_cache = img
+
+ return obs
+
+ def render(self, mode):
+ assert mode == 'rgb_array'
+
+ if self.render_cache is None:
+ self._get_obs()
+
+ return self.render_cache
diff --git a/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_keypoints_env.py b/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_keypoints_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb42686574228d226049c3d403e09049448331b3
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/pusht/pusht_keypoints_env.py
@@ -0,0 +1,131 @@
+from typing import Dict, Sequence, Union, Optional
+from gym import spaces
+from diffusion_policy.env.pusht.pusht_env import PushTEnv
+from diffusion_policy.env.pusht.pymunk_keypoint_manager import PymunkKeypointManager
+import numpy as np
+
+class PushTKeypointsEnv(PushTEnv):
+ def __init__(self,
+ legacy=False,
+ block_cog=None,
+ damping=None,
+ render_size=96,
+ keypoint_visible_rate=1.0,
+ agent_keypoints=False,
+ draw_keypoints=False,
+ reset_to_state=None,
+ render_action=True,
+ local_keypoint_map: Dict[str, np.ndarray]=None,
+ color_map: Optional[Dict[str, np.ndarray]]=None):
+ super().__init__(
+ legacy=legacy,
+ block_cog=block_cog,
+ damping=damping,
+ render_size=render_size,
+ reset_to_state=reset_to_state,
+ render_action=render_action)
+ ws = self.window_size
+
+ if local_keypoint_map is None:
+ # create default keypoint definition
+ kp_kwargs = self.genenerate_keypoint_manager_params()
+ local_keypoint_map = kp_kwargs['local_keypoint_map']
+ color_map = kp_kwargs['color_map']
+
+ # create observation spaces
+ Dblockkps = np.prod(local_keypoint_map['block'].shape)
+ Dagentkps = np.prod(local_keypoint_map['agent'].shape)
+ Dagentpos = 2
+
+ Do = Dblockkps
+ if agent_keypoints:
+ # blockkp + agnet_pos
+ Do += Dagentkps
+ else:
+ # blockkp + agnet_kp
+ Do += Dagentpos
+ # obs + obs_mask
+ Dobs = Do * 2
+
+ low = np.zeros((Dobs,), dtype=np.float64)
+ high = np.full_like(low, ws)
+ # mask range 0-1
+ high[Do:] = 1.
+
+ # (block_kps+agent_kps, xy+confidence)
+ self.observation_space = spaces.Box(
+ low=low,
+ high=high,
+ shape=low.shape,
+ dtype=np.float64
+ )
+
+ self.keypoint_visible_rate = keypoint_visible_rate
+ self.agent_keypoints = agent_keypoints
+ self.draw_keypoints = draw_keypoints
+ self.kp_manager = PymunkKeypointManager(
+ local_keypoint_map=local_keypoint_map,
+ color_map=color_map)
+ self.draw_kp_map = None
+
+ @classmethod
+ def genenerate_keypoint_manager_params(cls):
+ env = PushTEnv()
+ kp_manager = PymunkKeypointManager.create_from_pusht_env(env)
+ kp_kwargs = kp_manager.kwargs
+ return kp_kwargs
+
+ def _get_obs(self):
+ # get keypoints
+ obj_map = {
+ 'block': self.block
+ }
+ if self.agent_keypoints:
+ obj_map['agent'] = self.agent
+
+ kp_map = self.kp_manager.get_keypoints_global(
+ pose_map=obj_map, is_obj=True)
+ # python dict guerentee order of keys and values
+ kps = np.concatenate(list(kp_map.values()), axis=0)
+
+ # select keypoints to drop
+ n_kps = kps.shape[0]
+ visible_kps = self.np_random.random(size=(n_kps,)) < self.keypoint_visible_rate
+ kps_mask = np.repeat(visible_kps[:,None], 2, axis=1)
+
+ # save keypoints for rendering
+ vis_kps = kps.copy()
+ vis_kps[~visible_kps] = 0
+ draw_kp_map = {
+ 'block': vis_kps[:len(kp_map['block'])]
+ }
+ if self.agent_keypoints:
+ draw_kp_map['agent'] = vis_kps[len(kp_map['block']):]
+ self.draw_kp_map = draw_kp_map
+
+ # construct obs
+ obs = kps.flatten()
+ obs_mask = kps_mask.flatten()
+ if not self.agent_keypoints:
+ # passing agent position when keypoints are not available
+ agent_pos = np.array(self.agent.position)
+ obs = np.concatenate([
+ obs, agent_pos
+ ])
+ obs_mask = np.concatenate([
+ obs_mask, np.ones((2,), dtype=bool)
+ ])
+
+ # obs, obs_mask
+ obs = np.concatenate([
+ obs, obs_mask.astype(obs.dtype)
+ ], axis=0)
+ return obs
+
+
+ def _render_frame(self, mode):
+ img = super()._render_frame(mode)
+ if self.draw_keypoints:
+ self.kp_manager.draw_keypoints(
+ img, self.draw_kp_map, radius=int(img.shape[0]/96))
+ return img
diff --git a/third_party/diffusion_policy/diffusion_policy/env/pusht/pymunk_keypoint_manager.py b/third_party/diffusion_policy/diffusion_policy/env/pusht/pymunk_keypoint_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fd5530efb727a4358931c2f7062cc0d4e84fae1
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/pusht/pymunk_keypoint_manager.py
@@ -0,0 +1,146 @@
+from typing import Dict, Sequence, Union, Optional
+import numpy as np
+import skimage.transform as st
+import pymunk
+import pygame
+from matplotlib import cm
+import cv2
+from diffusion_policy.env.pusht.pymunk_override import DrawOptions
+
+
+def farthest_point_sampling(points: np.ndarray, n_points: int, init_idx: int):
+ """
+ Naive O(N^2)
+ """
+ assert(n_points >= 1)
+ chosen_points = [points[init_idx]]
+ for _ in range(n_points-1):
+ cpoints = np.array(chosen_points)
+ all_dists = np.linalg.norm(points[:,None,:] - cpoints[None,:,:], axis=-1)
+ min_dists = all_dists.min(axis=1)
+ next_idx = np.argmax(min_dists)
+ next_pt = points[next_idx]
+ chosen_points.append(next_pt)
+ result = np.array(chosen_points)
+ return result
+
+
+class PymunkKeypointManager:
+ def __init__(self,
+ local_keypoint_map: Dict[str, np.ndarray],
+ color_map: Optional[Dict[str, np.ndarray]]=None):
+ """
+ local_keypoint_map:
+ "": (N,2) floats in object local coordinate
+ """
+ if color_map is None:
+ cmap = cm.get_cmap('tab10')
+ color_map = dict()
+ for i, key in enumerate(local_keypoint_map.keys()):
+ color_map[key] = (np.array(cmap.colors[i]) * 255).astype(np.uint8)
+
+ self.local_keypoint_map = local_keypoint_map
+ self.color_map = color_map
+
+ @property
+ def kwargs(self):
+ return {
+ 'local_keypoint_map': self.local_keypoint_map,
+ 'color_map': self.color_map
+ }
+
+ @classmethod
+ def create_from_pusht_env(cls, env, n_block_kps=9, n_agent_kps=3, seed=0, **kwargs):
+ rng = np.random.default_rng(seed=seed)
+ local_keypoint_map = dict()
+ for name in ['block','agent']:
+ self = env
+ self.space = pymunk.Space()
+ if name == 'agent':
+ self.agent = obj = self.add_circle((256, 400), 15)
+ n_kps = n_agent_kps
+ else:
+ self.block = obj = self.add_tee((256, 300), 0)
+ n_kps = n_block_kps
+
+ self.screen = pygame.Surface((512,512))
+ self.screen.fill(pygame.Color("white"))
+ draw_options = DrawOptions(self.screen)
+ self.space.debug_draw(draw_options)
+ # pygame.display.flip()
+ img = np.uint8(pygame.surfarray.array3d(self.screen).transpose(1, 0, 2))
+ obj_mask = (img != np.array([255,255,255],dtype=np.uint8)).any(axis=-1)
+
+ tf_img_obj = cls.get_tf_img_obj(obj)
+ xy_img = np.moveaxis(np.array(np.indices((512,512))), 0, -1)[:,:,::-1]
+ local_coord_img = tf_img_obj.inverse(xy_img.reshape(-1,2)).reshape(xy_img.shape)
+ obj_local_coords = local_coord_img[obj_mask]
+
+ # furthest point sampling
+ init_idx = rng.choice(len(obj_local_coords))
+ obj_local_kps = farthest_point_sampling(obj_local_coords, n_kps, init_idx)
+ small_shift = rng.uniform(0, 1, size=obj_local_kps.shape)
+ obj_local_kps += small_shift
+
+ local_keypoint_map[name] = obj_local_kps
+
+ return cls(local_keypoint_map=local_keypoint_map, **kwargs)
+
+ @staticmethod
+ def get_tf_img(pose: Sequence):
+ pos = pose[:2]
+ rot = pose[2]
+ tf_img_obj = st.AffineTransform(
+ translation=pos, rotation=rot)
+ return tf_img_obj
+
+ @classmethod
+ def get_tf_img_obj(cls, obj: pymunk.Body):
+ pose = tuple(obj.position) + (obj.angle,)
+ return cls.get_tf_img(pose)
+
+ def get_keypoints_global(self,
+ pose_map: Dict[set, Union[Sequence, pymunk.Body]],
+ is_obj=False):
+ kp_map = dict()
+ for key, value in pose_map.items():
+ if is_obj:
+ tf_img_obj = self.get_tf_img_obj(value)
+ else:
+ tf_img_obj = self.get_tf_img(value)
+ kp_local = self.local_keypoint_map[key]
+ kp_global = tf_img_obj(kp_local)
+ kp_map[key] = kp_global
+ return kp_map
+
+ def draw_keypoints(self, img, kps_map, radius=1):
+ scale = np.array(img.shape[:2]) / np.array([512,512])
+ for key, value in kps_map.items():
+ color = self.color_map[key].tolist()
+ coords = (value * scale).astype(np.int32)
+ for coord in coords:
+ cv2.circle(img, coord, radius=radius, color=color, thickness=-1)
+ return img
+
+ def draw_keypoints_pose(self, img, pose_map, is_obj=False, **kwargs):
+ kp_map = self.get_keypoints_global(pose_map, is_obj=is_obj)
+ return self.draw_keypoints(img, kps_map=kp_map, **kwargs)
+
+
+def test():
+ from diffusion_policy.environment.push_t_env import PushTEnv
+ from matplotlib import pyplot as plt
+
+ env = PushTEnv(headless=True, obs_state=False, draw_action=False)
+ kp_manager = PymunkKeypointManager.create_from_pusht_env(env=env)
+ env.reset()
+ obj_map = {
+ 'block': env.block,
+ 'agent': env.agent
+ }
+
+ obs = env.render()
+ img = obs.astype(np.uint8)
+ kp_manager.draw_keypoints_pose(img=img, pose_map=obj_map, is_obj=True)
+
+ plt.imshow(img)
diff --git a/third_party/diffusion_policy/diffusion_policy/env/pusht/pymunk_override.py b/third_party/diffusion_policy/diffusion_policy/env/pusht/pymunk_override.py
new file mode 100644
index 0000000000000000000000000000000000000000..2439020a13e01ad48f3677919157bb7e49e50569
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/pusht/pymunk_override.py
@@ -0,0 +1,248 @@
+# ----------------------------------------------------------------------------
+# pymunk
+# Copyright (c) 2007-2016 Victor Blomqvist
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ----------------------------------------------------------------------------
+
+"""This submodule contains helper functions to help with quick prototyping
+using pymunk together with pygame.
+
+Intended to help with debugging and prototyping, not for actual production use
+in a full application. The methods contained in this module is opinionated
+about your coordinate system and not in any way optimized.
+"""
+
+__docformat__ = "reStructuredText"
+
+__all__ = [
+ "DrawOptions",
+ "get_mouse_pos",
+ "to_pygame",
+ "from_pygame",
+ "lighten",
+ "positive_y_is_up",
+]
+
+from typing import List, Sequence, Tuple
+
+import pygame
+
+import numpy as np
+
+import pymunk
+from pymunk.space_debug_draw_options import SpaceDebugColor
+from pymunk.vec2d import Vec2d
+
+positive_y_is_up: bool = False
+"""Make increasing values of y point upwards.
+
+When True::
+
+ y
+ ^
+ | . (3, 3)
+ |
+ | . (2, 2)
+ |
+ +------ > x
+
+When False::
+
+ +------ > x
+ |
+ | . (2, 2)
+ |
+ | . (3, 3)
+ v
+ y
+
+"""
+
+
+class DrawOptions(pymunk.SpaceDebugDrawOptions):
+ def __init__(self, surface: pygame.Surface) -> None:
+ """Draw a pymunk.Space on a pygame.Surface object.
+
+ Typical usage::
+
+ >>> import pymunk
+ >>> surface = pygame.Surface((10,10))
+ >>> space = pymunk.Space()
+ >>> options = pymunk.pygame_util.DrawOptions(surface)
+ >>> space.debug_draw(options)
+
+ You can control the color of a shape by setting shape.color to the color
+ you want it drawn in::
+
+ >>> c = pymunk.Circle(None, 10)
+ >>> c.color = pygame.Color("pink")
+
+ See pygame_util.demo.py for a full example
+
+ Since pygame uses a coordinate system where y points down (in contrast
+ to many other cases), you either have to make the physics simulation
+ with Pymunk also behave in that way, or flip everything when you draw.
+
+ The easiest is probably to just make the simulation behave the same
+ way as Pygame does. In that way all coordinates used are in the same
+ orientation and easy to reason about::
+
+ >>> space = pymunk.Space()
+ >>> space.gravity = (0, -1000)
+ >>> body = pymunk.Body()
+ >>> body.position = (0, 0) # will be positioned in the top left corner
+ >>> space.debug_draw(options)
+
+ To flip the drawing its possible to set the module property
+ :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
+ the simulation upside down before drawing::
+
+ >>> positive_y_is_up = True
+ >>> body = pymunk.Body()
+ >>> body.position = (0, 0)
+ >>> # Body will be position in bottom left corner
+
+ :Parameters:
+ surface : pygame.Surface
+ Surface that the objects will be drawn on
+ """
+ self.surface = surface
+ super(DrawOptions, self).__init__()
+
+ def draw_circle(
+ self,
+ pos: Vec2d,
+ angle: float,
+ radius: float,
+ outline_color: SpaceDebugColor,
+ fill_color: SpaceDebugColor,
+ ) -> None:
+ p = to_pygame(pos, self.surface)
+
+ pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
+ pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)
+
+ circle_edge = pos + Vec2d(radius, 0).rotated(angle)
+ p2 = to_pygame(circle_edge, self.surface)
+ line_r = 2 if radius > 20 else 1
+ # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
+
+ def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
+ p1 = to_pygame(a, self.surface)
+ p2 = to_pygame(b, self.surface)
+
+ pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
+
+ def draw_fat_segment(
+ self,
+ a: Tuple[float, float],
+ b: Tuple[float, float],
+ radius: float,
+ outline_color: SpaceDebugColor,
+ fill_color: SpaceDebugColor,
+ ) -> None:
+ p1 = to_pygame(a, self.surface)
+ p2 = to_pygame(b, self.surface)
+
+ r = round(max(1, radius * 2))
+ pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
+ if r > 2:
+ orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
+ if orthog[0] == 0 and orthog[1] == 0:
+ return
+ scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
+ orthog[0] = round(orthog[0] * scale)
+ orthog[1] = round(orthog[1] * scale)
+ points = [
+ (p1[0] - orthog[0], p1[1] - orthog[1]),
+ (p1[0] + orthog[0], p1[1] + orthog[1]),
+ (p2[0] + orthog[0], p2[1] + orthog[1]),
+ (p2[0] - orthog[0], p2[1] - orthog[1]),
+ ]
+ pygame.draw.polygon(self.surface, fill_color.as_int(), points)
+ pygame.draw.circle(
+ self.surface,
+ fill_color.as_int(),
+ (round(p1[0]), round(p1[1])),
+ round(radius),
+ )
+ pygame.draw.circle(
+ self.surface,
+ fill_color.as_int(),
+ (round(p2[0]), round(p2[1])),
+ round(radius),
+ )
+
+ def draw_polygon(
+ self,
+ verts: Sequence[Tuple[float, float]],
+ radius: float,
+ outline_color: SpaceDebugColor,
+ fill_color: SpaceDebugColor,
+ ) -> None:
+ ps = [to_pygame(v, self.surface) for v in verts]
+ ps += [ps[0]]
+
+ radius = 2
+ pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
+
+ if radius > 0:
+ for i in range(len(verts)):
+ a = verts[i]
+ b = verts[(i + 1) % len(verts)]
+ self.draw_fat_segment(a, b, radius, fill_color, fill_color)
+
+ def draw_dot(
+ self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
+ ) -> None:
+ p = to_pygame(pos, self.surface)
+ pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
+
+
+def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
+ """Get position of the mouse pointer in pymunk coordinates."""
+ p = pygame.mouse.get_pos()
+ return from_pygame(p, surface)
+
+
+def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
+ """Convenience method to convert pymunk coordinates to pygame surface
+ local coordinates.
+
+ Note that in case positive_y_is_up is False, this function won't actually do
+ anything except converting the point to integers.
+ """
+ if positive_y_is_up:
+ return round(p[0]), surface.get_height() - round(p[1])
+ else:
+ return round(p[0]), round(p[1])
+
+
+def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
+ """Convenience method to convert pygame surface local coordinates to
+ pymunk coordinates
+ """
+ return to_pygame(p, surface)
+
+
+def light_color(color: SpaceDebugColor):
+ color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
+ color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
+ return color
diff --git a/third_party/diffusion_policy/diffusion_policy/env/robomimic/robomimic_image_wrapper.py b/third_party/diffusion_policy/diffusion_policy/env/robomimic/robomimic_image_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8c250601d373147aad6532401f64fd8957d2f21
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/robomimic/robomimic_image_wrapper.py
@@ -0,0 +1,164 @@
+from typing import List, Optional
+from matplotlib.pyplot import fill
+import numpy as np
+import gym
+from gym import spaces
+from omegaconf import OmegaConf
+from robomimic.envs.env_robosuite import EnvRobosuite
+
+class RobomimicImageWrapper(gym.Env):
+ def __init__(self,
+ env: EnvRobosuite,
+ shape_meta: dict,
+ init_state: Optional[np.ndarray]=None,
+ render_obs_key='agentview_image',
+ ):
+
+ self.env = env
+ self.render_obs_key = render_obs_key
+ self.init_state = init_state
+ self.seed_state_map = dict()
+ self._seed = None
+ self.shape_meta = shape_meta
+ self.render_cache = None
+ self.has_reset_before = False
+
+ # setup spaces
+ action_shape = shape_meta['action']['shape']
+ action_space = spaces.Box(
+ low=-1,
+ high=1,
+ shape=action_shape,
+ dtype=np.float32
+ )
+ self.action_space = action_space
+
+ observation_space = spaces.Dict()
+ for key, value in shape_meta['obs'].items():
+ shape = value['shape']
+ min_value, max_value = -1, 1
+ if key.endswith('image'):
+ min_value, max_value = 0, 1
+ elif key.endswith('quat'):
+ min_value, max_value = -1, 1
+ elif key.endswith('qpos'):
+ min_value, max_value = -1, 1
+ elif key.endswith('pos'):
+ # better range?
+ min_value, max_value = -1, 1
+ else:
+ raise RuntimeError(f"Unsupported type {key}")
+
+ this_space = spaces.Box(
+ low=min_value,
+ high=max_value,
+ shape=shape,
+ dtype=np.float32
+ )
+ observation_space[key] = this_space
+ self.observation_space = observation_space
+
+
+ def get_observation(self, raw_obs=None):
+ if raw_obs is None:
+ raw_obs = self.env.get_observation()
+
+ self.render_cache = raw_obs[self.render_obs_key]
+
+ obs = dict()
+ for key in self.observation_space.keys():
+ obs[key] = raw_obs[key]
+ return obs
+
+ def seed(self, seed=None):
+ np.random.seed(seed=seed)
+ self._seed = seed
+
+ def reset(self):
+ if self.init_state is not None:
+ if not self.has_reset_before:
+ # the env must be fully reset at least once to ensure correct rendering
+ self.env.reset()
+ self.has_reset_before = True
+
+ # always reset to the same state
+ # to be compatible with gym
+ raw_obs = self.env.reset_to({'states': self.init_state})
+ elif self._seed is not None:
+ # reset to a specific seed
+ seed = self._seed
+ if seed in self.seed_state_map:
+ # env.reset is expensive, use cache
+ raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]})
+ else:
+ # robosuite's initializes all use numpy global random state
+ np.random.seed(seed=seed)
+ raw_obs = self.env.reset()
+ state = self.env.get_state()['states']
+ self.seed_state_map[seed] = state
+ self._seed = None
+ else:
+ # random reset
+ raw_obs = self.env.reset()
+
+ # return obs
+ obs = self.get_observation(raw_obs)
+ return obs
+
+ def step(self, action):
+ raw_obs, reward, done, info = self.env.step(action)
+ obs = self.get_observation(raw_obs)
+ return obs, reward, done, info
+
+ def render(self, mode='rgb_array'):
+ if self.render_cache is None:
+ raise RuntimeError('Must run reset or step before render.')
+ img = np.moveaxis(self.render_cache, 0, -1)
+ img = (img * 255).astype(np.uint8)
+ return img
+
+
+def test():
+ import os
+ from omegaconf import OmegaConf
+ cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
+ cfg = OmegaConf.load(cfg_path)
+ shape_meta = cfg['shape_meta']
+
+
+ import robomimic.utils.file_utils as FileUtils
+ import robomimic.utils.env_utils as EnvUtils
+ from matplotlib import pyplot as plt
+
+ dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5')
+ env_meta = FileUtils.get_env_metadata_from_dataset(
+ dataset_path)
+
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=True,
+ )
+
+ wrapper = RobomimicImageWrapper(
+ env=env,
+ shape_meta=shape_meta
+ )
+ wrapper.seed(0)
+ obs = wrapper.reset()
+ img = wrapper.render()
+ plt.imshow(img)
+
+
+ # states = list()
+ # for _ in range(2):
+ # wrapper.seed(0)
+ # wrapper.reset()
+ # states.append(wrapper.env.get_state()['states'])
+ # assert np.allclose(states[0], states[1])
+
+ # img = wrapper.render()
+ # plt.imshow(img)
+ # wrapper.seed()
+ # states.append(wrapper.env.get_state()['states'])
diff --git a/third_party/diffusion_policy/diffusion_policy/env/robomimic/robomimic_lowdim_wrapper.py b/third_party/diffusion_policy/diffusion_policy/env/robomimic/robomimic_lowdim_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..383ae327fd1d79049f5a1b6d23a8657fa8720711
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env/robomimic/robomimic_lowdim_wrapper.py
@@ -0,0 +1,133 @@
+from typing import List, Dict, Optional
+import numpy as np
+import gym
+from gym.spaces import Box
+from robomimic.envs.env_robosuite import EnvRobosuite
+
+class RobomimicLowdimWrapper(gym.Env):
+ def __init__(self,
+ env: EnvRobosuite,
+ obs_keys: List[str]=[
+ 'object',
+ 'robot0_eef_pos',
+ 'robot0_eef_quat',
+ 'robot0_gripper_qpos'],
+ init_state: Optional[np.ndarray]=None,
+ render_hw=(256,256),
+ render_camera_name='agentview'
+ ):
+
+ self.env = env
+ self.obs_keys = obs_keys
+ self.init_state = init_state
+ self.render_hw = render_hw
+ self.render_camera_name = render_camera_name
+ self.seed_state_map = dict()
+ self._seed = None
+
+ # setup spaces
+ low = np.full(env.action_dimension, fill_value=-1)
+ high = np.full(env.action_dimension, fill_value=1)
+ self.action_space = Box(
+ low=low,
+ high=high,
+ shape=low.shape,
+ dtype=low.dtype
+ )
+ obs_example = self.get_observation()
+ low = np.full_like(obs_example, fill_value=-1)
+ high = np.full_like(obs_example, fill_value=1)
+ self.observation_space = Box(
+ low=low,
+ high=high,
+ shape=low.shape,
+ dtype=low.dtype
+ )
+
+ def get_observation(self):
+ raw_obs = self.env.get_observation()
+ obs = np.concatenate([
+ raw_obs[key] for key in self.obs_keys
+ ], axis=0)
+ return obs
+
+ def seed(self, seed=None):
+ np.random.seed(seed=seed)
+ self._seed = seed
+
+ def reset(self):
+ if self.init_state is not None:
+ # always reset to the same state
+ # to be compatible with gym
+ self.env.reset_to({'states': self.init_state})
+ elif self._seed is not None:
+ # reset to a specific seed
+ seed = self._seed
+ if seed in self.seed_state_map:
+ # env.reset is expensive, use cache
+ self.env.reset_to({'states': self.seed_state_map[seed]})
+ else:
+ # robosuite's initializes all use numpy global random state
+ np.random.seed(seed=seed)
+ self.env.reset()
+ state = self.env.get_state()['states']
+ self.seed_state_map[seed] = state
+ self._seed = None
+ else:
+ # random reset
+ self.env.reset()
+
+ # return obs
+ obs = self.get_observation()
+ return obs
+
+ def step(self, action):
+ raw_obs, reward, done, info = self.env.step(action)
+ obs = np.concatenate([
+ raw_obs[key] for key in self.obs_keys
+ ], axis=0)
+ return obs, reward, done, info
+
+ def render(self, mode='rgb_array'):
+ h, w = self.render_hw
+ return self.env.render(mode=mode,
+ height=h, width=w,
+ camera_name=self.render_camera_name)
+
+
+def test():
+ import robomimic.utils.file_utils as FileUtils
+ import robomimic.utils.env_utils as EnvUtils
+ from matplotlib import pyplot as plt
+
+ dataset_path = '/home/cchi/dev/diffusion_policy/data/robomimic/datasets/square/ph/low_dim.hdf5'
+ env_meta = FileUtils.get_env_metadata_from_dataset(
+ dataset_path)
+
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ )
+ wrapper = RobomimicLowdimWrapper(
+ env=env,
+ obs_keys=[
+ 'object',
+ 'robot0_eef_pos',
+ 'robot0_eef_quat',
+ 'robot0_gripper_qpos'
+ ]
+ )
+
+ states = list()
+ for _ in range(2):
+ wrapper.seed(0)
+ wrapper.reset()
+ states.append(wrapper.env.get_state()['states'])
+ assert np.allclose(states[0], states[1])
+
+ img = wrapper.render()
+ plt.imshow(img)
+ # wrapper.seed()
+ # states.append(wrapper.env.get_state()['states'])
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/base_image_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/base_image_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..065200716037fb8c1d1baf1ab891d2d4e52e6fbe
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/base_image_runner.py
@@ -0,0 +1,9 @@
+from typing import Dict
+from diffusion_policy.policy.base_image_policy import BaseImagePolicy
+
+class BaseImageRunner:
+ def __init__(self, output_dir):
+ self.output_dir = output_dir
+
+ def run(self, policy: BaseImagePolicy) -> Dict:
+ raise NotImplementedError()
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/base_lowdim_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/base_lowdim_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..45437ec88d4d5b6f711fa953a1b54166aac9b530
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/base_lowdim_runner.py
@@ -0,0 +1,9 @@
+from typing import Dict
+from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
+
+class BaseLowdimRunner:
+ def __init__(self, output_dir):
+ self.output_dir = output_dir
+
+ def run(self, policy: BaseLowdimPolicy) -> Dict:
+ raise NotImplementedError()
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/blockpush_lowdim_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/blockpush_lowdim_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..5193dafbb12b92c7e381a9b4ecba966dbff5bfd4
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/blockpush_lowdim_runner.py
@@ -0,0 +1,293 @@
+import wandb
+import numpy as np
+import torch
+import collections
+import pathlib
+import tqdm
+import dill
+import math
+import wandb.sdk.data_types.video as wv
+from diffusion_policy.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
+from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
+from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
+from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
+from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
+from gym.wrappers import FlattenObservation
+
+from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
+
+class BlockPushLowdimRunner(BaseLowdimRunner):
+ def __init__(self,
+ output_dir,
+ n_train=10,
+ n_train_vis=3,
+ train_start_seed=0,
+ n_test=22,
+ n_test_vis=6,
+ test_start_seed=10000,
+ max_steps=200,
+ n_obs_steps=8,
+ n_action_steps=8,
+ fps=5,
+ crf=22,
+ past_action=False,
+ abs_action=False,
+ obs_eef_target=True,
+ tqdm_interval_sec=5.0,
+ n_envs=None
+ ):
+ super().__init__(output_dir)
+
+ if n_envs is None:
+ n_envs = n_train + n_test
+
+ task_fps = 10
+ steps_per_render = max(10 // fps, 1)
+
+ def env_fn():
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ FlattenObservation(
+ BlockPushMultimodal(
+ control_frequency=task_fps,
+ shared_memory=False,
+ seed=seed,
+ abs_action=abs_action
+ )
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ steps_per_render=steps_per_render
+ ),
+ n_obs_steps=n_obs_steps,
+ n_action_steps=n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ env_fns = [env_fn] * n_envs
+ env_seeds = list()
+ env_prefixs = list()
+ env_init_fn_dills = list()
+ # train
+ for i in range(n_train):
+ seed = train_start_seed + i
+ enable_render = i < n_train_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('train/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ # test
+ for i in range(n_test):
+ seed = test_start_seed + i
+ enable_render = i < n_test_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('test/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ env = AsyncVectorEnv(env_fns)
+ # env = SyncVectorEnv(env_fns)
+
+ self.env = env
+ self.env_fns = env_fns
+ self.env_seeds = env_seeds
+ self.env_prefixs = env_prefixs
+ self.env_init_fn_dills = env_init_fn_dills
+ self.fps = fps
+ self.crf = crf
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.past_action = past_action
+ self.max_steps = max_steps
+ self.tqdm_interval_sec = tqdm_interval_sec
+ self.obs_eef_target = obs_eef_target
+
+
+ def run(self, policy: BaseLowdimPolicy):
+ device = policy.device
+ dtype = policy.dtype
+ env = self.env
+
+ # plan for rollout
+ n_envs = len(self.env_fns)
+ n_inits = len(self.env_init_fn_dills)
+ n_chunks = math.ceil(n_inits / n_envs)
+
+ # allocate data
+ all_video_paths = [None] * n_inits
+ all_rewards = [None] * n_inits
+ last_info = [None] * n_inits
+
+ for chunk_idx in range(n_chunks):
+ start = chunk_idx * n_envs
+ end = min(n_inits, start + n_envs)
+ this_global_slice = slice(start, end)
+ this_n_active_envs = end - start
+ this_local_slice = slice(0,this_n_active_envs)
+
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
+ n_diff = n_envs - len(this_init_fns)
+ if n_diff > 0:
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
+ assert len(this_init_fns) == n_envs
+
+ # init envs
+ env.call_each('run_dill_function',
+ args_list=[(x,) for x in this_init_fns])
+
+ # start rollout
+ obs = env.reset()
+ past_action = None
+ policy.reset()
+
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval BlockPushLowdimRunner {chunk_idx+1}/{n_chunks}",
+ leave=False, mininterval=self.tqdm_interval_sec)
+ done = False
+ while not done:
+ # create obs dict
+ if not self.obs_eef_target:
+ obs[...,8:10] = 0
+ np_obs_dict = {
+ 'obs': obs.astype(np.float32)
+ }
+ if self.past_action and (past_action is not None):
+ # TODO: not tested
+ np_obs_dict['past_action'] = past_action[
+ :,-(self.n_obs_steps-1):].astype(np.float32)
+ # device transfer
+ obs_dict = dict_apply(np_obs_dict,
+ lambda x: torch.from_numpy(x).to(
+ device=device))
+
+ # run policy
+ with torch.no_grad():
+ action_dict = policy.predict_action(obs_dict)
+
+ # device_transfer
+ np_action_dict = dict_apply(action_dict,
+ lambda x: x.detach().to('cpu').numpy())
+
+ action = np_action_dict['action']
+
+ # step env
+ obs, reward, done, info = env.step(action)
+ done = np.all(done)
+ past_action = action
+
+ # update pbar
+ pbar.update(action.shape[1])
+ pbar.close()
+
+ # collect data for this round
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
+ last_info[this_global_slice] = [dict((k,v[-1]) for k, v in x.items()) for x in info][this_local_slice]
+
+ # log
+ total_rewards = collections.defaultdict(list)
+ total_p1 = collections.defaultdict(list)
+ total_p2 = collections.defaultdict(list)
+ prefix_event_counts = collections.defaultdict(lambda :collections.defaultdict(lambda : 0))
+ prefix_counts = collections.defaultdict(lambda : 0)
+
+ log_data = dict()
+ # results reported in the paper are generated using the commented out line below
+ # which will only report and average metrics from first n_envs initial condition and seeds
+ # fortunately this won't invalidate our conclusion since
+ # 1. This bug only affects the variance of metrics, not their mean
+ # 2. All baseline methods are evaluated using the same code
+ # to completely reproduce reported numbers, uncomment this line:
+ # for i in range(len(self.env_fns)):
+ # and comment out this line
+ for i in range(n_inits):
+ seed = self.env_seeds[i]
+ prefix = self.env_prefixs[i]
+ this_rewards = all_rewards[i]
+ total_reward = np.unique(this_rewards).sum() # (0, 0.49, 0.51)
+ p1 = total_reward > 0.4
+ p2 = total_reward > 0.9
+
+ total_rewards[prefix].append(total_reward)
+ total_p1[prefix].append(p1)
+ total_p2[prefix].append(p2)
+ log_data[prefix+f'sim_max_reward_{seed}'] = total_reward
+
+ # aggregate event counts
+ prefix_counts[prefix] += 1
+ for key, value in last_info[i].items():
+ delta_count = 1 if value > 0 else 0
+ prefix_event_counts[prefix][key] += delta_count
+
+ # visualize sim
+ video_path = all_video_paths[i]
+ if video_path is not None:
+ sim_video = wandb.Video(video_path)
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
+
+ # log aggregate metrics
+ for prefix, value in total_rewards.items():
+ name = prefix+'mean_score'
+ value = np.mean(value)
+ log_data[name] = value
+ for prefix, value in total_p1.items():
+ name = prefix+'p1'
+ value = np.mean(value)
+ log_data[name] = value
+ for prefix, value in total_p2.items():
+ name = prefix+'p2'
+ value = np.mean(value)
+ log_data[name] = value
+
+ # summarize probabilities
+ for prefix, events in prefix_event_counts.items():
+ prefix_count = prefix_counts[prefix]
+ for event, count in events.items():
+ prob = count / prefix_count
+ key = prefix + event
+ log_data[key] = prob
+
+ return log_data
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/kitchen_lowdim_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/kitchen_lowdim_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..39f562333a34a7d5f357b0134732eeb403565b97
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/kitchen_lowdim_runner.py
@@ -0,0 +1,319 @@
+import wandb
+import numpy as np
+import torch
+import collections
+import pathlib
+import tqdm
+import dill
+import math
+import logging
+import wandb.sdk.data_types.video as wv
+import gym
+import gym.spaces
+import multiprocessing as mp
+from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
+from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
+from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
+from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
+
+from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
+
+module_logger = logging.getLogger(__name__)
+
+class KitchenLowdimRunner(BaseLowdimRunner):
+ def __init__(self,
+ output_dir,
+ dataset_dir,
+ n_train=10,
+ n_train_vis=3,
+ train_start_seed=0,
+ n_test=22,
+ n_test_vis=6,
+ test_start_seed=10000,
+ max_steps=280,
+ n_obs_steps=2,
+ n_action_steps=8,
+ render_hw=(240,360),
+ fps=12.5,
+ crf=22,
+ past_action=False,
+ tqdm_interval_sec=5.0,
+ abs_action=False,
+ robot_noise_ratio=0.1,
+ n_envs=None
+ ):
+ super().__init__(output_dir)
+
+ if n_envs is None:
+ n_envs = n_train + n_test
+
+ task_fps = 12.5
+ steps_per_render = int(max(task_fps // fps, 1))
+
+ def env_fn():
+ from diffusion_policy.env.kitchen.v0 import KitchenAllV0
+ from diffusion_policy.env.kitchen.kitchen_lowdim_wrapper import KitchenLowdimWrapper
+ env = KitchenAllV0(use_abs_action=abs_action)
+ env.robot_noise_ratio = robot_noise_ratio
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ KitchenLowdimWrapper(
+ env=env,
+ init_qpos=None,
+ init_qvel=None,
+ render_hw=tuple(render_hw)
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ steps_per_render=steps_per_render
+ ),
+ n_obs_steps=n_obs_steps,
+ n_action_steps=n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ all_init_qpos = np.load(pathlib.Path(dataset_dir) / "all_init_qpos.npy")
+ all_init_qvel = np.load(pathlib.Path(dataset_dir) / "all_init_qvel.npy")
+ module_logger.info(f'Loaded {len(all_init_qpos)} known initial conditions.')
+
+ env_fns = [env_fn] * n_envs
+ env_seeds = list()
+ env_prefixs = list()
+ env_init_fn_dills = list()
+ # train
+ for i in range(n_train):
+ seed = train_start_seed + i
+ enable_render = i < n_train_vis
+ init_qpos = None
+ init_qvel = None
+ if i < len(all_init_qpos):
+ init_qpos = all_init_qpos[i]
+ init_qvel = all_init_qvel[i]
+
+ def init_fn(env, init_qpos=init_qpos, init_qvel=init_qvel, enable_render=enable_render):
+ from diffusion_policy.env.kitchen.kitchen_lowdim_wrapper import KitchenLowdimWrapper
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set initial condition
+ assert isinstance(env.env.env, KitchenLowdimWrapper)
+ env.env.env.init_qpos = init_qpos
+ env.env.env.init_qvel = init_qvel
+
+ env_seeds.append(seed)
+ env_prefixs.append('train/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ # test
+ for i in range(n_test):
+ seed = test_start_seed + i
+ enable_render = i < n_test_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ from diffusion_policy.env.kitchen.kitchen_lowdim_wrapper import KitchenLowdimWrapper
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set initial condition
+ assert isinstance(env.env.env, KitchenLowdimWrapper)
+ env.env.env.init_qpos = None
+ env.env.env.init_qvel = None
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('test/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ def dummy_env_fn():
+ # Avoid importing or using env in the main process
+ # to prevent OpenGL context issue with fork.
+ # Create a fake env whose sole purpos is to provide
+ # obs/action spaces and metadata.
+ env = gym.Env()
+ env.observation_space = gym.spaces.Box(
+ -8, 8, shape=(60,), dtype=np.float32)
+ env.action_space = gym.spaces.Box(
+ -8, 8, shape=(9,), dtype=np.float32)
+ env.metadata = {
+ 'render.modes': ['human', 'rgb_array', 'depth_array'],
+ 'video.frames_per_second': 12
+ }
+ env = MultiStepWrapper(
+ env=env,
+ n_obs_steps=n_obs_steps,
+ n_action_steps=n_action_steps,
+ max_episode_steps=max_steps
+ )
+ return env
+
+ env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
+ # env = SyncVectorEnv(env_fns)
+
+ self.env = env
+ self.env_fns = env_fns
+ self.env_seeds = env_seeds
+ self.env_prefixs = env_prefixs
+ self.env_init_fn_dills = env_init_fn_dills
+ self.fps = fps
+ self.crf = crf
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.past_action = past_action
+ self.max_steps = max_steps
+ self.tqdm_interval_sec = tqdm_interval_sec
+
+
+ def run(self, policy: BaseLowdimPolicy):
+ device = policy.device
+ dtype = policy.dtype
+ env = self.env
+
+ # plan for rollout
+ n_envs = len(self.env_fns)
+ n_inits = len(self.env_init_fn_dills)
+ n_chunks = math.ceil(n_inits / n_envs)
+
+ # allocate data
+ all_video_paths = [None] * n_inits
+ all_rewards = [None] * n_inits
+ last_info = [None] * n_inits
+
+ for chunk_idx in range(n_chunks):
+ start = chunk_idx * n_envs
+ end = min(n_inits, start + n_envs)
+ this_global_slice = slice(start, end)
+ this_n_active_envs = end - start
+ this_local_slice = slice(0,this_n_active_envs)
+
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
+ n_diff = n_envs - len(this_init_fns)
+ if n_diff > 0:
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
+ assert len(this_init_fns) == n_envs
+
+ # init envs
+ env.call_each('run_dill_function',
+ args_list=[(x,) for x in this_init_fns])
+
+ # start rollout
+ obs = env.reset()
+ past_action = None
+ policy.reset()
+
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval BlockPushLowdimRunner {chunk_idx+1}/{n_chunks}",
+ leave=False, mininterval=self.tqdm_interval_sec)
+ done = False
+ while not done:
+ # create obs dict
+ np_obs_dict = {
+ 'obs': obs.astype(np.float32)
+ }
+ if self.past_action and (past_action is not None):
+ # TODO: not tested
+ np_obs_dict['past_action'] = past_action[
+ :,-(self.n_obs_steps-1):].astype(np.float32)
+ # device transfer
+ obs_dict = dict_apply(np_obs_dict,
+ lambda x: torch.from_numpy(x).to(
+ device=device))
+
+ # run policy
+ with torch.no_grad():
+ action_dict = policy.predict_action(obs_dict)
+
+ # device_transfer
+ np_action_dict = dict_apply(action_dict,
+ lambda x: x.detach().to('cpu').numpy())
+
+ action = np_action_dict['action']
+
+ # step env
+ obs, reward, done, info = env.step(action)
+ done = np.all(done)
+ past_action = action
+
+ # update pbar
+ pbar.update(action.shape[1])
+ pbar.close()
+
+ # collect data for this round
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
+ last_info[this_global_slice] = [dict((k,v[-1]) for k, v in x.items()) for x in info][this_local_slice]
+
+ # reward is number of tasks completed, max 7
+ # use info to record the order of task completion?
+ # also report the probably to completing n tasks (different aggregation of reward).
+
+ # log
+ log_data = dict()
+ prefix_total_reward_map = collections.defaultdict(list)
+ prefix_n_completed_map = collections.defaultdict(list)
+ # results reported in the paper are generated using the commented out line below
+ # which will only report and average metrics from first n_envs initial condition and seeds
+ # fortunately this won't invalidate our conclusion since
+ # 1. This bug only affects the variance of metrics, not their mean
+ # 2. All baseline methods are evaluated using the same code
+ # to completely reproduce reported numbers, uncomment this line:
+ # for i in range(len(self.env_fns)):
+ # and comment out this line
+ for i in range(n_inits):
+ seed = self.env_seeds[i]
+ prefix = self.env_prefixs[i]
+ this_rewards = all_rewards[i]
+ total_reward = np.sum(this_rewards) / 7
+ prefix_total_reward_map[prefix].append(total_reward)
+
+ n_completed_tasks = len(last_info[i]['completed_tasks'])
+ prefix_n_completed_map[prefix].append(n_completed_tasks)
+
+ # visualize sim
+ video_path = all_video_paths[i]
+ if video_path is not None:
+ sim_video = wandb.Video(video_path)
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
+
+ # log aggregate metrics
+ for prefix, value in prefix_total_reward_map.items():
+ name = prefix+'mean_score'
+ value = np.mean(value)
+ log_data[name] = value
+ for prefix, value in prefix_n_completed_map.items():
+ n_completed = np.array(value)
+ for i in range(7):
+ n = i + 1
+ p_n = np.mean(n_completed >= n)
+ name = prefix + f'p_{n}'
+ log_data[name] = p_n
+
+ return log_data
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/pusht_image_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/pusht_image_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65c06a81c365746f679e25fae1f91eb329a95a7
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/pusht_image_runner.py
@@ -0,0 +1,251 @@
+import wandb
+import numpy as np
+import torch
+import collections
+import pathlib
+import tqdm
+import dill
+import math
+import wandb.sdk.data_types.video as wv
+from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
+from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
+# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
+from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
+from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
+
+from diffusion_policy.policy.base_image_policy import BaseImagePolicy
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
+
+class PushTImageRunner(BaseImageRunner):
+ def __init__(self,
+ output_dir,
+ n_train=10,
+ n_train_vis=3,
+ train_start_seed=0,
+ n_test=22,
+ n_test_vis=6,
+ legacy_test=False,
+ test_start_seed=10000,
+ max_steps=200,
+ n_obs_steps=8,
+ n_action_steps=8,
+ fps=10,
+ crf=22,
+ render_size=96,
+ past_action=False,
+ tqdm_interval_sec=5.0,
+ n_envs=None
+ ):
+ super().__init__(output_dir)
+ if n_envs is None:
+ n_envs = n_train + n_test
+
+ steps_per_render = max(10 // fps, 1)
+ def env_fn():
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ PushTImageEnv(
+ legacy=legacy_test,
+ render_size=render_size
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ steps_per_render=steps_per_render
+ ),
+ n_obs_steps=n_obs_steps,
+ n_action_steps=n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ env_fns = [env_fn] * n_envs
+ env_seeds = list()
+ env_prefixs = list()
+ env_init_fn_dills = list()
+ # train
+ for i in range(n_train):
+ seed = train_start_seed + i
+ enable_render = i < n_train_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('train/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ # test
+ for i in range(n_test):
+ seed = test_start_seed + i
+ enable_render = i < n_test_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('test/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ env = AsyncVectorEnv(env_fns)
+
+ # test env
+ # env.reset(seed=env_seeds)
+ # x = env.step(env.action_space.sample())
+ # imgs = env.call('render')
+ # import pdb; pdb.set_trace()
+
+ self.env = env
+ self.env_fns = env_fns
+ self.env_seeds = env_seeds
+ self.env_prefixs = env_prefixs
+ self.env_init_fn_dills = env_init_fn_dills
+ self.fps = fps
+ self.crf = crf
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.past_action = past_action
+ self.max_steps = max_steps
+ self.tqdm_interval_sec = tqdm_interval_sec
+
+ def run(self, policy: BaseImagePolicy):
+ device = policy.device
+ dtype = policy.dtype
+ env = self.env
+
+ # plan for rollout
+ n_envs = len(self.env_fns)
+ n_inits = len(self.env_init_fn_dills)
+ n_chunks = math.ceil(n_inits / n_envs)
+
+ # allocate data
+ all_video_paths = [None] * n_inits
+ all_rewards = [None] * n_inits
+
+ for chunk_idx in range(n_chunks):
+ start = chunk_idx * n_envs
+ end = min(n_inits, start + n_envs)
+ this_global_slice = slice(start, end)
+ this_n_active_envs = end - start
+ this_local_slice = slice(0,this_n_active_envs)
+
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
+ n_diff = n_envs - len(this_init_fns)
+ if n_diff > 0:
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
+ assert len(this_init_fns) == n_envs
+
+ # init envs
+ env.call_each('run_dill_function',
+ args_list=[(x,) for x in this_init_fns])
+
+ # start rollout
+ obs = env.reset()
+ past_action = None
+ policy.reset()
+
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval PushtImageRunner {chunk_idx+1}/{n_chunks}",
+ leave=False, mininterval=self.tqdm_interval_sec)
+ done = False
+ while not done:
+ # create obs dict
+ np_obs_dict = dict(obs)
+ if self.past_action and (past_action is not None):
+ # TODO: not tested
+ np_obs_dict['past_action'] = past_action[
+ :,-(self.n_obs_steps-1):].astype(np.float32)
+
+ # device transfer
+ obs_dict = dict_apply(np_obs_dict,
+ lambda x: torch.from_numpy(x).to(
+ device=device))
+
+ # run policy
+ with torch.no_grad():
+ action_dict = policy.predict_action(obs_dict)
+
+ # device_transfer
+ np_action_dict = dict_apply(action_dict,
+ lambda x: x.detach().to('cpu').numpy())
+
+ action = np_action_dict['action']
+
+ # step env
+ obs, reward, done, info = env.step(action)
+ done = np.all(done)
+ past_action = action
+
+ # update pbar
+ pbar.update(action.shape[1])
+ pbar.close()
+
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
+ # clear out video buffer
+ _ = env.reset()
+
+ # log
+ max_rewards = collections.defaultdict(list)
+ log_data = dict()
+ # results reported in the paper are generated using the commented out line below
+ # which will only report and average metrics from first n_envs initial condition and seeds
+ # fortunately this won't invalidate our conclusion since
+ # 1. This bug only affects the variance of metrics, not their mean
+ # 2. All baseline methods are evaluated using the same code
+ # to completely reproduce reported numbers, uncomment this line:
+ # for i in range(len(self.env_fns)):
+ # and comment out this line
+ for i in range(n_inits):
+ seed = self.env_seeds[i]
+ prefix = self.env_prefixs[i]
+ max_reward = np.max(all_rewards[i])
+ max_rewards[prefix].append(max_reward)
+ log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
+
+ # visualize sim
+ video_path = all_video_paths[i]
+ if video_path is not None:
+ sim_video = wandb.Video(video_path)
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
+
+ # log aggregate metrics
+ for prefix, value in max_rewards.items():
+ name = prefix+'mean_score'
+ value = np.mean(value)
+ log_data[name] = value
+
+ return log_data
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/pusht_keypoints_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/pusht_keypoints_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..a16bd58ed17fe3a8bd23e4f3f22e209fcc195d00
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/pusht_keypoints_runner.py
@@ -0,0 +1,273 @@
+import wandb
+import numpy as np
+import torch
+import collections
+import pathlib
+import tqdm
+import dill
+import math
+import wandb.sdk.data_types.video as wv
+from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
+from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
+# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
+from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
+from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
+
+from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
+
+class PushTKeypointsRunner(BaseLowdimRunner):
+ def __init__(self,
+ output_dir,
+ keypoint_visible_rate=1.0,
+ n_train=10,
+ n_train_vis=3,
+ train_start_seed=0,
+ n_test=22,
+ n_test_vis=6,
+ legacy_test=False,
+ test_start_seed=10000,
+ max_steps=200,
+ n_obs_steps=8,
+ n_action_steps=8,
+ n_latency_steps=0,
+ fps=10,
+ crf=22,
+ agent_keypoints=False,
+ past_action=False,
+ tqdm_interval_sec=5.0,
+ n_envs=None
+ ):
+ super().__init__(output_dir)
+
+ if n_envs is None:
+ n_envs = n_train + n_test
+
+ # handle latency step
+ # to mimic latency, we request n_latency_steps additional steps
+ # of past observations, and the discard the last n_latency_steps
+ env_n_obs_steps = n_obs_steps + n_latency_steps
+ env_n_action_steps = n_action_steps
+
+ # assert n_obs_steps <= n_action_steps
+ kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
+
+ def env_fn():
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ PushTKeypointsEnv(
+ legacy=legacy_test,
+ keypoint_visible_rate=keypoint_visible_rate,
+ agent_keypoints=agent_keypoints,
+ **kp_kwargs
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ ),
+ n_obs_steps=env_n_obs_steps,
+ n_action_steps=env_n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ env_fns = [env_fn] * n_envs
+ env_seeds = list()
+ env_prefixs = list()
+ env_init_fn_dills = list()
+ # train
+ for i in range(n_train):
+ seed = train_start_seed + i
+ enable_render = i < n_train_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('train/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ # test
+ for i in range(n_test):
+ seed = test_start_seed + i
+ enable_render = i < n_test_vis
+
+ def init_fn(env, seed=seed, enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # set seed
+ assert isinstance(env, MultiStepWrapper)
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('test/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ env = AsyncVectorEnv(env_fns)
+
+ # test env
+ # env.reset(seed=env_seeds)
+ # x = env.step(env.action_space.sample())
+ # imgs = env.call('render')
+ # import pdb; pdb.set_trace()
+
+ self.env = env
+ self.env_fns = env_fns
+ self.env_seeds = env_seeds
+ self.env_prefixs = env_prefixs
+ self.env_init_fn_dills = env_init_fn_dills
+ self.fps = fps
+ self.crf = crf
+ self.agent_keypoints = agent_keypoints
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.n_latency_steps = n_latency_steps
+ self.past_action = past_action
+ self.max_steps = max_steps
+ self.tqdm_interval_sec = tqdm_interval_sec
+
+ def run(self, policy: BaseLowdimPolicy):
+ device = policy.device
+ dtype = policy.dtype
+
+ env = self.env
+
+ # plan for rollout
+ n_envs = len(self.env_fns)
+ n_inits = len(self.env_init_fn_dills)
+ n_chunks = math.ceil(n_inits / n_envs)
+
+ # allocate data
+ all_video_paths = [None] * n_inits
+ all_rewards = [None] * n_inits
+
+ for chunk_idx in range(n_chunks):
+ start = chunk_idx * n_envs
+ end = min(n_inits, start + n_envs)
+ this_global_slice = slice(start, end)
+ this_n_active_envs = end - start
+ this_local_slice = slice(0,this_n_active_envs)
+
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
+ n_diff = n_envs - len(this_init_fns)
+ if n_diff > 0:
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
+ assert len(this_init_fns) == n_envs
+
+ # init envs
+ env.call_each('run_dill_function',
+ args_list=[(x,) for x in this_init_fns])
+
+ # start rollout
+ obs = env.reset()
+ past_action = None
+ policy.reset()
+
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval PushtKeypointsRunner {chunk_idx+1}/{n_chunks}",
+ leave=False, mininterval=self.tqdm_interval_sec)
+ done = False
+ while not done:
+ Do = obs.shape[-1] // 2
+ # create obs dict
+ np_obs_dict = {
+ # handle n_latency_steps by discarding the last n_latency_steps
+ 'obs': obs[...,:self.n_obs_steps,:Do].astype(np.float32),
+ 'obs_mask': obs[...,:self.n_obs_steps,Do:] > 0.5
+ }
+ if self.past_action and (past_action is not None):
+ # TODO: not tested
+ np_obs_dict['past_action'] = past_action[
+ :,-(self.n_obs_steps-1):].astype(np.float32)
+
+ # device transfer
+ obs_dict = dict_apply(np_obs_dict,
+ lambda x: torch.from_numpy(x).to(
+ device=device))
+
+ # run policy
+ with torch.no_grad():
+ action_dict = policy.predict_action(obs_dict)
+
+ # device_transfer
+ np_action_dict = dict_apply(action_dict,
+ lambda x: x.detach().to('cpu').numpy())
+
+ # handle latency_steps, we discard the first n_latency_steps actions
+ # to simulate latency
+ action = np_action_dict['action'][:,self.n_latency_steps:]
+
+ # step env
+ obs, reward, done, info = env.step(action)
+ done = np.all(done)
+ past_action = action
+
+ # update pbar
+ pbar.update(action.shape[1])
+ pbar.close()
+
+ # collect data for this round
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
+ # import pdb; pdb.set_trace()
+
+ # log
+ max_rewards = collections.defaultdict(list)
+ log_data = dict()
+ # results reported in the paper are generated using the commented out line below
+ # which will only report and average metrics from first n_envs initial condition and seeds
+ # fortunately this won't invalidate our conclusion since
+ # 1. This bug only affects the variance of metrics, not their mean
+ # 2. All baseline methods are evaluated using the same code
+ # to completely reproduce reported numbers, uncomment this line:
+ # for i in range(len(self.env_fns)):
+ # and comment out this line
+ for i in range(n_inits):
+ seed = self.env_seeds[i]
+ prefix = self.env_prefixs[i]
+ max_reward = np.max(all_rewards[i])
+ max_rewards[prefix].append(max_reward)
+ log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
+
+ # visualize sim
+ video_path = all_video_paths[i]
+ if video_path is not None:
+ sim_video = wandb.Video(video_path)
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
+
+ # log aggregate metrics
+ for prefix, value in max_rewards.items():
+ name = prefix+'mean_score'
+ value = np.mean(value)
+ log_data[name] = value
+
+ return log_data
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/real_pusht_image_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/real_pusht_image_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b58780ce669a23722ccbbb0d13071ee25f6751f
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/real_pusht_image_runner.py
@@ -0,0 +1,10 @@
+from diffusion_policy.policy.base_image_policy import BaseImagePolicy
+from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
+
+class RealPushTImageRunner(BaseImageRunner):
+ def __init__(self,
+ output_dir):
+ super().__init__(output_dir)
+
+ def run(self, policy: BaseImagePolicy):
+ return dict()
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/robomimic_image_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/robomimic_image_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbae74e86ce67e4470c42ba645a9f19f59e33e3a
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/robomimic_image_runner.py
@@ -0,0 +1,375 @@
+import os
+import wandb
+import numpy as np
+import torch
+import collections
+import pathlib
+import tqdm
+import h5py
+import math
+import dill
+import wandb.sdk.data_types.video as wv
+from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
+from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
+from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
+from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
+from diffusion_policy.model.common.rotation_transformer import RotationTransformer
+
+from diffusion_policy.policy.base_image_policy import BaseImagePolicy
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
+from diffusion_policy.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.obs_utils as ObsUtils
+
+
+def create_env(env_meta, shape_meta, enable_render=True):
+ modality_mapping = collections.defaultdict(list)
+ for key, attr in shape_meta['obs'].items():
+ modality_mapping[attr.get('type', 'low_dim')].append(key)
+ ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
+
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ render=False,
+ render_offscreen=enable_render,
+ use_image_obs=enable_render,
+ )
+ return env
+
+
+class RobomimicImageRunner(BaseImageRunner):
+ """
+ Robomimic envs already enforces number of steps.
+ """
+
+ def __init__(self,
+ output_dir,
+ dataset_path,
+ shape_meta:dict,
+ n_train=10,
+ n_train_vis=3,
+ train_start_idx=0,
+ n_test=22,
+ n_test_vis=6,
+ test_start_seed=10000,
+ max_steps=400,
+ n_obs_steps=2,
+ n_action_steps=8,
+ render_obs_key='agentview_image',
+ fps=10,
+ crf=22,
+ past_action=False,
+ abs_action=False,
+ tqdm_interval_sec=5.0,
+ n_envs=None
+ ):
+ super().__init__(output_dir)
+
+ if n_envs is None:
+ n_envs = n_train + n_test
+
+ # assert n_obs_steps <= n_action_steps
+ dataset_path = os.path.expanduser(dataset_path)
+ robosuite_fps = 20
+ steps_per_render = max(robosuite_fps // fps, 1)
+
+ # read from dataset
+ env_meta = FileUtils.get_env_metadata_from_dataset(
+ dataset_path)
+ # disable object state observation
+ env_meta['env_kwargs']['use_object_obs'] = False
+
+ rotation_transformer = None
+ if abs_action:
+ env_meta['env_kwargs']['controller_configs']['control_delta'] = False
+ rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')
+
+ def env_fn():
+ robomimic_env = create_env(
+ env_meta=env_meta,
+ shape_meta=shape_meta
+ )
+ # Robosuite's hard reset causes excessive memory consumption.
+ # Disabled to run more envs.
+ # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
+ robomimic_env.env.hard_reset = False
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ RobomimicImageWrapper(
+ env=robomimic_env,
+ shape_meta=shape_meta,
+ init_state=None,
+ render_obs_key=render_obs_key
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ steps_per_render=steps_per_render
+ ),
+ n_obs_steps=n_obs_steps,
+ n_action_steps=n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ # For each process the OpenGL context can only be initialized once
+ # Since AsyncVectorEnv uses fork to create worker process,
+ # a separate env_fn that does not create OpenGL context (enable_render=False)
+ # is needed to initialize spaces.
+ def dummy_env_fn():
+ robomimic_env = create_env(
+ env_meta=env_meta,
+ shape_meta=shape_meta,
+ enable_render=False
+ )
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ RobomimicImageWrapper(
+ env=robomimic_env,
+ shape_meta=shape_meta,
+ init_state=None,
+ render_obs_key=render_obs_key
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ steps_per_render=steps_per_render
+ ),
+ n_obs_steps=n_obs_steps,
+ n_action_steps=n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ env_fns = [env_fn] * n_envs
+ env_seeds = list()
+ env_prefixs = list()
+ env_init_fn_dills = list()
+
+ # train
+ with h5py.File(dataset_path, 'r') as f:
+ for i in range(n_train):
+ train_idx = train_start_idx + i
+ enable_render = i < n_train_vis
+ init_state = f[f'data/demo_{train_idx}/states'][0]
+
+ def init_fn(env, init_state=init_state,
+ enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # switch to init_state reset
+ assert isinstance(env.env.env, RobomimicImageWrapper)
+ env.env.env.init_state = init_state
+
+ env_seeds.append(train_idx)
+ env_prefixs.append('train/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ # test
+ for i in range(n_test):
+ seed = test_start_seed + i
+ enable_render = i < n_test_vis
+
+ def init_fn(env, seed=seed,
+ enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # switch to seed reset
+ assert isinstance(env.env.env, RobomimicImageWrapper)
+ env.env.env.init_state = None
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('test/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
+ # env = SyncVectorEnv(env_fns)
+
+
+ self.env_meta = env_meta
+ self.env = env
+ self.env_fns = env_fns
+ self.env_seeds = env_seeds
+ self.env_prefixs = env_prefixs
+ self.env_init_fn_dills = env_init_fn_dills
+ self.fps = fps
+ self.crf = crf
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.past_action = past_action
+ self.max_steps = max_steps
+ self.rotation_transformer = rotation_transformer
+ self.abs_action = abs_action
+ self.tqdm_interval_sec = tqdm_interval_sec
+
+ def run(self, policy: BaseImagePolicy):
+ device = policy.device
+ dtype = policy.dtype
+ env = self.env
+
+ # plan for rollout
+ n_envs = len(self.env_fns)
+ n_inits = len(self.env_init_fn_dills)
+ n_chunks = math.ceil(n_inits / n_envs)
+
+ # allocate data
+ all_video_paths = [None] * n_inits
+ all_rewards = [None] * n_inits
+
+ for chunk_idx in range(n_chunks):
+ start = chunk_idx * n_envs
+ end = min(n_inits, start + n_envs)
+ this_global_slice = slice(start, end)
+ this_n_active_envs = end - start
+ this_local_slice = slice(0,this_n_active_envs)
+
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
+ n_diff = n_envs - len(this_init_fns)
+ if n_diff > 0:
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
+ assert len(this_init_fns) == n_envs
+
+ # init envs
+ env.call_each('run_dill_function',
+ args_list=[(x,) for x in this_init_fns])
+
+ # start rollout
+ obs = env.reset()
+ past_action = None
+ policy.reset()
+
+ env_name = self.env_meta['env_name']
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx+1}/{n_chunks}",
+ leave=False, mininterval=self.tqdm_interval_sec)
+
+ done = False
+ while not done:
+ # create obs dict
+ np_obs_dict = dict(obs)
+ if self.past_action and (past_action is not None):
+ # TODO: not tested
+ np_obs_dict['past_action'] = past_action[
+ :,-(self.n_obs_steps-1):].astype(np.float32)
+
+ # device transfer
+ obs_dict = dict_apply(np_obs_dict,
+ lambda x: torch.from_numpy(x).to(
+ device=device))
+
+ # run policy
+ with torch.no_grad():
+ action_dict = policy.predict_action(obs_dict)
+
+ # device_transfer
+ np_action_dict = dict_apply(action_dict,
+ lambda x: x.detach().to('cpu').numpy())
+
+ action = np_action_dict['action']
+ if not np.all(np.isfinite(action)):
+ print(action)
+ raise RuntimeError("Nan or Inf action")
+
+ # step env
+ env_action = action
+ if self.abs_action:
+ env_action = self.undo_transform_action(action)
+
+ obs, reward, done, info = env.step(env_action)
+ done = np.all(done)
+ past_action = action
+
+ # update pbar
+ pbar.update(action.shape[1])
+ pbar.close()
+
+ # collect data for this round
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
+ # clear out video buffer
+ _ = env.reset()
+
+ # log
+ max_rewards = collections.defaultdict(list)
+ log_data = dict()
+ # results reported in the paper are generated using the commented out line below
+ # which will only report and average metrics from first n_envs initial condition and seeds
+ # fortunately this won't invalidate our conclusion since
+ # 1. This bug only affects the variance of metrics, not their mean
+ # 2. All baseline methods are evaluated using the same code
+ # to completely reproduce reported numbers, uncomment this line:
+ # for i in range(len(self.env_fns)):
+ # and comment out this line
+ for i in range(n_inits):
+ seed = self.env_seeds[i]
+ prefix = self.env_prefixs[i]
+ max_reward = np.max(all_rewards[i])
+ max_rewards[prefix].append(max_reward)
+ log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
+
+ # visualize sim
+ video_path = all_video_paths[i]
+ if video_path is not None:
+ sim_video = wandb.Video(video_path)
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
+
+ # log aggregate metrics
+ for prefix, value in max_rewards.items():
+ name = prefix+'mean_score'
+ value = np.mean(value)
+ log_data[name] = value
+
+ return log_data
+
+ def undo_transform_action(self, action):
+ raw_shape = action.shape
+ if raw_shape[-1] == 20:
+ # dual arm
+ action = action.reshape(-1,2,10)
+
+ d_rot = action.shape[-1] - 4
+ pos = action[...,:3]
+ rot = action[...,3:3+d_rot]
+ gripper = action[...,[-1]]
+ rot = self.rotation_transformer.inverse(rot)
+ uaction = np.concatenate([
+ pos, rot, gripper
+ ], axis=-1)
+
+ if raw_shape[-1] == 20:
+ # dual arm
+ uaction = uaction.reshape(*raw_shape[:-1], 14)
+
+ return uaction
diff --git a/third_party/diffusion_policy/diffusion_policy/env_runner/robomimic_lowdim_runner.py b/third_party/diffusion_policy/diffusion_policy/env_runner/robomimic_lowdim_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ba642e47f97774a6cc1606b0c8e8d1f952fb1c
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/env_runner/robomimic_lowdim_runner.py
@@ -0,0 +1,368 @@
+import os
+import wandb
+import numpy as np
+import torch
+import collections
+import pathlib
+import tqdm
+import h5py
+import dill
+import math
+import wandb.sdk.data_types.video as wv
+from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
+# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
+from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
+from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
+from diffusion_policy.model.common.rotation_transformer import RotationTransformer
+
+from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
+from diffusion_policy.env.robomimic.robomimic_lowdim_wrapper import RobomimicLowdimWrapper
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.obs_utils as ObsUtils
+
+
+def create_env(env_meta, obs_keys):
+ ObsUtils.initialize_obs_modality_mapping_from_dict(
+ {'low_dim': obs_keys})
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ render=False,
+ # only way to not show collision geometry
+ # is to enable render_offscreen
+ # which uses a lot of RAM.
+ render_offscreen=False,
+ use_image_obs=False,
+ )
+ return env
+
+
+class RobomimicLowdimRunner(BaseLowdimRunner):
+ """
+ Robomimic envs already enforces number of steps.
+ """
+
+ def __init__(self,
+ output_dir,
+ dataset_path,
+ obs_keys,
+ n_train=10,
+ n_train_vis=3,
+ train_start_idx=0,
+ n_test=22,
+ n_test_vis=6,
+ test_start_seed=10000,
+ max_steps=400,
+ n_obs_steps=2,
+ n_action_steps=8,
+ n_latency_steps=0,
+ render_hw=(256,256),
+ render_camera_name='agentview',
+ fps=10,
+ crf=22,
+ past_action=False,
+ abs_action=False,
+ tqdm_interval_sec=5.0,
+ n_envs=None
+ ):
+ """
+ Assuming:
+ n_obs_steps=2
+ n_latency_steps=3
+ n_action_steps=4
+ o: obs
+ i: inference
+ a: action
+ Batch t:
+ |o|o| | | | | | |
+ | |i|i|i| | | | |
+ | | | | |a|a|a|a|
+ Batch t+1
+ | | | | |o|o| | | | | | |
+ | | | | | |i|i|i| | | | |
+ | | | | | | | | |a|a|a|a|
+ """
+
+ super().__init__(output_dir)
+
+ if n_envs is None:
+ n_envs = n_train + n_test
+
+ # handle latency step
+ # to mimic latency, we request n_latency_steps additional steps
+ # of past observations, and the discard the last n_latency_steps
+ env_n_obs_steps = n_obs_steps + n_latency_steps
+ env_n_action_steps = n_action_steps
+
+ # assert n_obs_steps <= n_action_steps
+ dataset_path = os.path.expanduser(dataset_path)
+ robosuite_fps = 20
+ steps_per_render = max(robosuite_fps // fps, 1)
+
+ # read from dataset
+ env_meta = FileUtils.get_env_metadata_from_dataset(
+ dataset_path)
+ rotation_transformer = None
+ if abs_action:
+ env_meta['env_kwargs']['controller_configs']['control_delta'] = False
+ rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')
+
+ def env_fn():
+ robomimic_env = create_env(
+ env_meta=env_meta,
+ obs_keys=obs_keys
+ )
+ # hard reset doesn't influence lowdim env
+ # robomimic_env.env.hard_reset = False
+ return MultiStepWrapper(
+ VideoRecordingWrapper(
+ RobomimicLowdimWrapper(
+ env=robomimic_env,
+ obs_keys=obs_keys,
+ init_state=None,
+ render_hw=render_hw,
+ render_camera_name=render_camera_name
+ ),
+ video_recoder=VideoRecorder.create_h264(
+ fps=fps,
+ codec='h264',
+ input_pix_fmt='rgb24',
+ crf=crf,
+ thread_type='FRAME',
+ thread_count=1
+ ),
+ file_path=None,
+ steps_per_render=steps_per_render
+ ),
+ n_obs_steps=env_n_obs_steps,
+ n_action_steps=env_n_action_steps,
+ max_episode_steps=max_steps
+ )
+
+ env_fns = [env_fn] * n_envs
+ env_seeds = list()
+ env_prefixs = list()
+ env_init_fn_dills = list()
+
+ # train
+ with h5py.File(dataset_path, 'r') as f:
+ for i in range(n_train):
+ train_idx = train_start_idx + i
+ enable_render = i < n_train_vis
+ init_state = f[f'data/demo_{train_idx}/states'][0]
+
+ def init_fn(env, init_state=init_state,
+ enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # switch to init_state reset
+ assert isinstance(env.env.env, RobomimicLowdimWrapper)
+ env.env.env.init_state = init_state
+
+ env_seeds.append(train_idx)
+ env_prefixs.append('train/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ # test
+ for i in range(n_test):
+ seed = test_start_seed + i
+ enable_render = i < n_test_vis
+
+ def init_fn(env, seed=seed,
+ enable_render=enable_render):
+ # setup rendering
+ # video_wrapper
+ assert isinstance(env.env, VideoRecordingWrapper)
+ env.env.video_recoder.stop()
+ env.env.file_path = None
+ if enable_render:
+ filename = pathlib.Path(output_dir).joinpath(
+ 'media', wv.util.generate_id() + ".mp4")
+ filename.parent.mkdir(parents=False, exist_ok=True)
+ filename = str(filename)
+ env.env.file_path = filename
+
+ # switch to seed reset
+ assert isinstance(env.env.env, RobomimicLowdimWrapper)
+ env.env.env.init_state = None
+ env.seed(seed)
+
+ env_seeds.append(seed)
+ env_prefixs.append('test/')
+ env_init_fn_dills.append(dill.dumps(init_fn))
+
+ env = AsyncVectorEnv(env_fns)
+ # env = SyncVectorEnv(env_fns)
+
+ self.env_meta = env_meta
+ self.env = env
+ self.env_fns = env_fns
+ self.env_seeds = env_seeds
+ self.env_prefixs = env_prefixs
+ self.env_init_fn_dills = env_init_fn_dills
+ self.fps = fps
+ self.crf = crf
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.n_latency_steps = n_latency_steps
+ self.env_n_obs_steps = env_n_obs_steps
+ self.env_n_action_steps = env_n_action_steps
+ self.past_action = past_action
+ self.max_steps = max_steps
+ self.rotation_transformer = rotation_transformer
+ self.abs_action = abs_action
+ self.tqdm_interval_sec = tqdm_interval_sec
+
+ def run(self, policy: BaseLowdimPolicy):
+ device = policy.device
+ dtype = policy.dtype
+ env = self.env
+
+ # plan for rollout
+ n_envs = len(self.env_fns)
+ n_inits = len(self.env_init_fn_dills)
+ n_chunks = math.ceil(n_inits / n_envs)
+
+ # allocate data
+ all_video_paths = [None] * n_inits
+ all_rewards = [None] * n_inits
+
+ for chunk_idx in range(n_chunks):
+ start = chunk_idx * n_envs
+ end = min(n_inits, start + n_envs)
+ this_global_slice = slice(start, end)
+ this_n_active_envs = end - start
+ this_local_slice = slice(0,this_n_active_envs)
+
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
+ n_diff = n_envs - len(this_init_fns)
+ if n_diff > 0:
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
+ assert len(this_init_fns) == n_envs
+
+ # init envs
+ env.call_each('run_dill_function',
+ args_list=[(x,) for x in this_init_fns])
+
+ # start rollout
+ obs = env.reset()
+ past_action = None
+ policy.reset()
+
+ env_name = self.env_meta['env_name']
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Lowdim {chunk_idx+1}/{n_chunks}",
+ leave=False, mininterval=self.tqdm_interval_sec)
+
+ done = False
+ while not done:
+ # create obs dict
+ np_obs_dict = {
+ # handle n_latency_steps by discarding the last n_latency_steps
+ 'obs': obs[:,:self.n_obs_steps].astype(np.float32)
+ }
+ if self.past_action and (past_action is not None):
+ # TODO: not tested
+ np_obs_dict['past_action'] = past_action[
+ :,-(self.n_obs_steps-1):].astype(np.float32)
+
+ # device transfer
+ obs_dict = dict_apply(np_obs_dict,
+ lambda x: torch.from_numpy(x).to(
+ device=device))
+
+ # run policy
+ with torch.no_grad():
+ action_dict = policy.predict_action(obs_dict)
+
+ # device_transfer
+ np_action_dict = dict_apply(action_dict,
+ lambda x: x.detach().to('cpu').numpy())
+
+ # handle latency_steps, we discard the first n_latency_steps actions
+ # to simulate latency
+ action = np_action_dict['action'][:,self.n_latency_steps:]
+ if not np.all(np.isfinite(action)):
+ print(action)
+ raise RuntimeError("Nan or Inf action")
+
+ # step env
+ env_action = action
+ if self.abs_action:
+ env_action = self.undo_transform_action(action)
+
+ obs, reward, done, info = env.step(env_action)
+ done = np.all(done)
+ past_action = action
+
+ # update pbar
+ pbar.update(action.shape[1])
+ pbar.close()
+
+ # collect data for this round
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
+
+ # log
+ max_rewards = collections.defaultdict(list)
+ log_data = dict()
+ # results reported in the paper are generated using the commented out line below
+ # which will only report and average metrics from first n_envs initial condition and seeds
+ # fortunately this won't invalidate our conclusion since
+ # 1. This bug only affects the variance of metrics, not their mean
+ # 2. All baseline methods are evaluated using the same code
+ # to completely reproduce reported numbers, uncomment this line:
+ # for i in range(len(self.env_fns)):
+ # and comment out this line
+ for i in range(n_inits):
+ seed = self.env_seeds[i]
+ prefix = self.env_prefixs[i]
+ max_reward = np.max(all_rewards[i])
+ max_rewards[prefix].append(max_reward)
+ log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
+
+ # visualize sim
+ video_path = all_video_paths[i]
+ if video_path is not None:
+ sim_video = wandb.Video(video_path)
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
+
+ # log aggregate metrics
+ for prefix, value in max_rewards.items():
+ name = prefix+'mean_score'
+ value = np.mean(value)
+ log_data[name] = value
+
+ return log_data
+
+ def undo_transform_action(self, action):
+ raw_shape = action.shape
+ if raw_shape[-1] == 20:
+ # dual arm
+ action = action.reshape(-1,2,10)
+
+ d_rot = action.shape[-1] - 4
+ pos = action[...,:3]
+ rot = action[...,3:3+d_rot]
+ gripper = action[...,[-1]]
+ rot = self.rotation_transformer.inverse(rot)
+ uaction = np.concatenate([
+ pos, rot, gripper
+ ], axis=-1)
+
+ if raw_shape[-1] == 20:
+ # dual arm
+ uaction = uaction.reshape(*raw_shape[:-1], 14)
+
+ return uaction
diff --git a/third_party/diffusion_policy/diffusion_policy/gym_util/async_vector_env.py b/third_party/diffusion_policy/diffusion_policy/gym_util/async_vector_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb0f620ac3d39861e0dfc1a3a58fb22d443f244
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/gym_util/async_vector_env.py
@@ -0,0 +1,671 @@
+"""
+Back ported methods: call, set_attr from v0.26
+Disabled auto-reset after done
+Added render method.
+"""
+
+
+import numpy as np
+import multiprocessing as mp
+import time
+import sys
+from enum import Enum
+from copy import deepcopy
+
+from gym import logger
+from gym.vector.vector_env import VectorEnv
+from gym.error import (
+ AlreadyPendingCallError,
+ NoAsyncCallError,
+ ClosedEnvironmentError,
+ CustomSpaceError,
+)
+from gym.vector.utils import (
+ create_shared_memory,
+ create_empty_array,
+ write_to_shared_memory,
+ read_from_shared_memory,
+ concatenate,
+ CloudpickleWrapper,
+ clear_mpi_env_vars,
+)
+
+__all__ = ["AsyncVectorEnv"]
+
+
+class AsyncState(Enum):
+ DEFAULT = "default"
+ WAITING_RESET = "reset"
+ WAITING_STEP = "step"
+ WAITING_CALL = "call"
+
+
+class AsyncVectorEnv(VectorEnv):
+ """Vectorized environment that runs multiple environments in parallel. It
+ uses `multiprocessing` processes, and pipes for communication.
+ Parameters
+ ----------
+ env_fns : iterable of callable
+ Functions that create the environments.
+ observation_space : `gym.spaces.Space` instance, optional
+ Observation space of a single environment. If `None`, then the
+ observation space of the first environment is taken.
+ action_space : `gym.spaces.Space` instance, optional
+ Action space of a single environment. If `None`, then the action space
+ of the first environment is taken.
+ shared_memory : bool (default: `True`)
+ If `True`, then the observations from the worker processes are
+ communicated back through shared variables. This can improve the
+ efficiency if the observations are large (e.g. images).
+ copy : bool (default: `True`)
+ If `True`, then the `reset` and `step` methods return a copy of the
+ observations.
+ context : str, optional
+ Context for multiprocessing. If `None`, then the default context is used.
+ Only available in Python 3.
+ daemon : bool (default: `True`)
+ If `True`, then subprocesses have `daemon` flag turned on; that is, they
+ will quit if the head process quits. However, `daemon=True` prevents
+ subprocesses to spawn children, so for some environments you may want
+ to have it set to `False`
+ worker : function, optional
+ WARNING - advanced mode option! If set, then use that worker in a subprocess
+ instead of a default one. Can be useful to override some inner vector env
+ logic, for instance, how resets on done are handled. Provides high
+ degree of flexibility and a high chance to shoot yourself in the foot; thus,
+ if you are writing your own worker, it is recommended to start from the code
+ for `_worker` (or `_worker_shared_memory`) method below, and add changes
+ """
+
+ def __init__(
+ self,
+ env_fns,
+ dummy_env_fn=None,
+ observation_space=None,
+ action_space=None,
+ shared_memory=True,
+ copy=True,
+ context=None,
+ daemon=True,
+ worker=None,
+ ):
+ ctx = mp.get_context(context)
+ self.env_fns = env_fns
+ self.shared_memory = shared_memory
+ self.copy = copy
+
+ # Added dummy_env_fn to fix OpenGL error in Mujoco
+ # disable any OpenGL rendering in dummy_env_fn, since it
+ # will conflict with OpenGL context in the forked child process
+ if dummy_env_fn is None:
+ dummy_env_fn = env_fns[0]
+ dummy_env = dummy_env_fn()
+ self.metadata = dummy_env.metadata
+
+ if (observation_space is None) or (action_space is None):
+ observation_space = observation_space or dummy_env.observation_space
+ action_space = action_space or dummy_env.action_space
+ dummy_env.close()
+ del dummy_env
+ super(AsyncVectorEnv, self).__init__(
+ num_envs=len(env_fns),
+ observation_space=observation_space,
+ action_space=action_space,
+ )
+
+ if self.shared_memory:
+ try:
+ _obs_buffer = create_shared_memory(
+ self.single_observation_space, n=self.num_envs, ctx=ctx
+ )
+ self.observations = read_from_shared_memory(
+ _obs_buffer, self.single_observation_space, n=self.num_envs
+ )
+ except CustomSpaceError:
+ raise ValueError(
+ "Using `shared_memory=True` in `AsyncVectorEnv` "
+ "is incompatible with non-standard Gym observation spaces "
+ "(i.e. custom spaces inheriting from `gym.Space`), and is "
+ "only compatible with default Gym spaces (e.g. `Box`, "
+ "`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
+ "if you use custom observation spaces."
+ )
+ else:
+ _obs_buffer = None
+ self.observations = create_empty_array(
+ self.single_observation_space, n=self.num_envs, fn=np.zeros
+ )
+
+ self.parent_pipes, self.processes = [], []
+ self.error_queue = ctx.Queue()
+ target = _worker_shared_memory if self.shared_memory else _worker
+ target = worker or target
+ with clear_mpi_env_vars():
+ for idx, env_fn in enumerate(self.env_fns):
+ parent_pipe, child_pipe = ctx.Pipe()
+ process = ctx.Process(
+ target=target,
+ name="Worker<{0}>-{1}".format(type(self).__name__, idx),
+ args=(
+ idx,
+ CloudpickleWrapper(env_fn),
+ child_pipe,
+ parent_pipe,
+ _obs_buffer,
+ self.error_queue,
+ ),
+ )
+
+ self.parent_pipes.append(parent_pipe)
+ self.processes.append(process)
+
+ process.daemon = daemon
+ process.start()
+ child_pipe.close()
+
+ self._state = AsyncState.DEFAULT
+ self._check_observation_spaces()
+
+ def seed(self, seeds=None):
+ self._assert_is_running()
+ if seeds is None:
+ seeds = [None for _ in range(self.num_envs)]
+ if isinstance(seeds, int):
+ seeds = [seeds + i for i in range(self.num_envs)]
+ assert len(seeds) == self.num_envs
+
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ "Calling `seed` while waiting "
+ "for a pending call to `{0}` to complete.".format(self._state.value),
+ self._state.value,
+ )
+
+ for pipe, seed in zip(self.parent_pipes, seeds):
+ pipe.send(("seed", seed))
+ _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+
+ def reset_async(self):
+ self._assert_is_running()
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ "Calling `reset_async` while waiting "
+ "for a pending call to `{0}` to complete".format(self._state.value),
+ self._state.value,
+ )
+
+ for pipe in self.parent_pipes:
+ pipe.send(("reset", None))
+ self._state = AsyncState.WAITING_RESET
+
+ def reset_wait(self, timeout=None):
+ """
+ Parameters
+ ----------
+ timeout : int or float, optional
+ Number of seconds before the call to `reset_wait` times out. If
+ `None`, the call to `reset_wait` never times out.
+ Returns
+ -------
+ observations : sample from `observation_space`
+ A batch of observations from the vectorized environment.
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_RESET:
+ raise NoAsyncCallError(
+ "Calling `reset_wait` without any prior " "call to `reset_async`.",
+ AsyncState.WAITING_RESET.value,
+ )
+
+ if not self._poll(timeout):
+ self._state = AsyncState.DEFAULT
+ raise mp.TimeoutError(
+ "The call to `reset_wait` has timed out after "
+ "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ self._state = AsyncState.DEFAULT
+
+ if not self.shared_memory:
+ self.observations = concatenate(
+ results, self.observations, self.single_observation_space
+ )
+
+ return deepcopy(self.observations) if self.copy else self.observations
+
+ def step_async(self, actions):
+ """
+ Parameters
+ ----------
+ actions : iterable of samples from `action_space`
+ List of actions.
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ "Calling `step_async` while waiting "
+ "for a pending call to `{0}` to complete.".format(self._state.value),
+ self._state.value,
+ )
+
+ for pipe, action in zip(self.parent_pipes, actions):
+ pipe.send(("step", action))
+ self._state = AsyncState.WAITING_STEP
+
+ def step_wait(self, timeout=None):
+ """
+ Parameters
+ ----------
+ timeout : int or float, optional
+ Number of seconds before the call to `step_wait` times out. If
+ `None`, the call to `step_wait` never times out.
+ Returns
+ -------
+ observations : sample from `observation_space`
+ A batch of observations from the vectorized environment.
+ rewards : `np.ndarray` instance (dtype `np.float_`)
+ A vector of rewards from the vectorized environment.
+ dones : `np.ndarray` instance (dtype `np.bool_`)
+ A vector whose entries indicate whether the episode has ended.
+ infos : list of dict
+ A list of auxiliary diagnostic information.
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_STEP:
+ raise NoAsyncCallError(
+ "Calling `step_wait` without any prior call " "to `step_async`.",
+ AsyncState.WAITING_STEP.value,
+ )
+
+ if not self._poll(timeout):
+ self._state = AsyncState.DEFAULT
+ raise mp.TimeoutError(
+ "The call to `step_wait` has timed out after "
+ "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ self._state = AsyncState.DEFAULT
+ observations_list, rewards, dones, infos = zip(*results)
+
+ if not self.shared_memory:
+ self.observations = concatenate(
+ observations_list, self.observations, self.single_observation_space
+ )
+
+ return (
+ deepcopy(self.observations) if self.copy else self.observations,
+ np.array(rewards),
+ np.array(dones, dtype=np.bool_),
+ infos,
+ )
+
+ def close_extras(self, timeout=None, terminate=False):
+ """
+ Parameters
+ ----------
+ timeout : int or float, optional
+ Number of seconds before the call to `close` times out. If `None`,
+ the call to `close` never times out. If the call to `close` times
+ out, then all processes are terminated.
+ terminate : bool (default: `False`)
+ If `True`, then the `close` operation is forced and all processes
+ are terminated.
+ """
+ timeout = 0 if terminate else timeout
+ try:
+ if self._state != AsyncState.DEFAULT:
+ logger.warn(
+ "Calling `close` while waiting for a pending "
+ "call to `{0}` to complete.".format(self._state.value)
+ )
+ function = getattr(self, "{0}_wait".format(self._state.value))
+ function(timeout)
+ except mp.TimeoutError:
+ terminate = True
+
+ if terminate:
+ for process in self.processes:
+ if process.is_alive():
+ process.terminate()
+ else:
+ for pipe in self.parent_pipes:
+ if (pipe is not None) and (not pipe.closed):
+ pipe.send(("close", None))
+ for pipe in self.parent_pipes:
+ if (pipe is not None) and (not pipe.closed):
+ pipe.recv()
+
+ for pipe in self.parent_pipes:
+ if pipe is not None:
+ pipe.close()
+ for process in self.processes:
+ process.join()
+
+ def _poll(self, timeout=None):
+ self._assert_is_running()
+ if timeout is None:
+ return True
+ end_time = time.perf_counter() + timeout
+ delta = None
+ for pipe in self.parent_pipes:
+ delta = max(end_time - time.perf_counter(), 0)
+ if pipe is None:
+ return False
+ if pipe.closed or (not pipe.poll(delta)):
+ return False
+ return True
+
+ def _check_observation_spaces(self):
+ self._assert_is_running()
+ for pipe in self.parent_pipes:
+ pipe.send(("_check_observation_space", self.single_observation_space))
+ same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ if not all(same_spaces):
+ raise RuntimeError(
+ "Some environments have an observation space "
+ "different from `{0}`. In order to batch observations, the "
+ "observation spaces from all environments must be "
+ "equal.".format(self.single_observation_space)
+ )
+
+ def _assert_is_running(self):
+ if self.closed:
+ raise ClosedEnvironmentError(
+ "Trying to operate on `{0}`, after a "
+ "call to `close()`.".format(type(self).__name__)
+ )
+
+ def _raise_if_errors(self, successes):
+ if all(successes):
+ return
+
+ num_errors = self.num_envs - sum(successes)
+ assert num_errors > 0
+ for _ in range(num_errors):
+ index, exctype, value = self.error_queue.get()
+ logger.error(
+ "Received the following error from Worker-{0}: "
+ "{1}: {2}".format(index, exctype.__name__, value)
+ )
+ logger.error("Shutting down Worker-{0}.".format(index))
+ self.parent_pipes[index].close()
+ self.parent_pipes[index] = None
+
+ logger.error("Raising the last exception back to the main process.")
+ raise exctype(value)
+
+ def call_async(self, name: str, *args, **kwargs):
+ """Calls the method with name asynchronously and apply args and kwargs to the method.
+
+ Args:
+ name: Name of the method or property to call.
+ *args: Arguments to apply to the method call.
+ **kwargs: Keyword arguments to apply to the method call.
+
+ Raises:
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
+ AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ "Calling `call_async` while waiting "
+ f"for a pending call to `{self._state.value}` to complete.",
+ self._state.value,
+ )
+
+ for pipe in self.parent_pipes:
+ pipe.send(("_call", (name, args, kwargs)))
+ self._state = AsyncState.WAITING_CALL
+
+ def call_wait(self, timeout = None) -> list:
+ """Calls all parent pipes and waits for the results.
+
+ Args:
+ timeout: Number of seconds before the call to `step_wait` times out.
+ If `None` (default), the call to `step_wait` never times out.
+
+ Returns:
+ List of the results of the individual calls to the method or property for each environment.
+
+ Raises:
+ NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
+ TimeoutError: The call to `call_wait` has timed out after timeout second(s).
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_CALL:
+ raise NoAsyncCallError(
+ "Calling `call_wait` without any prior call to `call_async`.",
+ AsyncState.WAITING_CALL.value,
+ )
+
+ if not self._poll(timeout):
+ self._state = AsyncState.DEFAULT
+ raise mp.TimeoutError(
+ f"The call to `call_wait` has timed out after {timeout} second(s)."
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ self._state = AsyncState.DEFAULT
+
+ return results
+
+ def call(self, name: str, *args, **kwargs):
+ """Call a method, or get a property, from each parallel environment.
+
+ Args:
+ name (str): Name of the method or property to call.
+ *args: Arguments to apply to the method call.
+ **kwargs: Keyword arguments to apply to the method call.
+
+ Returns:
+ List of the results of the individual calls to the method or property for each environment.
+ """
+ self.call_async(name, *args, **kwargs)
+ return self.call_wait()
+
+
+ def call_each(self, name: str,
+ args_list: list=None,
+ kwargs_list: list=None,
+ timeout = None):
+ n_envs = len(self.parent_pipes)
+ if args_list is None:
+ args_list = [[]] * n_envs
+ assert len(args_list) == n_envs
+
+ if kwargs_list is None:
+ kwargs_list = [dict()] * n_envs
+ assert len(kwargs_list) == n_envs
+
+ # send
+ self._assert_is_running()
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ "Calling `call_async` while waiting "
+ f"for a pending call to `{self._state.value}` to complete.",
+ self._state.value,
+ )
+
+ for i, pipe in enumerate(self.parent_pipes):
+ pipe.send(("_call", (name, args_list[i], kwargs_list[i])))
+ self._state = AsyncState.WAITING_CALL
+
+ # receive
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_CALL:
+ raise NoAsyncCallError(
+ "Calling `call_wait` without any prior call to `call_async`.",
+ AsyncState.WAITING_CALL.value,
+ )
+
+ if not self._poll(timeout):
+ self._state = AsyncState.DEFAULT
+ raise mp.TimeoutError(
+ f"The call to `call_wait` has timed out after {timeout} second(s)."
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ self._state = AsyncState.DEFAULT
+
+ return results
+
+
+ def set_attr(self, name: str, values):
+ """Sets an attribute of the sub-environments.
+
+ Args:
+ name: Name of the property to be set in each individual environment.
+ values: Values of the property to be set to. If ``values`` is a list or
+ tuple, then it corresponds to the values for each individual
+ environment, otherwise a single value is set for all environments.
+
+ Raises:
+ ValueError: Values must be a list or tuple with length equal to the number of environments.
+ AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
+ """
+ self._assert_is_running()
+ if not isinstance(values, (list, tuple)):
+ values = [values for _ in range(self.num_envs)]
+ if len(values) != self.num_envs:
+ raise ValueError(
+ "Values must be a list or tuple with length equal to the "
+ f"number of environments. Got `{len(values)}` values for "
+ f"{self.num_envs} environments."
+ )
+
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ "Calling `set_attr` while waiting "
+ f"for a pending call to `{self._state.value}` to complete.",
+ self._state.value,
+ )
+
+ for pipe, value in zip(self.parent_pipes, values):
+ pipe.send(("_setattr", (name, value)))
+ _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+
+ def render(self, *args, **kwargs):
+ return self.call('render', *args, **kwargs)
+
+
+
+def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
+ assert shared_memory is None
+ env = env_fn()
+ parent_pipe.close()
+ try:
+ while True:
+ command, data = pipe.recv()
+ if command == "reset":
+ observation = env.reset()
+ pipe.send((observation, True))
+ elif command == "step":
+ observation, reward, done, info = env.step(data)
+ # if done:
+ # observation = env.reset()
+ pipe.send(((observation, reward, done, info), True))
+ elif command == "seed":
+ env.seed(data)
+ pipe.send((None, True))
+ elif command == "close":
+ pipe.send((None, True))
+ break
+ elif command == "_call":
+ name, args, kwargs = data
+ if name in ["reset", "step", "seed", "close"]:
+ raise ValueError(
+ f"Trying to call function `{name}` with "
+ f"`_call`. Use `{name}` directly instead."
+ )
+ function = getattr(env, name)
+ if callable(function):
+ pipe.send((function(*args, **kwargs), True))
+ else:
+ pipe.send((function, True))
+ elif command == "_setattr":
+ name, value = data
+ setattr(env, name, value)
+ pipe.send((None, True))
+
+ elif command == "_check_observation_space":
+ pipe.send((data == env.observation_space, True))
+ else:
+ raise RuntimeError(
+ "Received unknown command `{0}`. Must "
+ "be one of {`reset`, `step`, `seed`, `close`, "
+ "`_check_observation_space`}.".format(command)
+ )
+ except (KeyboardInterrupt, Exception):
+ error_queue.put((index,) + sys.exc_info()[:2])
+ pipe.send((None, False))
+ finally:
+ env.close()
+
+
+def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
+ assert shared_memory is not None
+ env = env_fn()
+ observation_space = env.observation_space
+ parent_pipe.close()
+ try:
+ while True:
+ command, data = pipe.recv()
+ if command == "reset":
+ observation = env.reset()
+ write_to_shared_memory(
+ index, observation, shared_memory, observation_space
+ )
+ pipe.send((None, True))
+ elif command == "step":
+ observation, reward, done, info = env.step(data)
+ # if done:
+ # observation = env.reset()
+ write_to_shared_memory(
+ index, observation, shared_memory, observation_space
+ )
+ pipe.send(((None, reward, done, info), True))
+ elif command == "seed":
+ env.seed(data)
+ pipe.send((None, True))
+ elif command == "close":
+ pipe.send((None, True))
+ break
+ elif command == "_call":
+ name, args, kwargs = data
+ if name in ["reset", "step", "seed", "close"]:
+ raise ValueError(
+ f"Trying to call function `{name}` with "
+ f"`_call`. Use `{name}` directly instead."
+ )
+ function = getattr(env, name)
+ if callable(function):
+ pipe.send((function(*args, **kwargs), True))
+ else:
+ pipe.send((function, True))
+ elif command == "_setattr":
+ name, value = data
+ setattr(env, name, value)
+ pipe.send((None, True))
+ elif command == "_check_observation_space":
+ pipe.send((data == observation_space, True))
+ else:
+ raise RuntimeError(
+ "Received unknown command `{0}`. Must "
+ "be one of {`reset`, `step`, `seed`, `close`, "
+ "`_check_observation_space`}.".format(command)
+ )
+ except (KeyboardInterrupt, Exception):
+ error_queue.put((index,) + sys.exc_info()[:2])
+ pipe.send((None, False))
+ finally:
+ env.close()
\ No newline at end of file
diff --git a/third_party/diffusion_policy/diffusion_policy/gym_util/multistep_wrapper.py b/third_party/diffusion_policy/diffusion_policy/gym_util/multistep_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf2cfdac13fe478485fac90acb4ca1b0e8b8c524
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/gym_util/multistep_wrapper.py
@@ -0,0 +1,162 @@
+import gym
+from gym import spaces
+import numpy as np
+from collections import defaultdict, deque
+import dill
+
+def stack_repeated(x, n):
+ return np.repeat(np.expand_dims(x,axis=0),n,axis=0)
+
+def repeated_box(box_space, n):
+ return spaces.Box(
+ low=stack_repeated(box_space.low, n),
+ high=stack_repeated(box_space.high, n),
+ shape=(n,) + box_space.shape,
+ dtype=box_space.dtype
+ )
+
+def repeated_space(space, n):
+ if isinstance(space, spaces.Box):
+ return repeated_box(space, n)
+ elif isinstance(space, spaces.Dict):
+ result_space = spaces.Dict()
+ for key, value in space.items():
+ result_space[key] = repeated_space(value, n)
+ return result_space
+ else:
+ raise RuntimeError(f'Unsupported space type {type(space)}')
+
+def take_last_n(x, n):
+ x = list(x)
+ n = min(len(x), n)
+ return np.array(x[-n:])
+
+def dict_take_last_n(x, n):
+ result = dict()
+ for key, value in x.items():
+ result[key] = take_last_n(value, n)
+ return result
+
+def aggregate(data, method='max'):
+ if method == 'max':
+ # equivalent to any
+ return np.max(data)
+ elif method == 'min':
+ # equivalent to all
+ return np.min(data)
+ elif method == 'mean':
+ return np.mean(data)
+ elif method == 'sum':
+ return np.sum(data)
+ else:
+ raise NotImplementedError()
+
+def stack_last_n_obs(all_obs, n_steps):
+ assert(len(all_obs) > 0)
+ all_obs = list(all_obs)
+ result = np.zeros((n_steps,) + all_obs[-1].shape,
+ dtype=all_obs[-1].dtype)
+ start_idx = -min(n_steps, len(all_obs))
+ result[start_idx:] = np.array(all_obs[start_idx:])
+ if n_steps > len(all_obs):
+ # pad
+ result[:start_idx] = result[start_idx]
+ return result
+
+
+class MultiStepWrapper(gym.Wrapper):
+ def __init__(self,
+ env,
+ n_obs_steps,
+ n_action_steps,
+ max_episode_steps=None,
+ reward_agg_method='max'
+ ):
+ super().__init__(env)
+ self._action_space = repeated_space(env.action_space, n_action_steps)
+ self._observation_space = repeated_space(env.observation_space, n_obs_steps)
+ self.max_episode_steps = max_episode_steps
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ self.reward_agg_method = reward_agg_method
+ self.n_obs_steps = n_obs_steps
+
+ self.obs = deque(maxlen=n_obs_steps+1)
+ self.reward = list()
+ self.done = list()
+ self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1))
+
+ def reset(self):
+ """Resets the environment using kwargs."""
+ obs = super().reset()
+
+ self.obs = deque([obs], maxlen=self.n_obs_steps+1)
+ self.reward = list()
+ self.done = list()
+ self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1))
+
+ obs = self._get_obs(self.n_obs_steps)
+ return obs
+
+ def step(self, action):
+ """
+ actions: (n_action_steps,) + action_shape
+ """
+ for act in action:
+ if len(self.done) > 0 and self.done[-1]:
+ # termination
+ break
+ observation, reward, done, info = super().step(act)
+
+ self.obs.append(observation)
+ self.reward.append(reward)
+ if (self.max_episode_steps is not None) \
+ and (len(self.reward) >= self.max_episode_steps):
+ # truncation
+ done = True
+ self.done.append(done)
+ self._add_info(info)
+
+ observation = self._get_obs(self.n_obs_steps)
+ reward = aggregate(self.reward, self.reward_agg_method)
+ done = aggregate(self.done, 'max')
+ info = dict_take_last_n(self.info, self.n_obs_steps)
+ return observation, reward, done, info
+
+ def _get_obs(self, n_steps=1):
+ """
+ Output (n_steps,) + obs_shape
+ """
+ assert(len(self.obs) > 0)
+ if isinstance(self.observation_space, spaces.Box):
+ return stack_last_n_obs(self.obs, n_steps)
+ elif isinstance(self.observation_space, spaces.Dict):
+ result = dict()
+ for key in self.observation_space.keys():
+ result[key] = stack_last_n_obs(
+ [obs[key] for obs in self.obs],
+ n_steps
+ )
+ return result
+ else:
+ raise RuntimeError('Unsupported space type')
+
+ def _add_info(self, info):
+ for key, value in info.items():
+ self.info[key].append(value)
+
+ def get_rewards(self):
+ return self.reward
+
+ def get_attr(self, name):
+ return getattr(self, name)
+
+ def run_dill_function(self, dill_fn):
+ fn = dill.loads(dill_fn)
+ return fn(self)
+
+ def get_infos(self):
+ result = dict()
+ for k, v in self.info.items():
+ result[k] = list(v)
+ return result
diff --git a/third_party/diffusion_policy/diffusion_policy/gym_util/sync_vector_env.py b/third_party/diffusion_policy/diffusion_policy/gym_util/sync_vector_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..c85a68223ec1a1f7e3cbb406d3472795374704f8
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/gym_util/sync_vector_env.py
@@ -0,0 +1,182 @@
+import numpy as np
+from copy import deepcopy
+
+from gym import logger
+from gym.vector.vector_env import VectorEnv
+from gym.vector.utils import concatenate, create_empty_array
+
+__all__ = ["SyncVectorEnv"]
+
+
+class SyncVectorEnv(VectorEnv):
+ """Vectorized environment that serially runs multiple environments.
+ Parameters
+ ----------
+ env_fns : iterable of callable
+ Functions that create the environments.
+ observation_space : `gym.spaces.Space` instance, optional
+ Observation space of a single environment. If `None`, then the
+ observation space of the first environment is taken.
+ action_space : `gym.spaces.Space` instance, optional
+ Action space of a single environment. If `None`, then the action space
+ of the first environment is taken.
+ copy : bool (default: `True`)
+ If `True`, then the `reset` and `step` methods return a copy of the
+ observations.
+ """
+
+ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
+ self.env_fns = env_fns
+ self.envs = [env_fn() for env_fn in env_fns]
+ self.copy = copy
+ self.metadata = self.envs[0].metadata
+
+ if (observation_space is None) or (action_space is None):
+ observation_space = observation_space or self.envs[0].observation_space
+ action_space = action_space or self.envs[0].action_space
+ super(SyncVectorEnv, self).__init__(
+ num_envs=len(env_fns),
+ observation_space=observation_space,
+ action_space=action_space,
+ )
+
+ self._check_observation_spaces()
+ self.observations = create_empty_array(
+ self.single_observation_space, n=self.num_envs, fn=np.zeros
+ )
+ self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
+ self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
+ # self._rewards = [0] * self.num_envs
+ # self._dones = [False] * self.num_envs
+ self._actions = None
+
+ def seed(self, seeds=None):
+ if seeds is None:
+ seeds = [None for _ in range(self.num_envs)]
+ if isinstance(seeds, int):
+ seeds = [seeds + i for i in range(self.num_envs)]
+ assert len(seeds) == self.num_envs
+
+ for env, seed in zip(self.envs, seeds):
+ env.seed(seed)
+
+ def reset_wait(self):
+ self._dones[:] = False
+ observations = []
+ for env in self.envs:
+ observation = env.reset()
+ observations.append(observation)
+ self.observations = concatenate(
+ observations, self.observations, self.single_observation_space
+ )
+
+ return deepcopy(self.observations) if self.copy else self.observations
+
+ def step_async(self, actions):
+ self._actions = actions
+
+ def step_wait(self):
+ observations, infos = [], []
+ for i, (env, action) in enumerate(zip(self.envs, self._actions)):
+ observation, self._rewards[i], self._dones[i], info = env.step(action)
+ # if self._dones[i]:
+ # observation = env.reset()
+ observations.append(observation)
+ infos.append(info)
+ self.observations = concatenate(
+ observations, self.observations, self.single_observation_space
+ )
+
+ return (
+ deepcopy(self.observations) if self.copy else self.observations,
+ np.copy(self._rewards),
+ np.copy(self._dones),
+ infos,
+ )
+
+ def close_extras(self, **kwargs):
+ [env.close() for env in self.envs]
+
+ def _check_observation_spaces(self):
+ for env in self.envs:
+ if not (env.observation_space == self.single_observation_space):
+ break
+ else:
+ return True
+ raise RuntimeError(
+ "Some environments have an observation space "
+ "different from `{0}`. In order to batch observations, the "
+ "observation spaces from all environments must be "
+ "equal.".format(self.single_observation_space)
+ )
+
+ def call(self, name, *args, **kwargs) -> tuple:
+ """Calls the method with name and applies args and kwargs.
+
+ Args:
+ name: The method name
+ *args: The method args
+ **kwargs: The method kwargs
+
+ Returns:
+ Tuple of results
+ """
+ results = []
+ for env in self.envs:
+ function = getattr(env, name)
+ if callable(function):
+ results.append(function(*args, **kwargs))
+ else:
+ results.append(function)
+
+ return tuple(results)
+
+ def call_each(self, name: str,
+ args_list: list=None,
+ kwargs_list: list=None):
+ n_envs = len(self.envs)
+ if args_list is None:
+ args_list = [[]] * n_envs
+ assert len(args_list) == n_envs
+
+ if kwargs_list is None:
+ kwargs_list = [dict()] * n_envs
+ assert len(kwargs_list) == n_envs
+
+ results = []
+ for i, env in enumerate(self.envs):
+ function = getattr(env, name)
+ if callable(function):
+ results.append(function(*args_list[i], **kwargs_list[i]))
+ else:
+ results.append(function)
+
+ return tuple(results)
+
+
+ def render(self, *args, **kwargs):
+ return self.call('render', *args, **kwargs)
+
+ def set_attr(self, name: str, values):
+ """Sets an attribute of the sub-environments.
+
+ Args:
+ name: The property name to change
+ values: Values of the property to be set to. If ``values`` is a list or
+ tuple, then it corresponds to the values for each individual
+ environment, otherwise, a single value is set for all environments.
+
+ Raises:
+ ValueError: Values must be a list or tuple with length equal to the number of environments.
+ """
+ if not isinstance(values, (list, tuple)):
+ values = [values for _ in range(self.num_envs)]
+ if len(values) != self.num_envs:
+ raise ValueError(
+ "Values must be a list or tuple with length equal to the "
+ f"number of environments. Got `{len(values)}` values for "
+ f"{self.num_envs} environments."
+ )
+
+ for env, value in zip(self.envs, values):
+ setattr(env, name, value)
\ No newline at end of file
diff --git a/third_party/diffusion_policy/diffusion_policy/gym_util/video_recording_wrapper.py b/third_party/diffusion_policy/diffusion_policy/gym_util/video_recording_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c690f50dca46b151d5931eb76a1b1db58986d3c8
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/gym_util/video_recording_wrapper.py
@@ -0,0 +1,51 @@
+import gym
+import numpy as np
+from diffusion_policy.real_world.video_recorder import VideoRecorder
+
+class VideoRecordingWrapper(gym.Wrapper):
+ def __init__(self,
+ env,
+ video_recoder: VideoRecorder,
+ mode='rgb_array',
+ file_path=None,
+ steps_per_render=1,
+ **kwargs
+ ):
+ """
+ When file_path is None, don't record.
+ """
+ super().__init__(env)
+
+ self.mode = mode
+ self.render_kwargs = kwargs
+ self.steps_per_render = steps_per_render
+ self.file_path = file_path
+ self.video_recoder = video_recoder
+
+ self.step_count = 0
+
+ def reset(self, **kwargs):
+ obs = super().reset(**kwargs)
+ self.frames = list()
+ self.step_count = 1
+ self.video_recoder.stop()
+ return obs
+
+ def step(self, action):
+ result = super().step(action)
+ self.step_count += 1
+ if self.file_path is not None \
+ and ((self.step_count % self.steps_per_render) == 0):
+ if not self.video_recoder.is_ready():
+ self.video_recoder.start(self.file_path)
+
+ frame = self.env.render(
+ mode=self.mode, **self.render_kwargs)
+ assert frame.dtype == np.uint8
+ self.video_recoder.write_frame(frame)
+ return result
+
+ def render(self, mode='rgb_array', **kwargs):
+ if self.video_recoder.is_ready():
+ self.video_recoder.stop()
+ return self.file_path
diff --git a/third_party/diffusion_policy/diffusion_policy/gym_util/video_wrapper.py b/third_party/diffusion_policy/diffusion_policy/gym_util/video_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..abfebbefea52ac7e7eef9e21707c8e091b741940
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/gym_util/video_wrapper.py
@@ -0,0 +1,44 @@
+import gym
+import numpy as np
+
+class VideoWrapper(gym.Wrapper):
+ def __init__(self,
+ env,
+ mode='rgb_array',
+ enabled=True,
+ steps_per_render=1,
+ **kwargs
+ ):
+ super().__init__(env)
+
+ self.mode = mode
+ self.enabled = enabled
+ self.render_kwargs = kwargs
+ self.steps_per_render = steps_per_render
+
+ self.frames = list()
+ self.step_count = 0
+
+ def reset(self, **kwargs):
+ obs = super().reset(**kwargs)
+ self.frames = list()
+ self.step_count = 1
+ if self.enabled:
+ frame = self.env.render(
+ mode=self.mode, **self.render_kwargs)
+ assert frame.dtype == np.uint8
+ self.frames.append(frame)
+ return obs
+
+ def step(self, action):
+ result = super().step(action)
+ self.step_count += 1
+ if self.enabled and ((self.step_count % self.steps_per_render) == 0):
+ frame = self.env.render(
+ mode=self.mode, **self.render_kwargs)
+ assert frame.dtype == np.uint8
+ self.frames.append(frame)
+ return result
+
+ def render(self, mode='rgb_array', **kwargs):
+ return self.frames
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/action_ae/__init__.py b/third_party/diffusion_policy/diffusion_policy/model/bet/action_ae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a7b88d8e6f8363f5a6d10c2bb5f4d00016d46c0
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/action_ae/__init__.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import abc
+
+from typing import Optional, Union
+
+import diffusion_policy.model.bet.utils as utils
+
+
+class AbstractActionAE(utils.SaveModule, abc.ABC):
+ @abc.abstractmethod
+ def fit_model(
+ self,
+ input_dataloader: DataLoader,
+ eval_dataloader: DataLoader,
+ obs_encoding_net: Optional[nn.Module] = None,
+ ) -> None:
+ pass
+
+ @abc.abstractmethod
+ def encode_into_latent(
+ self,
+ input_action: torch.Tensor,
+ input_rep: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ Given the input action, discretize it.
+
+ Inputs:
+ input_action (shape: ... x action_dim): The input action to discretize. This can be in a batch,
+ and is generally assumed that the last dimnesion is the action dimension.
+
+ Outputs:
+ discretized_action (shape: ... x num_tokens): The discretized action.
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def decode_actions(
+ self,
+ latent_action_batch: Optional[torch.Tensor],
+ input_rep_batch: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Given a discretized action, convert it to a continuous action.
+
+ Inputs:
+ latent_action_batch (shape: ... x num_tokens): The discretized action
+ generated by the discretizer.
+
+ Outputs:
+ continuous_action (shape: ... x action_dim): The continuous action.
+ """
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def num_latents(self) -> Union[int, float]:
+ """
+ Number of possible latents for this generator, useful for state priors that use softmax.
+ """
+ return float("inf")
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/action_ae/discretizers/k_means.py b/third_party/diffusion_policy/diffusion_policy/model/bet/action_ae/discretizers/k_means.py
new file mode 100644
index 0000000000000000000000000000000000000000..9051df083413701ff7f57642761b5b1eab2ed453
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/action_ae/discretizers/k_means.py
@@ -0,0 +1,148 @@
+import torch
+import numpy as np
+
+import tqdm
+
+from typing import Optional, Tuple, Union
+from diffusion_policy.model.common.dict_of_tensor_mixin import DictOfTensorMixin
+
+
+class KMeansDiscretizer(DictOfTensorMixin):
+ """
+ Simplified and modified version of KMeans algorithm from sklearn.
+ """
+
+ def __init__(
+ self,
+ action_dim: int,
+ num_bins: int = 100,
+ predict_offsets: bool = False,
+ ):
+ super().__init__()
+ self.n_bins = num_bins
+ self.action_dim = action_dim
+ self.predict_offsets = predict_offsets
+
+ def fit_discretizer(self, input_actions: torch.Tensor) -> None:
+ assert (
+ self.action_dim == input_actions.shape[-1]
+ ), f"Input action dimension {self.action_dim} does not match fitted model {input_actions.shape[-1]}"
+
+ flattened_actions = input_actions.view(-1, self.action_dim)
+ cluster_centers = KMeansDiscretizer._kmeans(
+ flattened_actions, ncluster=self.n_bins
+ )
+ self.params_dict['bin_centers'] = cluster_centers
+
+ @property
+ def suggested_actions(self) -> torch.Tensor:
+ return self.params_dict['bin_centers']
+
+ @classmethod
+ def _kmeans(cls, x: torch.Tensor, ncluster: int = 512, niter: int = 50):
+ """
+ Simple k-means clustering algorithm adapted from Karpathy's minGPT library
+ https://github.com/karpathy/minGPT/blob/master/play_image.ipynb
+ """
+ N, D = x.size()
+ c = x[torch.randperm(N)[:ncluster]] # init clusters at random
+
+ pbar = tqdm.trange(niter)
+ pbar.set_description("K-means clustering")
+ for i in pbar:
+ # assign all pixels to the closest codebook element
+ a = ((x[:, None, :] - c[None, :, :]) ** 2).sum(-1).argmin(1)
+ # move each codebook element to be the mean of the pixels that assigned to it
+ c = torch.stack([x[a == k].mean(0) for k in range(ncluster)])
+ # re-assign any poorly positioned codebook elements
+ nanix = torch.any(torch.isnan(c), dim=1)
+ ndead = nanix.sum().item()
+ if ndead:
+ tqdm.tqdm.write(
+ "done step %d/%d, re-initialized %d dead clusters"
+ % (i + 1, niter, ndead)
+ )
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
+ return c
+
+ def encode_into_latent(
+ self, input_action: torch.Tensor, input_rep: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Given the input action, discretize it using the k-Means clustering algorithm.
+
+ Inputs:
+ input_action (shape: ... x action_dim): The input action to discretize. This can be in a batch,
+ and is generally assumed that the last dimnesion is the action dimension.
+
+ Outputs:
+ discretized_action (shape: ... x num_tokens): The discretized action.
+ If self.predict_offsets is True, then the offsets are also returned.
+ """
+ assert (
+ input_action.shape[-1] == self.action_dim
+ ), "Input action dimension does not match fitted model"
+
+ # flatten the input action
+ flattened_actions = input_action.view(-1, self.action_dim)
+
+ # get the closest cluster center
+ closest_cluster_center = torch.argmin(
+ torch.sum(
+ (flattened_actions[:, None, :] - self.params_dict['bin_centers'][None, :, :]) ** 2,
+ dim=2,
+ ),
+ dim=1,
+ )
+ # Reshape to the original shape
+ discretized_action = closest_cluster_center.view(input_action.shape[:-1] + (1,))
+
+ if self.predict_offsets:
+ # decode from latent and get the difference
+ reconstructed_action = self.decode_actions(discretized_action)
+ offsets = input_action - reconstructed_action
+ return (discretized_action, offsets)
+ else:
+ # return the one-hot vector
+ return discretized_action
+
+ def decode_actions(
+ self,
+ latent_action_batch: torch.Tensor,
+ input_rep_batch: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Given the latent action, reconstruct the original action.
+
+ Inputs:
+ latent_action (shape: ... x 1): The latent action to reconstruct. This can be in a batch,
+ and is generally assumed that the last dimension is the action dimension. If the latent_action_batch
+ is a tuple, then it is assumed to be (discretized_action, offsets).
+
+ Outputs:
+ reconstructed_action (shape: ... x action_dim): The reconstructed action.
+ """
+ offsets = None
+ if type(latent_action_batch) == tuple:
+ latent_action_batch, offsets = latent_action_batch
+ # get the closest cluster center
+ closest_cluster_center = self.params_dict['bin_centers'][latent_action_batch]
+ # Reshape to the original shape
+ reconstructed_action = closest_cluster_center.view(
+ latent_action_batch.shape[:-1] + (self.action_dim,)
+ )
+ if offsets is not None:
+ reconstructed_action += offsets
+ return reconstructed_action
+
+ @property
+ def discretized_space(self) -> int:
+ return self.n_bins
+
+ @property
+ def latent_dim(self) -> int:
+ return 1
+
+ @property
+ def num_latents(self) -> int:
+ return self.n_bins
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/latent_generator.py b/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/latent_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..072bdc86abe3e303ea6d8563071a41d86142b957
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/latent_generator.py
@@ -0,0 +1,71 @@
+import abc
+import torch
+from typing import Tuple, Optional
+
+import diffusion_policy.model.bet.utils as utils
+
+
+class AbstractLatentGenerator(abc.ABC, utils.SaveModule):
+ """
+ Abstract class for a generative model that can generate latents given observation representations.
+
+ In the probabilisitc sense, this model fits and samples from P(latent|observation) given some observation.
+ """
+
+ @abc.abstractmethod
+ def get_latent_and_loss(
+ self,
+ obs_rep: torch.Tensor,
+ target_latents: torch.Tensor,
+ seq_masks: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Given a set of observation representation and generated latents, get the encoded latent and the loss.
+
+ Inputs:
+ input_action: Batch of the actions taken in the multimodal demonstrations.
+ target_latents: Batch of the latents that the generator should learn to generate the actions from.
+ seq_masks: Batch of masks that indicate which timesteps are valid.
+
+ Outputs:
+ latent: The sampled latent from the observation.
+ loss: The loss of the latent generator.
+ """
+ pass
+
+ @abc.abstractmethod
+ def generate_latents(
+ self, seq_obses: torch.Tensor, seq_masks: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Given a batch of sequences of observations, generate a batch of sequences of latents.
+
+ Inputs:
+ seq_obses: Batch of sequences of observations, of shape seq x batch x dim, following the transformer convention.
+ seq_masks: Batch of sequences of masks, of shape seq x batch, following the transformer convention.
+
+ Outputs:
+ seq_latents: Batch of sequences of latents of shape seq x batch x latent_dim.
+ """
+ pass
+
+ def get_optimizer(
+ self, weight_decay: float, learning_rate: float, betas: Tuple[float, float]
+ ) -> torch.optim.Optimizer:
+ """
+ Default optimizer class. Override this if you want to use a different optimizer.
+ """
+ return torch.optim.Adam(
+ self.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas
+ )
+
+
+class LatentGeneratorDataParallel(torch.nn.DataParallel):
+ def get_latent_and_loss(self, *args, **kwargs):
+ return self.module.get_latent_and_loss(*args, **kwargs) # type: ignore
+
+ def generate_latents(self, *args, **kwargs):
+ return self.module.generate_latents(*args, **kwargs) # type: ignore
+
+ def get_optimizer(self, *args, **kwargs):
+ return self.module.get_optimizer(*args, **kwargs) # type: ignore
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/mingpt.py b/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/mingpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..4836b2791c070054dfaa68f0a76d27723dd2488f
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/mingpt.py
@@ -0,0 +1,188 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+import diffusion_policy.model.bet.latent_generators.latent_generator as latent_generator
+
+import diffusion_policy.model.bet.libraries.mingpt.model as mingpt_model
+import diffusion_policy.model.bet.libraries.mingpt.trainer as mingpt_trainer
+from diffusion_policy.model.bet.libraries.loss_fn import FocalLoss, soft_cross_entropy
+
+from typing import Optional, Tuple
+
+
+class MinGPT(latent_generator.AbstractLatentGenerator):
+ def __init__(
+ self,
+ input_dim: int,
+ n_layer: int = 12,
+ n_head: int = 12,
+ n_embd: int = 768,
+ embd_pdrop: float = 0.1,
+ resid_pdrop: float = 0.1,
+ attn_pdrop: float = 0.1,
+ block_size: int = 128,
+ vocab_size: int = 50257,
+ latent_dim: int = 768, # Ignore, used for compatibility with other models.
+ action_dim: int = 0,
+ discrete_input: bool = False,
+ predict_offsets: bool = False,
+ offset_loss_scale: float = 1.0,
+ focal_loss_gamma: float = 0.0,
+ **kwargs
+ ):
+ super().__init__()
+ self.input_size = input_dim
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_embd = n_embd
+ self.embd_pdrop = embd_pdrop
+ self.resid_pdrop = resid_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.block_size = block_size
+ self.vocab_size = vocab_size
+ self.action_dim = action_dim
+ self.predict_offsets = predict_offsets
+ self.offset_loss_scale = offset_loss_scale
+ self.focal_loss_gamma = focal_loss_gamma
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ gpt_config = mingpt_model.GPTConfig(
+ input_size=self.input_size,
+ vocab_size=self.vocab_size * (1 + self.action_dim)
+ if self.predict_offsets
+ else self.vocab_size,
+ block_size=self.block_size,
+ n_layer=n_layer,
+ n_head=n_head,
+ n_embd=n_embd,
+ discrete_input=discrete_input,
+ embd_pdrop=embd_pdrop,
+ resid_pdrop=resid_pdrop,
+ attn_pdrop=attn_pdrop,
+ )
+
+ self.model = mingpt_model.GPT(gpt_config)
+
+ def get_latent_and_loss(
+ self,
+ obs_rep: torch.Tensor,
+ target_latents: torch.Tensor,
+ seq_masks: Optional[torch.Tensor] = None,
+ return_loss_components: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Unlike torch.transformers, GPT takes in batch x seq_len x embd_dim
+ # obs_rep = einops.rearrange(obs_rep, "seq batch embed -> batch seq embed")
+ # target_latents = einops.rearrange(
+ # target_latents, "seq batch embed -> batch seq embed"
+ # )
+ # While this has been trained autoregressively,
+ # there is no reason why it needs to be so.
+ # We can just use the observation as the input and the next latent as the target.
+ if self.predict_offsets:
+ target_latents, target_offsets = target_latents
+ is_soft_target = (target_latents.shape[-1] == self.vocab_size) and (
+ self.vocab_size != 1
+ )
+ if is_soft_target:
+ target_latents = target_latents.view(-1, target_latents.size(-1))
+ criterion = soft_cross_entropy
+ else:
+ target_latents = target_latents.view(-1)
+ if self.vocab_size == 1:
+ # unify k-means (target_class == 0) and GMM (target_prob == 1)
+ target_latents = torch.zeros_like(target_latents)
+ criterion = FocalLoss(gamma=self.focal_loss_gamma)
+ if self.predict_offsets:
+ output, _ = self.model(obs_rep)
+ logits = output[:, :, : self.vocab_size]
+ offsets = output[:, :, self.vocab_size :]
+ batch = logits.shape[0]
+ seq = logits.shape[1]
+ offsets = einops.rearrange(
+ offsets,
+ "N T (V A) -> (N T) V A", # N = batch, T = seq
+ V=self.vocab_size,
+ A=self.action_dim,
+ )
+ # calculate (optionally soft) cross entropy and offset losses
+ class_loss = criterion(logits.view(-1, logits.size(-1)), target_latents)
+ # offset loss is only calculated on the target class
+ # if soft targets, argmax is considered the target class
+ selected_offsets = offsets[
+ torch.arange(offsets.size(0)),
+ target_latents.argmax(dim=-1).view(-1)
+ if is_soft_target
+ else target_latents.view(-1),
+ ]
+ offset_loss = self.offset_loss_scale * F.mse_loss(
+ selected_offsets, target_offsets.view(-1, self.action_dim)
+ )
+ loss = offset_loss + class_loss
+ logits = einops.rearrange(logits, "batch seq classes -> seq batch classes")
+ offsets = einops.rearrange(
+ offsets,
+ "(N T) V A -> T N V A", # ? N, T order? Anyway does not affect loss and training (might affect visualization)
+ N=batch,
+ T=seq,
+ )
+ if return_loss_components:
+ return (
+ (logits, offsets),
+ loss,
+ {"offset": offset_loss, "class": class_loss, "total": loss},
+ )
+ else:
+ return (logits, offsets), loss
+ else:
+ logits, _ = self.model(obs_rep)
+ loss = criterion(logits.view(-1, logits.size(-1)), target_latents)
+ logits = einops.rearrange(
+ logits, "batch seq classes -> seq batch classes"
+ ) # ? N, T order? Anyway does not affect loss and training (might affect visualization)
+ if return_loss_components:
+ return logits, loss, {"class": loss, "total": loss}
+ else:
+ return logits, loss
+
+ def generate_latents(
+ self, obs_rep: torch.Tensor
+ ) -> torch.Tensor:
+ batch, seq, embed = obs_rep.shape
+
+ output, _ = self.model(obs_rep, None)
+ if self.predict_offsets:
+ logits = output[:, :, : self.vocab_size]
+ offsets = output[:, :, self.vocab_size :]
+ offsets = einops.rearrange(
+ offsets,
+ "N T (V A) -> (N T) V A", # N = batch, T = seq
+ V=self.vocab_size,
+ A=self.action_dim,
+ )
+ else:
+ logits = output
+ probs = F.softmax(logits, dim=-1)
+ batch, seq, choices = probs.shape
+ # Sample from the multinomial distribution, one per row.
+ sampled_data = torch.multinomial(probs.view(-1, choices), num_samples=1)
+ sampled_data = einops.rearrange(
+ sampled_data, "(batch seq) 1 -> batch seq 1", batch=batch, seq=seq
+ )
+ if self.predict_offsets:
+ sampled_offsets = offsets[
+ torch.arange(offsets.shape[0]), sampled_data.flatten()
+ ].view(batch, seq, self.action_dim)
+
+ return (sampled_data, sampled_offsets)
+ else:
+ return sampled_data
+
+ def get_optimizer(
+ self, weight_decay: float, learning_rate: float, betas: Tuple[float, float]
+ ) -> torch.optim.Optimizer:
+ trainer_cfg = mingpt_trainer.TrainerConfig(
+ weight_decay=weight_decay, learning_rate=learning_rate, betas=betas
+ )
+ return self.model.configure_optimizers(trainer_cfg)
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/transformer.py b/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..76af7ccedfe7db8fcc9f2802b41de780fd088d2c
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/latent_generators/transformer.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+import diffusion_policy.model.bet.latent_generators.latent_generator as latent_generator
+
+from diffusion_policy.model.diffusion.transformer_for_diffusion import TransformerForDiffusion
+from diffusion_policy.model.bet.libraries.loss_fn import FocalLoss, soft_cross_entropy
+
+from typing import Optional, Tuple
+
+class Transformer(latent_generator.AbstractLatentGenerator):
+ def __init__(
+ self,
+ input_dim: int,
+ num_bins: int,
+ action_dim: int,
+ horizon: int,
+ focal_loss_gamma: float,
+ offset_loss_scale: float,
+ **kwargs
+ ):
+ super().__init__()
+ self.model = TransformerForDiffusion(
+ input_dim=input_dim,
+ output_dim=num_bins * (1 + action_dim),
+ horizon=horizon,
+ **kwargs
+ )
+ self.vocab_size = num_bins
+ self.focal_loss_gamma = focal_loss_gamma
+ self.offset_loss_scale = offset_loss_scale
+ self.action_dim = action_dim
+
+ def get_optimizer(self, **kwargs) -> torch.optim.Optimizer:
+ return self.model.configure_optimizers(**kwargs)
+
+ def get_latent_and_loss(self,
+ obs_rep: torch.Tensor,
+ target_latents: torch.Tensor,
+ return_loss_components=True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ target_latents, target_offsets = target_latents
+ target_latents = target_latents.view(-1)
+ criterion = FocalLoss(gamma=self.focal_loss_gamma)
+
+ t = torch.tensor(0, device=self.model.device)
+ output = self.model(obs_rep, t)
+ logits = output[:, :, : self.vocab_size]
+ offsets = output[:, :, self.vocab_size :]
+ batch = logits.shape[0]
+ seq = logits.shape[1]
+ offsets = einops.rearrange(
+ offsets,
+ "N T (V A) -> (N T) V A", # N = batch, T = seq
+ V=self.vocab_size,
+ A=self.action_dim,
+ )
+ # calculate (optionally soft) cross entropy and offset losses
+ class_loss = criterion(logits.view(-1, logits.size(-1)), target_latents)
+ # offset loss is only calculated on the target class
+ # if soft targets, argmax is considered the target class
+ selected_offsets = offsets[
+ torch.arange(offsets.size(0)),
+ target_latents.view(-1),
+ ]
+ offset_loss = self.offset_loss_scale * F.mse_loss(
+ selected_offsets, target_offsets.view(-1, self.action_dim)
+ )
+ loss = offset_loss + class_loss
+ logits = einops.rearrange(logits, "batch seq classes -> seq batch classes")
+ offsets = einops.rearrange(
+ offsets,
+ "(N T) V A -> T N V A", # ? N, T order? Anyway does not affect loss and training (might affect visualization)
+ N=batch,
+ T=seq,
+ )
+ return (
+ (logits, offsets),
+ loss,
+ {"offset": offset_loss, "class": class_loss, "total": loss},
+ )
+
+ def generate_latents(
+ self, obs_rep: torch.Tensor
+ ) -> torch.Tensor:
+ t = torch.tensor(0, device=self.model.device)
+ output = self.model(obs_rep, t)
+ logits = output[:, :, : self.vocab_size]
+ offsets = output[:, :, self.vocab_size :]
+ offsets = einops.rearrange(
+ offsets,
+ "N T (V A) -> (N T) V A", # N = batch, T = seq
+ V=self.vocab_size,
+ A=self.action_dim,
+ )
+
+ probs = F.softmax(logits, dim=-1)
+ batch, seq, choices = probs.shape
+ # Sample from the multinomial distribution, one per row.
+ sampled_data = torch.multinomial(probs.view(-1, choices), num_samples=1)
+ sampled_data = einops.rearrange(
+ sampled_data, "(batch seq) 1 -> batch seq 1", batch=batch, seq=seq
+ )
+ sampled_offsets = offsets[
+ torch.arange(offsets.shape[0]), sampled_data.flatten()
+ ].view(batch, seq, self.action_dim)
+ return (sampled_data, sampled_offsets)
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/loss_fn.py b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/loss_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7873ae1b483a711087e1483c3637853bf50cc4c6
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/loss_fn.py
@@ -0,0 +1,168 @@
+from typing import Optional, Sequence
+
+import torch
+from torch import Tensor
+from torch import nn
+from torch.nn import functional as F
+
+# Reference: https://github.com/pytorch/pytorch/issues/11959
+def soft_cross_entropy(
+ input: torch.Tensor,
+ target: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Args:
+ input: (batch_size, num_classes): tensor of raw logits
+ target: (batch_size, num_classes): tensor of class probability; sum(target) == 1
+
+ Returns:
+ loss: (batch_size,)
+ """
+ log_probs = torch.log_softmax(input, dim=-1)
+ # target is a distribution
+ loss = F.kl_div(log_probs, target, reduction="batchmean")
+ return loss
+
+
+# Focal loss implementation
+# Source: https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
+# MIT License
+#
+# Copyright (c) 2020 Adeel Hassan
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+class FocalLoss(nn.Module):
+ """Focal Loss, as described in https://arxiv.org/abs/1708.02002.
+ It is essentially an enhancement to cross entropy loss and is
+ useful for classification tasks when there is a large class imbalance.
+ x is expected to contain raw, unnormalized scores for each class.
+ y is expected to contain class labels.
+ Shape:
+ - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
+ - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
+ """
+
+ def __init__(
+ self,
+ alpha: Optional[Tensor] = None,
+ gamma: float = 0.0,
+ reduction: str = "mean",
+ ignore_index: int = -100,
+ ):
+ """Constructor.
+ Args:
+ alpha (Tensor, optional): Weights for each class. Defaults to None.
+ gamma (float, optional): A constant, as described in the paper.
+ Defaults to 0.
+ reduction (str, optional): 'mean', 'sum' or 'none'.
+ Defaults to 'mean'.
+ ignore_index (int, optional): class label to ignore.
+ Defaults to -100.
+ """
+ if reduction not in ("mean", "sum", "none"):
+ raise ValueError('Reduction must be one of: "mean", "sum", "none".')
+
+ super().__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+ self.ignore_index = ignore_index
+ self.reduction = reduction
+
+ self.nll_loss = nn.NLLLoss(
+ weight=alpha, reduction="none", ignore_index=ignore_index
+ )
+
+ def __repr__(self):
+ arg_keys = ["alpha", "gamma", "ignore_index", "reduction"]
+ arg_vals = [self.__dict__[k] for k in arg_keys]
+ arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)]
+ arg_str = ", ".join(arg_strs)
+ return f"{type(self).__name__}({arg_str})"
+
+ def forward(self, x: Tensor, y: Tensor) -> Tensor:
+ if x.ndim > 2:
+ # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
+ c = x.shape[1]
+ x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
+ # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
+ y = y.view(-1)
+
+ unignored_mask = y != self.ignore_index
+ y = y[unignored_mask]
+ if len(y) == 0:
+ return 0.0
+ x = x[unignored_mask]
+
+ # compute weighted cross entropy term: -alpha * log(pt)
+ # (alpha is already part of self.nll_loss)
+ log_p = F.log_softmax(x, dim=-1)
+ ce = self.nll_loss(log_p, y)
+
+ # get true class column from each row
+ all_rows = torch.arange(len(x))
+ log_pt = log_p[all_rows, y]
+
+ # compute focal term: (1 - pt)^gamma
+ pt = log_pt.exp()
+ focal_term = (1 - pt) ** self.gamma
+
+ # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
+ loss = focal_term * ce
+
+ if self.reduction == "mean":
+ loss = loss.mean()
+ elif self.reduction == "sum":
+ loss = loss.sum()
+
+ return loss
+
+
+def focal_loss(
+ alpha: Optional[Sequence] = None,
+ gamma: float = 0.0,
+ reduction: str = "mean",
+ ignore_index: int = -100,
+ device="cpu",
+ dtype=torch.float32,
+) -> FocalLoss:
+ """Factory function for FocalLoss.
+ Args:
+ alpha (Sequence, optional): Weights for each class. Will be converted
+ to a Tensor if not None. Defaults to None.
+ gamma (float, optional): A constant, as described in the paper.
+ Defaults to 0.
+ reduction (str, optional): 'mean', 'sum' or 'none'.
+ Defaults to 'mean'.
+ ignore_index (int, optional): class label to ignore.
+ Defaults to -100.
+ device (str, optional): Device to move alpha to. Defaults to 'cpu'.
+ dtype (torch.dtype, optional): dtype to cast alpha to.
+ Defaults to torch.float32.
+ Returns:
+ A FocalLoss object
+ """
+ if alpha is not None:
+ if not isinstance(alpha, Tensor):
+ alpha = torch.tensor(alpha)
+ alpha = alpha.to(device=device, dtype=dtype)
+
+ fl = FocalLoss(
+ alpha=alpha, gamma=gamma, reduction=reduction, ignore_index=ignore_index
+ )
+ return fl
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/LICENSE b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..87bfb153c9a87537d1e21114b9fcaacdebe6a761
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/LICENSE
@@ -0,0 +1,8 @@
+The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/__init__.py b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/model.py b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..06308e3919b7f7d2b968551ac613ccac4e2df7ff
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/model.py
@@ -0,0 +1,250 @@
+"""
+GPT model:
+- the initial stem consists of a combination of token encoding and a positional encoding
+- the meat of it is a uniform sequence of Transformer blocks
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
+ - all blocks feed into a central residual pathway similar to resnets
+- the final decoder is a linear projection into a vanilla Softmax classifier
+"""
+
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+logger = logging.getLogger(__name__)
+
+
+class GPTConfig:
+ """base GPT config, params common to all GPT versions"""
+
+ embd_pdrop = 0.1
+ resid_pdrop = 0.1
+ attn_pdrop = 0.1
+ discrete_input = False
+ input_size = 10
+ n_embd = 768
+ n_layer = 12
+
+ def __init__(self, vocab_size, block_size, **kwargs):
+ self.vocab_size = vocab_size
+ self.block_size = block_size
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+
+class GPT1Config(GPTConfig):
+ """GPT-1 like network roughly 125M params"""
+
+ n_layer = 12
+ n_head = 12
+ n_embd = 768
+
+
+class CausalSelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "mask",
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
+ 1, 1, config.block_size, config.block_size
+ ),
+ )
+ self.n_head = config.n_head
+
+ def forward(self, x):
+ (
+ B,
+ T,
+ C,
+ ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = (
+ self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
+ ) # (B, nh, T, hs)
+ q = (
+ self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
+ ) # (B, nh, T, hs)
+ v = (
+ self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
+ ) # (B, nh, T, hs)
+
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
+ att = F.softmax(att, dim=-1)
+ att = self.attn_drop(att)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = (
+ y.transpose(1, 2).contiguous().view(B, T, C)
+ ) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+ return y
+
+
+class Block(nn.Module):
+ """an unassuming Transformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+ self.mlp = nn.Sequential(
+ nn.Linear(config.n_embd, 4 * config.n_embd),
+ nn.GELU(),
+ nn.Linear(4 * config.n_embd, config.n_embd),
+ nn.Dropout(config.resid_pdrop),
+ )
+
+ def forward(self, x):
+ x = x + self.attn(self.ln1(x))
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class GPT(nn.Module):
+ """the full GPT language model, with a context size of block_size"""
+
+ def __init__(self, config: GPTConfig):
+ super().__init__()
+
+ # input embedding stem
+ if config.discrete_input:
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ else:
+ self.tok_emb = nn.Linear(config.input_size, config.n_embd)
+ self.discrete_input = config.discrete_input
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+
+ logger.info(
+ "number of parameters: %e", sum(p.numel() for p in self.parameters())
+ )
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ torch.nn.init.zeros_(module.bias)
+ torch.nn.init.ones_(module.weight)
+ elif isinstance(module, GPT):
+ torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
+
+ def configure_optimizers(self, train_config):
+ """
+ This long function is unfortunately doing something very simple and is being very defensive:
+ We are separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ We are then returning the PyTorch optimizer object.
+ """
+
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear,)
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith("bias"):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add("pos_emb")
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert (
+ len(inter_params) == 0
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
+ assert (
+ len(param_dict.keys() - union_params) == 0
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
+ str(param_dict.keys() - union_params),
+ )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {
+ "params": [param_dict[pn] for pn in sorted(list(decay))],
+ "weight_decay": train_config.weight_decay,
+ },
+ {
+ "params": [param_dict[pn] for pn in sorted(list(no_decay))],
+ "weight_decay": 0.0,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ optim_groups, lr=train_config.learning_rate, betas=train_config.betas
+ )
+ return optimizer
+
+ def forward(self, idx, targets=None):
+ if self.discrete_input:
+ b, t = idx.size()
+ else:
+ b, t, dim = idx.size()
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+ position_embeddings = self.pos_emb[
+ :, :t, :
+ ] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/trainer.py b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..64fec8ee98a2674c92523ff331fbfa7c54c091cd
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/trainer.py
@@ -0,0 +1,162 @@
+"""
+Simple training loop; Boilerplate that could apply to any arbitrary neural network,
+so nothing in this file really has anything to do with GPT specifically.
+"""
+
+import math
+import logging
+
+from tqdm import tqdm
+import numpy as np
+
+import torch
+import torch.optim as optim
+from torch.optim.lr_scheduler import LambdaLR
+from torch.utils.data.dataloader import DataLoader
+
+logger = logging.getLogger(__name__)
+
+
+class TrainerConfig:
+ # optimization parameters
+ max_epochs = 10
+ batch_size = 64
+ learning_rate = 3e-4
+ betas = (0.9, 0.95)
+ grad_norm_clip = 1.0
+ weight_decay = 0.1 # only applied on matmul weights
+ # learning rate decay params: linear warmup followed by cosine decay to 10% of original
+ lr_decay = False
+ warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
+ final_tokens = 260e9 # (at what point we reach 10% of original LR)
+ # checkpoint settings
+ ckpt_path = None
+ num_workers = 0 # for DataLoader
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+
+class Trainer:
+ def __init__(self, model, train_dataset, test_dataset, config):
+ self.model = model
+ self.train_dataset = train_dataset
+ self.test_dataset = test_dataset
+ self.config = config
+
+ # take over whatever gpus are on the system
+ self.device = "cpu"
+ if torch.cuda.is_available():
+ self.device = torch.cuda.current_device()
+ self.model = torch.nn.DataParallel(self.model).to(self.device)
+
+ def save_checkpoint(self):
+ # DataParallel wrappers keep raw model object in .module attribute
+ raw_model = self.model.module if hasattr(self.model, "module") else self.model
+ logger.info("saving %s", self.config.ckpt_path)
+ torch.save(raw_model.state_dict(), self.config.ckpt_path)
+
+ def train(self):
+ model, config = self.model, self.config
+ raw_model = model.module if hasattr(self.model, "module") else model
+ optimizer = raw_model.configure_optimizers(config)
+
+ def run_epoch(loader, is_train):
+ model.train(is_train)
+
+ losses = []
+ pbar = (
+ tqdm(enumerate(loader), total=len(loader))
+ if is_train
+ else enumerate(loader)
+ )
+ for it, (x, y) in pbar:
+
+ # place data on the correct device
+ x = x.to(self.device)
+ y = y.to(self.device)
+
+ # forward the model
+ with torch.set_grad_enabled(is_train):
+ logits, loss = model(x, y)
+ loss = (
+ loss.mean()
+ ) # collapse all losses if they are scattered on multiple gpus
+ losses.append(loss.item())
+
+ if is_train:
+
+ # backprop and update the parameters
+ model.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(
+ model.parameters(), config.grad_norm_clip
+ )
+ optimizer.step()
+
+ # decay the learning rate based on our progress
+ if config.lr_decay:
+ self.tokens += (
+ y >= 0
+ ).sum() # number of tokens processed this step (i.e. label is not -100)
+ if self.tokens < config.warmup_tokens:
+ # linear warmup
+ lr_mult = float(self.tokens) / float(
+ max(1, config.warmup_tokens)
+ )
+ else:
+ # cosine learning rate decay
+ progress = float(
+ self.tokens - config.warmup_tokens
+ ) / float(
+ max(1, config.final_tokens - config.warmup_tokens)
+ )
+ lr_mult = max(
+ 0.1, 0.5 * (1.0 + math.cos(math.pi * progress))
+ )
+ lr = config.learning_rate * lr_mult
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+ else:
+ lr = config.learning_rate
+
+ # report progress
+ pbar.set_description( # type: ignore
+ f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}"
+ )
+
+ if not is_train:
+ test_loss = float(np.mean(losses))
+ logger.info("test loss: %f", test_loss)
+ return test_loss
+
+ best_loss = float("inf")
+ self.tokens = 0 # counter used for learning rate decay
+
+ train_loader = DataLoader(
+ self.train_dataset,
+ shuffle=True,
+ pin_memory=True,
+ batch_size=config.batch_size,
+ num_workers=config.num_workers,
+ )
+ if self.test_dataset is not None:
+ test_loader = DataLoader(
+ self.test_dataset,
+ shuffle=True,
+ pin_memory=True,
+ batch_size=config.batch_size,
+ num_workers=config.num_workers,
+ )
+
+ for epoch in range(config.max_epochs):
+ run_epoch(train_loader, is_train=True)
+ if self.test_dataset is not None:
+ test_loss = run_epoch(test_loader, is_train=False)
+
+ # supports early stopping based on the test loss, or just save always if no test set is provided
+ good_model = self.test_dataset is None or test_loss < best_loss
+ if self.config.ckpt_path is not None and good_model:
+ best_loss = test_loss
+ self.save_checkpoint()
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/utils.py b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..620b9d8c376354c72849e508f9cbe1dfc74aa7fc
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/libraries/mingpt/utils.py
@@ -0,0 +1,52 @@
+import random
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def top_k_logits(logits, k):
+ v, ix = torch.topk(logits, k)
+ out = logits.clone()
+ out[out < v[:, [-1]]] = -float("Inf")
+ return out
+
+
+@torch.no_grad()
+def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
+ """
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
+ of block_size, unlike an RNN that has an infinite context window.
+ """
+ block_size = model.get_block_size()
+ model.eval()
+ for k in range(steps):
+ x_cond = (
+ x if x.size(1) <= block_size else x[:, -block_size:]
+ ) # crop context if needed
+ logits, _ = model(x_cond)
+ # pluck the logits at the final step and scale by temperature
+ logits = logits[:, -1, :] / temperature
+ # optionally crop probabilities to only the top k options
+ if top_k is not None:
+ logits = top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = F.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ # append to the sequence and continue
+ x = torch.cat((x, ix), dim=1)
+
+ return x
diff --git a/third_party/diffusion_policy/diffusion_policy/model/bet/utils.py b/third_party/diffusion_policy/diffusion_policy/model/bet/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e021a0b22a735ac8f7d6e20e835bdd0429734e59
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/bet/utils.py
@@ -0,0 +1,131 @@
+import os
+import random
+from collections import OrderedDict
+from typing import List, Optional
+
+import einops
+import numpy as np
+import torch
+import torch.nn as nn
+
+from torch.utils.data import random_split
+import wandb
+
+
+def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
+ if hidden_depth == 0:
+ mods = [nn.Linear(input_dim, output_dim)]
+ else:
+ mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
+ for i in range(hidden_depth - 1):
+ mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
+ mods.append(nn.Linear(hidden_dim, output_dim))
+ if output_mod is not None:
+ mods.append(output_mod)
+ trunk = nn.Sequential(*mods)
+ return trunk
+
+
+class eval_mode:
+ def __init__(self, *models, no_grad=False):
+ self.models = models
+ self.no_grad = no_grad
+ self.no_grad_context = torch.no_grad()
+
+ def __enter__(self):
+ self.prev_states = []
+ for model in self.models:
+ self.prev_states.append(model.training)
+ model.train(False)
+ if self.no_grad:
+ self.no_grad_context.__enter__()
+
+ def __exit__(self, *args):
+ if self.no_grad:
+ self.no_grad_context.__exit__(*args)
+ for model, state in zip(self.models, self.prev_states):
+ model.train(state)
+ return False
+
+
+def freeze_module(module: nn.Module) -> nn.Module:
+ for param in module.parameters():
+ param.requires_grad = False
+ module.eval()
+ return module
+
+
+def set_seed_everywhere(seed):
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def shuffle_along_axis(a, axis):
+ idx = np.random.rand(*a.shape).argsort(axis=axis)
+ return np.take_along_axis(a, idx, axis=axis)
+
+
+def transpose_batch_timestep(*args):
+ return (einops.rearrange(arg, "b t ... -> t b ...") for arg in args)
+
+
+class TrainWithLogger:
+ def reset_log(self):
+ self.log_components = OrderedDict()
+
+ def log_append(self, log_key, length, loss_components):
+ for key, value in loss_components.items():
+ key_name = f"{log_key}/{key}"
+ count, sum = self.log_components.get(key_name, (0, 0.0))
+ self.log_components[key_name] = (
+ count + length,
+ sum + (length * value.detach().cpu().item()),
+ )
+
+ def flush_log(self, epoch, iterator=None):
+ log_components = OrderedDict()
+ iterator_log_component = OrderedDict()
+ for key, value in self.log_components.items():
+ count, sum = value
+ to_log = sum / count
+ log_components[key] = to_log
+ # Set the iterator status
+ log_key, name_key = key.split("/")
+ iterator_log_name = f"{log_key[0]}{name_key[0]}".upper()
+ iterator_log_component[iterator_log_name] = to_log
+ postfix = ",".join(
+ "{}:{:.2e}".format(key, iterator_log_component[key])
+ for key in iterator_log_component.keys()
+ )
+ if iterator is not None:
+ iterator.set_postfix_str(postfix)
+ wandb.log(log_components, step=epoch)
+ self.log_components = OrderedDict()
+
+
+class SaveModule(nn.Module):
+ def set_snapshot_path(self, path):
+ self.snapshot_path = path
+ print(f"Setting snapshot path to {self.snapshot_path}")
+
+ def save_snapshot(self):
+ os.makedirs(self.snapshot_path, exist_ok=True)
+ torch.save(self.state_dict(), self.snapshot_path / "snapshot.pth")
+
+ def load_snapshot(self):
+ self.load_state_dict(torch.load(self.snapshot_path / "snapshot.pth"))
+
+
+def split_datasets(dataset, train_fraction=0.95, random_seed=42):
+ dataset_length = len(dataset)
+ lengths = [
+ int(train_fraction * dataset_length),
+ dataset_length - int(train_fraction * dataset_length),
+ ]
+ train_set, val_set = random_split(
+ dataset, lengths, generator=torch.Generator().manual_seed(random_seed)
+ )
+ return train_set, val_set
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/dict_of_tensor_mixin.py b/third_party/diffusion_policy/diffusion_policy/model/common/dict_of_tensor_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d08bc13561d43079fcb62050d2ad7c3a94b4c18
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/dict_of_tensor_mixin.py
@@ -0,0 +1,38 @@
+import torch
+import torch.nn as nn
+
+class DictOfTensorMixin(nn.Module):
+ def __init__(self, params_dict=None):
+ super().__init__()
+ if params_dict is None:
+ params_dict = nn.ParameterDict()
+ self.params_dict = params_dict
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+ def dfs_add(dest, keys, value: torch.Tensor):
+ if len(keys) == 1:
+ dest[keys[0]] = value
+ return
+
+ if keys[0] not in dest:
+ dest[keys[0]] = nn.ParameterDict()
+ dfs_add(dest[keys[0]], keys[1:], value)
+
+ def load_dict(state_dict, prefix):
+ out_dict = nn.ParameterDict()
+ for key, value in state_dict.items():
+ value: torch.Tensor
+ if key.startswith(prefix):
+ param_keys = key[len(prefix):].split('.')[1:]
+ # if len(param_keys) == 0:
+ # import pdb; pdb.set_trace()
+ dfs_add(out_dict, param_keys, value.clone())
+ return out_dict
+
+ self.params_dict = load_dict(state_dict, prefix + 'params_dict')
+ self.params_dict.requires_grad_(False)
+ return
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/lr_scheduler.py b/third_party/diffusion_policy/diffusion_policy/model/common/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4c053363788826e65e894a33d3e52651a41d864
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/lr_scheduler.py
@@ -0,0 +1,46 @@
+from diffusers.optimization import (
+ Union, SchedulerType, Optional,
+ Optimizer, TYPE_TO_SCHEDULER_FUNCTION
+)
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ **kwargs
+):
+ """
+ Added kwargs vs diffuser's original implementation
+
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer, **kwargs)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs)
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/module_attr_mixin.py b/third_party/diffusion_policy/diffusion_policy/model/common/module_attr_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cbdf709927984de04126c30aed349e38df9f85b
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/module_attr_mixin.py
@@ -0,0 +1,14 @@
+import torch.nn as nn
+
+class ModuleAttrMixin(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._dummy_variable = nn.Parameter()
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+ @property
+ def dtype(self):
+ return next(iter(self.parameters())).dtype
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/normalizer.py b/third_party/diffusion_policy/diffusion_policy/model/common/normalizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26469290e7670d435257a1215f77d8e295cd8df
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/normalizer.py
@@ -0,0 +1,353 @@
+from typing import Union, Dict
+
+import unittest
+import zarr
+import numpy as np
+import torch
+import torch.nn as nn
+from diffusion_policy.common.pytorch_util import dict_apply
+from diffusion_policy.model.common.dict_of_tensor_mixin import DictOfTensorMixin
+
+
+class LinearNormalizer(DictOfTensorMixin):
+ avaliable_modes = ['limits', 'gaussian']
+
+ @torch.no_grad()
+ def fit(self,
+ data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
+ last_n_dims=1,
+ dtype=torch.float32,
+ mode='limits',
+ output_max=1.,
+ output_min=-1.,
+ range_eps=1e-4,
+ fit_offset=True):
+ if isinstance(data, dict):
+ for key, value in data.items():
+ self.params_dict[key] = _fit(value,
+ last_n_dims=last_n_dims,
+ dtype=dtype,
+ mode=mode,
+ output_max=output_max,
+ output_min=output_min,
+ range_eps=range_eps,
+ fit_offset=fit_offset)
+ else:
+ self.params_dict['_default'] = _fit(data,
+ last_n_dims=last_n_dims,
+ dtype=dtype,
+ mode=mode,
+ output_max=output_max,
+ output_min=output_min,
+ range_eps=range_eps,
+ fit_offset=fit_offset)
+
+ def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
+ return self.normalize(x)
+
+ def __getitem__(self, key: str):
+ return SingleFieldLinearNormalizer(self.params_dict[key])
+
+ def __setitem__(self, key: str , value: 'SingleFieldLinearNormalizer'):
+ self.params_dict[key] = value.params_dict
+
+ def _normalize_impl(self, x, forward=True):
+ if isinstance(x, dict):
+ result = dict()
+ for key, value in x.items():
+ params = self.params_dict[key]
+ result[key] = _normalize(value, params, forward=forward)
+ return result
+ else:
+ if '_default' not in self.params_dict:
+ raise RuntimeError("Not initialized")
+ params = self.params_dict['_default']
+ return _normalize(x, params, forward=forward)
+
+ def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
+ return self._normalize_impl(x, forward=True)
+
+ def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
+ return self._normalize_impl(x, forward=False)
+
+ def get_input_stats(self) -> Dict:
+ if len(self.params_dict) == 0:
+ raise RuntimeError("Not initialized")
+ if len(self.params_dict) == 1 and '_default' in self.params_dict:
+ return self.params_dict['_default']['input_stats']
+
+ result = dict()
+ for key, value in self.params_dict.items():
+ if key != '_default':
+ result[key] = value['input_stats']
+ return result
+
+
+ def get_output_stats(self, key='_default'):
+ input_stats = self.get_input_stats()
+ if 'min' in input_stats:
+ # no dict
+ return dict_apply(input_stats, self.normalize)
+
+ result = dict()
+ for key, group in input_stats.items():
+ this_dict = dict()
+ for name, value in group.items():
+ this_dict[name] = self.normalize({key:value})[key]
+ result[key] = this_dict
+ return result
+
+
+class SingleFieldLinearNormalizer(DictOfTensorMixin):
+ avaliable_modes = ['limits', 'gaussian']
+
+ @torch.no_grad()
+ def fit(self,
+ data: Union[torch.Tensor, np.ndarray, zarr.Array],
+ last_n_dims=1,
+ dtype=torch.float32,
+ mode='limits',
+ output_max=1.,
+ output_min=-1.,
+ range_eps=1e-4,
+ fit_offset=True):
+ self.params_dict = _fit(data,
+ last_n_dims=last_n_dims,
+ dtype=dtype,
+ mode=mode,
+ output_max=output_max,
+ output_min=output_min,
+ range_eps=range_eps,
+ fit_offset=fit_offset)
+
+ @classmethod
+ def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
+ obj = cls()
+ obj.fit(data, **kwargs)
+ return obj
+
+ @classmethod
+ def create_manual(cls,
+ scale: Union[torch.Tensor, np.ndarray],
+ offset: Union[torch.Tensor, np.ndarray],
+ input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]]):
+ def to_tensor(x):
+ if not isinstance(x, torch.Tensor):
+ x = torch.from_numpy(x)
+ x = x.flatten()
+ return x
+
+ # check
+ for x in [offset] + list(input_stats_dict.values()):
+ assert x.shape == scale.shape
+ assert x.dtype == scale.dtype
+
+ params_dict = nn.ParameterDict({
+ 'scale': to_tensor(scale),
+ 'offset': to_tensor(offset),
+ 'input_stats': nn.ParameterDict(
+ dict_apply(input_stats_dict, to_tensor))
+ })
+ return cls(params_dict)
+
+ @classmethod
+ def create_identity(cls, dtype=torch.float32):
+ scale = torch.tensor([1], dtype=dtype)
+ offset = torch.tensor([0], dtype=dtype)
+ input_stats_dict = {
+ 'min': torch.tensor([-1], dtype=dtype),
+ 'max': torch.tensor([1], dtype=dtype),
+ 'mean': torch.tensor([0], dtype=dtype),
+ 'std': torch.tensor([1], dtype=dtype)
+ }
+ return cls.create_manual(scale, offset, input_stats_dict)
+
+ def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
+ return _normalize(x, self.params_dict, forward=True)
+
+ def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
+ return _normalize(x, self.params_dict, forward=False)
+
+ def get_input_stats(self):
+ return self.params_dict['input_stats']
+
+ def get_output_stats(self):
+ return dict_apply(self.params_dict['input_stats'], self.normalize)
+
+ def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
+ return self.normalize(x)
+
+
+
+def _fit(data: Union[torch.Tensor, np.ndarray, zarr.Array],
+ last_n_dims=1,
+ dtype=torch.float32,
+ mode='limits',
+ output_max=1.,
+ output_min=-1.,
+ range_eps=1e-4,
+ fit_offset=True):
+ assert mode in ['limits', 'gaussian']
+ assert last_n_dims >= 0
+ assert output_max > output_min
+
+ # convert data to torch and type
+ if isinstance(data, zarr.Array):
+ data = data[:]
+ if isinstance(data, np.ndarray):
+ data = torch.from_numpy(data)
+ if dtype is not None:
+ data = data.type(dtype)
+
+ # convert shape
+ dim = 1
+ if last_n_dims > 0:
+ dim = np.prod(data.shape[-last_n_dims:])
+ data = data.reshape(-1,dim)
+
+ # compute input stats min max mean std
+ input_min, _ = data.min(axis=0)
+ input_max, _ = data.max(axis=0)
+ input_mean = data.mean(axis=0)
+ input_std = data.std(axis=0)
+
+ # compute scale and offset
+ if mode == 'limits':
+ if fit_offset:
+ # unit scale
+ input_range = input_max - input_min
+ ignore_dim = input_range < range_eps
+ input_range[ignore_dim] = output_max - output_min
+ scale = (output_max - output_min) / input_range
+ offset = output_min - scale * input_min
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
+ # ignore dims scaled to mean of output max and min
+ else:
+ # use this when data is pre-zero-centered.
+ assert output_max > 0
+ assert output_min < 0
+ # unit abs
+ output_abs = min(abs(output_min), abs(output_max))
+ input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
+ ignore_dim = input_abs < range_eps
+ input_abs[ignore_dim] = output_abs
+ # don't scale constant channels
+ scale = output_abs / input_abs
+ offset = torch.zeros_like(input_mean)
+ elif mode == 'gaussian':
+ ignore_dim = input_std < range_eps
+ scale = input_std.clone()
+ scale[ignore_dim] = 1
+ scale = 1 / scale
+
+ if fit_offset:
+ offset = - input_mean * scale
+ else:
+ offset = torch.zeros_like(input_mean)
+
+ # save
+ this_params = nn.ParameterDict({
+ 'scale': scale,
+ 'offset': offset,
+ 'input_stats': nn.ParameterDict({
+ 'min': input_min,
+ 'max': input_max,
+ 'mean': input_mean,
+ 'std': input_std
+ })
+ })
+ for p in this_params.parameters():
+ p.requires_grad_(False)
+ return this_params
+
+
+def _normalize(x, params, forward=True):
+ assert 'scale' in params
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ scale = params['scale']
+ offset = params['offset']
+ x = x.to(device=scale.device, dtype=scale.dtype)
+ src_shape = x.shape
+ x = x.reshape(-1, scale.shape[0])
+ if forward:
+ x = x * scale + offset
+ else:
+ x = (x - offset) / scale
+ x = x.reshape(src_shape)
+ return x
+
+
+def test():
+ data = torch.zeros((100,10,9,2)).uniform_()
+ data[...,0,0] = 0
+
+ normalizer = SingleFieldLinearNormalizer()
+ normalizer.fit(data, mode='limits', last_n_dims=2)
+ datan = normalizer.normalize(data)
+ assert datan.shape == data.shape
+ assert np.allclose(datan.max(), 1.)
+ assert np.allclose(datan.min(), -1.)
+ dataun = normalizer.unnormalize(datan)
+ assert torch.allclose(data, dataun, atol=1e-7)
+
+ input_stats = normalizer.get_input_stats()
+ output_stats = normalizer.get_output_stats()
+
+ normalizer = SingleFieldLinearNormalizer()
+ normalizer.fit(data, mode='limits', last_n_dims=1, fit_offset=False)
+ datan = normalizer.normalize(data)
+ assert datan.shape == data.shape
+ assert np.allclose(datan.max(), 1., atol=1e-3)
+ assert np.allclose(datan.min(), 0., atol=1e-3)
+ dataun = normalizer.unnormalize(datan)
+ assert torch.allclose(data, dataun, atol=1e-7)
+
+ data = torch.zeros((100,10,9,2)).uniform_()
+ normalizer = SingleFieldLinearNormalizer()
+ normalizer.fit(data, mode='gaussian', last_n_dims=0)
+ datan = normalizer.normalize(data)
+ assert datan.shape == data.shape
+ assert np.allclose(datan.mean(), 0., atol=1e-3)
+ assert np.allclose(datan.std(), 1., atol=1e-3)
+ dataun = normalizer.unnormalize(datan)
+ assert torch.allclose(data, dataun, atol=1e-7)
+
+
+ # dict
+ data = torch.zeros((100,10,9,2)).uniform_()
+ data[...,0,0] = 0
+
+ normalizer = LinearNormalizer()
+ normalizer.fit(data, mode='limits', last_n_dims=2)
+ datan = normalizer.normalize(data)
+ assert datan.shape == data.shape
+ assert np.allclose(datan.max(), 1.)
+ assert np.allclose(datan.min(), -1.)
+ dataun = normalizer.unnormalize(datan)
+ assert torch.allclose(data, dataun, atol=1e-7)
+
+ input_stats = normalizer.get_input_stats()
+ output_stats = normalizer.get_output_stats()
+
+ data = {
+ 'obs': torch.zeros((1000,128,9,2)).uniform_() * 512,
+ 'action': torch.zeros((1000,128,2)).uniform_() * 512
+ }
+ normalizer = LinearNormalizer()
+ normalizer.fit(data)
+ datan = normalizer.normalize(data)
+ dataun = normalizer.unnormalize(datan)
+ for key in data:
+ assert torch.allclose(data[key], dataun[key], atol=1e-4)
+
+ input_stats = normalizer.get_input_stats()
+ output_stats = normalizer.get_output_stats()
+
+ state_dict = normalizer.state_dict()
+ n = LinearNormalizer()
+ n.load_state_dict(state_dict)
+ datan = n.normalize(data)
+ dataun = n.unnormalize(datan)
+ for key in data:
+ assert torch.allclose(data[key], dataun[key], atol=1e-4)
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/rotation_transformer.py b/third_party/diffusion_policy/diffusion_policy/model/common/rotation_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a801d285fae5086adf9f15004a55429b00a9c9da
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/rotation_transformer.py
@@ -0,0 +1,103 @@
+from typing import Union
+import pytorch3d.transforms as pt
+import torch
+import numpy as np
+import functools
+
+class RotationTransformer:
+ valid_reps = [
+ 'axis_angle',
+ 'euler_angles',
+ 'quaternion',
+ 'rotation_6d',
+ 'matrix'
+ ]
+
+ def __init__(self,
+ from_rep='axis_angle',
+ to_rep='rotation_6d',
+ from_convention=None,
+ to_convention=None):
+ """
+ Valid representations
+
+ Always use matrix as intermediate representation.
+ """
+ assert from_rep != to_rep
+ assert from_rep in self.valid_reps
+ assert to_rep in self.valid_reps
+ if from_rep == 'euler_angles':
+ assert from_convention is not None
+ if to_rep == 'euler_angles':
+ assert to_convention is not None
+
+ forward_funcs = list()
+ inverse_funcs = list()
+
+ if from_rep != 'matrix':
+ funcs = [
+ getattr(pt, f'{from_rep}_to_matrix'),
+ getattr(pt, f'matrix_to_{from_rep}')
+ ]
+ if from_convention is not None:
+ funcs = [functools.partial(func, convention=from_convention)
+ for func in funcs]
+ forward_funcs.append(funcs[0])
+ inverse_funcs.append(funcs[1])
+
+ if to_rep != 'matrix':
+ funcs = [
+ getattr(pt, f'matrix_to_{to_rep}'),
+ getattr(pt, f'{to_rep}_to_matrix')
+ ]
+ if to_convention is not None:
+ funcs = [functools.partial(func, convention=to_convention)
+ for func in funcs]
+ forward_funcs.append(funcs[0])
+ inverse_funcs.append(funcs[1])
+
+ inverse_funcs = inverse_funcs[::-1]
+
+ self.forward_funcs = forward_funcs
+ self.inverse_funcs = inverse_funcs
+
+ @staticmethod
+ def _apply_funcs(x: Union[np.ndarray, torch.Tensor], funcs: list) -> Union[np.ndarray, torch.Tensor]:
+ x_ = x
+ if isinstance(x, np.ndarray):
+ x_ = torch.from_numpy(x)
+ x_: torch.Tensor
+ for func in funcs:
+ x_ = func(x_)
+ y = x_
+ if isinstance(x, np.ndarray):
+ y = x_.numpy()
+ return y
+
+ def forward(self, x: Union[np.ndarray, torch.Tensor]
+ ) -> Union[np.ndarray, torch.Tensor]:
+ return self._apply_funcs(x, self.forward_funcs)
+
+ def inverse(self, x: Union[np.ndarray, torch.Tensor]
+ ) -> Union[np.ndarray, torch.Tensor]:
+ return self._apply_funcs(x, self.inverse_funcs)
+
+
+def test():
+ tf = RotationTransformer()
+
+ rotvec = np.random.uniform(-2*np.pi,2*np.pi,size=(1000,3))
+ rot6d = tf.forward(rotvec)
+ new_rotvec = tf.inverse(rot6d)
+
+ from scipy.spatial.transform import Rotation
+ diff = Rotation.from_rotvec(rotvec) * Rotation.from_rotvec(new_rotvec).inv()
+ dist = diff.magnitude()
+ assert dist.max() < 1e-7
+
+ tf = RotationTransformer('rotation_6d', 'matrix')
+ rot6d_wrong = rot6d + np.random.normal(scale=0.1, size=rot6d.shape)
+ mat = tf.forward(rot6d_wrong)
+ mat_det = np.linalg.det(mat)
+ assert np.allclose(mat_det, 1)
+ # rotaiton_6d will be normalized to rotation matrix
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/shape_util.py b/third_party/diffusion_policy/diffusion_policy/model/common/shape_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1786c174e6f22794faa90ef0c3d2a7d29bae873
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/shape_util.py
@@ -0,0 +1,23 @@
+from typing import Dict, List, Tuple, Callable
+import torch
+import torch.nn as nn
+
+def get_module_device(m: nn.Module):
+ device = torch.device('cpu')
+ try:
+ param = next(iter(m.parameters()))
+ device = param.device
+ except StopIteration:
+ pass
+ return device
+
+@torch.no_grad()
+def get_output_shape(
+ input_shape: Tuple[int],
+ net: Callable[[torch.Tensor], torch.Tensor]
+ ):
+ device = get_module_device(net)
+ test_input = torch.zeros((1,)+tuple(input_shape), device=device)
+ test_output = net(test_input)
+ output_shape = tuple(test_output.shape[1:])
+ return output_shape
diff --git a/third_party/diffusion_policy/diffusion_policy/model/common/tensor_util.py b/third_party/diffusion_policy/diffusion_policy/model/common/tensor_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d6cbffc5e8197445cbc1933ccaa2cebe2c5a063
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/common/tensor_util.py
@@ -0,0 +1,960 @@
+"""
+A collection of utilities for working with nested tensor structures consisting
+of numpy arrays and torch tensors.
+"""
+import collections
+import numpy as np
+import torch
+
+
+def recursive_dict_list_tuple_apply(x, type_func_dict):
+ """
+ Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
+ {data_type: function_to_apply}.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ type_func_dict (dict): a mapping from data types to the functions to be
+ applied for each data type.
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ assert(list not in type_func_dict)
+ assert(tuple not in type_func_dict)
+ assert(dict not in type_func_dict)
+
+ if isinstance(x, (dict, collections.OrderedDict)):
+ new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else dict()
+ for k, v in x.items():
+ new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
+ return new_x
+ elif isinstance(x, (list, tuple)):
+ ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
+ if isinstance(x, tuple):
+ ret = tuple(ret)
+ return ret
+ else:
+ for t, f in type_func_dict.items():
+ if isinstance(x, t):
+ return f(x)
+ else:
+ raise NotImplementedError(
+ 'Cannot handle data type %s' % str(type(x)))
+
+
+def map_tensor(x, func):
+ """
+ Apply function @func to torch.Tensor objects in a nested dictionary or
+ list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ func (function): function to apply to each tensor
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: func,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def map_ndarray(x, func):
+ """
+ Apply function @func to np.ndarray objects in a nested dictionary or
+ list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ func (function): function to apply to each array
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ np.ndarray: func,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def map_tensor_ndarray(x, tensor_func, ndarray_func):
+ """
+ Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
+ np.ndarray objects in a nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ tensor_func (function): function to apply to each tensor
+ ndarray_Func (function): function to apply to each array
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: tensor_func,
+ np.ndarray: ndarray_func,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def clone(x):
+ """
+ Clones all torch tensors and numpy arrays in nested dictionary or list
+ or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.clone(),
+ np.ndarray: lambda x: x.copy(),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def detach(x):
+ """
+ Detaches all torch tensors in nested dictionary or list
+ or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.detach(),
+ }
+ )
+
+
+def to_batch(x):
+ """
+ Introduces a leading batch dimension of 1 for all torch tensors and numpy
+ arrays in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x[None, ...],
+ np.ndarray: lambda x: x[None, ...],
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_sequence(x):
+ """
+ Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
+ arrays in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x[:, None, ...],
+ np.ndarray: lambda x: x[:, None, ...],
+ type(None): lambda x: x,
+ }
+ )
+
+
+def index_at_time(x, ind):
+ """
+ Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
+ nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ ind (int): index
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x[:, ind, ...],
+ np.ndarray: lambda x: x[:, ind, ...],
+ type(None): lambda x: x,
+ }
+ )
+
+
+def unsqueeze(x, dim):
+ """
+ Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
+ in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ dim (int): dimension
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.unsqueeze(dim=dim),
+ np.ndarray: lambda x: np.expand_dims(x, axis=dim),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def contiguous(x):
+ """
+ Makes all torch tensors and numpy arrays contiguous in nested dictionary or
+ list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.contiguous(),
+ np.ndarray: lambda x: np.ascontiguousarray(x),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_device(x, device):
+ """
+ Sends all torch tensors in nested dictionary or list or tuple to device
+ @device, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ device (torch.Device): device to send tensors to
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, d=device: x.to(d),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_tensor(x):
+ """
+ Converts all numpy arrays in nested dictionary or list or tuple to
+ torch tensors (and leaves existing torch Tensors as-is), and returns
+ a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x,
+ np.ndarray: lambda x: torch.from_numpy(x),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_numpy(x):
+ """
+ Converts all torch tensors in nested dictionary or list or tuple to
+ numpy (and leaves existing numpy arrays as-is), and returns
+ a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ def f(tensor):
+ if tensor.is_cuda:
+ return tensor.detach().cpu().numpy()
+ else:
+ return tensor.detach().numpy()
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: f,
+ np.ndarray: lambda x: x,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_list(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to a list, and returns a new nested structure. Useful for
+ json encoding.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ def f(tensor):
+ if tensor.is_cuda:
+ return tensor.detach().cpu().numpy().tolist()
+ else:
+ return tensor.detach().numpy().tolist()
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: f,
+ np.ndarray: lambda x: x.tolist(),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_float(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to float type entries, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.float(),
+ np.ndarray: lambda x: x.astype(np.float32),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_uint8(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to uint8 type entries, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.byte(),
+ np.ndarray: lambda x: x.astype(np.uint8),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_torch(x, device):
+ """
+ Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
+ torch tensors on device @device and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ device (torch.Device): device to send tensors to
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return to_device(to_float(to_tensor(x)), device)
+
+
+def to_one_hot_single(tensor, num_class):
+ """
+ Convert tensor to one-hot representation, assuming a certain number of total class labels.
+
+ Args:
+ tensor (torch.Tensor): tensor containing integer labels
+ num_class (int): number of classes
+
+ Returns:
+ x (torch.Tensor): tensor containing one-hot representation of labels
+ """
+ x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
+ x.scatter_(-1, tensor.unsqueeze(-1), 1)
+ return x
+
+
+def to_one_hot(tensor, num_class):
+ """
+ Convert all tensors in nested dictionary or list or tuple to one-hot representation,
+ assuming a certain number of total class labels.
+
+ Args:
+ tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
+ num_class (int): number of classes
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
+
+
+def flatten_single(x, begin_axis=1):
+ """
+ Flatten a tensor in all dimensions from @begin_axis onwards.
+
+ Args:
+ x (torch.Tensor): tensor to flatten
+ begin_axis (int): which axis to flatten from
+
+ Returns:
+ y (torch.Tensor): flattened tensor
+ """
+ fixed_size = x.size()[:begin_axis]
+ _s = list(fixed_size) + [-1]
+ return x.reshape(*_s)
+
+
+def flatten(x, begin_axis=1):
+ """
+ Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): which axis to flatten from
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
+ }
+ )
+
+
+def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
+ """
+ Reshape selected dimensions in a tensor to a target dimension.
+
+ Args:
+ x (torch.Tensor): tensor to reshape
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+ target_dims (tuple or list): target shape for the range of dimensions
+ (@begin_axis, @end_axis)
+
+ Returns:
+ y (torch.Tensor): reshaped tensor
+ """
+ assert(begin_axis <= end_axis)
+ assert(begin_axis >= 0)
+ assert(end_axis < len(x.shape))
+ assert(isinstance(target_dims, (tuple, list)))
+ s = x.shape
+ final_s = []
+ for i in range(len(s)):
+ if i == begin_axis:
+ final_s.extend(target_dims)
+ elif i < begin_axis or i > end_axis:
+ final_s.append(s[i])
+ return x.reshape(*final_s)
+
+
+def reshape_dimensions(x, begin_axis, end_axis, target_dims):
+ """
+ Reshape selected dimensions for all tensors in nested dictionary or list or tuple
+ to a target dimension.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+ target_dims (tuple or list): target shape for the range of dimensions
+ (@begin_axis, @end_axis)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=t),
+ np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=t),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def join_dimensions(x, begin_axis, end_axis):
+ """
+ Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
+ all tensors in nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=[-1]),
+ np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=[-1]),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def expand_at_single(x, size, dim):
+ """
+ Expand a tensor at a single dimension @dim by @size
+
+ Args:
+ x (torch.Tensor): input tensor
+ size (int): size to expand
+ dim (int): dimension to expand
+
+ Returns:
+ y (torch.Tensor): expanded tensor
+ """
+ assert dim < x.ndimension()
+ assert x.shape[dim] == 1
+ expand_dims = [-1] * x.ndimension()
+ expand_dims[dim] = size
+ return x.expand(*expand_dims)
+
+
+def expand_at(x, size, dim):
+ """
+ Expand all tensors in nested dictionary or list or tuple at a single
+ dimension @dim by @size.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size to expand
+ dim (int): dimension to expand
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
+
+
+def unsqueeze_expand_at(x, size, dim):
+ """
+ Unsqueeze and expand a tensor at a dimension @dim by @size.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size to expand
+ dim (int): dimension to unsqueeze and expand
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ x = unsqueeze(x, dim)
+ return expand_at(x, size, dim)
+
+
+def repeat_by_expand_at(x, repeats, dim):
+ """
+ Repeat a dimension by combining expand and reshape operations.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ repeats (int): number of times to repeat the target dimension
+ dim (int): dimension to repeat on
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ x = unsqueeze_expand_at(x, repeats, dim + 1)
+ return join_dimensions(x, dim, dim + 1)
+
+
+def named_reduce_single(x, reduction, dim):
+ """
+ Reduce tensor at a dimension by named reduction functions.
+
+ Args:
+ x (torch.Tensor): tensor to be reduced
+ reduction (str): one of ["sum", "max", "mean", "flatten"]
+ dim (int): dimension to be reduced (or begin axis for flatten)
+
+ Returns:
+ y (torch.Tensor): reduced tensor
+ """
+ assert x.ndimension() > dim
+ assert reduction in ["sum", "max", "mean", "flatten"]
+ if reduction == "flatten":
+ x = flatten(x, begin_axis=dim)
+ elif reduction == "max":
+ x = torch.max(x, dim=dim)[0] # [B, D]
+ elif reduction == "sum":
+ x = torch.sum(x, dim=dim)
+ else:
+ x = torch.mean(x, dim=dim)
+ return x
+
+
+def named_reduce(x, reduction, dim):
+ """
+ Reduces all tensors in nested dictionary or list or tuple at a dimension
+ using a named reduction function.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ reduction (str): one of ["sum", "max", "mean", "flatten"]
+ dim (int): dimension to be reduced (or begin axis for flatten)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
+
+
+def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
+ """
+ This function indexes out a target dimension of a tensor in a structured way,
+ by allowing a different value to be selected for each member of a flat index
+ tensor (@indices) corresponding to a source dimension. This can be interpreted
+ as moving along the source dimension, using the corresponding index value
+ in @indices to select values for all other dimensions outside of the
+ source and target dimensions. A common use case is to gather values
+ in target dimension 1 for each batch member (target dimension 0).
+
+ Args:
+ x (torch.Tensor): tensor to gather values for
+ target_dim (int): dimension to gather values along
+ source_dim (int): dimension to hold constant and use for gathering values
+ from the other dimensions
+ indices (torch.Tensor): flat index tensor with same shape as tensor @x along
+ @source_dim
+
+ Returns:
+ y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
+ """
+ assert len(indices.shape) == 1
+ assert x.shape[source_dim] == indices.shape[0]
+
+ # unsqueeze in all dimensions except the source dimension
+ new_shape = [1] * x.ndimension()
+ new_shape[source_dim] = -1
+ indices = indices.reshape(*new_shape)
+
+ # repeat in all dimensions - but preserve shape of source dimension,
+ # and make sure target_dimension has singleton dimension
+ expand_shape = list(x.shape)
+ expand_shape[source_dim] = -1
+ expand_shape[target_dim] = 1
+ indices = indices.expand(*expand_shape)
+
+ out = x.gather(dim=target_dim, index=indices)
+ return out.squeeze(target_dim)
+
+
+def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
+ """
+ Apply @gather_along_dim_with_dim_single to all tensors in a nested
+ dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ target_dim (int): dimension to gather values along
+ source_dim (int): dimension to hold constant and use for gathering values
+ from the other dimensions
+ indices (torch.Tensor): flat index tensor with same shape as tensor @x along
+ @source_dim
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x,
+ lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i))
+
+
+def gather_sequence_single(seq, indices):
+ """
+ Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
+ the batch given an index for each sequence.
+
+ Args:
+ seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
+ indices (torch.Tensor): tensor indices of shape [B]
+
+ Return:
+ y (torch.Tensor): indexed tensor of shape [B, ....]
+ """
+ return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
+
+
+def gather_sequence(seq, indices):
+ """
+ Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
+ for tensors with leading dimensions [B, T, ...].
+
+ Args:
+ seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ indices (torch.Tensor): tensor indices of shape [B]
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
+ """
+ return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
+
+
+def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
+ """
+ Pad input tensor or array @seq in the time dimension (dimension 1).
+
+ Args:
+ seq (np.ndarray or torch.Tensor): sequence to be padded
+ padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
+ batched (bool): if sequence has the batch dimension
+ pad_same (bool): if pad by duplicating
+ pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
+
+ Returns:
+ padded sequence (np.ndarray or torch.Tensor)
+ """
+ assert isinstance(seq, (np.ndarray, torch.Tensor))
+ assert pad_same or pad_values is not None
+ if pad_values is not None:
+ assert isinstance(pad_values, float)
+ repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
+ concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
+ ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
+ seq_dim = 1 if batched else 0
+
+ begin_pad = []
+ end_pad = []
+
+ if padding[0] > 0:
+ pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
+ begin_pad.append(repeat_func(pad, padding[0], seq_dim))
+ if padding[1] > 0:
+ pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
+ end_pad.append(repeat_func(pad, padding[1], seq_dim))
+
+ return concat_func(begin_pad + [seq] + end_pad, seq_dim)
+
+
+def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
+ """
+ Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
+
+ Args:
+ seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
+ batched (bool): if sequence has the batch dimension
+ pad_same (bool): if pad by duplicating
+ pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
+
+ Returns:
+ padded sequence (dict or list or tuple)
+ """
+ return recursive_dict_list_tuple_apply(
+ seq,
+ {
+ torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
+ pad_sequence_single(x, p, b, ps, pv),
+ np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
+ pad_sequence_single(x, p, b, ps, pv),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def assert_size_at_dim_single(x, size, dim, msg):
+ """
+ Ensure that array or tensor @x has size @size in dim @dim.
+
+ Args:
+ x (np.ndarray or torch.Tensor): input array or tensor
+ size (int): size that tensors should have at @dim
+ dim (int): dimension to check
+ msg (str): text to display if assertion fails
+ """
+ assert x.shape[dim] == size, msg
+
+
+def assert_size_at_dim(x, size, dim, msg):
+ """
+ Ensure that arrays and tensors in nested dictionary or list or tuple have
+ size @size in dim @dim.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size that tensors should have at @dim
+ dim (int): dimension to check
+ """
+ map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
+
+
+def get_shape(x):
+ """
+ Get all shapes of arrays and tensors in nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple that contains each array or
+ tensor's shape
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.shape,
+ np.ndarray: lambda x: x.shape,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def list_of_flat_dict_to_dict_of_list(list_of_dict):
+ """
+ Helper function to go from a list of flat dictionaries to a dictionary of lists.
+ By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
+ floats, etc.
+
+ Args:
+ list_of_dict (list): list of flat dictionaries
+
+ Returns:
+ dict_of_list (dict): dictionary of lists
+ """
+ assert isinstance(list_of_dict, list)
+ dic = collections.OrderedDict()
+ for i in range(len(list_of_dict)):
+ for k in list_of_dict[i]:
+ if k not in dic:
+ dic[k] = []
+ dic[k].append(list_of_dict[i][k])
+ return dic
+
+
+def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
+ """
+ Flatten a nested dict or list to a list.
+
+ For example, given a dict
+ {
+ a: 1
+ b: {
+ c: 2
+ }
+ c: 3
+ }
+
+ the function would return [(a, 1), (b_c, 2), (c, 3)]
+
+ Args:
+ d (dict, list): a nested dict or list to be flattened
+ parent_key (str): recursion helper
+ sep (str): separator for nesting keys
+ item_key (str): recursion helper
+ Returns:
+ list: a list of (key, value) tuples
+ """
+ items = []
+ if isinstance(d, (tuple, list)):
+ new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
+ for i, v in enumerate(d):
+ items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
+ return items
+ elif isinstance(d, dict):
+ new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
+ for k, v in d.items():
+ assert isinstance(k, str)
+ items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
+ return items
+ else:
+ new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
+ return [(new_key, d)]
+
+
+def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
+ """
+ Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
+ batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
+ Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
+ outputs to [B, T, ...].
+
+ Args:
+ inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ op: a layer op that accepts inputs
+ activation: activation to apply at the output
+ inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
+ inputs_as_args (bool) whether to feed input as a args list to the op
+ kwargs (dict): other kwargs to supply to the op
+
+ Returns:
+ outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
+ """
+ batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
+ inputs = join_dimensions(inputs, 0, 1)
+ if inputs_as_kwargs:
+ outputs = op(**inputs, **kwargs)
+ elif inputs_as_args:
+ outputs = op(*inputs, **kwargs)
+ else:
+ outputs = op(inputs, **kwargs)
+
+ if activation is not None:
+ outputs = map_tensor(outputs, activation)
+ outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
+ return outputs
diff --git a/third_party/diffusion_policy/diffusion_policy/model/diffusion/__pycache__/conditional_unet1d.cpython-310.pyc b/third_party/diffusion_policy/diffusion_policy/model/diffusion/__pycache__/conditional_unet1d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b9cb640e14ee24e4bca666a7e4c38c2364c9dd8
Binary files /dev/null and b/third_party/diffusion_policy/diffusion_policy/model/diffusion/__pycache__/conditional_unet1d.cpython-310.pyc differ
diff --git a/third_party/diffusion_policy/diffusion_policy/model/diffusion/conv1d_components.py b/third_party/diffusion_policy/diffusion_policy/model/diffusion/conv1d_components.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c4cfc92550a753ab6bb0e9de958ba31f08f5146
--- /dev/null
+++ b/third_party/diffusion_policy/diffusion_policy/model/diffusion/conv1d_components.py
@@ -0,0 +1,46 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from einops.layers.torch import Rearrange
+
+
+class Downsample1d(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+class Upsample1d(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+class Conv1dBlock(nn.Module):
+ '''
+ Conv1d --> GroupNorm --> Mish
+ '''
+
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ super().__init__()
+
+ self.block = nn.Sequential(
+ nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
+ # Rearrange('batch channels horizon -> batch channels 1 horizon'),
+ nn.GroupNorm(n_groups, out_channels),
+ # Rearrange('batch channels 1 horizon -> batch channels horizon'),
+ nn.Mish(),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+def test():
+ cb = Conv1dBlock(256, 128, kernel_size=3)
+ x = torch.zeros((1,256,16))
+ o = cb(x)