Add files using upload-large-folder tool
Browse files- README.md +11 -0
- docs/ENVIRONMENT_NOTES.md +42 -0
- models/pointflowmatch_take_shoes_out_of_box/1717447341-indigo-quokka/1717447341-indigo-quokka/config.yaml +80 -0
- reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/run.log +129 -0
- reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json +16 -0
- scripts/run_pointflowmatch_take_shoes_out_of_box.sh +25 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/config.yaml +14 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/hydra.yaml +159 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/overrides.yaml +5 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/config.yaml +14 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/hydra.yaml +160 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/overrides.yaml +6 -0
- third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/evaluate.log +3 -0
- third_party/diffusion_policy/.gitignore +140 -0
- third_party/diffusion_policy/LICENSE +21 -0
- third_party/diffusion_policy/README.md +437 -0
- third_party/diffusion_policy/conda_environment.yaml +65 -0
- third_party/diffusion_policy/conda_environment_macos.yaml +55 -0
- third_party/diffusion_policy/conda_environment_real.yaml +73 -0
- third_party/diffusion_policy/demo_pusht.py +120 -0
- third_party/diffusion_policy/demo_real_robot.py +160 -0
- third_party/diffusion_policy/diffusion_policy.egg-info/PKG-INFO +5 -0
- third_party/diffusion_policy/diffusion_policy.egg-info/SOURCES.txt +19 -0
- third_party/diffusion_policy/diffusion_policy.egg-info/dependency_links.txt +1 -0
- third_party/diffusion_policy/diffusion_policy.egg-info/top_level.txt +1 -0
- third_party/diffusion_policy/diffusion_policy/workspace/train_diffusion_unet_lowdim_workspace.py +306 -0
- third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_hybrid_workspace.py +283 -0
- third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_lowdim_workspace.py +282 -0
- third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_image_workspace.py +254 -0
- third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_lowdim_workspace.py +221 -0
- third_party/diffusion_policy/eval.py +64 -0
- third_party/diffusion_policy/eval_real_robot.py +418 -0
- third_party/diffusion_policy/multirun_metrics.py +267 -0
- third_party/diffusion_policy/pyrightconfig.json +7 -0
- third_party/diffusion_policy/ray_exec.py +121 -0
- third_party/diffusion_policy/ray_train_multirun.py +271 -0
- third_party/diffusion_policy/setup.py +6 -0
- third_party/diffusion_policy/tests/test_block_pushing.py +44 -0
- third_party/diffusion_policy/tests/test_cv2_util.py +22 -0
- third_party/diffusion_policy/tests/test_multi_realsense.py +82 -0
- third_party/diffusion_policy/tests/test_pose_trajectory_interpolator.py +126 -0
- third_party/diffusion_policy/tests/test_precise_sleep.py +56 -0
- third_party/diffusion_policy/tests/test_replay_buffer.py +62 -0
- third_party/diffusion_policy/tests/test_ring_buffer.py +188 -0
- third_party/diffusion_policy/tests/test_robomimic_image_runner.py +38 -0
- third_party/diffusion_policy/tests/test_robomimic_lowdim_runner.py +34 -0
- third_party/diffusion_policy/tests/test_shared_queue.py +67 -0
- third_party/diffusion_policy/tests/test_single_realsense.py +87 -0
- third_party/diffusion_policy/tests/test_timestamp_accumulator.py +151 -0
- third_party/diffusion_policy/train.py +35 -0
README.md
CHANGED
|
@@ -7,11 +7,14 @@ iterations, with emphasis on RLBench2 `take tray out of oven`.
|
|
| 7 |
|
| 8 |
- VLAarchTests benchmark code and generated public benchmark manifest
|
| 9 |
- Patched AnyBimanual RLBench runtime used to execute the public oven benchmark
|
|
|
|
| 10 |
- Official `katefgroup/3d_flowmatch_actor` PerAct2 checkpoint and public test data
|
|
|
|
| 11 |
- Public AnyBimanual LF baseline weights and comparison logs
|
| 12 |
- Verified benchmark reports:
|
| 13 |
- oven subset run: `9/10`
|
| 14 |
- oven full official run: `95/100 = 0.95`
|
|
|
|
| 15 |
- hybrid public benchmark smoke outputs
|
| 16 |
- DexGarmentLab benchmark-related validation scripts and validation logs
|
| 17 |
|
|
@@ -26,12 +29,20 @@ Official oven result artifacts:
|
|
| 26 |
- `reports/3dfa_peract2_take_tray_out_of_oven_subset10/eval_after_official_ttm.json`
|
| 27 |
- `reports/3dfa_peract2_take_tray_out_of_oven_full100/eval.json`
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
## Important Code Paths
|
| 30 |
|
| 31 |
- `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py`
|
| 32 |
- `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_rlbench_hybrid_smoke.py`
|
| 33 |
- `third_party/AnyBimanual/third_party/RLBench/rlbench/bimanual_tasks/bimanual_take_tray_out_of_oven.py`
|
| 34 |
- `third_party/AnyBimanual/third_party/RLBench/rlbench/task_ttms/bimanual_take_tray_out_of_oven.ttm`
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
## External Dependencies Not Mirrored Here
|
| 37 |
|
|
|
|
| 7 |
|
| 8 |
- VLAarchTests benchmark code and generated public benchmark manifest
|
| 9 |
- Patched AnyBimanual RLBench runtime used to execute the public oven benchmark
|
| 10 |
+
- Patched PointFlowMatch + `diffusion_policy` source snapshot used to execute `take shoes out of box`
|
| 11 |
- Official `katefgroup/3d_flowmatch_actor` PerAct2 checkpoint and public test data
|
| 12 |
+
- Official PointFlowMatch `1717447341-indigo-quokka` checkpoint for `take_shoes_out_of_box`
|
| 13 |
- Public AnyBimanual LF baseline weights and comparison logs
|
| 14 |
- Verified benchmark reports:
|
| 15 |
- oven subset run: `9/10`
|
| 16 |
- oven full official run: `95/100 = 0.95`
|
| 17 |
+
- shoes GPU search: non-zero success verified before later simulator crash
|
| 18 |
- hybrid public benchmark smoke outputs
|
| 19 |
- DexGarmentLab benchmark-related validation scripts and validation logs
|
| 20 |
|
|
|
|
| 29 |
- `reports/3dfa_peract2_take_tray_out_of_oven_subset10/eval_after_official_ttm.json`
|
| 30 |
- `reports/3dfa_peract2_take_tray_out_of_oven_full100/eval.json`
|
| 31 |
|
| 32 |
+
Shoes result artifacts:
|
| 33 |
+
|
| 34 |
+
- `reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json`
|
| 35 |
+
- `reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/run.log`
|
| 36 |
+
|
| 37 |
## Important Code Paths
|
| 38 |
|
| 39 |
- `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py`
|
| 40 |
- `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_rlbench_hybrid_smoke.py`
|
| 41 |
- `third_party/AnyBimanual/third_party/RLBench/rlbench/bimanual_tasks/bimanual_take_tray_out_of_oven.py`
|
| 42 |
- `third_party/AnyBimanual/third_party/RLBench/rlbench/task_ttms/bimanual_take_tray_out_of_oven.ttm`
|
| 43 |
+
- `third_party/PointFlowMatch/pfp/envs/rlbench_env.py`
|
| 44 |
+
- `third_party/PointFlowMatch/pfp/policy/fm_policy.py`
|
| 45 |
+
- `scripts/run_pointflowmatch_take_shoes_out_of_box.sh`
|
| 46 |
|
| 47 |
## External Dependencies Not Mirrored Here
|
| 48 |
|
docs/ENVIRONMENT_NOTES.md
CHANGED
|
@@ -45,3 +45,45 @@ xvfb-run -a -s "-screen 0 1400x900x24" python \
|
|
| 45 |
--denoise_timesteps 5 \
|
| 46 |
--denoise_model rectified_flow
|
| 47 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
--denoise_timesteps 5 \
|
| 46 |
--denoise_model rectified_flow
|
| 47 |
```
|
| 48 |
+
|
| 49 |
+
The `take shoes out of box` validation path was run from the bundled
|
| 50 |
+
PointFlowMatch source on a Blackwell GPU machine after upgrading the RLBench
|
| 51 |
+
environment to Torch `2.11.0+cu128` / torchvision `0.26.0+cu128` /
|
| 52 |
+
torchaudio `2.11.0+cu128`. The bundled PointFlowMatch tree also contains two
|
| 53 |
+
local compatibility fixes required for this workspace:
|
| 54 |
+
|
| 55 |
+
- `pfp/envs/rlbench_env.py`
|
| 56 |
+
- adapts PointFlowMatch to the local RLBench camera naming and observation API
|
| 57 |
+
- broadens motion-planning failure recovery to handle simulator-side runtime failures
|
| 58 |
+
- `pfp/policy/fm_policy.py`
|
| 59 |
+
- adds an inference-only fallback when legacy `composer` cannot import on modern Torch
|
| 60 |
+
- loads checkpoints with `weights_only=False` for PyTorch 2.6+
|
| 61 |
+
|
| 62 |
+
The shoes evaluation command used here was:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
scripts/run_pointflowmatch_take_shoes_out_of_box.sh 10 50
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
That wrapper expands to the equivalent raw command:
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
export PYTHONPATH=third_party/diffusion_policy:third_party/PointFlowMatch:${PYTHONPATH:-}
|
| 72 |
+
export COPPELIASIM_ROOT=/path/to/CoppeliaSim
|
| 73 |
+
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:$COPPELIASIM_ROOT
|
| 74 |
+
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
|
| 75 |
+
xvfb-run -a -s "-screen 0 1400x900x24" python \
|
| 76 |
+
third_party/PointFlowMatch/scripts/evaluate.py \
|
| 77 |
+
log_wandb=False \
|
| 78 |
+
env_runner.env_config.vis=False \
|
| 79 |
+
env_runner.num_episodes=10 \
|
| 80 |
+
env_runner.max_episode_length=200 \
|
| 81 |
+
policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka \
|
| 82 |
+
policy.num_k_infer=50
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Result note for shoes:
|
| 86 |
+
|
| 87 |
+
- `reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json`
|
| 88 |
+
records a verified non-zero result before a later RLBench/PyRep crash in the
|
| 89 |
+
same longer rollout.
|
models/pointflowmatch_take_shoes_out_of_box/1717447341-indigo-quokka/1717447341-indigo-quokka/config.yaml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
_target_: pfp.policy.fm_policy.FMPolicy
|
| 3 |
+
x_dim: 266
|
| 4 |
+
y_dim: 10
|
| 5 |
+
n_obs_steps: 2
|
| 6 |
+
n_pred_steps: 32
|
| 7 |
+
num_k_infer: 10
|
| 8 |
+
time_conditioning: true
|
| 9 |
+
norm_pcd_center:
|
| 10 |
+
- 0.4
|
| 11 |
+
- 0.0
|
| 12 |
+
- 1.4
|
| 13 |
+
augment_data: false
|
| 14 |
+
noise_type: gaussian
|
| 15 |
+
noise_scale: 1.0
|
| 16 |
+
loss_type: l2
|
| 17 |
+
flow_schedule: exp
|
| 18 |
+
exp_scale: 4.0
|
| 19 |
+
snr_sampler: uniform
|
| 20 |
+
obs_encoder:
|
| 21 |
+
_target_: pfp.backbones.pointnet.PointNetBackbone
|
| 22 |
+
embed_dim: 256
|
| 23 |
+
input_channels: 3
|
| 24 |
+
input_transform: false
|
| 25 |
+
use_group_norm: false
|
| 26 |
+
diffusion_net:
|
| 27 |
+
_target_: diffusion_policy.model.diffusion.conditional_unet1d.ConditionalUnet1D
|
| 28 |
+
input_dim: 10
|
| 29 |
+
global_cond_dim: 532
|
| 30 |
+
diffusion_step_embed_dim: 256
|
| 31 |
+
down_dims:
|
| 32 |
+
- 256
|
| 33 |
+
- 512
|
| 34 |
+
- 1024
|
| 35 |
+
kernel_size: 5
|
| 36 |
+
n_groups: 8
|
| 37 |
+
cond_predict_scale: true
|
| 38 |
+
loss_weights:
|
| 39 |
+
xyz: 10.0
|
| 40 |
+
rot6d: 10.0
|
| 41 |
+
grip: 1.0
|
| 42 |
+
backbone:
|
| 43 |
+
_target_: pfp.backbones.pointnet.PointNetBackbone
|
| 44 |
+
embed_dim: 256
|
| 45 |
+
input_channels: 3
|
| 46 |
+
input_transform: false
|
| 47 |
+
use_group_norm: false
|
| 48 |
+
seed: 1234
|
| 49 |
+
epochs: 1500
|
| 50 |
+
log_wandb: true
|
| 51 |
+
task_name: take_shoes_out_of_box
|
| 52 |
+
obs_features_dim: 256
|
| 53 |
+
y_dim: 10
|
| 54 |
+
x_dim: 266
|
| 55 |
+
n_obs_steps: 2
|
| 56 |
+
n_pred_steps: 32
|
| 57 |
+
use_ema: true
|
| 58 |
+
save_each_n_epochs: 500
|
| 59 |
+
obs_mode: pcd
|
| 60 |
+
run_name: null
|
| 61 |
+
dataset:
|
| 62 |
+
n_obs_steps: 2
|
| 63 |
+
n_pred_steps: 32
|
| 64 |
+
subs_factor: 3
|
| 65 |
+
use_pc_color: false
|
| 66 |
+
n_points: 4096
|
| 67 |
+
dataloader:
|
| 68 |
+
batch_size: 128
|
| 69 |
+
num_workers: 8
|
| 70 |
+
optimizer:
|
| 71 |
+
_target_: torch.optim.AdamW
|
| 72 |
+
lr: 3.0e-05
|
| 73 |
+
betas:
|
| 74 |
+
- 0.95
|
| 75 |
+
- 0.999
|
| 76 |
+
eps: 1.0e-08
|
| 77 |
+
weight_decay: 1.0e-06
|
| 78 |
+
lr_scheduler:
|
| 79 |
+
name: cosine
|
| 80 |
+
num_warmup_steps: 5000
|
reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/run.log
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/10 [00:00<?, ?it/s][2026-04-03 00:44:29,033][root][WARNING] - single robot
|
|
|
|
| 1 |
10%|█ | 1/10 [02:31<22:41, 151.33s/it]'NoneType' object has no attribute 'step'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
10%|█ | 1/10 [03:32<31:48, 212.03s/it]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/workspace/envs/rlbench/lib/python3.10/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.13) or chardet (7.4.0.post2)/charset_normalizer (2.1.1) doesn't match a supported version!
|
| 2 |
+
warnings.warn(
|
| 3 |
+
pytorch3d not installed
|
| 4 |
+
WARNING: Rerun not installed. Visualization will not work.
|
| 5 |
+
WARNING: Rerun not installed. Visualization will not work.
|
| 6 |
+
seed: 5678
|
| 7 |
+
log_wandb: false
|
| 8 |
+
env_runner:
|
| 9 |
+
num_episodes: 10
|
| 10 |
+
max_episode_length: 200
|
| 11 |
+
verbose: true
|
| 12 |
+
env_config:
|
| 13 |
+
voxel_size: 0.01
|
| 14 |
+
headless: true
|
| 15 |
+
vis: false
|
| 16 |
+
policy:
|
| 17 |
+
ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 18 |
+
ckpt_episode: ep1500
|
| 19 |
+
num_k_infer: 50
|
| 20 |
+
|
| 21 |
+
seed: 5678
|
| 22 |
+
log_wandb: false
|
| 23 |
+
env_runner:
|
| 24 |
+
num_episodes: 10
|
| 25 |
+
max_episode_length: 200
|
| 26 |
+
verbose: true
|
| 27 |
+
env_config:
|
| 28 |
+
voxel_size: 0.01
|
| 29 |
+
headless: true
|
| 30 |
+
vis: false
|
| 31 |
+
task_name: take_shoes_out_of_box
|
| 32 |
+
obs_mode: pcd
|
| 33 |
+
use_pc_color: false
|
| 34 |
+
n_points: 4096
|
| 35 |
+
policy:
|
| 36 |
+
ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 37 |
+
ckpt_episode: ep1500
|
| 38 |
+
num_k_infer: 50
|
| 39 |
+
_target_: pfp.policy.fm_policy.FMPolicy.load_from_checkpoint
|
| 40 |
+
model:
|
| 41 |
+
_target_: pfp.policy.fm_policy.FMPolicy
|
| 42 |
+
x_dim: 266
|
| 43 |
+
y_dim: 10
|
| 44 |
+
n_obs_steps: 2
|
| 45 |
+
n_pred_steps: 32
|
| 46 |
+
num_k_infer: 10
|
| 47 |
+
time_conditioning: true
|
| 48 |
+
norm_pcd_center:
|
| 49 |
+
- 0.4
|
| 50 |
+
- 0.0
|
| 51 |
+
- 1.4
|
| 52 |
+
augment_data: false
|
| 53 |
+
noise_type: gaussian
|
| 54 |
+
noise_scale: 1.0
|
| 55 |
+
loss_type: l2
|
| 56 |
+
flow_schedule: exp
|
| 57 |
+
exp_scale: 4.0
|
| 58 |
+
snr_sampler: uniform
|
| 59 |
+
obs_encoder:
|
| 60 |
+
_target_: pfp.backbones.pointnet.PointNetBackbone
|
| 61 |
+
embed_dim: 256
|
| 62 |
+
input_channels: 3
|
| 63 |
+
input_transform: false
|
| 64 |
+
use_group_norm: false
|
| 65 |
+
diffusion_net:
|
| 66 |
+
_target_: diffusion_policy.model.diffusion.conditional_unet1d.ConditionalUnet1D
|
| 67 |
+
input_dim: 10
|
| 68 |
+
global_cond_dim: 532
|
| 69 |
+
diffusion_step_embed_dim: 256
|
| 70 |
+
down_dims:
|
| 71 |
+
- 256
|
| 72 |
+
- 512
|
| 73 |
+
- 1024
|
| 74 |
+
kernel_size: 5
|
| 75 |
+
n_groups: 8
|
| 76 |
+
cond_predict_scale: true
|
| 77 |
+
loss_weights:
|
| 78 |
+
xyz: 10.0
|
| 79 |
+
rot6d: 10.0
|
| 80 |
+
grip: 1.0
|
| 81 |
+
|
| 82 |
+
output_dim 10
|
| 83 |
+
[2026-04-03 00:44:27,184][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
|
| 84 |
+
QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
|
| 85 |
+
|
| 86 |
0%| | 0/10 [00:00<?, ?it/s][2026-04-03 00:44:29,033][root][WARNING] - single robot
|
| 87 |
+
|
| 88 |
10%|█ | 1/10 [02:31<22:41, 151.33s/it]'NoneType' object has no attribute 'step'
|
| 89 |
+
'NoneType' object has no attribute 'step'
|
| 90 |
+
'NoneType' object has no attribute 'step'
|
| 91 |
+
'NoneType' object has no attribute 'step'
|
| 92 |
+
'NoneType' object has no attribute 'step'
|
| 93 |
+
'NoneType' object has no attribute 'step'
|
| 94 |
+
Steps: 182
|
| 95 |
+
Success: True
|
| 96 |
+
[2026-04-03 00:47:00,210][root][WARNING] - single robot
|
| 97 |
+
|
| 98 |
10%|█ | 1/10 [03:32<31:48, 212.03s/it]
|
| 99 |
+
'NoneType' object has no attribute 'step'
|
| 100 |
+
'NoneType' object has no attribute 'step'
|
| 101 |
+
Error executing job with overrides: ['log_wandb=False', 'env_runner.env_config.vis=False', 'env_runner.num_episodes=10', 'env_runner.max_episode_length=200', 'policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka', 'policy.num_k_infer=50']
|
| 102 |
+
Traceback (most recent call last):
|
| 103 |
+
File "/workspace/third_party/PointFlowMatch/scripts/evaluate.py", line 52, in main
|
| 104 |
+
_ = env_runner.run(policy)
|
| 105 |
+
File "/workspace/third_party/PointFlowMatch/pfp/envs/rlbench_runner.py", line 35, in run
|
| 106 |
+
reward, terminate = self.env.step(next_robot_state)
|
| 107 |
+
File "/workspace/third_party/PointFlowMatch/pfp/envs/rlbench_env.py", line 105, in step
|
| 108 |
+
reward, terminate = self._step_safe(action)
|
| 109 |
+
File "/workspace/third_party/PointFlowMatch/pfp/envs/rlbench_env.py", line 113, in _step_safe
|
| 110 |
+
_, reward, terminate = self.task.step(action)
|
| 111 |
+
File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/task_environment.py", line 109, in step
|
| 112 |
+
self._action_mode.action(self._scene, action)
|
| 113 |
+
File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/action_modes/action_mode.py", line 47, in action
|
| 114 |
+
self.arm_action_mode.action(scene, arm_action, ignore_collisions)
|
| 115 |
+
File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/action_modes/arm_action_modes.py", line 304, in action
|
| 116 |
+
success, terminate = scene.task.success()
|
| 117 |
+
File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/backend/task.py", line 303, in success
|
| 118 |
+
[cond.condition_met()[0] for cond in self._success_conditions])
|
| 119 |
+
File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/backend/task.py", line 303, in <listcomp>
|
| 120 |
+
[cond.condition_met()[0] for cond in self._success_conditions])
|
| 121 |
+
File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/backend/conditions.py", line 51, in condition_met
|
| 122 |
+
met = self._detector.is_detected(self._obj)
|
| 123 |
+
File "/workspace/third_party/PyRep/pyrep/objects/proximity_sensor.py", line 36, in is_detected
|
| 124 |
+
state, point = sim.simCheckProximitySensor(
|
| 125 |
+
File "/workspace/third_party/PyRep/pyrep/backend/sim.py", line 345, in simCheckProximitySensor
|
| 126 |
+
_check_return(state)
|
| 127 |
+
File "/workspace/third_party/PyRep/pyrep/backend/sim.py", line 27, in _check_return
|
| 128 |
+
raise RuntimeError(
|
| 129 |
+
RuntimeError: The call failed on the V-REP side. Return value: -1
|
| 130 |
+
|
| 131 |
+
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
|
| 132 |
+
QMutex: destroying locked mutex
|
reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"task_name": "take_shoes_out_of_box",
|
| 3 |
+
"model_family": "PointFlowMatch",
|
| 4 |
+
"checkpoint_name": "1717447341-indigo-quokka",
|
| 5 |
+
"checkpoint_source": "http://pointflowmatch.cs.uni-freiburg.de/download/1717447341-indigo-quokka.zip",
|
| 6 |
+
"num_k_infer": 50,
|
| 7 |
+
"requested_episodes": 10,
|
| 8 |
+
"completed_episodes_before_crash": 1,
|
| 9 |
+
"success_count_before_crash": 1,
|
| 10 |
+
"first_success_episode_index": 0,
|
| 11 |
+
"first_success_steps": 182,
|
| 12 |
+
"verified_non_zero_performance": true,
|
| 13 |
+
"run_status": "partial_success_then_crash",
|
| 14 |
+
"crash_reason": "RLBench/PyRep RuntimeError after first successful episode during continued rollout",
|
| 15 |
+
"report_log": "run.log"
|
| 16 |
+
}
|
scripts/run_pointflowmatch_take_shoes_out_of_box.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
EPISODES="${1:-10}"
|
| 5 |
+
NUM_K_INFER="${2:-50}"
|
| 6 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 8 |
+
COPPELIASIM_ROOT="${COPPELIASIM_ROOT:?set COPPELIASIM_ROOT to your CoppeliaSim root}"
|
| 9 |
+
REPORT_DIR="${REPORT_DIR:-$ROOT/reports/pointflowmatch_take_shoes_out_of_box_ep${EPISODES}_k${NUM_K_INFER}_gpu}"
|
| 10 |
+
|
| 11 |
+
export PYTHONPATH="$ROOT/third_party/diffusion_policy:$ROOT/third_party/PointFlowMatch:${PYTHONPATH:-}"
|
| 12 |
+
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}:$COPPELIASIM_ROOT"
|
| 13 |
+
export QT_QPA_PLATFORM_PLUGIN_PATH="$COPPELIASIM_ROOT"
|
| 14 |
+
|
| 15 |
+
mkdir -p "$REPORT_DIR"
|
| 16 |
+
cd "$ROOT/third_party/PointFlowMatch"
|
| 17 |
+
|
| 18 |
+
xvfb-run -a -s "-screen 0 1400x900x24" "$PYTHON_BIN" scripts/evaluate.py \
|
| 19 |
+
log_wandb=False \
|
| 20 |
+
env_runner.env_config.vis=False \
|
| 21 |
+
env_runner.num_episodes="$EPISODES" \
|
| 22 |
+
env_runner.max_episode_length=200 \
|
| 23 |
+
policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka \
|
| 24 |
+
policy.num_k_infer="$NUM_K_INFER" \
|
| 25 |
+
2>&1 | tee "$REPORT_DIR/run.log"
|
third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/config.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 5678
|
| 2 |
+
log_wandb: false
|
| 3 |
+
env_runner:
|
| 4 |
+
num_episodes: 1
|
| 5 |
+
max_episode_length: 200
|
| 6 |
+
verbose: true
|
| 7 |
+
env_config:
|
| 8 |
+
voxel_size: 0.01
|
| 9 |
+
headless: true
|
| 10 |
+
vis: false
|
| 11 |
+
policy:
|
| 12 |
+
ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 13 |
+
ckpt_episode: ep1500
|
| 14 |
+
num_k_infer: 50
|
third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task:
|
| 115 |
+
- log_wandb=False
|
| 116 |
+
- env_runner.env_config.vis=False
|
| 117 |
+
- env_runner.num_episodes=1
|
| 118 |
+
- env_runner.max_episode_length=200
|
| 119 |
+
- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 120 |
+
job:
|
| 121 |
+
name: evaluate
|
| 122 |
+
chdir: null
|
| 123 |
+
override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 124 |
+
id: ???
|
| 125 |
+
num: ???
|
| 126 |
+
config_name: eval
|
| 127 |
+
env_set: {}
|
| 128 |
+
env_copy: []
|
| 129 |
+
config:
|
| 130 |
+
override_dirname:
|
| 131 |
+
kv_sep: '='
|
| 132 |
+
item_sep: ','
|
| 133 |
+
exclude_keys: []
|
| 134 |
+
runtime:
|
| 135 |
+
version: 1.3.2
|
| 136 |
+
version_base: '1.3'
|
| 137 |
+
cwd: /workspace/third_party/PointFlowMatch
|
| 138 |
+
config_sources:
|
| 139 |
+
- path: hydra.conf
|
| 140 |
+
schema: pkg
|
| 141 |
+
provider: hydra
|
| 142 |
+
- path: /workspace/third_party/PointFlowMatch/conf
|
| 143 |
+
schema: file
|
| 144 |
+
provider: main
|
| 145 |
+
- path: ''
|
| 146 |
+
schema: structured
|
| 147 |
+
provider: schema
|
| 148 |
+
output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-25-21
|
| 149 |
+
choices:
|
| 150 |
+
hydra/env: default
|
| 151 |
+
hydra/callbacks: null
|
| 152 |
+
hydra/job_logging: default
|
| 153 |
+
hydra/hydra_logging: default
|
| 154 |
+
hydra/hydra_help: default
|
| 155 |
+
hydra/help: default
|
| 156 |
+
hydra/sweeper: basic
|
| 157 |
+
hydra/launcher: basic
|
| 158 |
+
hydra/output: default
|
| 159 |
+
verbose: false
|
third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- log_wandb=False
|
| 2 |
+
- env_runner.env_config.vis=False
|
| 3 |
+
- env_runner.num_episodes=1
|
| 4 |
+
- env_runner.max_episode_length=200
|
| 5 |
+
- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
|
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/config.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 5678
|
| 2 |
+
log_wandb: false
|
| 3 |
+
env_runner:
|
| 4 |
+
num_episodes: 10
|
| 5 |
+
max_episode_length: 200
|
| 6 |
+
verbose: true
|
| 7 |
+
env_config:
|
| 8 |
+
voxel_size: 0.01
|
| 9 |
+
headless: true
|
| 10 |
+
vis: false
|
| 11 |
+
policy:
|
| 12 |
+
ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 13 |
+
ckpt_episode: ep1500
|
| 14 |
+
num_k_infer: 50
|
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task:
|
| 115 |
+
- log_wandb=False
|
| 116 |
+
- env_runner.env_config.vis=False
|
| 117 |
+
- env_runner.num_episodes=10
|
| 118 |
+
- env_runner.max_episode_length=200
|
| 119 |
+
- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 120 |
+
- policy.num_k_infer=50
|
| 121 |
+
job:
|
| 122 |
+
name: evaluate
|
| 123 |
+
chdir: null
|
| 124 |
+
override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=10,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka,policy.num_k_infer=50
|
| 125 |
+
id: ???
|
| 126 |
+
num: ???
|
| 127 |
+
config_name: eval
|
| 128 |
+
env_set: {}
|
| 129 |
+
env_copy: []
|
| 130 |
+
config:
|
| 131 |
+
override_dirname:
|
| 132 |
+
kv_sep: '='
|
| 133 |
+
item_sep: ','
|
| 134 |
+
exclude_keys: []
|
| 135 |
+
runtime:
|
| 136 |
+
version: 1.3.2
|
| 137 |
+
version_base: '1.3'
|
| 138 |
+
cwd: /workspace/third_party/PointFlowMatch
|
| 139 |
+
config_sources:
|
| 140 |
+
- path: hydra.conf
|
| 141 |
+
schema: pkg
|
| 142 |
+
provider: hydra
|
| 143 |
+
- path: /workspace/third_party/PointFlowMatch/conf
|
| 144 |
+
schema: file
|
| 145 |
+
provider: main
|
| 146 |
+
- path: ''
|
| 147 |
+
schema: structured
|
| 148 |
+
provider: schema
|
| 149 |
+
output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-44-21
|
| 150 |
+
choices:
|
| 151 |
+
hydra/env: default
|
| 152 |
+
hydra/callbacks: null
|
| 153 |
+
hydra/job_logging: default
|
| 154 |
+
hydra/hydra_logging: default
|
| 155 |
+
hydra/hydra_help: default
|
| 156 |
+
hydra/help: default
|
| 157 |
+
hydra/sweeper: basic
|
| 158 |
+
hydra/launcher: basic
|
| 159 |
+
hydra/output: default
|
| 160 |
+
verbose: false
|
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- log_wandb=False
|
| 2 |
+
- env_runner.env_config.vis=False
|
| 3 |
+
- env_runner.num_episodes=10
|
| 4 |
+
- env_runner.max_episode_length=200
|
| 5 |
+
- policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
|
| 6 |
+
- policy.num_k_infer=50
|
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/evaluate.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[2026-04-03 00:44:27,184][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
|
| 2 |
+
[2026-04-03 00:44:29,033][root][WARNING] - single robot
|
| 3 |
+
[2026-04-03 00:47:00,210][root][WARNING] - single robot
|
third_party/diffusion_policy/.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bin
|
| 2 |
+
logs
|
| 3 |
+
wandb
|
| 4 |
+
outputs
|
| 5 |
+
data
|
| 6 |
+
data_local
|
| 7 |
+
.vscode
|
| 8 |
+
_wandb
|
| 9 |
+
|
| 10 |
+
**/.DS_Store
|
| 11 |
+
|
| 12 |
+
fuse.cfg
|
| 13 |
+
|
| 14 |
+
*.ai
|
| 15 |
+
|
| 16 |
+
# Generation results
|
| 17 |
+
results/
|
| 18 |
+
|
| 19 |
+
ray/auth.json
|
| 20 |
+
|
| 21 |
+
# Byte-compiled / optimized / DLL files
|
| 22 |
+
__pycache__/
|
| 23 |
+
*.py[cod]
|
| 24 |
+
*$py.class
|
| 25 |
+
|
| 26 |
+
# C extensions
|
| 27 |
+
*.so
|
| 28 |
+
|
| 29 |
+
# Distribution / packaging
|
| 30 |
+
.Python
|
| 31 |
+
build/
|
| 32 |
+
develop-eggs/
|
| 33 |
+
dist/
|
| 34 |
+
downloads/
|
| 35 |
+
eggs/
|
| 36 |
+
.eggs/
|
| 37 |
+
lib/
|
| 38 |
+
lib64/
|
| 39 |
+
parts/
|
| 40 |
+
sdist/
|
| 41 |
+
var/
|
| 42 |
+
wheels/
|
| 43 |
+
pip-wheel-metadata/
|
| 44 |
+
share/python-wheels/
|
| 45 |
+
*.egg-info/
|
| 46 |
+
.installed.cfg
|
| 47 |
+
*.egg
|
| 48 |
+
MANIFEST
|
| 49 |
+
|
| 50 |
+
# PyInstaller
|
| 51 |
+
# Usually these files are written by a python script from a template
|
| 52 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 53 |
+
*.manifest
|
| 54 |
+
*.spec
|
| 55 |
+
|
| 56 |
+
# Installer logs
|
| 57 |
+
pip-log.txt
|
| 58 |
+
pip-delete-this-directory.txt
|
| 59 |
+
|
| 60 |
+
# Unit test / coverage reports
|
| 61 |
+
htmlcov/
|
| 62 |
+
.tox/
|
| 63 |
+
.nox/
|
| 64 |
+
.coverage
|
| 65 |
+
.coverage.*
|
| 66 |
+
.cache
|
| 67 |
+
nosetests.xml
|
| 68 |
+
coverage.xml
|
| 69 |
+
*.cover
|
| 70 |
+
*.py,cover
|
| 71 |
+
.hypothesis/
|
| 72 |
+
.pytest_cache/
|
| 73 |
+
|
| 74 |
+
# Translations
|
| 75 |
+
*.mo
|
| 76 |
+
*.pot
|
| 77 |
+
|
| 78 |
+
# Django stuff:
|
| 79 |
+
*.log
|
| 80 |
+
local_settings.py
|
| 81 |
+
db.sqlite3
|
| 82 |
+
db.sqlite3-journal
|
| 83 |
+
|
| 84 |
+
# Flask stuff:
|
| 85 |
+
instance/
|
| 86 |
+
.webassets-cache
|
| 87 |
+
|
| 88 |
+
# Scrapy stuff:
|
| 89 |
+
.scrapy
|
| 90 |
+
|
| 91 |
+
# Sphinx documentation
|
| 92 |
+
docs/_build/
|
| 93 |
+
|
| 94 |
+
# PyBuilder
|
| 95 |
+
target/
|
| 96 |
+
|
| 97 |
+
# Jupyter Notebook
|
| 98 |
+
.ipynb_checkpoints
|
| 99 |
+
|
| 100 |
+
# IPython
|
| 101 |
+
profile_default/
|
| 102 |
+
ipython_config.py
|
| 103 |
+
|
| 104 |
+
# pyenv
|
| 105 |
+
.python-version
|
| 106 |
+
|
| 107 |
+
# pipenv
|
| 108 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 109 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 110 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 111 |
+
# install all needed dependencies.
|
| 112 |
+
#Pipfile.lock
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Spyder project settings
|
| 125 |
+
.spyderproject
|
| 126 |
+
.spyproject
|
| 127 |
+
|
| 128 |
+
# Rope project settings
|
| 129 |
+
.ropeproject
|
| 130 |
+
|
| 131 |
+
# mkdocs documentation
|
| 132 |
+
/site
|
| 133 |
+
|
| 134 |
+
# mypy
|
| 135 |
+
.mypy_cache/
|
| 136 |
+
.dmypy.json
|
| 137 |
+
dmypy.json
|
| 138 |
+
|
| 139 |
+
# Pyre type checker
|
| 140 |
+
.pyre/
|
third_party/diffusion_policy/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Columbia Artificial Intelligence and Robotics Lab
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
third_party/diffusion_policy/README.md
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusion Policy
|
| 2 |
+
|
| 3 |
+
[[Project page]](https://diffusion-policy.cs.columbia.edu/)
|
| 4 |
+
[[Paper]](https://diffusion-policy.cs.columbia.edu/#paper)
|
| 5 |
+
[[Data]](https://diffusion-policy.cs.columbia.edu/data/)
|
| 6 |
+
[[Colab (state)]](https://colab.research.google.com/drive/1gxdkgRVfM55zihY9TFLja97cSVZOZq2B?usp=sharing)
|
| 7 |
+
[[Colab (vision)]](https://colab.research.google.com/drive/18GIHeOQ5DyjMN8iIRZL2EKZ0745NLIpg?usp=sharing)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
[Cheng Chi](http://cheng-chi.github.io/)<sup>1</sup>,
|
| 11 |
+
[Siyuan Feng](https://www.cs.cmu.edu/~sfeng/)<sup>2</sup>,
|
| 12 |
+
[Yilun Du](https://yilundu.github.io/)<sup>3</sup>,
|
| 13 |
+
[Zhenjia Xu](https://www.zhenjiaxu.com/)<sup>1</sup>,
|
| 14 |
+
[Eric Cousineau](https://www.eacousineau.com/)<sup>2</sup>,
|
| 15 |
+
[Benjamin Burchfiel](http://www.benburchfiel.com/)<sup>2</sup>,
|
| 16 |
+
[Shuran Song](https://www.cs.columbia.edu/~shurans/)<sup>1</sup>
|
| 17 |
+
|
| 18 |
+
<sup>1</sup>Columbia University,
|
| 19 |
+
<sup>2</sup>Toyota Research Institute,
|
| 20 |
+
<sup>3</sup>MIT
|
| 21 |
+
|
| 22 |
+
<img src="media/teaser.png" alt="drawing" width="100%"/>
|
| 23 |
+
<img src="media/multimodal_sim.png" alt="drawing" width="100%"/>
|
| 24 |
+
|
| 25 |
+
## 🛝 Try it out!
|
| 26 |
+
Our self-contained Google Colab notebooks is the easiest way to play with Diffusion Policy. We provide separate notebooks for [state-based environment](https://colab.research.google.com/drive/1gxdkgRVfM55zihY9TFLja97cSVZOZq2B?usp=sharing) and [vision-based environment](https://colab.research.google.com/drive/18GIHeOQ5DyjMN8iIRZL2EKZ0745NLIpg?usp=sharing).
|
| 27 |
+
|
| 28 |
+
## 🧾 Checkout our experiment logs!
|
| 29 |
+
For each experiment used to generate Table I,II and IV in the [paper](https://diffusion-policy.cs.columbia.edu/#paper), we provide:
|
| 30 |
+
1. A `config.yaml` that contains all parameters needed to reproduce the experiment.
|
| 31 |
+
2. Detailed training/eval `logs.json.txt` for every training step.
|
| 32 |
+
3. Checkpoints for the best `epoch=*-test_mean_score=*.ckpt` and last `latest.ckpt` epoch of each run.
|
| 33 |
+
|
| 34 |
+
Experiment logs are hosted on our website as nested directories in format:
|
| 35 |
+
`https://diffusion-policy.cs.columbia.edu/data/experiments/<image|low_dim>/<task>/<method>/`
|
| 36 |
+
|
| 37 |
+
Within each experiment directory you may find:
|
| 38 |
+
```
|
| 39 |
+
.
|
| 40 |
+
├── config.yaml
|
| 41 |
+
├── metrics
|
| 42 |
+
│ └── logs.json.txt
|
| 43 |
+
├── train_0
|
| 44 |
+
│ ├── checkpoints
|
| 45 |
+
│ │ ├── epoch=0300-test_mean_score=1.000.ckpt
|
| 46 |
+
│ │ └── latest.ckpt
|
| 47 |
+
│ └── logs.json.txt
|
| 48 |
+
├── train_1
|
| 49 |
+
│ ├── checkpoints
|
| 50 |
+
│ │ ├── epoch=0250-test_mean_score=1.000.ckpt
|
| 51 |
+
│ │ └── latest.ckpt
|
| 52 |
+
│ └── logs.json.txt
|
| 53 |
+
└── train_2
|
| 54 |
+
├── checkpoints
|
| 55 |
+
│ ├── epoch=0250-test_mean_score=1.000.ckpt
|
| 56 |
+
│ └── latest.ckpt
|
| 57 |
+
└── logs.json.txt
|
| 58 |
+
```
|
| 59 |
+
The `metrics/logs.json.txt` file aggregates evaluation metrics from all 3 training runs every 50 epochs using `multirun_metrics.py`. The numbers reported in the paper correspond to `max` and `k_min_train_loss` aggregation keys.
|
| 60 |
+
|
| 61 |
+
To download all files in a subdirectory, use:
|
| 62 |
+
|
| 63 |
+
```console
|
| 64 |
+
$ wget --recursive --no-parent --no-host-directories --relative --reject="index.html*" https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/square_ph/diffusion_policy_cnn/
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## 🛠️ Installation
|
| 68 |
+
### 🖥️ Simulation
|
| 69 |
+
To reproduce our simulation benchmark results, install our conda environment on a Linux machine with Nvidia GPU. On Ubuntu 20.04 you need to install the following apt packages for mujoco:
|
| 70 |
+
```console
|
| 71 |
+
$ sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
We recommend [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) instead of the standard anaconda distribution for faster installation:
|
| 75 |
+
```console
|
| 76 |
+
$ mamba env create -f conda_environment.yaml
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
but you can use conda as well:
|
| 80 |
+
```console
|
| 81 |
+
$ conda env create -f conda_environment.yaml
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
The `conda_environment_macos.yaml` file is only for development on MacOS and does not have full support for benchmarks.
|
| 85 |
+
|
| 86 |
+
### 🦾 Real Robot
|
| 87 |
+
Hardware (for Push-T):
|
| 88 |
+
* 1x [UR5-CB3](https://www.universal-robots.com/cb3) or [UR5e](https://www.universal-robots.com/products/ur5-robot/) ([RTDE Interface](https://www.universal-robots.com/articles/ur/interface-communication/real-time-data-exchange-rtde-guide/) is required)
|
| 89 |
+
* 2x [RealSense D415](https://www.intelrealsense.com/depth-camera-d415/)
|
| 90 |
+
* 1x [3Dconnexion SpaceMouse](https://3dconnexion.com/us/product/spacemouse-wireless/) (for teleop)
|
| 91 |
+
* 1x [Millibar Robotics Manual Tool Changer](https://www.millibar.com/manual-tool-changer/) (only need robot side)
|
| 92 |
+
* 1x 3D printed [End effector](https://cad.onshape.com/documents/a818888644a15afa6cc68ee5/w/2885b48b018cda84f425beca/e/3e8771c2124cee024edd2fed?renderMode=0&uiState=63ffcba6631ca919895e64e5)
|
| 93 |
+
* 1x 3D printed [T-block](https://cad.onshape.com/documents/f1140134e38f6ed6902648d5/w/a78cf81827600e4ff4058d03/e/f35f57fb7589f72e05c76caf?renderMode=0&uiState=63ffcbc9af4a881b344898ee)
|
| 94 |
+
* USB-C cables and screws for RealSense
|
| 95 |
+
|
| 96 |
+
Software:
|
| 97 |
+
* Ubuntu 20.04.3 (tested)
|
| 98 |
+
* Mujoco dependencies:
|
| 99 |
+
`sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf`
|
| 100 |
+
* [RealSense SDK](https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md)
|
| 101 |
+
* Spacemouse dependencies:
|
| 102 |
+
`sudo apt install libspnav-dev spacenavd; sudo systemctl start spacenavd`
|
| 103 |
+
* Conda environment `mamba env create -f conda_environment_real.yaml`
|
| 104 |
+
|
| 105 |
+
## 🖥️ Reproducing Simulation Benchmark Results
|
| 106 |
+
### Download Training Data
|
| 107 |
+
Under the repo root, create data subdirectory:
|
| 108 |
+
```console
|
| 109 |
+
[diffusion_policy]$ mkdir data && cd data
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
Download the corresponding zip file from [https://diffusion-policy.cs.columbia.edu/data/training/](https://diffusion-policy.cs.columbia.edu/data/training/)
|
| 113 |
+
```console
|
| 114 |
+
[data]$ wget https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
Extract training data:
|
| 118 |
+
```console
|
| 119 |
+
[data]$ unzip pusht.zip && rm -f pusht.zip && cd ..
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Grab config file for the corresponding experiment:
|
| 123 |
+
```console
|
| 124 |
+
[diffusion_policy]$ wget -O image_pusht_diffusion_policy_cnn.yaml https://diffusion-policy.cs.columbia.edu/data/experiments/image/pusht/diffusion_policy_cnn/config.yaml
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Running for a single seed
|
| 128 |
+
Activate conda environment and login to [wandb](https://wandb.ai) (if you haven't already).
|
| 129 |
+
```console
|
| 130 |
+
[diffusion_policy]$ conda activate robodiff
|
| 131 |
+
(robodiff)[diffusion_policy]$ wandb login
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Launch training with seed 42 on GPU 0.
|
| 135 |
+
```console
|
| 136 |
+
(robodiff)[diffusion_policy]$ python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml training.seed=42 training.device=cuda:0 hydra.run.dir='data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}'
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
This will create a directory in format `data/outputs/yyyy.mm.dd/hh.mm.ss_<method_name>_<task_name>` where configs, logs and checkpoints are written to. The policy will be evaluated every 50 epochs with the success rate logged as `test/mean_score` on wandb, as well as videos for some rollouts.
|
| 140 |
+
```console
|
| 141 |
+
(robodiff)[diffusion_policy]$ tree data/outputs/2023.03.01/20.02.03_train_diffusion_unet_hybrid_pusht_image -I wandb
|
| 142 |
+
data/outputs/2023.03.01/20.02.03_train_diffusion_unet_hybrid_pusht_image
|
| 143 |
+
├── checkpoints
|
| 144 |
+
│ ├── epoch=0000-test_mean_score=0.134.ckpt
|
| 145 |
+
│ └── latest.ckpt
|
| 146 |
+
├── .hydra
|
| 147 |
+
│ ├── config.yaml
|
| 148 |
+
│ ├── hydra.yaml
|
| 149 |
+
│ └── overrides.yaml
|
| 150 |
+
├── logs.json.txt
|
| 151 |
+
├── media
|
| 152 |
+
│ ├── 2k5u6wli.mp4
|
| 153 |
+
│ ├── 2kvovxms.mp4
|
| 154 |
+
│ ├── 2pxd9f6b.mp4
|
| 155 |
+
│ ├── 2q5gjt5f.mp4
|
| 156 |
+
│ ├── 2sawbf6m.mp4
|
| 157 |
+
│ └── 538ubl79.mp4
|
| 158 |
+
└── train.log
|
| 159 |
+
|
| 160 |
+
3 directories, 13 files
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### Running for multiple seeds
|
| 164 |
+
Launch local ray cluster. For large scale experiments, you might want to setup an [AWS cluster with autoscaling](https://docs.ray.io/en/master/cluster/vms/user-guides/launching-clusters/aws.html). All other commands remain the same.
|
| 165 |
+
```console
|
| 166 |
+
(robodiff)[diffusion_policy]$ export CUDA_VISIBLE_DEVICES=0,1,2 # select GPUs to be managed by the ray cluster
|
| 167 |
+
(robodiff)[diffusion_policy]$ ray start --head --num-gpus=3
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
Launch a ray client which will start 3 training workers (3 seeds) and 1 metrics monitor worker.
|
| 171 |
+
```console
|
| 172 |
+
(robodiff)[diffusion_policy]$ python ray_train_multirun.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml --seeds=42,43,44 --monitor_key=test/mean_score -- multi_run.run_dir='data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' multi_run.wandb_name_base='${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}'
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
In addition to the wandb log written by each training worker individually, the metrics monitor worker will log to wandb project `diffusion_policy_metrics` for the metrics aggregated from all 3 training runs. Local config, logs and checkpoints will be written to `data/outputs/yyyy.mm.dd/hh.mm.ss_<method_name>_<task_name>` in a directory structure identical to our [training logs](https://diffusion-policy.cs.columbia.edu/data/experiments/):
|
| 176 |
+
```console
|
| 177 |
+
(robodiff)[diffusion_policy]$ tree data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image -I 'wandb|media'
|
| 178 |
+
data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
|
| 179 |
+
├── config.yaml
|
| 180 |
+
├── metrics
|
| 181 |
+
│ ├── logs.json.txt
|
| 182 |
+
│ ├── metrics.json
|
| 183 |
+
│ └── metrics.log
|
| 184 |
+
├── train_0
|
| 185 |
+
│ ├── checkpoints
|
| 186 |
+
│ │ ├── epoch=0000-test_mean_score=0.174.ckpt
|
| 187 |
+
│ │ └── latest.ckpt
|
| 188 |
+
│ ├── logs.json.txt
|
| 189 |
+
│ └── train.log
|
| 190 |
+
├── train_1
|
| 191 |
+
│ ├── checkpoints
|
| 192 |
+
│ │ ├── epoch=0000-test_mean_score=0.131.ckpt
|
| 193 |
+
│ │ └── latest.ckpt
|
| 194 |
+
│ ├── logs.json.txt
|
| 195 |
+
│ └── train.log
|
| 196 |
+
└── train_2
|
| 197 |
+
├── checkpoints
|
| 198 |
+
│ ├── epoch=0000-test_mean_score=0.105.ckpt
|
| 199 |
+
│ └── latest.ckpt
|
| 200 |
+
├── logs.json.txt
|
| 201 |
+
└── train.log
|
| 202 |
+
|
| 203 |
+
7 directories, 16 files
|
| 204 |
+
```
|
| 205 |
+
### 🆕 Evaluate Pre-trained Checkpoints
|
| 206 |
+
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
|
| 207 |
+
|
| 208 |
+
Run the evaluation script:
|
| 209 |
+
```console
|
| 210 |
+
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
This will generate the following directory structure:
|
| 214 |
+
```console
|
| 215 |
+
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
|
| 216 |
+
data/pusht_eval_output
|
| 217 |
+
├── eval_log.json
|
| 218 |
+
└── media
|
| 219 |
+
├── 1fxtno84.mp4
|
| 220 |
+
├── 224l7jqd.mp4
|
| 221 |
+
├── 2fo4btlf.mp4
|
| 222 |
+
├── 2in4cn7a.mp4
|
| 223 |
+
├── 34b3o2qq.mp4
|
| 224 |
+
└── 3p7jqn32.mp4
|
| 225 |
+
|
| 226 |
+
1 directory, 7 files
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
`eval_log.json` contains metrics that is logged to wandb during training:
|
| 230 |
+
```console
|
| 231 |
+
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
|
| 232 |
+
{
|
| 233 |
+
"test/mean_score": 0.9150393806777066,
|
| 234 |
+
"test/sim_max_reward_4300000": 1.0,
|
| 235 |
+
"test/sim_max_reward_4300001": 0.9872969750774386,
|
| 236 |
+
...
|
| 237 |
+
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
|
| 238 |
+
}
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
## 🦾 Demo, Training and Eval on a Real Robot
|
| 242 |
+
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
| 243 |
+
|
| 244 |
+
Start the demonstration collection script. Press "C" to start recording. Use SpaceMouse to move the robot. Press "S" to stop recording.
|
| 245 |
+
```console
|
| 246 |
+
(robodiff)[diffusion_policy]$ python demo_real_robot.py -o data/demo_pusht_real --robot_ip 192.168.0.204
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
This should result in a demonstration dataset in `data/demo_pusht_real` with in the same structure as our example [real Push-T training dataset](https://diffusion-policy.cs.columbia.edu/data/training/pusht_real.zip).
|
| 250 |
+
|
| 251 |
+
To train a Diffusion Policy, launch training with config:
|
| 252 |
+
```console
|
| 253 |
+
(robodiff)[diffusion_policy]$ python train.py --config-name=train_diffusion_unet_real_image_workspace task.dataset_path=data/demo_pusht_real
|
| 254 |
+
```
|
| 255 |
+
Edit [`diffusion_policy/config/task/real_pusht_image.yaml`](./diffusion_policy/config/task/real_pusht_image.yaml) if your camera setup is different.
|
| 256 |
+
|
| 257 |
+
Assuming the training has finished and you have a checkpoint at `data/outputs/blah/checkpoints/latest.ckpt`, launch the evaluation script with:
|
| 258 |
+
```console
|
| 259 |
+
python eval_real_robot.py -i data/outputs/blah/checkpoints/latest.ckpt -o data/eval_pusht_real --robot_ip 192.168.0.204
|
| 260 |
+
```
|
| 261 |
+
Press "C" to start evaluation (handing control over to the policy). Press "S" to stop the current episode.
|
| 262 |
+
|
| 263 |
+
## 🗺️ Codebase Tutorial
|
| 264 |
+
This codebase is structured under the requirement that:
|
| 265 |
+
1. implementing `N` tasks and `M` methods will only require `O(N+M)` amount of code instead of `O(N*M)`
|
| 266 |
+
2. while retaining maximum flexibility.
|
| 267 |
+
|
| 268 |
+
To achieve this requirement, we
|
| 269 |
+
1. maintained a simple unified interface between tasks and methods and
|
| 270 |
+
2. made the implementation of the tasks and the methods independent of each other.
|
| 271 |
+
|
| 272 |
+
These design decisions come at the cost of code repetition between the tasks and the methods. However, we believe that the benefit of being able to add/modify task/methods without affecting the remainder and being able understand a task/method by reading the code linearly outweighs the cost of copying and pasting 😊.
|
| 273 |
+
|
| 274 |
+
### The Split
|
| 275 |
+
On the task side, we have:
|
| 276 |
+
* `Dataset`: adapts a (third-party) dataset to the interface.
|
| 277 |
+
* `EnvRunner`: executes a `Policy` that accepts the interface and produce logs and metrics.
|
| 278 |
+
* `config/task/<task_name>.yaml`: contains all information needed to construct `Dataset` and `EnvRunner`.
|
| 279 |
+
* (optional) `Env`: an `gym==0.21.0` compatible class that encapsulates the task environment.
|
| 280 |
+
|
| 281 |
+
On the policy side, we have:
|
| 282 |
+
* `Policy`: implements inference according to the interface and part of the training process.
|
| 283 |
+
* `Workspace`: manages the life-cycle of training and evaluation (interleaved) of a method.
|
| 284 |
+
* `config/<workspace_name>.yaml`: contains all information needed to construct `Policy` and `Workspace`.
|
| 285 |
+
|
| 286 |
+
### The Interface
|
| 287 |
+
#### Low Dim
|
| 288 |
+
A [`LowdimPolicy`](./diffusion_policy/policy/base_lowdim_policy.py) takes observation dictionary:
|
| 289 |
+
- `"obs":` Tensor of shape `(B,To,Do)`
|
| 290 |
+
|
| 291 |
+
and predicts action dictionary:
|
| 292 |
+
- `"action": ` Tensor of shape `(B,Ta,Da)`
|
| 293 |
+
|
| 294 |
+
A [`LowdimDataset`](./diffusion_policy/dataset/base_dataset.py) returns a sample of dictionary:
|
| 295 |
+
- `"obs":` Tensor of shape `(To, Do)`
|
| 296 |
+
- `"action":` Tensor of shape `(Ta, Da)`
|
| 297 |
+
|
| 298 |
+
Its `get_normalizer` method returns a [`LinearNormalizer`](./diffusion_policy/model/common/normalizer.py) with keys `"obs","action"`.
|
| 299 |
+
|
| 300 |
+
The `Policy` handles normalization on GPU with its copy of the `LinearNormalizer`. The parameters of the `LinearNormalizer` is saved as part of the `Policy`'s weights checkpoint.
|
| 301 |
+
|
| 302 |
+
#### Image
|
| 303 |
+
A [`ImagePolicy`](./diffusion_policy/policy/base_image_policy.py) takes observation dictionary:
|
| 304 |
+
- `"key0":` Tensor of shape `(B,To,*)`
|
| 305 |
+
- `"key1":` Tensor of shape e.g. `(B,To,H,W,3)` ([0,1] float32)
|
| 306 |
+
|
| 307 |
+
and predicts action dictionary:
|
| 308 |
+
- `"action": ` Tensor of shape `(B,Ta,Da)`
|
| 309 |
+
|
| 310 |
+
A [`ImageDataset`](./diffusion_policy/dataset/base_dataset.py) returns a sample of dictionary:
|
| 311 |
+
- `"obs":` Dict of
|
| 312 |
+
- `"key0":` Tensor of shape `(To, *)`
|
| 313 |
+
- `"key1":` Tensor fo shape `(To,H,W,3)`
|
| 314 |
+
- `"action":` Tensor of shape `(Ta, Da)`
|
| 315 |
+
|
| 316 |
+
Its `get_normalizer` method returns a [`LinearNormalizer`](./diffusion_policy/model/common/normalizer.py) with keys `"key0","key1","action"`.
|
| 317 |
+
|
| 318 |
+
#### Example
|
| 319 |
+
```
|
| 320 |
+
To = 3
|
| 321 |
+
Ta = 4
|
| 322 |
+
T = 6
|
| 323 |
+
|o|o|o|
|
| 324 |
+
| | |a|a|a|a|
|
| 325 |
+
|o|o|
|
| 326 |
+
| |a|a|a|a|a|
|
| 327 |
+
| | | | |a|a|
|
| 328 |
+
```
|
| 329 |
+
Terminology in the paper: `varname` in the codebase
|
| 330 |
+
- Observation Horizon: `To|n_obs_steps`
|
| 331 |
+
- Action Horizon: `Ta|n_action_steps`
|
| 332 |
+
- Prediction Horizon: `T|horizon`
|
| 333 |
+
|
| 334 |
+
The classical (e.g. MDP) single step observation/action formulation is included as a special case where `To=1` and `Ta=1`.
|
| 335 |
+
|
| 336 |
+
## 🔩 Key Components
|
| 337 |
+
### `Workspace`
|
| 338 |
+
A `Workspace` object encapsulates all states and code needed to run an experiment.
|
| 339 |
+
* Inherits from [`BaseWorkspace`](./diffusion_policy/workspace/base_workspace.py).
|
| 340 |
+
* A single `OmegaConf` config object generated by `hydra` should contain all information needed to construct the Workspace object and running experiments. This config correspond to `config/<workspace_name>.yaml` + hydra overrides.
|
| 341 |
+
* The `run` method contains the entire pipeline for the experiment.
|
| 342 |
+
* Checkpoints happen at the `Workspace` level. All training states implemented as object attributes are automatically saved by the `save_checkpoint` method.
|
| 343 |
+
* All other states for the experiment should be implemented as local variables in the `run` method.
|
| 344 |
+
|
| 345 |
+
The entrypoint for training is `train.py` which uses `@hydra.main` decorator. Read [hydra](https://hydra.cc/)'s official documentation for command line arguments and config overrides. For example, the argument `task=<task_name>` will replace the `task` subtree of the config with the content of `config/task/<task_name>.yaml`, thereby selecting the task to run for this experiment.
|
| 346 |
+
|
| 347 |
+
### `Dataset`
|
| 348 |
+
A `Dataset` object:
|
| 349 |
+
* Inherits from `torch.utils.data.Dataset`.
|
| 350 |
+
* Returns a sample conforming to [the interface](#the-interface) depending on whether the task has Low Dim or Image observations.
|
| 351 |
+
* Has a method `get_normalizer` that returns a `LinearNormalizer` conforming to [the interface](#the-interface).
|
| 352 |
+
|
| 353 |
+
Normalization is a very common source of bugs during project development. It is sometimes helpful to print out the specific `scale` and `bias` vectors used for each key in the `LinearNormalizer`.
|
| 354 |
+
|
| 355 |
+
Most of our implementations of `Dataset` uses a combination of [`ReplayBuffer`](#replaybuffer) and [`SequenceSampler`](./diffusion_policy/common/sampler.py) to generate samples. Correctly handling padding at the beginning and the end of each demonstration episode according to `To` and `Ta` is important for good performance. Please read our [`SequenceSampler`](./diffusion_policy/common/sampler.py) before implementing your own sampling method.
|
| 356 |
+
|
| 357 |
+
### `Policy`
|
| 358 |
+
A `Policy` object:
|
| 359 |
+
* Inherits from `BaseLowdimPolicy` or `BaseImagePolicy`.
|
| 360 |
+
* Has a method `predict_action` that given observation dict, predicts actions conforming to [the interface](#the-interface).
|
| 361 |
+
* Has a method `set_normalizer` that takes in a `LinearNormalizer` and handles observation/action normalization internally in the policy.
|
| 362 |
+
* (optional) Might has a method `compute_loss` that takes in a batch and returns the loss to be optimized.
|
| 363 |
+
* (optional) Usually each `Policy` class correspond to a `Workspace` class due to the differences of training and evaluation process between methods.
|
| 364 |
+
|
| 365 |
+
### `EnvRunner`
|
| 366 |
+
A `EnvRunner` object abstracts away the subtle differences between different task environments.
|
| 367 |
+
* Has a method `run` that takes a `Policy` object for evaluation, and returns a dict of logs and metrics. Each value should be compatible with `wandb.log`.
|
| 368 |
+
|
| 369 |
+
To maximize evaluation speed, we usually vectorize environments using our modification of [`gym.vector.AsyncVectorEnv`](./diffusion_policy/gym_util/async_vector_env.py) which runs each individual environment in a separate process (workaround python GIL).
|
| 370 |
+
|
| 371 |
+
⚠️ Since subprocesses are launched using `fork` on linux, you need to be specially careful for environments that creates its OpenGL context during initialization (e.g. robosuite) which, once inherited by the child process memory space, often causes obscure bugs like segmentation fault. As a workaround, you can provide a `dummy_env_fn` that constructs an environment without initializing OpenGL.
|
| 372 |
+
|
| 373 |
+
### `ReplayBuffer`
|
| 374 |
+
The [`ReplayBuffer`](./diffusion_policy/common/replay_buffer.py) is a key data structure for storing a demonstration dataset both in-memory and on-disk with chunking and compression. It makes heavy use of the [`zarr`](https://zarr.readthedocs.io/en/stable/index.html) format but also has a `numpy` backend for lower access overhead.
|
| 375 |
+
|
| 376 |
+
On disk, it can be stored as a nested directory (e.g. `data/pusht_cchi_v7_replay.zarr`) or a zip file (e.g. `data/robomimic/datasets/square/mh/image_abs.hdf5.zarr.zip`).
|
| 377 |
+
|
| 378 |
+
Due to the relative small size of our datasets, it's often possible to store the entire image-based dataset in RAM with [`Jpeg2000` compression](./diffusion_policy/codecs/imagecodecs_numcodecs.py) which eliminates disk IO during training at the expense increasing of CPU workload.
|
| 379 |
+
|
| 380 |
+
Example:
|
| 381 |
+
```
|
| 382 |
+
data/pusht_cchi_v7_replay.zarr
|
| 383 |
+
├── data
|
| 384 |
+
│ ├── action (25650, 2) float32
|
| 385 |
+
│ ├── img (25650, 96, 96, 3) float32
|
| 386 |
+
│ ├── keypoint (25650, 9, 2) float32
|
| 387 |
+
│ ├── n_contacts (25650, 1) float32
|
| 388 |
+
│ └── state (25650, 5) float32
|
| 389 |
+
└── meta
|
| 390 |
+
└── episode_ends (206,) int64
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
Each array in `data` stores one data field from all episodes concatenated along the first dimension (time). The `meta/episode_ends` array stores the end index for each episode along the fist dimension.
|
| 394 |
+
|
| 395 |
+
### `SharedMemoryRingBuffer`
|
| 396 |
+
The [`SharedMemoryRingBuffer`](./diffusion_policy/shared_memory/shared_memory_ring_buffer.py) is a lock-free FILO data structure used extensively in our [real robot implementation](./diffusion_policy/real_world) to utilize multiple CPU cores while avoiding pickle serialization and locking overhead for `multiprocessing.Queue`.
|
| 397 |
+
|
| 398 |
+
As an example, we would like to get the most recent `To` frames from 5 RealSense cameras. We launch 1 realsense SDK/pipeline per process using [`SingleRealsense`](./diffusion_policy/real_world/single_realsense.py), each continuously writes the captured images into a `SharedMemoryRingBuffer` shared with the main process. We can very quickly get the last `To` frames in the main process due to the FILO nature of `SharedMemoryRingBuffer`.
|
| 399 |
+
|
| 400 |
+
We also implemented [`SharedMemoryQueue`](./diffusion_policy/shared_memory/shared_memory_queue.py) for FIFO, which is used in [`RTDEInterpolationController`](./diffusion_policy/real_world/rtde_interpolation_controller.py).
|
| 401 |
+
|
| 402 |
+
### `RealEnv`
|
| 403 |
+
In contrast to [OpenAI Gym](https://gymnasium.farama.org/), our polices interact with the environment asynchronously. In [`RealEnv`](./diffusion_policy/real_world/real_env.py), the `step` method in `gym` is split into two methods: `get_obs` and `exec_actions`.
|
| 404 |
+
|
| 405 |
+
The `get_obs` method returns the latest observation from `SharedMemoryRingBuffer` as well as their corresponding timestamps. This method can be call at any time during an evaluation episode.
|
| 406 |
+
|
| 407 |
+
The `exec_actions` method accepts a sequence of actions and timestamps for the expected time of execution for each step. Once called, the actions are simply enqueued to the `RTDEInterpolationController`, and the method returns without blocking for execution.
|
| 408 |
+
|
| 409 |
+
## 🩹 Adding a Task
|
| 410 |
+
Read and imitate:
|
| 411 |
+
* `diffusion_policy/dataset/pusht_image_dataset.py`
|
| 412 |
+
* `diffusion_policy/env_runner/pusht_image_runner.py`
|
| 413 |
+
* `diffusion_policy/config/task/pusht_image.yaml`
|
| 414 |
+
|
| 415 |
+
Make sure that `shape_meta` correspond to input and output shapes for your task. Make sure `env_runner._target_` and `dataset._target_` point to the new classes you have added. When training, add `task=<your_task_name>` to `train.py`'s arguments.
|
| 416 |
+
|
| 417 |
+
## 🩹 Adding a Method
|
| 418 |
+
Read and imitate:
|
| 419 |
+
* `diffusion_policy/workspace/train_diffusion_unet_image_workspace.py`
|
| 420 |
+
* `diffusion_policy/policy/diffusion_unet_image_policy.py`
|
| 421 |
+
* `diffusion_policy/config/train_diffusion_unet_image_workspace.yaml`
|
| 422 |
+
|
| 423 |
+
Make sure your workspace yaml's `_target_` points to the new workspace class you created.
|
| 424 |
+
|
| 425 |
+
## 🏷️ License
|
| 426 |
+
This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
|
| 427 |
+
|
| 428 |
+
## 🙏 Acknowledgement
|
| 429 |
+
* Our [`ConditionalUnet1D`](./diffusion_policy/model/diffusion/conditional_unet1d.py) implementation is adapted from [Planning with Diffusion](https://github.com/jannerm/diffuser).
|
| 430 |
+
* Our [`TransformerForDiffusion`](./diffusion_policy/model/diffusion/transformer_for_diffusion.py) implementation is adapted from [MinGPT](https://github.com/karpathy/minGPT).
|
| 431 |
+
* The [BET](./diffusion_policy/model/bet) baseline is adapted from [its original repo](https://github.com/notmahi/bet).
|
| 432 |
+
* The [IBC](./diffusion_policy/policy/ibc_dfo_lowdim_policy.py) baseline is adapted from [Kevin Zakka's reimplementation](https://github.com/kevinzakka/ibc).
|
| 433 |
+
* The [Robomimic](https://github.com/ARISE-Initiative/robomimic) tasks and [`ObservationEncoder`](https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/models/obs_nets.py) are used extensively in this project.
|
| 434 |
+
* The [Push-T](./diffusion_policy/env/pusht) task is adapted from [IBC](https://github.com/google-research/ibc).
|
| 435 |
+
* The [Block Pushing](./diffusion_policy/env/block_pushing) task is adapted from [BET](https://github.com/notmahi/bet) and [IBC](https://github.com/google-research/ibc).
|
| 436 |
+
* The [Kitchen](./diffusion_policy/env/kitchen) task is adapted from [BET](https://github.com/notmahi/bet) and [Relay Policy Learning](https://github.com/google-research/relay-policy-learning).
|
| 437 |
+
* Our [shared_memory](./diffusion_policy/shared_memory) data structures are heavily inspired by [shared-ndarray2](https://gitlab.com/osu-nrsg/shared-ndarray2).
|
third_party/diffusion_policy/conda_environment.yaml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: robodiff
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- pytorch3d
|
| 5 |
+
- nvidia
|
| 6 |
+
- conda-forge
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.9
|
| 9 |
+
- pip=22.2.2
|
| 10 |
+
- cudatoolkit=11.6
|
| 11 |
+
- pytorch=1.12.1
|
| 12 |
+
- torchvision=0.13.1
|
| 13 |
+
- pytorch3d=0.7.0
|
| 14 |
+
- numpy=1.23.3
|
| 15 |
+
- numba==0.56.4
|
| 16 |
+
- scipy==1.9.1
|
| 17 |
+
- py-opencv=4.6.0
|
| 18 |
+
- cffi=1.15.1
|
| 19 |
+
- ipykernel=6.16
|
| 20 |
+
- matplotlib=3.6.1
|
| 21 |
+
- zarr=2.12.0
|
| 22 |
+
- numcodecs=0.10.2
|
| 23 |
+
- h5py=3.7.0
|
| 24 |
+
- hydra-core=1.2.0
|
| 25 |
+
- einops=0.4.1
|
| 26 |
+
- tqdm=4.64.1
|
| 27 |
+
- dill=0.3.5.1
|
| 28 |
+
- scikit-video=1.1.11
|
| 29 |
+
- scikit-image=0.19.3
|
| 30 |
+
- gym=0.21.0
|
| 31 |
+
- pymunk=6.2.1
|
| 32 |
+
- wandb=0.13.3
|
| 33 |
+
- threadpoolctl=3.1.0
|
| 34 |
+
- shapely=1.8.4
|
| 35 |
+
- cython=0.29.32
|
| 36 |
+
- imageio=2.22.0
|
| 37 |
+
- imageio-ffmpeg=0.4.7
|
| 38 |
+
- termcolor=2.0.1
|
| 39 |
+
- tensorboard=2.10.1
|
| 40 |
+
- tensorboardx=2.5.1
|
| 41 |
+
- psutil=5.9.2
|
| 42 |
+
- click=8.0.4
|
| 43 |
+
- boto3=1.24.96
|
| 44 |
+
- accelerate=0.13.2
|
| 45 |
+
- datasets=2.6.1
|
| 46 |
+
- diffusers=0.11.1
|
| 47 |
+
- av=10.0.0
|
| 48 |
+
- cmake=3.24.3
|
| 49 |
+
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
| 50 |
+
- llvm-openmp=14
|
| 51 |
+
# trick to force reinstall imagecodecs via pip
|
| 52 |
+
- imagecodecs==2022.8.8
|
| 53 |
+
- pip:
|
| 54 |
+
- ray[default,tune]==2.2.0
|
| 55 |
+
# requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
|
| 56 |
+
- free-mujoco-py==2.1.6
|
| 57 |
+
- pygame==2.1.2
|
| 58 |
+
- pybullet-svl==3.1.6.4
|
| 59 |
+
- robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz
|
| 60 |
+
- robomimic==0.2.0
|
| 61 |
+
- pytorchvideo==0.1.5
|
| 62 |
+
# pip package required for jpeg-xl
|
| 63 |
+
- imagecodecs==2022.9.26
|
| 64 |
+
- r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
|
| 65 |
+
- dm-control==1.0.9
|
third_party/diffusion_policy/conda_environment_macos.yaml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: robodiff
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- conda-forge
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.9
|
| 7 |
+
- pip=22.2.2
|
| 8 |
+
- pytorch=1.12.1
|
| 9 |
+
- torchvision=0.13.1
|
| 10 |
+
- numpy=1.23.3
|
| 11 |
+
- numba==0.56.4
|
| 12 |
+
- scipy==1.9.1
|
| 13 |
+
- py-opencv=4.6.0
|
| 14 |
+
- cffi=1.15.1
|
| 15 |
+
- ipykernel=6.16
|
| 16 |
+
- matplotlib=3.6.1
|
| 17 |
+
- zarr=2.12.0
|
| 18 |
+
- numcodecs=0.10.2
|
| 19 |
+
- h5py=3.7.0
|
| 20 |
+
- hydra-core=1.2.0
|
| 21 |
+
- einops=0.4.1
|
| 22 |
+
- tqdm=4.64.1
|
| 23 |
+
- dill=0.3.5.1
|
| 24 |
+
- scikit-video=1.1.11
|
| 25 |
+
- scikit-image=0.19.3
|
| 26 |
+
- gym=0.21.0
|
| 27 |
+
- pymunk=6.2.1
|
| 28 |
+
- wandb=0.13.3
|
| 29 |
+
- threadpoolctl=3.1.0
|
| 30 |
+
- shapely=1.8.4
|
| 31 |
+
- cython=0.29.32
|
| 32 |
+
- imageio=2.22.0
|
| 33 |
+
- imageio-ffmpeg=0.4.7
|
| 34 |
+
- termcolor=2.0.1
|
| 35 |
+
- tensorboard=2.10.1
|
| 36 |
+
- tensorboardx=2.5.1
|
| 37 |
+
- psutil=5.9.2
|
| 38 |
+
- click=8.0.4
|
| 39 |
+
- boto3=1.24.96
|
| 40 |
+
- accelerate=0.13.2
|
| 41 |
+
- datasets=2.6.1
|
| 42 |
+
- diffusers=0.11.1
|
| 43 |
+
- av=10.0.0
|
| 44 |
+
- cmake=3.24.3
|
| 45 |
+
# trick to force reinstall imagecodecs via pip
|
| 46 |
+
- imagecodecs==2022.8.8
|
| 47 |
+
- pip:
|
| 48 |
+
- ray[default,tune]==2.2.0
|
| 49 |
+
- pygame==2.1.2
|
| 50 |
+
- robomimic==0.2.0
|
| 51 |
+
- pytorchvideo==0.1.5
|
| 52 |
+
- atomics==1.0.2
|
| 53 |
+
# No support for jpeg-xl for MacOS
|
| 54 |
+
- imagecodecs==2022.9.26
|
| 55 |
+
- r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
|
third_party/diffusion_policy/conda_environment_real.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: robodiff
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- pytorch3d
|
| 5 |
+
- nvidia
|
| 6 |
+
- conda-forge
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.9
|
| 9 |
+
- pip=22.2.2
|
| 10 |
+
- cudatoolkit=11.6
|
| 11 |
+
- pytorch=1.12.1
|
| 12 |
+
- torchvision=0.13.1
|
| 13 |
+
- pytorch3d=0.7.0
|
| 14 |
+
- numpy=1.23.3
|
| 15 |
+
- numba==0.56.4
|
| 16 |
+
- scipy==1.9.1
|
| 17 |
+
- py-opencv=4.6.0
|
| 18 |
+
- cffi=1.15.1
|
| 19 |
+
- ipykernel=6.16
|
| 20 |
+
- matplotlib=3.6.1
|
| 21 |
+
- zarr=2.12.0
|
| 22 |
+
- numcodecs=0.10.2
|
| 23 |
+
- h5py=3.7.0
|
| 24 |
+
- hydra-core=1.2.0
|
| 25 |
+
- einops=0.4.1
|
| 26 |
+
- tqdm=4.64.1
|
| 27 |
+
- dill=0.3.5.1
|
| 28 |
+
- scikit-video=1.1.11
|
| 29 |
+
- scikit-image=0.19.3
|
| 30 |
+
- gym=0.21.0
|
| 31 |
+
- pymunk=6.2.1
|
| 32 |
+
- wandb=0.13.3
|
| 33 |
+
- threadpoolctl=3.1.0
|
| 34 |
+
- shapely=1.8.4
|
| 35 |
+
- cython=0.29.32
|
| 36 |
+
- imageio=2.22.0
|
| 37 |
+
- imageio-ffmpeg=0.4.7
|
| 38 |
+
- termcolor=2.0.1
|
| 39 |
+
- tensorboard=2.10.1
|
| 40 |
+
- tensorboardx=2.5.1
|
| 41 |
+
- psutil=5.9.2
|
| 42 |
+
- click=8.0.4
|
| 43 |
+
- boto3=1.24.96
|
| 44 |
+
- accelerate=0.13.2
|
| 45 |
+
- datasets=2.6.1
|
| 46 |
+
- diffusers=0.11.1
|
| 47 |
+
- av=10.0.0
|
| 48 |
+
- cmake=3.24.3
|
| 49 |
+
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
| 50 |
+
- llvm-openmp=14
|
| 51 |
+
# trick to force reinstall imagecodecs via pip
|
| 52 |
+
- imagecodecs==2022.8.8
|
| 53 |
+
- pip:
|
| 54 |
+
- ray[default,tune]==2.2.0
|
| 55 |
+
# requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
|
| 56 |
+
- free-mujoco-py==2.1.6
|
| 57 |
+
- pygame==2.1.2
|
| 58 |
+
- pybullet-svl==3.1.6.4
|
| 59 |
+
- robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz
|
| 60 |
+
- robomimic==0.2.0
|
| 61 |
+
- pytorchvideo==0.1.5
|
| 62 |
+
# requires librealsense https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md
|
| 63 |
+
- pyrealsense2==2.51.1.4348
|
| 64 |
+
# requires apt install libspnav-dev spacenavd; systemctl start spacenavd
|
| 65 |
+
- spnav @ https://github.com/cheng-chi/spnav/archive/c1c938ebe3cc542db4685e0d13850ff1abfdb943.tar.gz
|
| 66 |
+
- ur-rtde==1.5.5
|
| 67 |
+
- atomics==1.0.2
|
| 68 |
+
# pip package required for jpeg-xl
|
| 69 |
+
- imagecodecs==2022.9.26
|
| 70 |
+
- r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
|
| 71 |
+
- dm-control==1.0.9
|
| 72 |
+
- pynput==1.7.6
|
| 73 |
+
|
third_party/diffusion_policy/demo_pusht.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import click
|
| 3 |
+
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
| 4 |
+
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
| 5 |
+
import pygame
|
| 6 |
+
|
| 7 |
+
@click.command()
|
| 8 |
+
@click.option('-o', '--output', required=True)
|
| 9 |
+
@click.option('-rs', '--render_size', default=96, type=int)
|
| 10 |
+
@click.option('-hz', '--control_hz', default=10, type=int)
|
| 11 |
+
def main(output, render_size, control_hz):
|
| 12 |
+
"""
|
| 13 |
+
Collect demonstration for the Push-T task.
|
| 14 |
+
|
| 15 |
+
Usage: python demo_pusht.py -o data/pusht_demo.zarr
|
| 16 |
+
|
| 17 |
+
This script is compatible with both Linux and MacOS.
|
| 18 |
+
Hover mouse close to the blue circle to start.
|
| 19 |
+
Push the T block into the green area.
|
| 20 |
+
The episode will automatically terminate if the task is succeeded.
|
| 21 |
+
Press "Q" to exit.
|
| 22 |
+
Press "R" to retry.
|
| 23 |
+
Hold "Space" to pause.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# create replay buffer in read-write mode
|
| 27 |
+
replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
|
| 28 |
+
|
| 29 |
+
# create PushT env with keypoints
|
| 30 |
+
kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
|
| 31 |
+
env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs)
|
| 32 |
+
agent = env.teleop_agent()
|
| 33 |
+
clock = pygame.time.Clock()
|
| 34 |
+
|
| 35 |
+
# episode-level while loop
|
| 36 |
+
while True:
|
| 37 |
+
episode = list()
|
| 38 |
+
# record in seed order, starting with 0
|
| 39 |
+
seed = replay_buffer.n_episodes
|
| 40 |
+
print(f'starting seed {seed}')
|
| 41 |
+
|
| 42 |
+
# set seed for env
|
| 43 |
+
env.seed(seed)
|
| 44 |
+
|
| 45 |
+
# reset env and get observations (including info and render for recording)
|
| 46 |
+
obs = env.reset()
|
| 47 |
+
info = env._get_info()
|
| 48 |
+
img = env.render(mode='human')
|
| 49 |
+
|
| 50 |
+
# loop state
|
| 51 |
+
retry = False
|
| 52 |
+
pause = False
|
| 53 |
+
done = False
|
| 54 |
+
plan_idx = 0
|
| 55 |
+
pygame.display.set_caption(f'plan_idx:{plan_idx}')
|
| 56 |
+
# step-level while loop
|
| 57 |
+
while not done:
|
| 58 |
+
# process keypress events
|
| 59 |
+
for event in pygame.event.get():
|
| 60 |
+
if event.type == pygame.KEYDOWN:
|
| 61 |
+
if event.key == pygame.K_SPACE:
|
| 62 |
+
# hold Space to pause
|
| 63 |
+
plan_idx += 1
|
| 64 |
+
pygame.display.set_caption(f'plan_idx:{plan_idx}')
|
| 65 |
+
pause = True
|
| 66 |
+
elif event.key == pygame.K_r:
|
| 67 |
+
# press "R" to retry
|
| 68 |
+
retry=True
|
| 69 |
+
elif event.key == pygame.K_q:
|
| 70 |
+
# press "Q" to exit
|
| 71 |
+
exit(0)
|
| 72 |
+
if event.type == pygame.KEYUP:
|
| 73 |
+
if event.key == pygame.K_SPACE:
|
| 74 |
+
pause = False
|
| 75 |
+
|
| 76 |
+
# handle control flow
|
| 77 |
+
if retry:
|
| 78 |
+
break
|
| 79 |
+
if pause:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# get action from mouse
|
| 83 |
+
# None if mouse is not close to the agent
|
| 84 |
+
act = agent.act(obs)
|
| 85 |
+
if not act is None:
|
| 86 |
+
# teleop started
|
| 87 |
+
# state dim 2+3
|
| 88 |
+
state = np.concatenate([info['pos_agent'], info['block_pose']])
|
| 89 |
+
# discard unused information such as visibility mask and agent pos
|
| 90 |
+
# for compatibility
|
| 91 |
+
keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9]
|
| 92 |
+
data = {
|
| 93 |
+
'img': img,
|
| 94 |
+
'state': np.float32(state),
|
| 95 |
+
'keypoint': np.float32(keypoint),
|
| 96 |
+
'action': np.float32(act),
|
| 97 |
+
'n_contacts': np.float32([info['n_contacts']])
|
| 98 |
+
}
|
| 99 |
+
episode.append(data)
|
| 100 |
+
|
| 101 |
+
# step env and render
|
| 102 |
+
obs, reward, done, info = env.step(act)
|
| 103 |
+
img = env.render(mode='human')
|
| 104 |
+
|
| 105 |
+
# regulate control frequency
|
| 106 |
+
clock.tick(control_hz)
|
| 107 |
+
if not retry:
|
| 108 |
+
# save episode buffer to replay buffer (on disk)
|
| 109 |
+
data_dict = dict()
|
| 110 |
+
for key in episode[0].keys():
|
| 111 |
+
data_dict[key] = np.stack(
|
| 112 |
+
[x[key] for x in episode])
|
| 113 |
+
replay_buffer.add_episode(data_dict, compressors='disk')
|
| 114 |
+
print(f'saved seed {seed}')
|
| 115 |
+
else:
|
| 116 |
+
print(f'retry seed {seed}')
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
main()
|
third_party/diffusion_policy/demo_real_robot.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
(robodiff)$ python demo_real_robot.py -o <demo_save_dir> --robot_ip <ip_of_ur5>
|
| 4 |
+
|
| 5 |
+
Robot movement:
|
| 6 |
+
Move your SpaceMouse to move the robot EEF (locked in xy plane).
|
| 7 |
+
Press SpaceMouse right button to unlock z axis.
|
| 8 |
+
Press SpaceMouse left button to enable rotation axes.
|
| 9 |
+
|
| 10 |
+
Recording control:
|
| 11 |
+
Click the opencv window (make sure it's in focus).
|
| 12 |
+
Press "C" to start recording.
|
| 13 |
+
Press "S" to stop recording.
|
| 14 |
+
Press "Q" to exit program.
|
| 15 |
+
Press "Backspace" to delete the previously recorded episode.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# %%
|
| 19 |
+
import time
|
| 20 |
+
from multiprocessing.managers import SharedMemoryManager
|
| 21 |
+
import click
|
| 22 |
+
import cv2
|
| 23 |
+
import numpy as np
|
| 24 |
+
import scipy.spatial.transform as st
|
| 25 |
+
from diffusion_policy.real_world.real_env import RealEnv
|
| 26 |
+
from diffusion_policy.real_world.spacemouse_shared_memory import Spacemouse
|
| 27 |
+
from diffusion_policy.common.precise_sleep import precise_wait
|
| 28 |
+
from diffusion_policy.real_world.keystroke_counter import (
|
| 29 |
+
KeystrokeCounter, Key, KeyCode
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@click.command()
|
| 33 |
+
@click.option('--output', '-o', required=True, help="Directory to save demonstration dataset.")
|
| 34 |
+
@click.option('--robot_ip', '-ri', required=True, help="UR5's IP address e.g. 192.168.0.204")
|
| 35 |
+
@click.option('--vis_camera_idx', default=0, type=int, help="Which RealSense camera to visualize.")
|
| 36 |
+
@click.option('--init_joints', '-j', is_flag=True, default=False, help="Whether to initialize robot joint configuration in the beginning.")
|
| 37 |
+
@click.option('--frequency', '-f', default=10, type=float, help="Control frequency in Hz.")
|
| 38 |
+
@click.option('--command_latency', '-cl', default=0.01, type=float, help="Latency between receiving SapceMouse command to executing on Robot in Sec.")
|
| 39 |
+
def main(output, robot_ip, vis_camera_idx, init_joints, frequency, command_latency):
|
| 40 |
+
dt = 1/frequency
|
| 41 |
+
with SharedMemoryManager() as shm_manager:
|
| 42 |
+
with KeystrokeCounter() as key_counter, \
|
| 43 |
+
Spacemouse(shm_manager=shm_manager) as sm, \
|
| 44 |
+
RealEnv(
|
| 45 |
+
output_dir=output,
|
| 46 |
+
robot_ip=robot_ip,
|
| 47 |
+
# recording resolution
|
| 48 |
+
obs_image_resolution=(1280,720),
|
| 49 |
+
frequency=frequency,
|
| 50 |
+
init_joints=init_joints,
|
| 51 |
+
enable_multi_cam_vis=True,
|
| 52 |
+
record_raw_video=True,
|
| 53 |
+
# number of threads per camera view for video recording (H.264)
|
| 54 |
+
thread_per_video=3,
|
| 55 |
+
# video recording quality, lower is better (but slower).
|
| 56 |
+
video_crf=21,
|
| 57 |
+
shm_manager=shm_manager
|
| 58 |
+
) as env:
|
| 59 |
+
cv2.setNumThreads(1)
|
| 60 |
+
|
| 61 |
+
# realsense exposure
|
| 62 |
+
env.realsense.set_exposure(exposure=120, gain=0)
|
| 63 |
+
# realsense white balance
|
| 64 |
+
env.realsense.set_white_balance(white_balance=5900)
|
| 65 |
+
|
| 66 |
+
time.sleep(1.0)
|
| 67 |
+
print('Ready!')
|
| 68 |
+
state = env.get_robot_state()
|
| 69 |
+
target_pose = state['TargetTCPPose']
|
| 70 |
+
t_start = time.monotonic()
|
| 71 |
+
iter_idx = 0
|
| 72 |
+
stop = False
|
| 73 |
+
is_recording = False
|
| 74 |
+
while not stop:
|
| 75 |
+
# calculate timing
|
| 76 |
+
t_cycle_end = t_start + (iter_idx + 1) * dt
|
| 77 |
+
t_sample = t_cycle_end - command_latency
|
| 78 |
+
t_command_target = t_cycle_end + dt
|
| 79 |
+
|
| 80 |
+
# pump obs
|
| 81 |
+
obs = env.get_obs()
|
| 82 |
+
|
| 83 |
+
# handle key presses
|
| 84 |
+
press_events = key_counter.get_press_events()
|
| 85 |
+
for key_stroke in press_events:
|
| 86 |
+
if key_stroke == KeyCode(char='q'):
|
| 87 |
+
# Exit program
|
| 88 |
+
stop = True
|
| 89 |
+
elif key_stroke == KeyCode(char='c'):
|
| 90 |
+
# Start recording
|
| 91 |
+
env.start_episode(t_start + (iter_idx + 2) * dt - time.monotonic() + time.time())
|
| 92 |
+
key_counter.clear()
|
| 93 |
+
is_recording = True
|
| 94 |
+
print('Recording!')
|
| 95 |
+
elif key_stroke == KeyCode(char='s'):
|
| 96 |
+
# Stop recording
|
| 97 |
+
env.end_episode()
|
| 98 |
+
key_counter.clear()
|
| 99 |
+
is_recording = False
|
| 100 |
+
print('Stopped.')
|
| 101 |
+
elif key_stroke == Key.backspace:
|
| 102 |
+
# Delete the most recent recorded episode
|
| 103 |
+
if click.confirm('Are you sure to drop an episode?'):
|
| 104 |
+
env.drop_episode()
|
| 105 |
+
key_counter.clear()
|
| 106 |
+
is_recording = False
|
| 107 |
+
# delete
|
| 108 |
+
stage = key_counter[Key.space]
|
| 109 |
+
|
| 110 |
+
# visualize
|
| 111 |
+
vis_img = obs[f'camera_{vis_camera_idx}'][-1,:,:,::-1].copy()
|
| 112 |
+
episode_id = env.replay_buffer.n_episodes
|
| 113 |
+
text = f'Episode: {episode_id}, Stage: {stage}'
|
| 114 |
+
if is_recording:
|
| 115 |
+
text += ', Recording!'
|
| 116 |
+
cv2.putText(
|
| 117 |
+
vis_img,
|
| 118 |
+
text,
|
| 119 |
+
(10,30),
|
| 120 |
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
| 121 |
+
fontScale=1,
|
| 122 |
+
thickness=2,
|
| 123 |
+
color=(255,255,255)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
cv2.imshow('default', vis_img)
|
| 127 |
+
cv2.pollKey()
|
| 128 |
+
|
| 129 |
+
precise_wait(t_sample)
|
| 130 |
+
# get teleop command
|
| 131 |
+
sm_state = sm.get_motion_state_transformed()
|
| 132 |
+
# print(sm_state)
|
| 133 |
+
dpos = sm_state[:3] * (env.max_pos_speed / frequency)
|
| 134 |
+
drot_xyz = sm_state[3:] * (env.max_rot_speed / frequency)
|
| 135 |
+
|
| 136 |
+
if not sm.is_button_pressed(0):
|
| 137 |
+
# translation mode
|
| 138 |
+
drot_xyz[:] = 0
|
| 139 |
+
else:
|
| 140 |
+
dpos[:] = 0
|
| 141 |
+
if not sm.is_button_pressed(1):
|
| 142 |
+
# 2D translation mode
|
| 143 |
+
dpos[2] = 0
|
| 144 |
+
|
| 145 |
+
drot = st.Rotation.from_euler('xyz', drot_xyz)
|
| 146 |
+
target_pose[:3] += dpos
|
| 147 |
+
target_pose[3:] = (drot * st.Rotation.from_rotvec(
|
| 148 |
+
target_pose[3:])).as_rotvec()
|
| 149 |
+
|
| 150 |
+
# execute teleop command
|
| 151 |
+
env.exec_actions(
|
| 152 |
+
actions=[target_pose],
|
| 153 |
+
timestamps=[t_command_target-time.monotonic()+time.time()],
|
| 154 |
+
stages=[stage])
|
| 155 |
+
precise_wait(t_cycle_end)
|
| 156 |
+
iter_idx += 1
|
| 157 |
+
|
| 158 |
+
# %%
|
| 159 |
+
if __name__ == '__main__':
|
| 160 |
+
main()
|
third_party/diffusion_policy/diffusion_policy.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: diffusion_policy
|
| 3 |
+
Version: 0.0.0
|
| 4 |
+
License-File: LICENSE
|
| 5 |
+
Dynamic: license-file
|
third_party/diffusion_policy/diffusion_policy.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
setup.py
|
| 4 |
+
diffusion_policy.egg-info/PKG-INFO
|
| 5 |
+
diffusion_policy.egg-info/SOURCES.txt
|
| 6 |
+
diffusion_policy.egg-info/dependency_links.txt
|
| 7 |
+
diffusion_policy.egg-info/top_level.txt
|
| 8 |
+
tests/test_block_pushing.py
|
| 9 |
+
tests/test_cv2_util.py
|
| 10 |
+
tests/test_multi_realsense.py
|
| 11 |
+
tests/test_pose_trajectory_interpolator.py
|
| 12 |
+
tests/test_precise_sleep.py
|
| 13 |
+
tests/test_replay_buffer.py
|
| 14 |
+
tests/test_ring_buffer.py
|
| 15 |
+
tests/test_robomimic_image_runner.py
|
| 16 |
+
tests/test_robomimic_lowdim_runner.py
|
| 17 |
+
tests/test_shared_queue.py
|
| 18 |
+
tests/test_single_realsense.py
|
| 19 |
+
tests/test_timestamp_accumulator.py
|
third_party/diffusion_policy/diffusion_policy.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
third_party/diffusion_policy/diffusion_policy.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
third_party/diffusion_policy/diffusion_policy/workspace/train_diffusion_unet_lowdim_workspace.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
os.chdir(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hydra
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
import pathlib
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
import copy
|
| 17 |
+
import numpy as np
|
| 18 |
+
import random
|
| 19 |
+
import wandb
|
| 20 |
+
import tqdm
|
| 21 |
+
import shutil
|
| 22 |
+
|
| 23 |
+
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
|
| 24 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 25 |
+
from diffusion_policy.policy.diffusion_unet_lowdim_policy import DiffusionUnetLowdimPolicy
|
| 26 |
+
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
| 27 |
+
from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
|
| 28 |
+
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
|
| 29 |
+
from diffusion_policy.common.json_logger import JsonLogger
|
| 30 |
+
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
| 31 |
+
from diffusers.training_utils import EMAModel
|
| 32 |
+
|
| 33 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 34 |
+
|
| 35 |
+
# %%
|
| 36 |
+
class TrainDiffusionUnetLowdimWorkspace(BaseWorkspace):
|
| 37 |
+
include_keys = ['global_step', 'epoch']
|
| 38 |
+
|
| 39 |
+
def __init__(self, cfg: OmegaConf, output_dir=None):
|
| 40 |
+
super().__init__(cfg, output_dir=output_dir)
|
| 41 |
+
|
| 42 |
+
# set seed
|
| 43 |
+
seed = cfg.training.seed
|
| 44 |
+
torch.manual_seed(seed)
|
| 45 |
+
np.random.seed(seed)
|
| 46 |
+
random.seed(seed)
|
| 47 |
+
|
| 48 |
+
# configure model
|
| 49 |
+
self.model: DiffusionUnetLowdimPolicy
|
| 50 |
+
self.model = hydra.utils.instantiate(cfg.policy)
|
| 51 |
+
|
| 52 |
+
self.ema_model: DiffusionUnetLowdimPolicy = None
|
| 53 |
+
if cfg.training.use_ema:
|
| 54 |
+
self.ema_model = copy.deepcopy(self.model)
|
| 55 |
+
|
| 56 |
+
# configure training state
|
| 57 |
+
self.optimizer = hydra.utils.instantiate(
|
| 58 |
+
cfg.optimizer, params=self.model.parameters())
|
| 59 |
+
|
| 60 |
+
self.global_step = 0
|
| 61 |
+
self.epoch = 0
|
| 62 |
+
|
| 63 |
+
def run(self):
|
| 64 |
+
cfg = copy.deepcopy(self.cfg)
|
| 65 |
+
|
| 66 |
+
# resume training
|
| 67 |
+
if cfg.training.resume:
|
| 68 |
+
lastest_ckpt_path = self.get_checkpoint_path()
|
| 69 |
+
if lastest_ckpt_path.is_file():
|
| 70 |
+
print(f"Resuming from checkpoint {lastest_ckpt_path}")
|
| 71 |
+
self.load_checkpoint(path=lastest_ckpt_path)
|
| 72 |
+
|
| 73 |
+
# configure dataset
|
| 74 |
+
dataset: BaseLowdimDataset
|
| 75 |
+
dataset = hydra.utils.instantiate(cfg.task.dataset)
|
| 76 |
+
assert isinstance(dataset, BaseLowdimDataset)
|
| 77 |
+
train_dataloader = DataLoader(dataset, **cfg.dataloader)
|
| 78 |
+
normalizer = dataset.get_normalizer()
|
| 79 |
+
|
| 80 |
+
# configure validation dataset
|
| 81 |
+
val_dataset = dataset.get_validation_dataset()
|
| 82 |
+
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
|
| 83 |
+
|
| 84 |
+
self.model.set_normalizer(normalizer)
|
| 85 |
+
if cfg.training.use_ema:
|
| 86 |
+
self.ema_model.set_normalizer(normalizer)
|
| 87 |
+
|
| 88 |
+
# configure lr scheduler
|
| 89 |
+
lr_scheduler = get_scheduler(
|
| 90 |
+
cfg.training.lr_scheduler,
|
| 91 |
+
optimizer=self.optimizer,
|
| 92 |
+
num_warmup_steps=cfg.training.lr_warmup_steps,
|
| 93 |
+
num_training_steps=(
|
| 94 |
+
len(train_dataloader) * cfg.training.num_epochs) \
|
| 95 |
+
// cfg.training.gradient_accumulate_every,
|
| 96 |
+
# pytorch assumes stepping LRScheduler every epoch
|
| 97 |
+
# however huggingface diffusers steps it every batch
|
| 98 |
+
last_epoch=self.global_step-1
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# configure ema
|
| 102 |
+
ema: EMAModel = None
|
| 103 |
+
if cfg.training.use_ema:
|
| 104 |
+
ema = hydra.utils.instantiate(
|
| 105 |
+
cfg.ema,
|
| 106 |
+
model=self.ema_model)
|
| 107 |
+
|
| 108 |
+
# configure env runner
|
| 109 |
+
env_runner: BaseLowdimRunner
|
| 110 |
+
env_runner = hydra.utils.instantiate(
|
| 111 |
+
cfg.task.env_runner,
|
| 112 |
+
output_dir=self.output_dir)
|
| 113 |
+
assert isinstance(env_runner, BaseLowdimRunner)
|
| 114 |
+
|
| 115 |
+
# configure logging
|
| 116 |
+
wandb_run = wandb.init(
|
| 117 |
+
dir=str(self.output_dir),
|
| 118 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 119 |
+
**cfg.logging
|
| 120 |
+
)
|
| 121 |
+
wandb.config.update(
|
| 122 |
+
{
|
| 123 |
+
"output_dir": self.output_dir,
|
| 124 |
+
}
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# configure checkpoint
|
| 128 |
+
topk_manager = TopKCheckpointManager(
|
| 129 |
+
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
| 130 |
+
**cfg.checkpoint.topk
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# device transfer
|
| 134 |
+
device = torch.device(cfg.training.device)
|
| 135 |
+
self.model.to(device)
|
| 136 |
+
if self.ema_model is not None:
|
| 137 |
+
self.ema_model.to(device)
|
| 138 |
+
optimizer_to(self.optimizer, device)
|
| 139 |
+
|
| 140 |
+
# save batch for sampling
|
| 141 |
+
train_sampling_batch = None
|
| 142 |
+
|
| 143 |
+
if cfg.training.debug:
|
| 144 |
+
cfg.training.num_epochs = 2
|
| 145 |
+
cfg.training.max_train_steps = 3
|
| 146 |
+
cfg.training.max_val_steps = 3
|
| 147 |
+
cfg.training.rollout_every = 1
|
| 148 |
+
cfg.training.checkpoint_every = 1
|
| 149 |
+
cfg.training.val_every = 1
|
| 150 |
+
cfg.training.sample_every = 1
|
| 151 |
+
|
| 152 |
+
# training loop
|
| 153 |
+
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
| 154 |
+
with JsonLogger(log_path) as json_logger:
|
| 155 |
+
for local_epoch_idx in range(cfg.training.num_epochs):
|
| 156 |
+
step_log = dict()
|
| 157 |
+
# ========= train for this epoch ==========
|
| 158 |
+
train_losses = list()
|
| 159 |
+
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
| 160 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 161 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 162 |
+
# device transfer
|
| 163 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 164 |
+
if train_sampling_batch is None:
|
| 165 |
+
train_sampling_batch = batch
|
| 166 |
+
|
| 167 |
+
# compute loss
|
| 168 |
+
raw_loss = self.model.compute_loss(batch)
|
| 169 |
+
loss = raw_loss / cfg.training.gradient_accumulate_every
|
| 170 |
+
loss.backward()
|
| 171 |
+
|
| 172 |
+
# step optimizer
|
| 173 |
+
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
| 174 |
+
self.optimizer.step()
|
| 175 |
+
self.optimizer.zero_grad()
|
| 176 |
+
lr_scheduler.step()
|
| 177 |
+
|
| 178 |
+
# update ema
|
| 179 |
+
if cfg.training.use_ema:
|
| 180 |
+
ema.step(self.model)
|
| 181 |
+
|
| 182 |
+
# logging
|
| 183 |
+
raw_loss_cpu = raw_loss.item()
|
| 184 |
+
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
| 185 |
+
train_losses.append(raw_loss_cpu)
|
| 186 |
+
step_log = {
|
| 187 |
+
'train_loss': raw_loss_cpu,
|
| 188 |
+
'global_step': self.global_step,
|
| 189 |
+
'epoch': self.epoch,
|
| 190 |
+
'lr': lr_scheduler.get_last_lr()[0]
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
| 194 |
+
if not is_last_batch:
|
| 195 |
+
# log of last step is combined with validation and rollout
|
| 196 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 197 |
+
json_logger.log(step_log)
|
| 198 |
+
self.global_step += 1
|
| 199 |
+
|
| 200 |
+
if (cfg.training.max_train_steps is not None) \
|
| 201 |
+
and batch_idx >= (cfg.training.max_train_steps-1):
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
# at the end of each epoch
|
| 205 |
+
# replace train_loss with epoch average
|
| 206 |
+
train_loss = np.mean(train_losses)
|
| 207 |
+
step_log['train_loss'] = train_loss
|
| 208 |
+
|
| 209 |
+
# ========= eval for this epoch ==========
|
| 210 |
+
policy = self.model
|
| 211 |
+
if cfg.training.use_ema:
|
| 212 |
+
policy = self.ema_model
|
| 213 |
+
policy.eval()
|
| 214 |
+
|
| 215 |
+
# run rollout
|
| 216 |
+
if (self.epoch % cfg.training.rollout_every) == 0:
|
| 217 |
+
runner_log = env_runner.run(policy)
|
| 218 |
+
# log all
|
| 219 |
+
step_log.update(runner_log)
|
| 220 |
+
|
| 221 |
+
# run validation
|
| 222 |
+
if (self.epoch % cfg.training.val_every) == 0:
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
val_losses = list()
|
| 225 |
+
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
| 226 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 227 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 228 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 229 |
+
loss = self.model.compute_loss(batch)
|
| 230 |
+
val_losses.append(loss)
|
| 231 |
+
if (cfg.training.max_val_steps is not None) \
|
| 232 |
+
and batch_idx >= (cfg.training.max_val_steps-1):
|
| 233 |
+
break
|
| 234 |
+
if len(val_losses) > 0:
|
| 235 |
+
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
| 236 |
+
# log epoch average validation loss
|
| 237 |
+
step_log['val_loss'] = val_loss
|
| 238 |
+
|
| 239 |
+
# run diffusion sampling on a training batch
|
| 240 |
+
if (self.epoch % cfg.training.sample_every) == 0:
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
# sample trajectory from training set, and evaluate difference
|
| 243 |
+
batch = train_sampling_batch
|
| 244 |
+
obs_dict = {'obs': batch['obs']}
|
| 245 |
+
gt_action = batch['action']
|
| 246 |
+
|
| 247 |
+
result = policy.predict_action(obs_dict)
|
| 248 |
+
if cfg.pred_action_steps_only:
|
| 249 |
+
pred_action = result['action']
|
| 250 |
+
start = cfg.n_obs_steps - 1
|
| 251 |
+
end = start + cfg.n_action_steps
|
| 252 |
+
gt_action = gt_action[:,start:end]
|
| 253 |
+
else:
|
| 254 |
+
pred_action = result['action_pred']
|
| 255 |
+
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
| 256 |
+
# log
|
| 257 |
+
step_log['train_action_mse_error'] = mse.item()
|
| 258 |
+
# release RAM
|
| 259 |
+
del batch
|
| 260 |
+
del obs_dict
|
| 261 |
+
del gt_action
|
| 262 |
+
del result
|
| 263 |
+
del pred_action
|
| 264 |
+
del mse
|
| 265 |
+
|
| 266 |
+
# checkpoint
|
| 267 |
+
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
| 268 |
+
# checkpointing
|
| 269 |
+
if cfg.checkpoint.save_last_ckpt:
|
| 270 |
+
self.save_checkpoint()
|
| 271 |
+
if cfg.checkpoint.save_last_snapshot:
|
| 272 |
+
self.save_snapshot()
|
| 273 |
+
|
| 274 |
+
# sanitize metric names
|
| 275 |
+
metric_dict = dict()
|
| 276 |
+
for key, value in step_log.items():
|
| 277 |
+
new_key = key.replace('/', '_')
|
| 278 |
+
metric_dict[new_key] = value
|
| 279 |
+
|
| 280 |
+
# We can't copy the last checkpoint here
|
| 281 |
+
# since save_checkpoint uses threads.
|
| 282 |
+
# therefore at this point the file might have been empty!
|
| 283 |
+
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
| 284 |
+
|
| 285 |
+
if topk_ckpt_path is not None:
|
| 286 |
+
self.save_checkpoint(path=topk_ckpt_path)
|
| 287 |
+
# ========= eval end for this epoch ==========
|
| 288 |
+
policy.train()
|
| 289 |
+
|
| 290 |
+
# end of epoch
|
| 291 |
+
# log of last step is combined with validation and rollout
|
| 292 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 293 |
+
json_logger.log(step_log)
|
| 294 |
+
self.global_step += 1
|
| 295 |
+
self.epoch += 1
|
| 296 |
+
|
| 297 |
+
@hydra.main(
|
| 298 |
+
version_base=None,
|
| 299 |
+
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
|
| 300 |
+
config_name=pathlib.Path(__file__).stem)
|
| 301 |
+
def main(cfg):
|
| 302 |
+
workspace = TrainDiffusionUnetLowdimWorkspace(cfg)
|
| 303 |
+
workspace.run()
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
main()
|
third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_hybrid_workspace.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
os.chdir(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hydra
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
import pathlib
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
import copy
|
| 17 |
+
import random
|
| 18 |
+
import wandb
|
| 19 |
+
import tqdm
|
| 20 |
+
import numpy as np
|
| 21 |
+
import shutil
|
| 22 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 23 |
+
from diffusion_policy.policy.ibc_dfo_hybrid_image_policy import IbcDfoHybridImagePolicy
|
| 24 |
+
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
| 25 |
+
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
| 26 |
+
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
|
| 27 |
+
from diffusion_policy.common.json_logger import JsonLogger
|
| 28 |
+
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
|
| 29 |
+
from diffusion_policy.model.diffusion.ema_model import EMAModel
|
| 30 |
+
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
| 31 |
+
|
| 32 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 33 |
+
|
| 34 |
+
class TrainIbcDfoHybridWorkspace(BaseWorkspace):
|
| 35 |
+
include_keys = ['global_step', 'epoch']
|
| 36 |
+
|
| 37 |
+
def __init__(self, cfg: OmegaConf, output_dir=None):
|
| 38 |
+
super().__init__(cfg, output_dir=output_dir)
|
| 39 |
+
|
| 40 |
+
# set seed
|
| 41 |
+
seed = cfg.training.seed
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
np.random.seed(seed)
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
|
| 46 |
+
# configure model
|
| 47 |
+
self.model: IbcDfoHybridImagePolicy= hydra.utils.instantiate(cfg.policy)
|
| 48 |
+
|
| 49 |
+
# configure training state
|
| 50 |
+
self.optimizer = hydra.utils.instantiate(
|
| 51 |
+
cfg.optimizer, params=self.model.parameters())
|
| 52 |
+
|
| 53 |
+
# configure training state
|
| 54 |
+
self.global_step = 0
|
| 55 |
+
self.epoch = 0
|
| 56 |
+
|
| 57 |
+
def run(self):
|
| 58 |
+
cfg = copy.deepcopy(self.cfg)
|
| 59 |
+
|
| 60 |
+
# resume training
|
| 61 |
+
if cfg.training.resume:
|
| 62 |
+
lastest_ckpt_path = self.get_checkpoint_path()
|
| 63 |
+
if lastest_ckpt_path.is_file():
|
| 64 |
+
print(f"Resuming from checkpoint {lastest_ckpt_path}")
|
| 65 |
+
self.load_checkpoint(path=lastest_ckpt_path)
|
| 66 |
+
|
| 67 |
+
# configure dataset
|
| 68 |
+
dataset: BaseImageDataset
|
| 69 |
+
dataset = hydra.utils.instantiate(cfg.task.dataset)
|
| 70 |
+
assert isinstance(dataset, BaseImageDataset)
|
| 71 |
+
train_dataloader = DataLoader(dataset, **cfg.dataloader)
|
| 72 |
+
normalizer = dataset.get_normalizer()
|
| 73 |
+
|
| 74 |
+
# configure validation dataset
|
| 75 |
+
val_dataset = dataset.get_validation_dataset()
|
| 76 |
+
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
|
| 77 |
+
|
| 78 |
+
self.model.set_normalizer(normalizer)
|
| 79 |
+
|
| 80 |
+
# configure lr scheduler
|
| 81 |
+
lr_scheduler = get_scheduler(
|
| 82 |
+
cfg.training.lr_scheduler,
|
| 83 |
+
optimizer=self.optimizer,
|
| 84 |
+
num_warmup_steps=cfg.training.lr_warmup_steps,
|
| 85 |
+
num_training_steps=(
|
| 86 |
+
len(train_dataloader) * cfg.training.num_epochs) \
|
| 87 |
+
// cfg.training.gradient_accumulate_every,
|
| 88 |
+
# pytorch assumes stepping LRScheduler every epoch
|
| 89 |
+
# however huggingface diffusers steps it every batch
|
| 90 |
+
last_epoch=self.global_step-1
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# configure env
|
| 94 |
+
env_runner: BaseImageRunner
|
| 95 |
+
env_runner = hydra.utils.instantiate(
|
| 96 |
+
cfg.task.env_runner,
|
| 97 |
+
output_dir=self.output_dir)
|
| 98 |
+
assert isinstance(env_runner, BaseImageRunner)
|
| 99 |
+
|
| 100 |
+
# configure logging
|
| 101 |
+
wandb_run = wandb.init(
|
| 102 |
+
dir=str(self.output_dir),
|
| 103 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 104 |
+
**cfg.logging
|
| 105 |
+
)
|
| 106 |
+
wandb.config.update(
|
| 107 |
+
{
|
| 108 |
+
"output_dir": self.output_dir,
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# configure checkpoint
|
| 113 |
+
topk_manager = TopKCheckpointManager(
|
| 114 |
+
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
| 115 |
+
**cfg.checkpoint.topk
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# device transfer
|
| 119 |
+
device = torch.device(cfg.training.device)
|
| 120 |
+
self.model.to(device)
|
| 121 |
+
optimizer_to(self.optimizer, device)
|
| 122 |
+
|
| 123 |
+
# save batch for sampling
|
| 124 |
+
train_sampling_batch = None
|
| 125 |
+
|
| 126 |
+
if cfg.training.debug:
|
| 127 |
+
cfg.training.num_epochs = 2
|
| 128 |
+
cfg.training.max_train_steps = 3
|
| 129 |
+
cfg.training.max_val_steps = 3
|
| 130 |
+
cfg.training.rollout_every = 1
|
| 131 |
+
cfg.training.checkpoint_every = 1
|
| 132 |
+
cfg.training.val_every = 1
|
| 133 |
+
cfg.training.sample_every = 1
|
| 134 |
+
|
| 135 |
+
# training loop
|
| 136 |
+
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
| 137 |
+
with JsonLogger(log_path) as json_logger:
|
| 138 |
+
for local_epoch_idx in range(cfg.training.num_epochs):
|
| 139 |
+
step_log = dict()
|
| 140 |
+
# ========= train for this epoch ==========
|
| 141 |
+
train_losses = list()
|
| 142 |
+
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
| 143 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 144 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 145 |
+
# device transfer
|
| 146 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 147 |
+
if train_sampling_batch is None:
|
| 148 |
+
train_sampling_batch = batch
|
| 149 |
+
|
| 150 |
+
# compute loss
|
| 151 |
+
raw_loss = self.model.compute_loss(batch)
|
| 152 |
+
loss = raw_loss / cfg.training.gradient_accumulate_every
|
| 153 |
+
loss.backward()
|
| 154 |
+
|
| 155 |
+
# step optimizer
|
| 156 |
+
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
| 157 |
+
self.optimizer.step()
|
| 158 |
+
self.optimizer.zero_grad()
|
| 159 |
+
lr_scheduler.step()
|
| 160 |
+
|
| 161 |
+
# logging
|
| 162 |
+
raw_loss_cpu = raw_loss.item()
|
| 163 |
+
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
| 164 |
+
train_losses.append(raw_loss_cpu)
|
| 165 |
+
step_log = {
|
| 166 |
+
'train_loss': raw_loss_cpu,
|
| 167 |
+
'global_step': self.global_step,
|
| 168 |
+
'epoch': self.epoch,
|
| 169 |
+
'lr': lr_scheduler.get_last_lr()[0]
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
| 173 |
+
if not is_last_batch:
|
| 174 |
+
# log of last step is combined with validation and rollout
|
| 175 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 176 |
+
json_logger.log(step_log)
|
| 177 |
+
self.global_step += 1
|
| 178 |
+
|
| 179 |
+
if (cfg.training.max_train_steps is not None) \
|
| 180 |
+
and batch_idx >= (cfg.training.max_train_steps-1):
|
| 181 |
+
break
|
| 182 |
+
|
| 183 |
+
# at the end of each epoch
|
| 184 |
+
# replace train_loss with epoch average
|
| 185 |
+
train_loss = np.mean(train_losses)
|
| 186 |
+
step_log['train_loss'] = train_loss
|
| 187 |
+
|
| 188 |
+
# ========= eval for this epoch ==========
|
| 189 |
+
policy = self.model
|
| 190 |
+
policy.eval()
|
| 191 |
+
|
| 192 |
+
# run rollout
|
| 193 |
+
if (self.epoch % cfg.training.rollout_every) == 0:
|
| 194 |
+
runner_log = env_runner.run(policy)
|
| 195 |
+
# log all
|
| 196 |
+
step_log.update(runner_log)
|
| 197 |
+
|
| 198 |
+
# run validation
|
| 199 |
+
if (self.epoch % cfg.training.val_every) == 0:
|
| 200 |
+
with torch.no_grad():
|
| 201 |
+
val_losses = list()
|
| 202 |
+
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
| 203 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 204 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 205 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 206 |
+
loss = self.model.compute_loss(batch)
|
| 207 |
+
val_losses.append(loss)
|
| 208 |
+
if (cfg.training.max_val_steps is not None) \
|
| 209 |
+
and batch_idx >= (cfg.training.max_val_steps-1):
|
| 210 |
+
break
|
| 211 |
+
if len(val_losses) > 0:
|
| 212 |
+
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
| 213 |
+
# log epoch average validation loss
|
| 214 |
+
step_log['val_loss'] = val_loss
|
| 215 |
+
|
| 216 |
+
# run diffusion sampling on a training batch
|
| 217 |
+
if (self.epoch % cfg.training.sample_every) == 0:
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
# sample trajectory from training set, and evaluate difference
|
| 220 |
+
batch = train_sampling_batch
|
| 221 |
+
n_samples = cfg.training.sample_max_batch
|
| 222 |
+
batch = dict_apply(train_sampling_batch,
|
| 223 |
+
lambda x: x.to(device, non_blocking=True))
|
| 224 |
+
obs_dict = dict_apply(batch['obs'], lambda x: x[:n_samples])
|
| 225 |
+
gt_action = batch['action']
|
| 226 |
+
|
| 227 |
+
result = policy.predict_action(obs_dict)
|
| 228 |
+
pred_action = result['action']
|
| 229 |
+
start = cfg.n_obs_steps - 1
|
| 230 |
+
end = start + cfg.n_action_steps
|
| 231 |
+
gt_action = gt_action[:,start:end]
|
| 232 |
+
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
| 233 |
+
# log
|
| 234 |
+
step_log['train_action_mse_error'] = mse.item()
|
| 235 |
+
# release RAM
|
| 236 |
+
del batch
|
| 237 |
+
del obs_dict
|
| 238 |
+
del gt_action
|
| 239 |
+
del result
|
| 240 |
+
del pred_action
|
| 241 |
+
del mse
|
| 242 |
+
|
| 243 |
+
# checkpoint
|
| 244 |
+
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
| 245 |
+
# checkpointing
|
| 246 |
+
if cfg.checkpoint.save_last_ckpt:
|
| 247 |
+
self.save_checkpoint()
|
| 248 |
+
if cfg.checkpoint.save_last_snapshot:
|
| 249 |
+
self.save_snapshot()
|
| 250 |
+
|
| 251 |
+
# sanitize metric names
|
| 252 |
+
metric_dict = dict()
|
| 253 |
+
for key, value in step_log.items():
|
| 254 |
+
new_key = key.replace('/', '_')
|
| 255 |
+
metric_dict[new_key] = value
|
| 256 |
+
|
| 257 |
+
# We can't copy the last checkpoint here
|
| 258 |
+
# since save_checkpoint uses threads.
|
| 259 |
+
# therefore at this point the file might have been empty!
|
| 260 |
+
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
| 261 |
+
|
| 262 |
+
if topk_ckpt_path is not None:
|
| 263 |
+
self.save_checkpoint(path=topk_ckpt_path)
|
| 264 |
+
# ========= eval end for this epoch ==========
|
| 265 |
+
policy.train()
|
| 266 |
+
|
| 267 |
+
# end of epoch
|
| 268 |
+
# log of last step is combined with validation and rollout
|
| 269 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 270 |
+
json_logger.log(step_log)
|
| 271 |
+
self.global_step += 1
|
| 272 |
+
self.epoch += 1
|
| 273 |
+
|
| 274 |
+
@hydra.main(
|
| 275 |
+
version_base=None,
|
| 276 |
+
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
|
| 277 |
+
config_name=pathlib.Path(__file__).stem)
|
| 278 |
+
def main(cfg):
|
| 279 |
+
workspace = TrainIbcDfoHybridWorkspace(cfg)
|
| 280 |
+
workspace.run()
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
main()
|
third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_lowdim_workspace.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
os.chdir(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hydra
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
import pathlib
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
import copy
|
| 17 |
+
import numpy as np
|
| 18 |
+
import random
|
| 19 |
+
import wandb
|
| 20 |
+
import tqdm
|
| 21 |
+
import shutil
|
| 22 |
+
|
| 23 |
+
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
|
| 24 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 25 |
+
from diffusion_policy.policy.ibc_dfo_lowdim_policy import IbcDfoLowdimPolicy
|
| 26 |
+
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
| 27 |
+
from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
|
| 28 |
+
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
|
| 29 |
+
from diffusion_policy.common.json_logger import JsonLogger
|
| 30 |
+
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
| 31 |
+
|
| 32 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 33 |
+
|
| 34 |
+
# %%
|
| 35 |
+
class TrainIbcDfoLowdimWorkspace(BaseWorkspace):
|
| 36 |
+
include_keys = ['global_step', 'epoch']
|
| 37 |
+
|
| 38 |
+
def __init__(self, cfg: OmegaConf, output_dir=None):
|
| 39 |
+
super().__init__(cfg, output_dir=output_dir)
|
| 40 |
+
|
| 41 |
+
# set seed
|
| 42 |
+
seed = cfg.training.seed
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
random.seed(seed)
|
| 46 |
+
|
| 47 |
+
# configure model
|
| 48 |
+
self.model: IbcDfoLowdimPolicy
|
| 49 |
+
self.model = hydra.utils.instantiate(cfg.policy)
|
| 50 |
+
|
| 51 |
+
# configure training state
|
| 52 |
+
self.optimizer = hydra.utils.instantiate(
|
| 53 |
+
cfg.optimizer, params=self.model.parameters())
|
| 54 |
+
|
| 55 |
+
self.global_step = 0
|
| 56 |
+
self.epoch = 0
|
| 57 |
+
|
| 58 |
+
def run(self):
|
| 59 |
+
cfg = copy.deepcopy(self.cfg)
|
| 60 |
+
|
| 61 |
+
# resume training
|
| 62 |
+
if cfg.training.resume:
|
| 63 |
+
lastest_ckpt_path = self.get_checkpoint_path()
|
| 64 |
+
if lastest_ckpt_path.is_file():
|
| 65 |
+
print(f"Resuming from checkpoint {lastest_ckpt_path}")
|
| 66 |
+
self.load_checkpoint(path=lastest_ckpt_path)
|
| 67 |
+
|
| 68 |
+
# configure dataset
|
| 69 |
+
dataset: BaseLowdimDataset
|
| 70 |
+
dataset = hydra.utils.instantiate(cfg.task.dataset)
|
| 71 |
+
assert isinstance(dataset, BaseLowdimDataset)
|
| 72 |
+
train_dataloader = DataLoader(dataset, **cfg.dataloader)
|
| 73 |
+
normalizer = dataset.get_normalizer()
|
| 74 |
+
|
| 75 |
+
# configure validation dataset
|
| 76 |
+
val_dataset = dataset.get_validation_dataset()
|
| 77 |
+
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
|
| 78 |
+
|
| 79 |
+
self.model.set_normalizer(normalizer)
|
| 80 |
+
|
| 81 |
+
# configure lr scheduler
|
| 82 |
+
lr_scheduler = get_scheduler(
|
| 83 |
+
cfg.training.lr_scheduler,
|
| 84 |
+
optimizer=self.optimizer,
|
| 85 |
+
num_warmup_steps=cfg.training.lr_warmup_steps,
|
| 86 |
+
num_training_steps=(
|
| 87 |
+
len(train_dataloader) * cfg.training.num_epochs) \
|
| 88 |
+
// cfg.training.gradient_accumulate_every,
|
| 89 |
+
# pytorch assumes stepping LRScheduler every epoch
|
| 90 |
+
# however huggingface diffusers steps it every batch
|
| 91 |
+
last_epoch=self.global_step-1
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# configure env runner
|
| 95 |
+
env_runner: BaseLowdimRunner
|
| 96 |
+
env_runner = hydra.utils.instantiate(
|
| 97 |
+
cfg.task.env_runner,
|
| 98 |
+
output_dir=self.output_dir)
|
| 99 |
+
assert isinstance(env_runner, BaseLowdimRunner)
|
| 100 |
+
|
| 101 |
+
# configure logging
|
| 102 |
+
wandb_run = wandb.init(
|
| 103 |
+
dir=str(self.output_dir),
|
| 104 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 105 |
+
**cfg.logging
|
| 106 |
+
)
|
| 107 |
+
wandb.config.update(
|
| 108 |
+
{
|
| 109 |
+
"output_dir": self.output_dir,
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# configure checkpoint
|
| 114 |
+
topk_manager = TopKCheckpointManager(
|
| 115 |
+
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
| 116 |
+
**cfg.checkpoint.topk
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# device transfer
|
| 120 |
+
device = torch.device(cfg.training.device)
|
| 121 |
+
self.model.to(device)
|
| 122 |
+
optimizer_to(self.optimizer, device)
|
| 123 |
+
|
| 124 |
+
# save batch for sampling
|
| 125 |
+
train_sampling_batch = None
|
| 126 |
+
|
| 127 |
+
if cfg.training.debug:
|
| 128 |
+
cfg.training.num_epochs = 2
|
| 129 |
+
cfg.training.max_train_steps = 3
|
| 130 |
+
cfg.training.max_val_steps = 3
|
| 131 |
+
cfg.training.rollout_every = 1
|
| 132 |
+
cfg.training.checkpoint_every = 1
|
| 133 |
+
cfg.training.val_every = 1
|
| 134 |
+
cfg.training.sample_every = 1
|
| 135 |
+
|
| 136 |
+
# training loop
|
| 137 |
+
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
| 138 |
+
with JsonLogger(log_path) as json_logger:
|
| 139 |
+
for local_epoch_idx in range(cfg.training.num_epochs):
|
| 140 |
+
step_log = dict()
|
| 141 |
+
# ========= train for this epoch ==========
|
| 142 |
+
train_losses = list()
|
| 143 |
+
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
| 144 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 145 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 146 |
+
# device transfer
|
| 147 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 148 |
+
if train_sampling_batch is None:
|
| 149 |
+
train_sampling_batch = batch
|
| 150 |
+
|
| 151 |
+
# compute loss
|
| 152 |
+
raw_loss = self.model.compute_loss(batch)
|
| 153 |
+
loss = raw_loss / cfg.training.gradient_accumulate_every
|
| 154 |
+
loss.backward()
|
| 155 |
+
|
| 156 |
+
# step optimizer
|
| 157 |
+
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
| 158 |
+
self.optimizer.step()
|
| 159 |
+
self.optimizer.zero_grad()
|
| 160 |
+
lr_scheduler.step()
|
| 161 |
+
|
| 162 |
+
# logging
|
| 163 |
+
raw_loss_cpu = raw_loss.item()
|
| 164 |
+
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
| 165 |
+
train_losses.append(raw_loss_cpu)
|
| 166 |
+
step_log = {
|
| 167 |
+
'train_loss': raw_loss_cpu,
|
| 168 |
+
'global_step': self.global_step,
|
| 169 |
+
'epoch': self.epoch,
|
| 170 |
+
'lr': lr_scheduler.get_last_lr()[0]
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
| 174 |
+
if not is_last_batch:
|
| 175 |
+
# log of last step is combined with validation and rollout
|
| 176 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 177 |
+
json_logger.log(step_log)
|
| 178 |
+
self.global_step += 1
|
| 179 |
+
|
| 180 |
+
if (cfg.training.max_train_steps is not None) \
|
| 181 |
+
and batch_idx >= (cfg.training.max_train_steps-1):
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
# at the end of each epoch
|
| 185 |
+
# replace train_loss with epoch average
|
| 186 |
+
train_loss = np.mean(train_losses)
|
| 187 |
+
step_log['train_loss'] = train_loss
|
| 188 |
+
|
| 189 |
+
# ========= eval for this epoch ==========
|
| 190 |
+
policy = self.model
|
| 191 |
+
policy.eval()
|
| 192 |
+
|
| 193 |
+
# run rollout
|
| 194 |
+
if (self.epoch % cfg.training.rollout_every) == 0:
|
| 195 |
+
runner_log = env_runner.run(policy)
|
| 196 |
+
# log all
|
| 197 |
+
step_log.update(runner_log)
|
| 198 |
+
|
| 199 |
+
# run validation
|
| 200 |
+
if (self.epoch % cfg.training.val_every) == 0:
|
| 201 |
+
with torch.no_grad():
|
| 202 |
+
val_losses = list()
|
| 203 |
+
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
| 204 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 205 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 206 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 207 |
+
loss = self.model.compute_loss(batch)
|
| 208 |
+
val_losses.append(loss)
|
| 209 |
+
if (cfg.training.max_val_steps is not None) \
|
| 210 |
+
and batch_idx >= (cfg.training.max_val_steps-1):
|
| 211 |
+
break
|
| 212 |
+
if len(val_losses) > 0:
|
| 213 |
+
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
| 214 |
+
# log epoch average validation loss
|
| 215 |
+
step_log['val_loss'] = val_loss
|
| 216 |
+
|
| 217 |
+
# run diffusion sampling on a training batch
|
| 218 |
+
if (self.epoch % cfg.training.sample_every) == 0:
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
# sample trajectory from training set, and evaluate difference
|
| 221 |
+
batch = train_sampling_batch
|
| 222 |
+
n_samples = cfg.training.sample_max_batch
|
| 223 |
+
obs_dict = {'obs': batch['obs'][:n_samples]}
|
| 224 |
+
gt_action = batch['action'][:n_samples]
|
| 225 |
+
|
| 226 |
+
result = policy.predict_action(obs_dict)
|
| 227 |
+
pred_action = result['action']
|
| 228 |
+
start = cfg.n_obs_steps - 1
|
| 229 |
+
end = start + cfg.n_action_steps
|
| 230 |
+
gt_action = gt_action[:,start:end]
|
| 231 |
+
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
| 232 |
+
# log
|
| 233 |
+
step_log['train_action_mse_error'] = mse.item()
|
| 234 |
+
# release RAM
|
| 235 |
+
del batch
|
| 236 |
+
del obs_dict
|
| 237 |
+
del gt_action
|
| 238 |
+
del result
|
| 239 |
+
del pred_action
|
| 240 |
+
del mse
|
| 241 |
+
|
| 242 |
+
# checkpoint
|
| 243 |
+
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
| 244 |
+
# checkpointing
|
| 245 |
+
if cfg.checkpoint.save_last_ckpt:
|
| 246 |
+
self.save_checkpoint()
|
| 247 |
+
if cfg.checkpoint.save_last_snapshot:
|
| 248 |
+
self.save_snapshot()
|
| 249 |
+
|
| 250 |
+
# sanitize metric names
|
| 251 |
+
metric_dict = dict()
|
| 252 |
+
for key, value in step_log.items():
|
| 253 |
+
new_key = key.replace('/', '_')
|
| 254 |
+
metric_dict[new_key] = value
|
| 255 |
+
|
| 256 |
+
# We can't copy the last checkpoint here
|
| 257 |
+
# since save_checkpoint uses threads.
|
| 258 |
+
# therefore at this point the file might have been empty!
|
| 259 |
+
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
| 260 |
+
|
| 261 |
+
if topk_ckpt_path is not None:
|
| 262 |
+
self.save_checkpoint(path=topk_ckpt_path)
|
| 263 |
+
# ========= eval end for this epoch ==========
|
| 264 |
+
policy.train()
|
| 265 |
+
|
| 266 |
+
# end of epoch
|
| 267 |
+
# log of last step is combined with validation and rollout
|
| 268 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 269 |
+
json_logger.log(step_log)
|
| 270 |
+
self.global_step += 1
|
| 271 |
+
self.epoch += 1
|
| 272 |
+
|
| 273 |
+
@hydra.main(
|
| 274 |
+
version_base=None,
|
| 275 |
+
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
|
| 276 |
+
config_name=pathlib.Path(__file__).stem)
|
| 277 |
+
def main(cfg):
|
| 278 |
+
workspace = TrainIbcDfoLowdimWorkspace(cfg)
|
| 279 |
+
workspace.run()
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
main()
|
third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_image_workspace.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
os.chdir(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hydra
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
import pathlib
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
import copy
|
| 17 |
+
import random
|
| 18 |
+
import wandb
|
| 19 |
+
import tqdm
|
| 20 |
+
import numpy as np
|
| 21 |
+
import shutil
|
| 22 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 23 |
+
from diffusion_policy.policy.robomimic_image_policy import RobomimicImagePolicy
|
| 24 |
+
from diffusion_policy.dataset.base_dataset import BaseImageDataset
|
| 25 |
+
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
| 26 |
+
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
|
| 27 |
+
from diffusion_policy.common.json_logger import JsonLogger
|
| 28 |
+
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 32 |
+
|
| 33 |
+
class TrainRobomimicImageWorkspace(BaseWorkspace):
|
| 34 |
+
include_keys = ['global_step', 'epoch']
|
| 35 |
+
|
| 36 |
+
def __init__(self, cfg: OmegaConf, output_dir=None):
|
| 37 |
+
super().__init__(cfg, output_dir=output_dir)
|
| 38 |
+
|
| 39 |
+
# set seed
|
| 40 |
+
seed = cfg.training.seed
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
|
| 45 |
+
# configure model
|
| 46 |
+
self.model: RobomimicImagePolicy = hydra.utils.instantiate(cfg.policy)
|
| 47 |
+
|
| 48 |
+
# configure training state
|
| 49 |
+
self.global_step = 0
|
| 50 |
+
self.epoch = 0
|
| 51 |
+
|
| 52 |
+
def run(self):
|
| 53 |
+
cfg = copy.deepcopy(self.cfg)
|
| 54 |
+
|
| 55 |
+
# resume training
|
| 56 |
+
if cfg.training.resume:
|
| 57 |
+
lastest_ckpt_path = self.get_checkpoint_path()
|
| 58 |
+
if lastest_ckpt_path.is_file():
|
| 59 |
+
print(f"Resuming from checkpoint {lastest_ckpt_path}")
|
| 60 |
+
self.load_checkpoint(path=lastest_ckpt_path)
|
| 61 |
+
|
| 62 |
+
# configure dataset
|
| 63 |
+
dataset: BaseImageDataset
|
| 64 |
+
dataset = hydra.utils.instantiate(cfg.task.dataset)
|
| 65 |
+
assert isinstance(dataset, BaseImageDataset)
|
| 66 |
+
train_dataloader = DataLoader(dataset, **cfg.dataloader)
|
| 67 |
+
normalizer = dataset.get_normalizer()
|
| 68 |
+
|
| 69 |
+
# configure validation dataset
|
| 70 |
+
val_dataset = dataset.get_validation_dataset()
|
| 71 |
+
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
|
| 72 |
+
|
| 73 |
+
self.model.set_normalizer(normalizer)
|
| 74 |
+
|
| 75 |
+
# configure env
|
| 76 |
+
env_runner: BaseImageRunner
|
| 77 |
+
env_runner = hydra.utils.instantiate(
|
| 78 |
+
cfg.task.env_runner,
|
| 79 |
+
output_dir=self.output_dir)
|
| 80 |
+
assert isinstance(env_runner, BaseImageRunner)
|
| 81 |
+
|
| 82 |
+
# configure logging
|
| 83 |
+
wandb_run = wandb.init(
|
| 84 |
+
dir=str(self.output_dir),
|
| 85 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 86 |
+
**cfg.logging
|
| 87 |
+
)
|
| 88 |
+
wandb.config.update(
|
| 89 |
+
{
|
| 90 |
+
"output_dir": self.output_dir,
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# configure checkpoint
|
| 95 |
+
topk_manager = TopKCheckpointManager(
|
| 96 |
+
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
| 97 |
+
**cfg.checkpoint.topk
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# device transfer
|
| 101 |
+
device = torch.device(cfg.training.device)
|
| 102 |
+
self.model.to(device)
|
| 103 |
+
|
| 104 |
+
# save batch for sampling
|
| 105 |
+
train_sampling_batch = None
|
| 106 |
+
|
| 107 |
+
if cfg.training.debug:
|
| 108 |
+
cfg.training.num_epochs = 2
|
| 109 |
+
cfg.training.max_train_steps = 3
|
| 110 |
+
cfg.training.max_val_steps = 3
|
| 111 |
+
cfg.training.rollout_every = 1
|
| 112 |
+
cfg.training.checkpoint_every = 1
|
| 113 |
+
cfg.training.val_every = 1
|
| 114 |
+
cfg.training.sample_every = 1
|
| 115 |
+
|
| 116 |
+
# training loop
|
| 117 |
+
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
| 118 |
+
with JsonLogger(log_path) as json_logger:
|
| 119 |
+
for local_epoch_idx in range(cfg.training.num_epochs):
|
| 120 |
+
step_log = dict()
|
| 121 |
+
# ========= train for this epoch ==========
|
| 122 |
+
train_losses = list()
|
| 123 |
+
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
| 124 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 125 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 126 |
+
# device transfer
|
| 127 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 128 |
+
if train_sampling_batch is None:
|
| 129 |
+
train_sampling_batch = batch
|
| 130 |
+
|
| 131 |
+
info = self.model.train_on_batch(batch, epoch=self.epoch)
|
| 132 |
+
|
| 133 |
+
# logging
|
| 134 |
+
loss_cpu = info['losses']['action_loss'].item()
|
| 135 |
+
tepoch.set_postfix(loss=loss_cpu, refresh=False)
|
| 136 |
+
train_losses.append(loss_cpu)
|
| 137 |
+
step_log = {
|
| 138 |
+
'train_loss': loss_cpu,
|
| 139 |
+
'global_step': self.global_step,
|
| 140 |
+
'epoch': self.epoch
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
| 144 |
+
if not is_last_batch:
|
| 145 |
+
# log of last step is combined with validation and rollout
|
| 146 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 147 |
+
json_logger.log(step_log)
|
| 148 |
+
self.global_step += 1
|
| 149 |
+
|
| 150 |
+
if (cfg.training.max_train_steps is not None) \
|
| 151 |
+
and batch_idx >= (cfg.training.max_train_steps-1):
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
# at the end of each epoch
|
| 155 |
+
# replace train_loss with epoch average
|
| 156 |
+
train_loss = np.mean(train_losses)
|
| 157 |
+
step_log['train_loss'] = train_loss
|
| 158 |
+
|
| 159 |
+
# ========= eval for this epoch ==========
|
| 160 |
+
self.model.eval()
|
| 161 |
+
|
| 162 |
+
# run rollout
|
| 163 |
+
if (self.epoch % cfg.training.rollout_every) == 0:
|
| 164 |
+
runner_log = env_runner.run(self.model)
|
| 165 |
+
# log all
|
| 166 |
+
step_log.update(runner_log)
|
| 167 |
+
|
| 168 |
+
# run validation
|
| 169 |
+
if (self.epoch % cfg.training.val_every) == 0:
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
val_losses = list()
|
| 172 |
+
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
| 173 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 174 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 175 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 176 |
+
info = self.model.train_on_batch(batch, epoch=self.epoch, validate=True)
|
| 177 |
+
loss = info['losses']['action_loss']
|
| 178 |
+
val_losses.append(loss)
|
| 179 |
+
if (cfg.training.max_val_steps is not None) \
|
| 180 |
+
and batch_idx >= (cfg.training.max_val_steps-1):
|
| 181 |
+
break
|
| 182 |
+
if len(val_losses) > 0:
|
| 183 |
+
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
| 184 |
+
# log epoch average validation loss
|
| 185 |
+
step_log['val_loss'] = val_loss
|
| 186 |
+
|
| 187 |
+
# run diffusion sampling on a training batch
|
| 188 |
+
if (self.epoch % cfg.training.sample_every) == 0:
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
# sample trajectory from training set, and evaluate difference
|
| 191 |
+
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
|
| 192 |
+
obs_dict = batch['obs']
|
| 193 |
+
gt_action = batch['action']
|
| 194 |
+
T = gt_action.shape[1]
|
| 195 |
+
|
| 196 |
+
pred_actions = list()
|
| 197 |
+
self.model.reset()
|
| 198 |
+
for i in range(T):
|
| 199 |
+
result = self.model.predict_action(
|
| 200 |
+
dict_apply(obs_dict, lambda x: x[:,[i]])
|
| 201 |
+
)
|
| 202 |
+
pred_actions.append(result['action'])
|
| 203 |
+
pred_actions = torch.cat(pred_actions, dim=1)
|
| 204 |
+
mse = torch.nn.functional.mse_loss(pred_actions, gt_action)
|
| 205 |
+
step_log['train_action_mse_error'] = mse.item()
|
| 206 |
+
del batch
|
| 207 |
+
del obs_dict
|
| 208 |
+
del gt_action
|
| 209 |
+
del result
|
| 210 |
+
del pred_actions
|
| 211 |
+
del mse
|
| 212 |
+
|
| 213 |
+
# checkpoint
|
| 214 |
+
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
| 215 |
+
# checkpointing
|
| 216 |
+
if cfg.checkpoint.save_last_ckpt:
|
| 217 |
+
self.save_checkpoint()
|
| 218 |
+
if cfg.checkpoint.save_last_snapshot:
|
| 219 |
+
self.save_snapshot()
|
| 220 |
+
|
| 221 |
+
# sanitize metric names
|
| 222 |
+
metric_dict = dict()
|
| 223 |
+
for key, value in step_log.items():
|
| 224 |
+
new_key = key.replace('/', '_')
|
| 225 |
+
metric_dict[new_key] = value
|
| 226 |
+
|
| 227 |
+
# We can't copy the last checkpoint here
|
| 228 |
+
# since save_checkpoint uses threads.
|
| 229 |
+
# therefore at this point the file might have been empty!
|
| 230 |
+
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
| 231 |
+
|
| 232 |
+
if topk_ckpt_path is not None:
|
| 233 |
+
self.save_checkpoint(path=topk_ckpt_path)
|
| 234 |
+
# ========= eval end for this epoch ==========
|
| 235 |
+
self.model.train()
|
| 236 |
+
|
| 237 |
+
# end of epoch
|
| 238 |
+
# log of last step is combined with validation and rollout
|
| 239 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 240 |
+
json_logger.log(step_log)
|
| 241 |
+
self.global_step += 1
|
| 242 |
+
self.epoch += 1
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@hydra.main(
|
| 246 |
+
version_base=None,
|
| 247 |
+
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
|
| 248 |
+
config_name=pathlib.Path(__file__).stem)
|
| 249 |
+
def main(cfg):
|
| 250 |
+
workspace = TrainRobomimicImageWorkspace(cfg)
|
| 251 |
+
workspace.run()
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_lowdim_workspace.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
os.chdir(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hydra
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
import pathlib
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
import copy
|
| 17 |
+
import random
|
| 18 |
+
import wandb
|
| 19 |
+
import tqdm
|
| 20 |
+
import numpy as np
|
| 21 |
+
import shutil
|
| 22 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 23 |
+
from diffusion_policy.policy.robomimic_lowdim_policy import RobomimicLowdimPolicy
|
| 24 |
+
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
|
| 25 |
+
from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
|
| 26 |
+
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
|
| 27 |
+
from diffusion_policy.common.json_logger import JsonLogger
|
| 28 |
+
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 32 |
+
|
| 33 |
+
class TrainRobomimicLowdimWorkspace(BaseWorkspace):
|
| 34 |
+
include_keys = ['global_step', 'epoch']
|
| 35 |
+
|
| 36 |
+
def __init__(self, cfg: OmegaConf):
|
| 37 |
+
super().__init__(cfg)
|
| 38 |
+
|
| 39 |
+
# set seed
|
| 40 |
+
seed = cfg.training.seed
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
|
| 45 |
+
# configure model
|
| 46 |
+
self.model: RobomimicLowdimPolicy = hydra.utils.instantiate(cfg.policy)
|
| 47 |
+
|
| 48 |
+
# configure training state
|
| 49 |
+
self.global_step = 0
|
| 50 |
+
self.epoch = 0
|
| 51 |
+
|
| 52 |
+
def run(self):
|
| 53 |
+
cfg = copy.deepcopy(self.cfg)
|
| 54 |
+
|
| 55 |
+
# resume training
|
| 56 |
+
if cfg.training.resume:
|
| 57 |
+
lastest_ckpt_path = self.get_checkpoint_path()
|
| 58 |
+
if lastest_ckpt_path.is_file():
|
| 59 |
+
print(f"Resuming from checkpoint {lastest_ckpt_path}")
|
| 60 |
+
self.load_checkpoint(path=lastest_ckpt_path)
|
| 61 |
+
|
| 62 |
+
# configure dataset
|
| 63 |
+
dataset: BaseLowdimDataset
|
| 64 |
+
dataset = hydra.utils.instantiate(cfg.task.dataset)
|
| 65 |
+
assert isinstance(dataset, BaseLowdimDataset)
|
| 66 |
+
train_dataloader = DataLoader(dataset, **cfg.dataloader)
|
| 67 |
+
normalizer = dataset.get_normalizer()
|
| 68 |
+
|
| 69 |
+
# configure validation dataset
|
| 70 |
+
val_dataset = dataset.get_validation_dataset()
|
| 71 |
+
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
|
| 72 |
+
|
| 73 |
+
self.model.set_normalizer(normalizer)
|
| 74 |
+
|
| 75 |
+
# configure env
|
| 76 |
+
env_runner: BaseLowdimRunner
|
| 77 |
+
env_runner = hydra.utils.instantiate(
|
| 78 |
+
cfg.task.env_runner,
|
| 79 |
+
output_dir=self.output_dir)
|
| 80 |
+
assert isinstance(env_runner, BaseLowdimRunner)
|
| 81 |
+
|
| 82 |
+
# configure logging
|
| 83 |
+
wandb_run = wandb.init(
|
| 84 |
+
dir=str(self.output_dir),
|
| 85 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 86 |
+
**cfg.logging
|
| 87 |
+
)
|
| 88 |
+
wandb.config.update(
|
| 89 |
+
{
|
| 90 |
+
"output_dir": self.output_dir,
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# configure checkpoint
|
| 95 |
+
topk_manager = TopKCheckpointManager(
|
| 96 |
+
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
| 97 |
+
**cfg.checkpoint.topk
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# device transfer
|
| 101 |
+
device = torch.device(cfg.training.device)
|
| 102 |
+
self.model.to(device)
|
| 103 |
+
|
| 104 |
+
if cfg.training.debug:
|
| 105 |
+
cfg.training.num_epochs = 2
|
| 106 |
+
cfg.training.max_train_steps = 3
|
| 107 |
+
cfg.training.max_val_steps = 3
|
| 108 |
+
cfg.training.rollout_every = 1
|
| 109 |
+
cfg.training.checkpoint_every = 1
|
| 110 |
+
cfg.training.val_every = 1
|
| 111 |
+
|
| 112 |
+
# training loop
|
| 113 |
+
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
| 114 |
+
with JsonLogger(log_path) as json_logger:
|
| 115 |
+
for local_epoch_idx in range(cfg.training.num_epochs):
|
| 116 |
+
step_log = dict()
|
| 117 |
+
# ========= train for this epoch ==========
|
| 118 |
+
train_losses = list()
|
| 119 |
+
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
| 120 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 121 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 122 |
+
# device transfer
|
| 123 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 124 |
+
info = self.model.train_on_batch(batch, epoch=self.epoch)
|
| 125 |
+
|
| 126 |
+
# logging
|
| 127 |
+
loss_cpu = info['losses']['action_loss'].item()
|
| 128 |
+
tepoch.set_postfix(loss=loss_cpu, refresh=False)
|
| 129 |
+
train_losses.append(loss_cpu)
|
| 130 |
+
step_log = {
|
| 131 |
+
'train_loss': loss_cpu,
|
| 132 |
+
'global_step': self.global_step,
|
| 133 |
+
'epoch': self.epoch
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
| 137 |
+
if not is_last_batch:
|
| 138 |
+
# log of last step is combined with validation and rollout
|
| 139 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 140 |
+
json_logger.log(step_log)
|
| 141 |
+
self.global_step += 1
|
| 142 |
+
|
| 143 |
+
if (cfg.training.max_train_steps is not None) \
|
| 144 |
+
and batch_idx >= (cfg.training.max_train_steps-1):
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
# at the end of each epoch
|
| 148 |
+
# replace train_loss with epoch average
|
| 149 |
+
train_loss = np.mean(train_losses)
|
| 150 |
+
step_log['train_loss'] = train_loss
|
| 151 |
+
|
| 152 |
+
# ========= eval for this epoch ==========
|
| 153 |
+
self.model.eval()
|
| 154 |
+
|
| 155 |
+
# run rollout
|
| 156 |
+
if (self.epoch % cfg.training.rollout_every) == 0:
|
| 157 |
+
runner_log = env_runner.run(self.model)
|
| 158 |
+
# log all
|
| 159 |
+
step_log.update(runner_log)
|
| 160 |
+
|
| 161 |
+
# run validation
|
| 162 |
+
if (self.epoch % cfg.training.val_every) == 0:
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
val_losses = list()
|
| 165 |
+
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
| 166 |
+
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
| 167 |
+
for batch_idx, batch in enumerate(tepoch):
|
| 168 |
+
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
| 169 |
+
info = self.model.train_on_batch(batch, epoch=self.epoch, validate=True)
|
| 170 |
+
loss = info['losses']['action_loss']
|
| 171 |
+
val_losses.append(loss)
|
| 172 |
+
if (cfg.training.max_val_steps is not None) \
|
| 173 |
+
and batch_idx >= (cfg.training.max_val_steps-1):
|
| 174 |
+
break
|
| 175 |
+
if len(val_losses) > 0:
|
| 176 |
+
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
| 177 |
+
# log epoch average validation loss
|
| 178 |
+
step_log['val_loss'] = val_loss
|
| 179 |
+
|
| 180 |
+
# checkpoint
|
| 181 |
+
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
| 182 |
+
# checkpointing
|
| 183 |
+
if cfg.checkpoint.save_last_ckpt:
|
| 184 |
+
self.save_checkpoint()
|
| 185 |
+
if cfg.checkpoint.save_last_snapshot:
|
| 186 |
+
self.save_snapshot()
|
| 187 |
+
|
| 188 |
+
# sanitize metric names
|
| 189 |
+
metric_dict = dict()
|
| 190 |
+
for key, value in step_log.items():
|
| 191 |
+
new_key = key.replace('/', '_')
|
| 192 |
+
metric_dict[new_key] = value
|
| 193 |
+
|
| 194 |
+
# We can't copy the last checkpoint here
|
| 195 |
+
# since save_checkpoint uses threads.
|
| 196 |
+
# therefore at this point the file might have been empty!
|
| 197 |
+
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
| 198 |
+
|
| 199 |
+
if topk_ckpt_path is not None:
|
| 200 |
+
self.save_checkpoint(path=topk_ckpt_path)
|
| 201 |
+
# ========= eval end for this epoch ==========
|
| 202 |
+
self.model.train()
|
| 203 |
+
|
| 204 |
+
# end of epoch
|
| 205 |
+
# log of last step is combined with validation and rollout
|
| 206 |
+
wandb_run.log(step_log, step=self.global_step)
|
| 207 |
+
json_logger.log(step_log)
|
| 208 |
+
self.global_step += 1
|
| 209 |
+
self.epoch += 1
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@hydra.main(
|
| 213 |
+
version_base=None,
|
| 214 |
+
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
|
| 215 |
+
config_name=pathlib.Path(__file__).stem)
|
| 216 |
+
def main(cfg):
|
| 217 |
+
workspace = TrainRobomimicLowdimWorkspace(cfg)
|
| 218 |
+
workspace.run()
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
main()
|
third_party/diffusion_policy/eval.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
# use line-buffering for both stdout and stderr
|
| 8 |
+
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
| 9 |
+
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import pathlib
|
| 13 |
+
import click
|
| 14 |
+
import hydra
|
| 15 |
+
import torch
|
| 16 |
+
import dill
|
| 17 |
+
import wandb
|
| 18 |
+
import json
|
| 19 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 20 |
+
|
| 21 |
+
@click.command()
|
| 22 |
+
@click.option('-c', '--checkpoint', required=True)
|
| 23 |
+
@click.option('-o', '--output_dir', required=True)
|
| 24 |
+
@click.option('-d', '--device', default='cuda:0')
|
| 25 |
+
def main(checkpoint, output_dir, device):
|
| 26 |
+
if os.path.exists(output_dir):
|
| 27 |
+
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
| 28 |
+
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
# load checkpoint
|
| 31 |
+
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
| 32 |
+
cfg = payload['cfg']
|
| 33 |
+
cls = hydra.utils.get_class(cfg._target_)
|
| 34 |
+
workspace = cls(cfg, output_dir=output_dir)
|
| 35 |
+
workspace: BaseWorkspace
|
| 36 |
+
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
| 37 |
+
|
| 38 |
+
# get policy from workspace
|
| 39 |
+
policy = workspace.model
|
| 40 |
+
if cfg.training.use_ema:
|
| 41 |
+
policy = workspace.ema_model
|
| 42 |
+
|
| 43 |
+
device = torch.device(device)
|
| 44 |
+
policy.to(device)
|
| 45 |
+
policy.eval()
|
| 46 |
+
|
| 47 |
+
# run eval
|
| 48 |
+
env_runner = hydra.utils.instantiate(
|
| 49 |
+
cfg.task.env_runner,
|
| 50 |
+
output_dir=output_dir)
|
| 51 |
+
runner_log = env_runner.run(policy)
|
| 52 |
+
|
| 53 |
+
# dump log to json
|
| 54 |
+
json_log = dict()
|
| 55 |
+
for key, value in runner_log.items():
|
| 56 |
+
if isinstance(value, wandb.sdk.data_types.video.Video):
|
| 57 |
+
json_log[key] = value._path
|
| 58 |
+
else:
|
| 59 |
+
json_log[key] = value
|
| 60 |
+
out_path = os.path.join(output_dir, 'eval_log.json')
|
| 61 |
+
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
| 62 |
+
|
| 63 |
+
if __name__ == '__main__':
|
| 64 |
+
main()
|
third_party/diffusion_policy/eval_real_robot.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
(robodiff)$ python eval_real_robot.py -i <ckpt_path> -o <save_dir> --robot_ip <ip_of_ur5>
|
| 4 |
+
|
| 5 |
+
================ Human in control ==============
|
| 6 |
+
Robot movement:
|
| 7 |
+
Move your SpaceMouse to move the robot EEF (locked in xy plane).
|
| 8 |
+
Press SpaceMouse right button to unlock z axis.
|
| 9 |
+
Press SpaceMouse left button to enable rotation axes.
|
| 10 |
+
|
| 11 |
+
Recording control:
|
| 12 |
+
Click the opencv window (make sure it's in focus).
|
| 13 |
+
Press "C" to start evaluation (hand control over to policy).
|
| 14 |
+
Press "Q" to exit program.
|
| 15 |
+
|
| 16 |
+
================ Policy in control ==============
|
| 17 |
+
Make sure you can hit the robot hardware emergency-stop button quickly!
|
| 18 |
+
|
| 19 |
+
Recording control:
|
| 20 |
+
Press "S" to stop evaluation and gain control back.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# %%
|
| 24 |
+
import time
|
| 25 |
+
from multiprocessing.managers import SharedMemoryManager
|
| 26 |
+
import click
|
| 27 |
+
import cv2
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
import dill
|
| 31 |
+
import hydra
|
| 32 |
+
import pathlib
|
| 33 |
+
import skvideo.io
|
| 34 |
+
from omegaconf import OmegaConf
|
| 35 |
+
import scipy.spatial.transform as st
|
| 36 |
+
from diffusion_policy.real_world.real_env import RealEnv
|
| 37 |
+
from diffusion_policy.real_world.spacemouse_shared_memory import Spacemouse
|
| 38 |
+
from diffusion_policy.common.precise_sleep import precise_wait
|
| 39 |
+
from diffusion_policy.real_world.real_inference_util import (
|
| 40 |
+
get_real_obs_resolution,
|
| 41 |
+
get_real_obs_dict)
|
| 42 |
+
from diffusion_policy.common.pytorch_util import dict_apply
|
| 43 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 44 |
+
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
| 45 |
+
from diffusion_policy.common.cv2_util import get_image_transform
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 49 |
+
|
| 50 |
+
@click.command()
|
| 51 |
+
@click.option('--input', '-i', required=True, help='Path to checkpoint')
|
| 52 |
+
@click.option('--output', '-o', required=True, help='Directory to save recording')
|
| 53 |
+
@click.option('--robot_ip', '-ri', required=True, help="UR5's IP address e.g. 192.168.0.204")
|
| 54 |
+
@click.option('--match_dataset', '-m', default=None, help='Dataset used to overlay and adjust initial condition')
|
| 55 |
+
@click.option('--match_episode', '-me', default=None, type=int, help='Match specific episode from the match dataset')
|
| 56 |
+
@click.option('--vis_camera_idx', default=0, type=int, help="Which RealSense camera to visualize.")
|
| 57 |
+
@click.option('--init_joints', '-j', is_flag=True, default=False, help="Whether to initialize robot joint configuration in the beginning.")
|
| 58 |
+
@click.option('--steps_per_inference', '-si', default=6, type=int, help="Action horizon for inference.")
|
| 59 |
+
@click.option('--max_duration', '-md', default=60, help='Max duration for each epoch in seconds.')
|
| 60 |
+
@click.option('--frequency', '-f', default=10, type=float, help="Control frequency in Hz.")
|
| 61 |
+
@click.option('--command_latency', '-cl', default=0.01, type=float, help="Latency between receiving SapceMouse command to executing on Robot in Sec.")
|
| 62 |
+
def main(input, output, robot_ip, match_dataset, match_episode,
|
| 63 |
+
vis_camera_idx, init_joints,
|
| 64 |
+
steps_per_inference, max_duration,
|
| 65 |
+
frequency, command_latency):
|
| 66 |
+
# load match_dataset
|
| 67 |
+
match_camera_idx = 0
|
| 68 |
+
episode_first_frame_map = dict()
|
| 69 |
+
if match_dataset is not None:
|
| 70 |
+
match_dir = pathlib.Path(match_dataset)
|
| 71 |
+
match_video_dir = match_dir.joinpath('videos')
|
| 72 |
+
for vid_dir in match_video_dir.glob("*/"):
|
| 73 |
+
episode_idx = int(vid_dir.stem)
|
| 74 |
+
match_video_path = vid_dir.joinpath(f'{match_camera_idx}.mp4')
|
| 75 |
+
if match_video_path.exists():
|
| 76 |
+
frames = skvideo.io.vread(
|
| 77 |
+
str(match_video_path), num_frames=1)
|
| 78 |
+
episode_first_frame_map[episode_idx] = frames[0]
|
| 79 |
+
print(f"Loaded initial frame for {len(episode_first_frame_map)} episodes")
|
| 80 |
+
|
| 81 |
+
# load checkpoint
|
| 82 |
+
ckpt_path = input
|
| 83 |
+
payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill)
|
| 84 |
+
cfg = payload['cfg']
|
| 85 |
+
cls = hydra.utils.get_class(cfg._target_)
|
| 86 |
+
workspace = cls(cfg)
|
| 87 |
+
workspace: BaseWorkspace
|
| 88 |
+
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
| 89 |
+
|
| 90 |
+
# hacks for method-specific setup.
|
| 91 |
+
action_offset = 0
|
| 92 |
+
delta_action = False
|
| 93 |
+
if 'diffusion' in cfg.name:
|
| 94 |
+
# diffusion model
|
| 95 |
+
policy: BaseImagePolicy
|
| 96 |
+
policy = workspace.model
|
| 97 |
+
if cfg.training.use_ema:
|
| 98 |
+
policy = workspace.ema_model
|
| 99 |
+
|
| 100 |
+
device = torch.device('cuda')
|
| 101 |
+
policy.eval().to(device)
|
| 102 |
+
|
| 103 |
+
# set inference params
|
| 104 |
+
policy.num_inference_steps = 16 # DDIM inference iterations
|
| 105 |
+
policy.n_action_steps = policy.horizon - policy.n_obs_steps + 1
|
| 106 |
+
|
| 107 |
+
elif 'robomimic' in cfg.name:
|
| 108 |
+
# BCRNN model
|
| 109 |
+
policy: BaseImagePolicy
|
| 110 |
+
policy = workspace.model
|
| 111 |
+
|
| 112 |
+
device = torch.device('cuda')
|
| 113 |
+
policy.eval().to(device)
|
| 114 |
+
|
| 115 |
+
# BCRNN always has action horizon of 1
|
| 116 |
+
steps_per_inference = 1
|
| 117 |
+
action_offset = cfg.n_latency_steps
|
| 118 |
+
delta_action = cfg.task.dataset.get('delta_action', False)
|
| 119 |
+
|
| 120 |
+
elif 'ibc' in cfg.name:
|
| 121 |
+
policy: BaseImagePolicy
|
| 122 |
+
policy = workspace.model
|
| 123 |
+
policy.pred_n_iter = 5
|
| 124 |
+
policy.pred_n_samples = 4096
|
| 125 |
+
|
| 126 |
+
device = torch.device('cuda')
|
| 127 |
+
policy.eval().to(device)
|
| 128 |
+
steps_per_inference = 1
|
| 129 |
+
action_offset = 1
|
| 130 |
+
delta_action = cfg.task.dataset.get('delta_action', False)
|
| 131 |
+
else:
|
| 132 |
+
raise RuntimeError("Unsupported policy type: ", cfg.name)
|
| 133 |
+
|
| 134 |
+
# setup experiment
|
| 135 |
+
dt = 1/frequency
|
| 136 |
+
|
| 137 |
+
obs_res = get_real_obs_resolution(cfg.task.shape_meta)
|
| 138 |
+
n_obs_steps = cfg.n_obs_steps
|
| 139 |
+
print("n_obs_steps: ", n_obs_steps)
|
| 140 |
+
print("steps_per_inference:", steps_per_inference)
|
| 141 |
+
print("action_offset:", action_offset)
|
| 142 |
+
|
| 143 |
+
with SharedMemoryManager() as shm_manager:
|
| 144 |
+
with Spacemouse(shm_manager=shm_manager) as sm, RealEnv(
|
| 145 |
+
output_dir=output,
|
| 146 |
+
robot_ip=robot_ip,
|
| 147 |
+
frequency=frequency,
|
| 148 |
+
n_obs_steps=n_obs_steps,
|
| 149 |
+
obs_image_resolution=obs_res,
|
| 150 |
+
obs_float32=True,
|
| 151 |
+
init_joints=init_joints,
|
| 152 |
+
enable_multi_cam_vis=True,
|
| 153 |
+
record_raw_video=True,
|
| 154 |
+
# number of threads per camera view for video recording (H.264)
|
| 155 |
+
thread_per_video=3,
|
| 156 |
+
# video recording quality, lower is better (but slower).
|
| 157 |
+
video_crf=21,
|
| 158 |
+
shm_manager=shm_manager) as env:
|
| 159 |
+
cv2.setNumThreads(1)
|
| 160 |
+
|
| 161 |
+
# Should be the same as demo
|
| 162 |
+
# realsense exposure
|
| 163 |
+
env.realsense.set_exposure(exposure=120, gain=0)
|
| 164 |
+
# realsense white balance
|
| 165 |
+
env.realsense.set_white_balance(white_balance=5900)
|
| 166 |
+
|
| 167 |
+
print("Waiting for realsense")
|
| 168 |
+
time.sleep(1.0)
|
| 169 |
+
|
| 170 |
+
print("Warming up policy inference")
|
| 171 |
+
obs = env.get_obs()
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
policy.reset()
|
| 174 |
+
obs_dict_np = get_real_obs_dict(
|
| 175 |
+
env_obs=obs, shape_meta=cfg.task.shape_meta)
|
| 176 |
+
obs_dict = dict_apply(obs_dict_np,
|
| 177 |
+
lambda x: torch.from_numpy(x).unsqueeze(0).to(device))
|
| 178 |
+
result = policy.predict_action(obs_dict)
|
| 179 |
+
action = result['action'][0].detach().to('cpu').numpy()
|
| 180 |
+
assert action.shape[-1] == 2
|
| 181 |
+
del result
|
| 182 |
+
|
| 183 |
+
print('Ready!')
|
| 184 |
+
while True:
|
| 185 |
+
# ========= human control loop ==========
|
| 186 |
+
print("Human in control!")
|
| 187 |
+
state = env.get_robot_state()
|
| 188 |
+
target_pose = state['TargetTCPPose']
|
| 189 |
+
t_start = time.monotonic()
|
| 190 |
+
iter_idx = 0
|
| 191 |
+
while True:
|
| 192 |
+
# calculate timing
|
| 193 |
+
t_cycle_end = t_start + (iter_idx + 1) * dt
|
| 194 |
+
t_sample = t_cycle_end - command_latency
|
| 195 |
+
t_command_target = t_cycle_end + dt
|
| 196 |
+
|
| 197 |
+
# pump obs
|
| 198 |
+
obs = env.get_obs()
|
| 199 |
+
|
| 200 |
+
# visualize
|
| 201 |
+
episode_id = env.replay_buffer.n_episodes
|
| 202 |
+
vis_img = obs[f'camera_{vis_camera_idx}'][-1]
|
| 203 |
+
match_episode_id = episode_id
|
| 204 |
+
if match_episode is not None:
|
| 205 |
+
match_episode_id = match_episode
|
| 206 |
+
if match_episode_id in episode_first_frame_map:
|
| 207 |
+
match_img = episode_first_frame_map[match_episode_id]
|
| 208 |
+
ih, iw, _ = match_img.shape
|
| 209 |
+
oh, ow, _ = vis_img.shape
|
| 210 |
+
tf = get_image_transform(
|
| 211 |
+
input_res=(iw, ih),
|
| 212 |
+
output_res=(ow, oh),
|
| 213 |
+
bgr_to_rgb=False)
|
| 214 |
+
match_img = tf(match_img).astype(np.float32) / 255
|
| 215 |
+
vis_img = np.minimum(vis_img, match_img)
|
| 216 |
+
|
| 217 |
+
text = f'Episode: {episode_id}'
|
| 218 |
+
cv2.putText(
|
| 219 |
+
vis_img,
|
| 220 |
+
text,
|
| 221 |
+
(10,20),
|
| 222 |
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
| 223 |
+
fontScale=0.5,
|
| 224 |
+
thickness=1,
|
| 225 |
+
color=(255,255,255)
|
| 226 |
+
)
|
| 227 |
+
cv2.imshow('default', vis_img[...,::-1])
|
| 228 |
+
key_stroke = cv2.pollKey()
|
| 229 |
+
if key_stroke == ord('q'):
|
| 230 |
+
# Exit program
|
| 231 |
+
env.end_episode()
|
| 232 |
+
exit(0)
|
| 233 |
+
elif key_stroke == ord('c'):
|
| 234 |
+
# Exit human control loop
|
| 235 |
+
# hand control over to the policy
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
precise_wait(t_sample)
|
| 239 |
+
# get teleop command
|
| 240 |
+
sm_state = sm.get_motion_state_transformed()
|
| 241 |
+
# print(sm_state)
|
| 242 |
+
dpos = sm_state[:3] * (env.max_pos_speed / frequency)
|
| 243 |
+
drot_xyz = sm_state[3:] * (env.max_rot_speed / frequency)
|
| 244 |
+
|
| 245 |
+
if not sm.is_button_pressed(0):
|
| 246 |
+
# translation mode
|
| 247 |
+
drot_xyz[:] = 0
|
| 248 |
+
else:
|
| 249 |
+
dpos[:] = 0
|
| 250 |
+
if not sm.is_button_pressed(1):
|
| 251 |
+
# 2D translation mode
|
| 252 |
+
dpos[2] = 0
|
| 253 |
+
|
| 254 |
+
drot = st.Rotation.from_euler('xyz', drot_xyz)
|
| 255 |
+
target_pose[:3] += dpos
|
| 256 |
+
target_pose[3:] = (drot * st.Rotation.from_rotvec(
|
| 257 |
+
target_pose[3:])).as_rotvec()
|
| 258 |
+
# clip target pose
|
| 259 |
+
target_pose[:2] = np.clip(target_pose[:2], [0.25, -0.45], [0.77, 0.40])
|
| 260 |
+
|
| 261 |
+
# execute teleop command
|
| 262 |
+
env.exec_actions(
|
| 263 |
+
actions=[target_pose],
|
| 264 |
+
timestamps=[t_command_target-time.monotonic()+time.time()])
|
| 265 |
+
precise_wait(t_cycle_end)
|
| 266 |
+
iter_idx += 1
|
| 267 |
+
|
| 268 |
+
# ========== policy control loop ==============
|
| 269 |
+
try:
|
| 270 |
+
# start episode
|
| 271 |
+
policy.reset()
|
| 272 |
+
start_delay = 1.0
|
| 273 |
+
eval_t_start = time.time() + start_delay
|
| 274 |
+
t_start = time.monotonic() + start_delay
|
| 275 |
+
env.start_episode(eval_t_start)
|
| 276 |
+
# wait for 1/30 sec to get the closest frame actually
|
| 277 |
+
# reduces overall latency
|
| 278 |
+
frame_latency = 1/30
|
| 279 |
+
precise_wait(eval_t_start - frame_latency, time_func=time.time)
|
| 280 |
+
print("Started!")
|
| 281 |
+
iter_idx = 0
|
| 282 |
+
term_area_start_timestamp = float('inf')
|
| 283 |
+
perv_target_pose = None
|
| 284 |
+
while True:
|
| 285 |
+
# calculate timing
|
| 286 |
+
t_cycle_end = t_start + (iter_idx + steps_per_inference) * dt
|
| 287 |
+
|
| 288 |
+
# get obs
|
| 289 |
+
print('get_obs')
|
| 290 |
+
obs = env.get_obs()
|
| 291 |
+
obs_timestamps = obs['timestamp']
|
| 292 |
+
print(f'Obs latency {time.time() - obs_timestamps[-1]}')
|
| 293 |
+
|
| 294 |
+
# run inference
|
| 295 |
+
with torch.no_grad():
|
| 296 |
+
s = time.time()
|
| 297 |
+
obs_dict_np = get_real_obs_dict(
|
| 298 |
+
env_obs=obs, shape_meta=cfg.task.shape_meta)
|
| 299 |
+
obs_dict = dict_apply(obs_dict_np,
|
| 300 |
+
lambda x: torch.from_numpy(x).unsqueeze(0).to(device))
|
| 301 |
+
result = policy.predict_action(obs_dict)
|
| 302 |
+
# this action starts from the first obs step
|
| 303 |
+
action = result['action'][0].detach().to('cpu').numpy()
|
| 304 |
+
print('Inference latency:', time.time() - s)
|
| 305 |
+
|
| 306 |
+
# convert policy action to env actions
|
| 307 |
+
if delta_action:
|
| 308 |
+
assert len(action) == 1
|
| 309 |
+
if perv_target_pose is None:
|
| 310 |
+
perv_target_pose = obs['robot_eef_pose'][-1]
|
| 311 |
+
this_target_pose = perv_target_pose.copy()
|
| 312 |
+
this_target_pose[[0,1]] += action[-1]
|
| 313 |
+
perv_target_pose = this_target_pose
|
| 314 |
+
this_target_poses = np.expand_dims(this_target_pose, axis=0)
|
| 315 |
+
else:
|
| 316 |
+
this_target_poses = np.zeros((len(action), len(target_pose)), dtype=np.float64)
|
| 317 |
+
this_target_poses[:] = target_pose
|
| 318 |
+
this_target_poses[:,[0,1]] = action
|
| 319 |
+
|
| 320 |
+
# deal with timing
|
| 321 |
+
# the same step actions are always the target for
|
| 322 |
+
action_timestamps = (np.arange(len(action), dtype=np.float64) + action_offset
|
| 323 |
+
) * dt + obs_timestamps[-1]
|
| 324 |
+
action_exec_latency = 0.01
|
| 325 |
+
curr_time = time.time()
|
| 326 |
+
is_new = action_timestamps > (curr_time + action_exec_latency)
|
| 327 |
+
if np.sum(is_new) == 0:
|
| 328 |
+
# exceeded time budget, still do something
|
| 329 |
+
this_target_poses = this_target_poses[[-1]]
|
| 330 |
+
# schedule on next available step
|
| 331 |
+
next_step_idx = int(np.ceil((curr_time - eval_t_start) / dt))
|
| 332 |
+
action_timestamp = eval_t_start + (next_step_idx) * dt
|
| 333 |
+
print('Over budget', action_timestamp - curr_time)
|
| 334 |
+
action_timestamps = np.array([action_timestamp])
|
| 335 |
+
else:
|
| 336 |
+
this_target_poses = this_target_poses[is_new]
|
| 337 |
+
action_timestamps = action_timestamps[is_new]
|
| 338 |
+
|
| 339 |
+
# clip actions
|
| 340 |
+
this_target_poses[:,:2] = np.clip(
|
| 341 |
+
this_target_poses[:,:2], [0.25, -0.45], [0.77, 0.40])
|
| 342 |
+
|
| 343 |
+
# execute actions
|
| 344 |
+
env.exec_actions(
|
| 345 |
+
actions=this_target_poses,
|
| 346 |
+
timestamps=action_timestamps
|
| 347 |
+
)
|
| 348 |
+
print(f"Submitted {len(this_target_poses)} steps of actions.")
|
| 349 |
+
|
| 350 |
+
# visualize
|
| 351 |
+
episode_id = env.replay_buffer.n_episodes
|
| 352 |
+
vis_img = obs[f'camera_{vis_camera_idx}'][-1]
|
| 353 |
+
text = 'Episode: {}, Time: {:.1f}'.format(
|
| 354 |
+
episode_id, time.monotonic() - t_start
|
| 355 |
+
)
|
| 356 |
+
cv2.putText(
|
| 357 |
+
vis_img,
|
| 358 |
+
text,
|
| 359 |
+
(10,20),
|
| 360 |
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
| 361 |
+
fontScale=0.5,
|
| 362 |
+
thickness=1,
|
| 363 |
+
color=(255,255,255)
|
| 364 |
+
)
|
| 365 |
+
cv2.imshow('default', vis_img[...,::-1])
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
key_stroke = cv2.pollKey()
|
| 369 |
+
if key_stroke == ord('s'):
|
| 370 |
+
# Stop episode
|
| 371 |
+
# Hand control back to human
|
| 372 |
+
env.end_episode()
|
| 373 |
+
print('Stopped.')
|
| 374 |
+
break
|
| 375 |
+
|
| 376 |
+
# auto termination
|
| 377 |
+
terminate = False
|
| 378 |
+
if time.monotonic() - t_start > max_duration:
|
| 379 |
+
terminate = True
|
| 380 |
+
print('Terminated by the timeout!')
|
| 381 |
+
|
| 382 |
+
term_pose = np.array([ 3.40948500e-01, 2.17721816e-01, 4.59076878e-02, 2.22014183e+00, -2.22184883e+00, -4.07186655e-04])
|
| 383 |
+
curr_pose = obs['robot_eef_pose'][-1]
|
| 384 |
+
dist = np.linalg.norm((curr_pose - term_pose)[:2], axis=-1)
|
| 385 |
+
if dist < 0.03:
|
| 386 |
+
# in termination area
|
| 387 |
+
curr_timestamp = obs['timestamp'][-1]
|
| 388 |
+
if term_area_start_timestamp > curr_timestamp:
|
| 389 |
+
term_area_start_timestamp = curr_timestamp
|
| 390 |
+
else:
|
| 391 |
+
term_area_time = curr_timestamp - term_area_start_timestamp
|
| 392 |
+
if term_area_time > 0.5:
|
| 393 |
+
terminate = True
|
| 394 |
+
print('Terminated by the policy!')
|
| 395 |
+
else:
|
| 396 |
+
# out of the area
|
| 397 |
+
term_area_start_timestamp = float('inf')
|
| 398 |
+
|
| 399 |
+
if terminate:
|
| 400 |
+
env.end_episode()
|
| 401 |
+
break
|
| 402 |
+
|
| 403 |
+
# wait for execution
|
| 404 |
+
precise_wait(t_cycle_end - frame_latency)
|
| 405 |
+
iter_idx += steps_per_inference
|
| 406 |
+
|
| 407 |
+
except KeyboardInterrupt:
|
| 408 |
+
print("Interrupted!")
|
| 409 |
+
# stop robot.
|
| 410 |
+
env.end_episode()
|
| 411 |
+
|
| 412 |
+
print("Stopped.")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# %%
|
| 417 |
+
if __name__ == '__main__':
|
| 418 |
+
main()
|
third_party/diffusion_policy/multirun_metrics.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
import pathlib
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import numba
|
| 6 |
+
import click
|
| 7 |
+
import time
|
| 8 |
+
import collections
|
| 9 |
+
import json
|
| 10 |
+
import wandb
|
| 11 |
+
import yaml
|
| 12 |
+
import numbers
|
| 13 |
+
import scipy.ndimage as sn
|
| 14 |
+
from diffusion_policy.common.json_logger import read_json_log, JsonLogger
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
@numba.jit(nopython=True)
|
| 18 |
+
def get_indexed_window_average(
|
| 19 |
+
arr: np.ndarray, idxs: np.ndarray, window_size: int):
|
| 20 |
+
result = np.zeros(idxs.shape, dtype=arr.dtype)
|
| 21 |
+
length = arr.shape[0]
|
| 22 |
+
for i in range(len(idxs)):
|
| 23 |
+
idx = idxs[i]
|
| 24 |
+
start = max(idx - window_size, 0)
|
| 25 |
+
end = min(start + window_size, length)
|
| 26 |
+
result[i] = np.mean(arr[start:end])
|
| 27 |
+
return result
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def compute_metrics(log_df: pd.DataFrame, key: str,
|
| 31 |
+
end_step: Optional[int]=None,
|
| 32 |
+
k_min_loss: int=10,
|
| 33 |
+
k_around_max: int=10,
|
| 34 |
+
max_k_window: int=10,
|
| 35 |
+
replace_slash: int=True,
|
| 36 |
+
):
|
| 37 |
+
if key not in log_df:
|
| 38 |
+
return dict()
|
| 39 |
+
|
| 40 |
+
# prepare data
|
| 41 |
+
if end_step is not None:
|
| 42 |
+
log_df = log_df.iloc[:end_step]
|
| 43 |
+
is_key = ~pd.isnull(log_df[key])
|
| 44 |
+
is_key_idxs = is_key.index[is_key].to_numpy()
|
| 45 |
+
if len(is_key_idxs) == 0:
|
| 46 |
+
return dict()
|
| 47 |
+
|
| 48 |
+
key_data = log_df[key][is_key].to_numpy()
|
| 49 |
+
# after adding validation to workspace
|
| 50 |
+
# rollout happens at the last step of each epoch
|
| 51 |
+
# where the reported train_loss and val_loss
|
| 52 |
+
# are already the average for that epoch
|
| 53 |
+
train_loss = log_df['train_loss'][is_key].to_numpy()
|
| 54 |
+
val_loss = log_df['val_loss'][is_key].to_numpy()
|
| 55 |
+
|
| 56 |
+
result = dict()
|
| 57 |
+
|
| 58 |
+
log_key = key
|
| 59 |
+
if replace_slash:
|
| 60 |
+
log_key = key.replace('/', '_')
|
| 61 |
+
# max
|
| 62 |
+
max_value = np.max(key_data)
|
| 63 |
+
result['max/'+log_key] = max_value
|
| 64 |
+
|
| 65 |
+
# k_around_max
|
| 66 |
+
max_idx = np.argmax(key_data)
|
| 67 |
+
end = min(max_idx + k_around_max // 2, len(key_data))
|
| 68 |
+
start = max(end - k_around_max, 0)
|
| 69 |
+
k_around_max_value = np.mean(key_data[start:end])
|
| 70 |
+
result['k_around_max/'+log_key] = k_around_max_value
|
| 71 |
+
|
| 72 |
+
# max_k_window
|
| 73 |
+
k_window_value = sn.uniform_filter1d(key_data, size=max_k_window, axis=0, mode='nearest')
|
| 74 |
+
max_k_window_value = np.max(k_window_value)
|
| 75 |
+
result['max_k_window/'+log_key] = max_k_window_value
|
| 76 |
+
|
| 77 |
+
# min_train_loss
|
| 78 |
+
min_idx = np.argmin(train_loss)
|
| 79 |
+
min_train_loss_value = key_data[min_idx]
|
| 80 |
+
result['min_train_loss/'+log_key] = min_train_loss_value
|
| 81 |
+
|
| 82 |
+
# min_val_loss
|
| 83 |
+
min_idx = np.argmin(val_loss)
|
| 84 |
+
min_val_loss_value = key_data[min_idx]
|
| 85 |
+
result['min_val_loss/'+log_key] = min_val_loss_value
|
| 86 |
+
|
| 87 |
+
# k_min_train_loss
|
| 88 |
+
min_loss_idxs = np.argsort(train_loss)[:k_min_loss]
|
| 89 |
+
k_min_train_loss_value = np.mean(key_data[min_loss_idxs])
|
| 90 |
+
result['k_min_train_loss/'+log_key] = k_min_train_loss_value
|
| 91 |
+
|
| 92 |
+
# k_min_val_loss
|
| 93 |
+
min_loss_idxs = np.argsort(val_loss)[:k_min_loss]
|
| 94 |
+
k_min_val_loss_value = np.mean(key_data[min_loss_idxs])
|
| 95 |
+
result['k_min_val_loss/'+log_key] = k_min_val_loss_value
|
| 96 |
+
|
| 97 |
+
# last
|
| 98 |
+
result['last/'+log_key] = key_data[-1]
|
| 99 |
+
|
| 100 |
+
# global step for visualization
|
| 101 |
+
result['metric_global_step/'+log_key] = is_key_idxs[-1]
|
| 102 |
+
return result
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def compute_metrics_agg(
|
| 106 |
+
log_dfs: List[pd.DataFrame],
|
| 107 |
+
key: str, end_step:int,
|
| 108 |
+
**kwargs):
|
| 109 |
+
|
| 110 |
+
# compute metrics
|
| 111 |
+
results = collections.defaultdict(list)
|
| 112 |
+
for log_df in log_dfs:
|
| 113 |
+
result = compute_metrics(log_df, key=key, end_step=end_step, **kwargs)
|
| 114 |
+
for k, v in result.items():
|
| 115 |
+
results[k].append(v)
|
| 116 |
+
# agg
|
| 117 |
+
agg_result = dict()
|
| 118 |
+
for k, v in results.items():
|
| 119 |
+
value = np.mean(v)
|
| 120 |
+
if k.startswith('metric_global_step'):
|
| 121 |
+
value = int(value)
|
| 122 |
+
agg_result[k] = value
|
| 123 |
+
return agg_result
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@click.command()
|
| 127 |
+
@click.option('--input', '-i', required=True, help='Root logging dir, contains train_* dirs')
|
| 128 |
+
@click.option('--key', '-k', multiple=True, default=['test/mean_score'])
|
| 129 |
+
@click.option('--interval', default=10, type=float)
|
| 130 |
+
@click.option('--replace_slash', default=True, type=bool)
|
| 131 |
+
@click.option('--index_key', '-ik', default='global_step')
|
| 132 |
+
@click.option('--use_wandb', '-w', is_flag=True, default=False)
|
| 133 |
+
@click.option('--project', default=None)
|
| 134 |
+
@click.option('--name', default=None)
|
| 135 |
+
@click.option('--id', default=None)
|
| 136 |
+
@click.option('--group', default=None)
|
| 137 |
+
def main(
|
| 138 |
+
input,
|
| 139 |
+
key,
|
| 140 |
+
interval,
|
| 141 |
+
replace_slash,
|
| 142 |
+
index_key,
|
| 143 |
+
use_wandb,
|
| 144 |
+
# wandb args
|
| 145 |
+
project,
|
| 146 |
+
name,
|
| 147 |
+
id,
|
| 148 |
+
group):
|
| 149 |
+
root_dir = pathlib.Path(input)
|
| 150 |
+
assert root_dir.is_dir()
|
| 151 |
+
metrics_dir = root_dir.joinpath('metrics')
|
| 152 |
+
metrics_dir.mkdir(exist_ok=True)
|
| 153 |
+
|
| 154 |
+
logging.basicConfig(
|
| 155 |
+
level=logging.INFO,
|
| 156 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 157 |
+
handlers=[
|
| 158 |
+
logging.FileHandler(str(metrics_dir.joinpath("metrics.log"))),
|
| 159 |
+
logging.StreamHandler()
|
| 160 |
+
]
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
train_dirs = list(root_dir.glob('train_*'))
|
| 164 |
+
log_files = [x.joinpath('logs.json.txt') for x in train_dirs]
|
| 165 |
+
logging.info("Monitor waiting for log files!")
|
| 166 |
+
while True:
|
| 167 |
+
# wait for files to show up
|
| 168 |
+
files_exist = True
|
| 169 |
+
for log_file in log_files:
|
| 170 |
+
if not log_file.is_file():
|
| 171 |
+
files_exist = False
|
| 172 |
+
if files_exist:
|
| 173 |
+
break
|
| 174 |
+
time.sleep(1.0)
|
| 175 |
+
logging.info("All log files ready!")
|
| 176 |
+
|
| 177 |
+
# init path
|
| 178 |
+
metric_log_path = metrics_dir.joinpath('logs.json.txt')
|
| 179 |
+
metric_path = metrics_dir.joinpath('metrics.json')
|
| 180 |
+
config_path = root_dir.joinpath('config.yaml')
|
| 181 |
+
|
| 182 |
+
# load config
|
| 183 |
+
config = yaml.safe_load(config_path.open('r'))
|
| 184 |
+
|
| 185 |
+
# init wandb
|
| 186 |
+
wandb_run = None
|
| 187 |
+
if use_wandb:
|
| 188 |
+
wandb_kwargs = config['logging']
|
| 189 |
+
if project is not None:
|
| 190 |
+
wandb_kwargs['project'] = project
|
| 191 |
+
if id is not None:
|
| 192 |
+
wandb_kwargs['id'] = id
|
| 193 |
+
if name is not None:
|
| 194 |
+
wandb_kwargs['name'] = name
|
| 195 |
+
if group is not None:
|
| 196 |
+
wandb_kwargs['group'] = group
|
| 197 |
+
wandb_kwargs['resume'] = True
|
| 198 |
+
wandb_run = wandb.init(
|
| 199 |
+
dir=str(metrics_dir),
|
| 200 |
+
config=config,
|
| 201 |
+
# auto-resume run, automatically load id
|
| 202 |
+
# as long as using the same dir.
|
| 203 |
+
# https://docs.wandb.ai/guides/track/advanced/resuming#resuming-guidance
|
| 204 |
+
**wandb_kwargs
|
| 205 |
+
)
|
| 206 |
+
wandb.config.update(
|
| 207 |
+
{
|
| 208 |
+
"output_dir": str(root_dir),
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
with JsonLogger(metric_log_path) as json_logger:
|
| 213 |
+
last_log = json_logger.get_last_log()
|
| 214 |
+
while True:
|
| 215 |
+
# read json files
|
| 216 |
+
log_dfs = [read_json_log(str(x), required_keys=key) for x in log_files]
|
| 217 |
+
|
| 218 |
+
# previously logged data point
|
| 219 |
+
last_log_idx = -1
|
| 220 |
+
if last_log is not None:
|
| 221 |
+
last_log_idx = log_dfs[0].index[log_dfs[0][index_key] <= last_log[index_key]][-1]
|
| 222 |
+
|
| 223 |
+
start_idx = last_log_idx + 1
|
| 224 |
+
# last idx where we have a data point from all logs
|
| 225 |
+
end_idx = min(*[len(x) for x in log_dfs])
|
| 226 |
+
|
| 227 |
+
# log every position
|
| 228 |
+
for this_idx in range(start_idx, end_idx):
|
| 229 |
+
# compute metrics
|
| 230 |
+
all_metrics = dict()
|
| 231 |
+
global_step = log_dfs[0]['global_step'][this_idx]
|
| 232 |
+
epoch = log_dfs[0]['epoch'][this_idx]
|
| 233 |
+
all_metrics['global_step'] = global_step
|
| 234 |
+
all_metrics['epoch'] = epoch
|
| 235 |
+
for k in key:
|
| 236 |
+
metrics = compute_metrics_agg(
|
| 237 |
+
log_dfs=log_dfs, key=k, end_step=this_idx+1,
|
| 238 |
+
replace_slash=replace_slash)
|
| 239 |
+
all_metrics.update(metrics)
|
| 240 |
+
|
| 241 |
+
# sanitize metrics
|
| 242 |
+
old_metrics = all_metrics
|
| 243 |
+
all_metrics = dict()
|
| 244 |
+
for k, v in old_metrics.items():
|
| 245 |
+
if isinstance(v, numbers.Integral):
|
| 246 |
+
all_metrics[k] = int(v)
|
| 247 |
+
elif isinstance(v, numbers.Number):
|
| 248 |
+
all_metrics[k] = float(v)
|
| 249 |
+
|
| 250 |
+
has_update = all_metrics != last_log
|
| 251 |
+
if has_update:
|
| 252 |
+
last_log = all_metrics
|
| 253 |
+
json_logger.log(all_metrics)
|
| 254 |
+
|
| 255 |
+
with metric_path.open('w') as f:
|
| 256 |
+
json.dump(all_metrics, f, sort_keys=True, indent=2)
|
| 257 |
+
|
| 258 |
+
if wandb_run is not None:
|
| 259 |
+
wandb_run.log(all_metrics, step=all_metrics[index_key])
|
| 260 |
+
|
| 261 |
+
logging.info(f"Metrics logged at step {all_metrics[index_key]}")
|
| 262 |
+
|
| 263 |
+
time.sleep(interval)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
main()
|
third_party/diffusion_policy/pyrightconfig.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"exclude": [
|
| 3 |
+
"data/**",
|
| 4 |
+
"data_local/**",
|
| 5 |
+
"outputs/**"
|
| 6 |
+
]
|
| 7 |
+
}
|
third_party/diffusion_policy/ray_exec.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
Training:
|
| 4 |
+
python train.py --config-name=train_diffusion_lowdim_workspace -- logger.mode=online
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import ray
|
| 8 |
+
import click
|
| 9 |
+
|
| 10 |
+
def worker_fn(command_args, data_src=None, unbuffer_python=False, use_shell=False):
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import signal
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
# setup data symlink
|
| 17 |
+
if data_src is not None:
|
| 18 |
+
cwd = os.getcwd()
|
| 19 |
+
src = os.path.expanduser(data_src)
|
| 20 |
+
dst = os.path.join(cwd, 'data')
|
| 21 |
+
try:
|
| 22 |
+
os.symlink(src=src, dst=dst)
|
| 23 |
+
except FileExistsError:
|
| 24 |
+
# it's fine if it already exists
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
# run command
|
| 28 |
+
process_env = os.environ.copy()
|
| 29 |
+
if unbuffer_python:
|
| 30 |
+
# disable stdout/stderr buffering for subprocess (if python)
|
| 31 |
+
# to remove latency between print statement and receiving printed result
|
| 32 |
+
process_env['PYTHONUNBUFFERED'] = 'TRUE'
|
| 33 |
+
|
| 34 |
+
# ray worker masks out Ctrl-C signal (ie SIGINT)
|
| 35 |
+
# here we unblock this signal for the child process
|
| 36 |
+
def preexec_function():
|
| 37 |
+
import signal
|
| 38 |
+
signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGINT})
|
| 39 |
+
|
| 40 |
+
if use_shell:
|
| 41 |
+
command_args = ' '.join(command_args)
|
| 42 |
+
|
| 43 |
+
# stdout passthrough to ray worker, which is then passed to ray driver
|
| 44 |
+
process = subprocess.Popen(
|
| 45 |
+
args=command_args,
|
| 46 |
+
env=process_env,
|
| 47 |
+
preexec_fn=preexec_function,
|
| 48 |
+
shell=use_shell)
|
| 49 |
+
|
| 50 |
+
while process.poll() is None:
|
| 51 |
+
try:
|
| 52 |
+
# sleep to ensure that monitor thread can acquire gil
|
| 53 |
+
# and raise KeyboardInterrupt here.
|
| 54 |
+
time.sleep(0.01)
|
| 55 |
+
except KeyboardInterrupt:
|
| 56 |
+
process.send_signal(signal.SIGINT)
|
| 57 |
+
print('SIGINT sent to subprocess')
|
| 58 |
+
except Exception as e:
|
| 59 |
+
process.terminate()
|
| 60 |
+
raise e
|
| 61 |
+
|
| 62 |
+
if process.returncode not in (0, -2):
|
| 63 |
+
print("Failed execution!")
|
| 64 |
+
raise RuntimeError("Failed execution.")
|
| 65 |
+
return process.returncode
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@click.command()
|
| 69 |
+
@click.option('--ray_address', '-ra', default='auto')
|
| 70 |
+
@click.option('--num_cpus', '-nc', default=7, type=float)
|
| 71 |
+
@click.option('--num_gpus', '-ng', default=1, type=float)
|
| 72 |
+
@click.option('--max_retries', '-mr', default=0, type=int)
|
| 73 |
+
@click.option('--data_src', '-d', default='./data', type=str)
|
| 74 |
+
@click.option('--unbuffer_python', '-u', is_flag=True, default=False)
|
| 75 |
+
@click.argument('command_args', nargs=-1, type=str)
|
| 76 |
+
def main(ray_address,
|
| 77 |
+
num_cpus, num_gpus, max_retries,
|
| 78 |
+
data_src, unbuffer_python,
|
| 79 |
+
command_args):
|
| 80 |
+
# expand path
|
| 81 |
+
if data_src is not None:
|
| 82 |
+
data_src = os.path.abspath(os.path.expanduser(data_src))
|
| 83 |
+
|
| 84 |
+
# init ray
|
| 85 |
+
root_dir = os.path.dirname(__file__)
|
| 86 |
+
runtime_env = {
|
| 87 |
+
'working_dir': root_dir,
|
| 88 |
+
'excludes': ['.git']
|
| 89 |
+
}
|
| 90 |
+
ray.init(
|
| 91 |
+
address=ray_address,
|
| 92 |
+
runtime_env=runtime_env
|
| 93 |
+
)
|
| 94 |
+
# remote worker func
|
| 95 |
+
worker_ray = ray.remote(worker_fn).options(
|
| 96 |
+
num_cpus=num_cpus,
|
| 97 |
+
num_gpus=num_gpus,
|
| 98 |
+
max_retries=max_retries,
|
| 99 |
+
# resources=resources,
|
| 100 |
+
retry_exceptions=True
|
| 101 |
+
)
|
| 102 |
+
# run
|
| 103 |
+
task_ref = worker_ray.remote(command_args, data_src, unbuffer_python)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
# normal case
|
| 107 |
+
result = ray.get(task_ref)
|
| 108 |
+
print('Return code: ', result)
|
| 109 |
+
except KeyboardInterrupt:
|
| 110 |
+
# a KeyboardInterrupt will be raised in worker
|
| 111 |
+
ray.cancel(task_ref, force=False)
|
| 112 |
+
result = ray.get(task_ref)
|
| 113 |
+
print('Return code: ', result)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
# worker will be terminated
|
| 116 |
+
ray.cancel(task_ref, force=True)
|
| 117 |
+
raise e
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == '__main__':
|
| 121 |
+
main()
|
third_party/diffusion_policy/ray_train_multirun.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Start local ray cluster
|
| 3 |
+
(robodiff)$ export CUDA_VISIBLE_DEVICES=0,1,2 # select GPUs to be managed by the ray cluster
|
| 4 |
+
(robodiff)$ ray start --head --num-gpus=3
|
| 5 |
+
|
| 6 |
+
Training:
|
| 7 |
+
python ray_train_multirun.py --config-name=train_diffusion_unet_lowdim_workspace --seeds=42,43,44 --monitor_key=test/mean_score -- logger.mode=online training.eval_first=True
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import ray
|
| 11 |
+
import click
|
| 12 |
+
import hydra
|
| 13 |
+
import yaml
|
| 14 |
+
import wandb
|
| 15 |
+
import pathlib
|
| 16 |
+
import collections
|
| 17 |
+
from pprint import pprint
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
from ray_exec import worker_fn
|
| 20 |
+
from ray.util.placement_group import (
|
| 21 |
+
placement_group,
|
| 22 |
+
)
|
| 23 |
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
| 24 |
+
|
| 25 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 26 |
+
|
| 27 |
+
@click.command()
|
| 28 |
+
@click.option('--config-name', '-cn', required=True, type=str)
|
| 29 |
+
@click.option('--config-dir', '-cd', default=None, type=str)
|
| 30 |
+
@click.option('--seeds', '-s', default='42,43,44', type=str)
|
| 31 |
+
@click.option('--monitor_key', '-k', multiple=True, default=['test/mean_score'])
|
| 32 |
+
@click.option('--ray_address', '-ra', default='auto')
|
| 33 |
+
@click.option('--num_cpus', '-nc', default=7, type=float)
|
| 34 |
+
@click.option('--num_gpus', '-ng', default=1, type=float)
|
| 35 |
+
@click.option('--max_retries', '-mr', default=0, type=int)
|
| 36 |
+
@click.option('--monitor_max_retires', default=3, type=int)
|
| 37 |
+
@click.option('--data_src', '-d', default='./data', type=str)
|
| 38 |
+
@click.option('--unbuffer_python', '-u', is_flag=True, default=False)
|
| 39 |
+
@click.option('--single_node', '-sn', is_flag=True, default=False, help='run all experiments on a single machine')
|
| 40 |
+
@click.argument('command_args', nargs=-1, type=str)
|
| 41 |
+
def main(config_name, config_dir, seeds, monitor_key, ray_address,
|
| 42 |
+
num_cpus, num_gpus, max_retries, monitor_max_retires,
|
| 43 |
+
data_src, unbuffer_python,
|
| 44 |
+
single_node, command_args):
|
| 45 |
+
# parse args
|
| 46 |
+
seeds = [int(x) for x in seeds.split(',')]
|
| 47 |
+
# expand path
|
| 48 |
+
if data_src is not None:
|
| 49 |
+
data_src = os.path.abspath(os.path.expanduser(data_src))
|
| 50 |
+
|
| 51 |
+
# initialize hydra
|
| 52 |
+
if config_dir is None:
|
| 53 |
+
config_path_abs = pathlib.Path(__file__).parent.joinpath(
|
| 54 |
+
'diffusion_policy','config')
|
| 55 |
+
config_path_rel = str(config_path_abs.relative_to(pathlib.Path.cwd()))
|
| 56 |
+
else:
|
| 57 |
+
config_path_rel = config_dir
|
| 58 |
+
|
| 59 |
+
run_command_args = list()
|
| 60 |
+
monitor_command_args = list()
|
| 61 |
+
with hydra.initialize(
|
| 62 |
+
version_base=None,
|
| 63 |
+
config_path=config_path_rel):
|
| 64 |
+
|
| 65 |
+
# generate raw config
|
| 66 |
+
cfg = hydra.compose(
|
| 67 |
+
config_name=config_name,
|
| 68 |
+
overrides=command_args)
|
| 69 |
+
OmegaConf.resolve(cfg)
|
| 70 |
+
|
| 71 |
+
# manually create output dir
|
| 72 |
+
output_dir = pathlib.Path(cfg.multi_run.run_dir)
|
| 73 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
| 74 |
+
config_path = output_dir.joinpath('config.yaml')
|
| 75 |
+
print(output_dir)
|
| 76 |
+
|
| 77 |
+
# save current config
|
| 78 |
+
yaml.dump(OmegaConf.to_container(cfg, resolve=True),
|
| 79 |
+
config_path.open('w'), default_flow_style=False)
|
| 80 |
+
|
| 81 |
+
# wandb
|
| 82 |
+
wandb_group_id = wandb.util.generate_id()
|
| 83 |
+
name_base = cfg.multi_run.wandb_name_base
|
| 84 |
+
|
| 85 |
+
# create monitor command args
|
| 86 |
+
monitor_command_args = [
|
| 87 |
+
'python',
|
| 88 |
+
'multirun_metrics.py',
|
| 89 |
+
'--input', str(output_dir),
|
| 90 |
+
'--use_wandb',
|
| 91 |
+
'--project', 'diffusion_policy_metrics',
|
| 92 |
+
'--group', wandb_group_id
|
| 93 |
+
]
|
| 94 |
+
for k in monitor_key:
|
| 95 |
+
monitor_command_args.extend([
|
| 96 |
+
'--key', k
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
+
# generate command args
|
| 100 |
+
run_command_args = list()
|
| 101 |
+
for i, seed in enumerate(seeds):
|
| 102 |
+
test_start_seed = (seed + 1) * 100000
|
| 103 |
+
this_output_dir = output_dir.joinpath(f'train_{i}')
|
| 104 |
+
this_output_dir.mkdir()
|
| 105 |
+
wandb_name = name_base + f'_train_{i}'
|
| 106 |
+
wandb_run_id = wandb_group_id + f'_train_{i}'
|
| 107 |
+
|
| 108 |
+
this_command_args = [
|
| 109 |
+
'python',
|
| 110 |
+
'train.py',
|
| 111 |
+
'--config-name='+config_name,
|
| 112 |
+
'--config-dir='+config_path_rel
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
this_command_args.extend(command_args)
|
| 116 |
+
this_command_args.extend([
|
| 117 |
+
f'training.seed={seed}',
|
| 118 |
+
f'task.env_runner.test_start_seed={test_start_seed}',
|
| 119 |
+
f'logging.name={wandb_name}',
|
| 120 |
+
f'logging.id={wandb_run_id}',
|
| 121 |
+
f'logging.group={wandb_group_id}',
|
| 122 |
+
f'hydra.run.dir={this_output_dir}'
|
| 123 |
+
])
|
| 124 |
+
run_command_args.append(this_command_args)
|
| 125 |
+
|
| 126 |
+
# init ray
|
| 127 |
+
root_dir = os.path.dirname(__file__)
|
| 128 |
+
runtime_env = {
|
| 129 |
+
'working_dir': root_dir,
|
| 130 |
+
'excludes': ['.git'],
|
| 131 |
+
'pip': ['dm-control==1.0.9']
|
| 132 |
+
}
|
| 133 |
+
ray.init(
|
| 134 |
+
address=ray_address,
|
| 135 |
+
runtime_env=runtime_env
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# create resources for train
|
| 139 |
+
train_resources = dict()
|
| 140 |
+
|
| 141 |
+
train_bundle = dict(train_resources)
|
| 142 |
+
train_bundle['CPU'] = num_cpus
|
| 143 |
+
train_bundle['GPU'] = num_gpus
|
| 144 |
+
|
| 145 |
+
# create resources for monitor
|
| 146 |
+
monitor_resources = dict()
|
| 147 |
+
monitor_resources['CPU'] = 1
|
| 148 |
+
|
| 149 |
+
monitor_bundle = dict(monitor_resources)
|
| 150 |
+
|
| 151 |
+
# aggregate bundle
|
| 152 |
+
bundle = collections.defaultdict(lambda:0)
|
| 153 |
+
n_train_bundles = 1
|
| 154 |
+
if single_node:
|
| 155 |
+
n_train_bundles = len(seeds)
|
| 156 |
+
for _ in range(n_train_bundles):
|
| 157 |
+
for k, v in train_bundle.items():
|
| 158 |
+
bundle[k] += v
|
| 159 |
+
for k, v in monitor_bundle.items():
|
| 160 |
+
bundle[k] += v
|
| 161 |
+
bundle = dict(bundle)
|
| 162 |
+
|
| 163 |
+
# create placement group
|
| 164 |
+
print("Creating placement group with resources:")
|
| 165 |
+
pprint(bundle)
|
| 166 |
+
pg = placement_group([bundle])
|
| 167 |
+
|
| 168 |
+
# run
|
| 169 |
+
task_name_map = dict()
|
| 170 |
+
task_refs = list()
|
| 171 |
+
for i, this_command_args in enumerate(run_command_args):
|
| 172 |
+
if single_node or i == (len(run_command_args) - 1):
|
| 173 |
+
print(f'Training worker {i} with placement group.')
|
| 174 |
+
ray.get(pg.ready())
|
| 175 |
+
print("Placement Group created!")
|
| 176 |
+
worker_ray = ray.remote(worker_fn).options(
|
| 177 |
+
num_cpus=num_cpus,
|
| 178 |
+
num_gpus=num_gpus,
|
| 179 |
+
max_retries=max_retries,
|
| 180 |
+
resources=train_resources,
|
| 181 |
+
retry_exceptions=True,
|
| 182 |
+
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
| 183 |
+
placement_group=pg)
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
print(f'Training worker {i} without placement group.')
|
| 187 |
+
worker_ray = ray.remote(worker_fn).options(
|
| 188 |
+
num_cpus=num_cpus,
|
| 189 |
+
num_gpus=num_gpus,
|
| 190 |
+
max_retries=max_retries,
|
| 191 |
+
resources=train_resources,
|
| 192 |
+
retry_exceptions=True,
|
| 193 |
+
)
|
| 194 |
+
task_ref = worker_ray.remote(
|
| 195 |
+
this_command_args, data_src, unbuffer_python)
|
| 196 |
+
task_refs.append(task_ref)
|
| 197 |
+
task_name_map[task_ref] = f'train_{i}'
|
| 198 |
+
|
| 199 |
+
# monitor worker is always packed on the same node
|
| 200 |
+
# as training worker 0
|
| 201 |
+
ray.get(pg.ready())
|
| 202 |
+
monitor_worker_ray = ray.remote(worker_fn).options(
|
| 203 |
+
num_cpus=1,
|
| 204 |
+
num_gpus=0,
|
| 205 |
+
max_retries=monitor_max_retires,
|
| 206 |
+
# resources=monitor_resources,
|
| 207 |
+
retry_exceptions=True,
|
| 208 |
+
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
| 209 |
+
placement_group=pg)
|
| 210 |
+
)
|
| 211 |
+
monitor_ref = monitor_worker_ray.remote(
|
| 212 |
+
monitor_command_args, data_src, unbuffer_python)
|
| 213 |
+
task_name_map[monitor_ref] = 'metrics'
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
# normal case
|
| 217 |
+
ready_refs = list()
|
| 218 |
+
rest_refs = task_refs
|
| 219 |
+
while len(ready_refs) < len(task_refs):
|
| 220 |
+
this_ready_refs, rest_refs = ray.wait(rest_refs,
|
| 221 |
+
num_returns=1, timeout=None, fetch_local=True)
|
| 222 |
+
cancel_other_tasks = False
|
| 223 |
+
for ref in this_ready_refs:
|
| 224 |
+
task_name = task_name_map[ref]
|
| 225 |
+
try:
|
| 226 |
+
result = ray.get(ref)
|
| 227 |
+
print(f"Task {task_name} finished with result: {result}")
|
| 228 |
+
except KeyboardInterrupt as e:
|
| 229 |
+
# skip to outer try catch
|
| 230 |
+
raise KeyboardInterrupt
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"Task {task_name} raised exception: {e}")
|
| 233 |
+
this_cancel_other_tasks = True
|
| 234 |
+
if isinstance(e, ray.exceptions.RayTaskError):
|
| 235 |
+
if isinstance(e.cause, ray.exceptions.TaskCancelledError):
|
| 236 |
+
this_cancel_other_tasks = False
|
| 237 |
+
cancel_other_tasks = cancel_other_tasks or this_cancel_other_tasks
|
| 238 |
+
ready_refs.append(ref)
|
| 239 |
+
if cancel_other_tasks:
|
| 240 |
+
print('Exception! Cancelling all other tasks.')
|
| 241 |
+
# cancel all other refs
|
| 242 |
+
for _ref in rest_refs:
|
| 243 |
+
ray.cancel(_ref, force=False)
|
| 244 |
+
print("Training tasks done.")
|
| 245 |
+
ray.cancel(monitor_ref, force=False)
|
| 246 |
+
except KeyboardInterrupt:
|
| 247 |
+
print('KeyboardInterrupt received in the driver.')
|
| 248 |
+
# a KeyboardInterrupt will be raised in worker
|
| 249 |
+
_ = [ray.cancel(x, force=False) for x in task_refs + [monitor_ref]]
|
| 250 |
+
print('KeyboardInterrupt sent to workers.')
|
| 251 |
+
except Exception as e:
|
| 252 |
+
# worker will be terminated
|
| 253 |
+
_ = [ray.cancel(x, force=True) for x in task_refs + [monitor_ref]]
|
| 254 |
+
raise e
|
| 255 |
+
|
| 256 |
+
for ref in task_refs + [monitor_ref]:
|
| 257 |
+
task_name = task_name_map[ref]
|
| 258 |
+
try:
|
| 259 |
+
result = ray.get(ref)
|
| 260 |
+
print(f"Task {task_name} finished with result: {result}")
|
| 261 |
+
except KeyboardInterrupt as e:
|
| 262 |
+
# force kill everything.
|
| 263 |
+
print("Force killing all workers")
|
| 264 |
+
_ = [ray.cancel(x, force=True) for x in task_refs]
|
| 265 |
+
ray.cancel(monitor_ref, force=True)
|
| 266 |
+
except Exception as e:
|
| 267 |
+
print(f"Task {task_name} raised exception: {e}")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
main()
|
third_party/diffusion_policy/setup.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name = 'diffusion_policy',
|
| 5 |
+
packages = find_packages(),
|
| 6 |
+
)
|
third_party/diffusion_policy/tests/test_block_pushing.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
from diffusion_policy.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
|
| 9 |
+
from gym.wrappers import FlattenObservation
|
| 10 |
+
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
| 11 |
+
from diffusion_policy.gym_util.video_wrapper import VideoWrapper
|
| 12 |
+
|
| 13 |
+
def test():
|
| 14 |
+
env = MultiStepWrapper(
|
| 15 |
+
VideoWrapper(
|
| 16 |
+
FlattenObservation(
|
| 17 |
+
BlockPushMultimodal()
|
| 18 |
+
),
|
| 19 |
+
enabled=True,
|
| 20 |
+
steps_per_render=2
|
| 21 |
+
),
|
| 22 |
+
n_obs_steps=2,
|
| 23 |
+
n_action_steps=8,
|
| 24 |
+
max_episode_steps=16
|
| 25 |
+
)
|
| 26 |
+
env = BlockPushMultimodal()
|
| 27 |
+
obs = env.reset()
|
| 28 |
+
import pdb; pdb.set_trace()
|
| 29 |
+
|
| 30 |
+
env = FlattenObservation(BlockPushMultimodal())
|
| 31 |
+
obs = env.reset()
|
| 32 |
+
action = env.action_space.sample()
|
| 33 |
+
next_obs, reward, done, info = env.step(action)
|
| 34 |
+
print(obs[8:10] + action - next_obs[8:10])
|
| 35 |
+
import pdb; pdb.set_trace()
|
| 36 |
+
|
| 37 |
+
for i in range(3):
|
| 38 |
+
obs, reward, done, info = env.step(env.action_space.sample())
|
| 39 |
+
img = env.render()
|
| 40 |
+
import pdb; pdb.set_trace()
|
| 41 |
+
print("Done!", done)
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
test()
|
third_party/diffusion_policy/tests/test_cv2_util.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from diffusion_policy.common.cv2_util import get_image_transform
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test():
|
| 13 |
+
tf = get_image_transform((1280,720), (640,480), bgr_to_rgb=False)
|
| 14 |
+
in_img = np.zeros((720,1280,3), dtype=np.uint8)
|
| 15 |
+
out_img = tf(in_img)
|
| 16 |
+
# print(out_img.shape)
|
| 17 |
+
assert out_img.shape == (480,640,3)
|
| 18 |
+
|
| 19 |
+
# import pdb; pdb.set_trace()
|
| 20 |
+
|
| 21 |
+
if __name__ == '__main__':
|
| 22 |
+
test()
|
third_party/diffusion_policy/tests/test_multi_realsense.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import numpy as np
|
| 12 |
+
from diffusion_policy.real_world.multi_realsense import MultiRealsense
|
| 13 |
+
from diffusion_policy.real_world.video_recorder import VideoRecorder
|
| 14 |
+
|
| 15 |
+
def test():
|
| 16 |
+
config = json.load(open('/home/cchi/dev/diffusion_policy/diffusion_policy/real_world/realsense_config/415_high_accuracy_mode.json', 'r'))
|
| 17 |
+
|
| 18 |
+
def transform(data):
|
| 19 |
+
color = data['color']
|
| 20 |
+
h,w,_ = color.shape
|
| 21 |
+
factor = 4
|
| 22 |
+
color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA)
|
| 23 |
+
# color = color[:,140:500]
|
| 24 |
+
data['color'] = color
|
| 25 |
+
return data
|
| 26 |
+
|
| 27 |
+
from diffusion_policy.common.cv2_util import get_image_transform
|
| 28 |
+
color_transform = get_image_transform(
|
| 29 |
+
input_res=(1280,720),
|
| 30 |
+
output_res=(640,480),
|
| 31 |
+
bgr_to_rgb=False)
|
| 32 |
+
def transform(data):
|
| 33 |
+
data['color'] = color_transform(data['color'])
|
| 34 |
+
return data
|
| 35 |
+
|
| 36 |
+
# one thread per camera
|
| 37 |
+
video_recorder = VideoRecorder.create_h264(
|
| 38 |
+
fps=30,
|
| 39 |
+
codec='h264',
|
| 40 |
+
thread_type='FRAME'
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
with MultiRealsense(
|
| 44 |
+
resolution=(1280,720),
|
| 45 |
+
capture_fps=30,
|
| 46 |
+
record_fps=15,
|
| 47 |
+
enable_color=True,
|
| 48 |
+
# advanced_mode_config=config,
|
| 49 |
+
transform=transform,
|
| 50 |
+
# recording_transform=transform,
|
| 51 |
+
# video_recorder=video_recorder,
|
| 52 |
+
verbose=True
|
| 53 |
+
) as realsense:
|
| 54 |
+
realsense.set_exposure(exposure=150, gain=5)
|
| 55 |
+
intr = realsense.get_intrinsics()
|
| 56 |
+
print(intr)
|
| 57 |
+
|
| 58 |
+
video_path = 'data_local/test'
|
| 59 |
+
rec_start_time = time.time() + 1
|
| 60 |
+
realsense.start_recording(video_path, start_time=rec_start_time)
|
| 61 |
+
realsense.restart_put(rec_start_time)
|
| 62 |
+
|
| 63 |
+
out = None
|
| 64 |
+
vis_img = None
|
| 65 |
+
while True:
|
| 66 |
+
out = realsense.get(out=out)
|
| 67 |
+
|
| 68 |
+
# bgr = out['color']
|
| 69 |
+
# print(bgr.shape)
|
| 70 |
+
# vis_img = np.concatenate(list(bgr), axis=0, out=vis_img)
|
| 71 |
+
# cv2.imshow('default', vis_img)
|
| 72 |
+
# key = cv2.pollKey()
|
| 73 |
+
# if key == ord('q'):
|
| 74 |
+
# break
|
| 75 |
+
|
| 76 |
+
time.sleep(1/60)
|
| 77 |
+
if time.time() > (rec_start_time + 20.0):
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
test()
|
third_party/diffusion_policy/tests/test_pose_trajectory_interpolator.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import numpy as np
|
| 10 |
+
import scipy.interpolate as si
|
| 11 |
+
import scipy.spatial.transform as st
|
| 12 |
+
from diffusion_policy.common.pose_trajectory_interpolator import (
|
| 13 |
+
rotation_distance,
|
| 14 |
+
pose_distance,
|
| 15 |
+
PoseTrajectoryInterpolator)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_rotation_distance():
|
| 19 |
+
def rotation_distance_align(a: st.Rotation, b: st.Rotation) -> float:
|
| 20 |
+
return st.Rotation.align_vectors(b.as_matrix().T, a.as_matrix().T)[0].magnitude()
|
| 21 |
+
|
| 22 |
+
for i in range(10000):
|
| 23 |
+
a = st.Rotation.from_euler('xyz', np.random.uniform(-7,7,size=3))
|
| 24 |
+
b = st.Rotation.from_euler('xyz', np.random.uniform(-7,7,size=3))
|
| 25 |
+
x = rotation_distance(a, b)
|
| 26 |
+
y = rotation_distance_align(a, b)
|
| 27 |
+
assert abs(x-y) < 1e-7
|
| 28 |
+
|
| 29 |
+
def test_pose_trajectory_interpolator():
|
| 30 |
+
t = np.linspace(-1,5,100)
|
| 31 |
+
interp = PoseTrajectoryInterpolator(
|
| 32 |
+
[0,1,3],
|
| 33 |
+
np.zeros((3,6))
|
| 34 |
+
)
|
| 35 |
+
times = interp.times
|
| 36 |
+
poses = interp.poses
|
| 37 |
+
|
| 38 |
+
trimmed_interp = interp.trim(-1,4)
|
| 39 |
+
assert len(trimmed_interp.times) == 5
|
| 40 |
+
trimmed_interp(t)
|
| 41 |
+
|
| 42 |
+
trimmed_interp = interp.trim(-1,4)
|
| 43 |
+
assert len(trimmed_interp.times) == 5
|
| 44 |
+
trimmed_interp(t)
|
| 45 |
+
|
| 46 |
+
trimmed_interp = interp.trim(0.5, 3.5)
|
| 47 |
+
assert len(trimmed_interp.times) == 4
|
| 48 |
+
trimmed_interp(t)
|
| 49 |
+
|
| 50 |
+
trimmed_interp = interp.trim(0.5, 2.5)
|
| 51 |
+
assert len(trimmed_interp.times) == 3
|
| 52 |
+
trimmed_interp(t)
|
| 53 |
+
|
| 54 |
+
trimmed_interp = interp.trim(0.5, 1.5)
|
| 55 |
+
assert len(trimmed_interp.times) == 3
|
| 56 |
+
trimmed_interp(t)
|
| 57 |
+
|
| 58 |
+
trimmed_interp = interp.trim(1.2, 1.5)
|
| 59 |
+
assert len(trimmed_interp.times) == 2
|
| 60 |
+
trimmed_interp(t)
|
| 61 |
+
|
| 62 |
+
trimmed_interp = interp.trim(1.3, 1.3)
|
| 63 |
+
assert len(trimmed_interp.times) == 1
|
| 64 |
+
trimmed_interp(t)
|
| 65 |
+
|
| 66 |
+
# import pdb; pdb.set_trace()
|
| 67 |
+
|
| 68 |
+
def test_add_waypoint():
|
| 69 |
+
# fuzz testing
|
| 70 |
+
for i in tqdm(range(10000)):
|
| 71 |
+
rng = np.random.default_rng(i)
|
| 72 |
+
n_waypoints = rng.integers(1, 5)
|
| 73 |
+
waypoint_times = np.sort(rng.uniform(0, 1, size=n_waypoints))
|
| 74 |
+
last_waypoint_time = waypoint_times[-1]
|
| 75 |
+
insert_time = rng.uniform(-0.1, 1.1)
|
| 76 |
+
curr_time = rng.uniform(-0.1, 1.1)
|
| 77 |
+
max_pos_speed = rng.poisson(3) + 1e-3
|
| 78 |
+
max_rot_speed = rng.poisson(3) + 1e-3
|
| 79 |
+
waypoint_poses = rng.normal(0, 3, size=(n_waypoints, 6))
|
| 80 |
+
new_pose = rng.normal(0, 3, size=6)
|
| 81 |
+
|
| 82 |
+
if rng.random() < 0.1:
|
| 83 |
+
last_waypoint_time = None
|
| 84 |
+
if rng.random() < 0.1:
|
| 85 |
+
curr_time = None
|
| 86 |
+
|
| 87 |
+
interp = PoseTrajectoryInterpolator(
|
| 88 |
+
times=waypoint_times,
|
| 89 |
+
poses=waypoint_poses)
|
| 90 |
+
new_interp = interp.add_waypoint(
|
| 91 |
+
pose=new_pose,
|
| 92 |
+
time=insert_time,
|
| 93 |
+
max_pos_speed=max_pos_speed,
|
| 94 |
+
max_rot_speed=max_rot_speed,
|
| 95 |
+
curr_time=curr_time,
|
| 96 |
+
last_waypoint_time=last_waypoint_time
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def test_drive_to_waypoint():
|
| 100 |
+
# fuzz testing
|
| 101 |
+
for i in tqdm(range(10000)):
|
| 102 |
+
rng = np.random.default_rng(i)
|
| 103 |
+
n_waypoints = rng.integers(1, 5)
|
| 104 |
+
waypoint_times = np.sort(rng.uniform(0, 1, size=n_waypoints))
|
| 105 |
+
insert_time = rng.uniform(-0.1, 1.1)
|
| 106 |
+
curr_time = rng.uniform(-0.1, 1.1)
|
| 107 |
+
max_pos_speed = rng.poisson(3) + 1e-3
|
| 108 |
+
max_rot_speed = rng.poisson(3) + 1e-3
|
| 109 |
+
waypoint_poses = rng.normal(0, 3, size=(n_waypoints, 6))
|
| 110 |
+
new_pose = rng.normal(0, 3, size=6)
|
| 111 |
+
|
| 112 |
+
interp = PoseTrajectoryInterpolator(
|
| 113 |
+
times=waypoint_times,
|
| 114 |
+
poses=waypoint_poses)
|
| 115 |
+
new_interp = interp.drive_to_waypoint(
|
| 116 |
+
pose=new_pose,
|
| 117 |
+
time=insert_time,
|
| 118 |
+
curr_time=curr_time,
|
| 119 |
+
max_pos_speed=max_pos_speed,
|
| 120 |
+
max_rot_speed=max_rot_speed
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == '__main__':
|
| 126 |
+
test_drive_to_waypoint()
|
third_party/diffusion_policy/tests/test_precise_sleep.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import numpy as np
|
| 10 |
+
from diffusion_policy.common.precise_sleep import precise_sleep, precise_wait
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_sleep():
|
| 14 |
+
dt = 0.1
|
| 15 |
+
tol = 1e-3
|
| 16 |
+
time_samples = list()
|
| 17 |
+
for i in range(100):
|
| 18 |
+
precise_sleep(dt)
|
| 19 |
+
# time.sleep(dt)
|
| 20 |
+
time_samples.append(time.monotonic())
|
| 21 |
+
time_deltas = np.diff(time_samples)
|
| 22 |
+
|
| 23 |
+
from matplotlib import pyplot as plt
|
| 24 |
+
plt.plot(time_deltas)
|
| 25 |
+
plt.ylim((dt-tol,dt+tol))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_wait():
|
| 29 |
+
dt = 0.1
|
| 30 |
+
tol = 1e-3
|
| 31 |
+
errors = list()
|
| 32 |
+
t_start = time.monotonic()
|
| 33 |
+
for i in range(1,100):
|
| 34 |
+
t_end_desired = t_start + i * dt
|
| 35 |
+
time.sleep(t_end_desired - time.monotonic())
|
| 36 |
+
t_end = time.monotonic()
|
| 37 |
+
errors.append(t_end - t_end_desired)
|
| 38 |
+
|
| 39 |
+
new_errors = list()
|
| 40 |
+
t_start = time.monotonic()
|
| 41 |
+
for i in range(1,100):
|
| 42 |
+
t_end_desired = t_start + i * dt
|
| 43 |
+
precise_wait(t_end_desired)
|
| 44 |
+
t_end = time.monotonic()
|
| 45 |
+
new_errors.append(t_end - t_end_desired)
|
| 46 |
+
|
| 47 |
+
from matplotlib import pyplot as plt
|
| 48 |
+
plt.plot(errors, label='time.sleep')
|
| 49 |
+
plt.plot(new_errors, label='sleep/spin hybrid')
|
| 50 |
+
plt.ylim((-tol,+tol))
|
| 51 |
+
plt.title('0.1 sec sleep error')
|
| 52 |
+
plt.legend()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
test_sleep()
|
third_party/diffusion_policy/tests/test_replay_buffer.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import zarr
|
| 9 |
+
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
| 10 |
+
|
| 11 |
+
def test():
|
| 12 |
+
import numpy as np
|
| 13 |
+
buff = ReplayBuffer.create_empty_numpy()
|
| 14 |
+
buff.add_episode({
|
| 15 |
+
'obs': np.zeros((100,10), dtype=np.float16)
|
| 16 |
+
})
|
| 17 |
+
buff.add_episode({
|
| 18 |
+
'obs': np.ones((50,10)),
|
| 19 |
+
'action': np.ones((50,2))
|
| 20 |
+
})
|
| 21 |
+
# buff.rechunk(256)
|
| 22 |
+
obs = buff.get_episode(0)
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
buff = ReplayBuffer.create_empty_zarr()
|
| 26 |
+
buff.add_episode({
|
| 27 |
+
'obs': np.zeros((100,10), dtype=np.float16)
|
| 28 |
+
})
|
| 29 |
+
buff.add_episode({
|
| 30 |
+
'obs': np.ones((50,10)),
|
| 31 |
+
'action': np.ones((50,2))
|
| 32 |
+
})
|
| 33 |
+
obs = buff.get_episode(0)
|
| 34 |
+
buff.set_chunks({
|
| 35 |
+
'obs': (100,10),
|
| 36 |
+
'action': (100,2)
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_real():
|
| 41 |
+
import os
|
| 42 |
+
dist_group = zarr.open(
|
| 43 |
+
os.path.expanduser('~/dev/diffusion_policy/data/pusht/pusht_cchi_v2.zarr'), 'r')
|
| 44 |
+
|
| 45 |
+
buff = ReplayBuffer.create_empty_numpy()
|
| 46 |
+
key, group = next(iter(dist_group.items()))
|
| 47 |
+
for key, group in dist_group.items():
|
| 48 |
+
buff.add_episode(group)
|
| 49 |
+
|
| 50 |
+
# out_path = os.path.expanduser('~/dev/diffusion_policy/data/pusht_cchi2_v2_replay.zarr')
|
| 51 |
+
out_path = os.path.expanduser('~/dev/diffusion_policy/data/test.zarr')
|
| 52 |
+
out_store = zarr.DirectoryStore(out_path)
|
| 53 |
+
buff.save_to_store(out_store)
|
| 54 |
+
|
| 55 |
+
buff = ReplayBuffer.copy_from_path(out_path, store=zarr.MemoryStore())
|
| 56 |
+
buff.pop_episode()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_pop():
|
| 60 |
+
buff = ReplayBuffer.create_from_path(
|
| 61 |
+
'/home/chengchi/dev/diffusion_policy/data/pusht_cchi_v3_replay.zarr',
|
| 62 |
+
mode='rw')
|
third_party/diffusion_policy/tests/test_ring_buffer.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import numpy as np
|
| 10 |
+
import multiprocessing as mp
|
| 11 |
+
from multiprocessing.managers import SharedMemoryManager
|
| 12 |
+
from diffusion_policy.shared_memory.shared_memory_ring_buffer import (
|
| 13 |
+
SharedMemoryRingBuffer,
|
| 14 |
+
SharedAtomicCounter)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test():
|
| 18 |
+
shm_manager = SharedMemoryManager()
|
| 19 |
+
shm_manager.start()
|
| 20 |
+
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
|
| 21 |
+
shm_manager,
|
| 22 |
+
{'timestamp': np.array(0, dtype=np.float64)},
|
| 23 |
+
buffer_size=128
|
| 24 |
+
)
|
| 25 |
+
for i in range(30):
|
| 26 |
+
ring_buffer.put({
|
| 27 |
+
'timestamp': np.array(
|
| 28 |
+
time.perf_counter(),
|
| 29 |
+
dtype=np.float64)
|
| 30 |
+
})
|
| 31 |
+
print(ring_buffer.get())
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _timestamp_worker(ring_buffer, start_event, stop_event):
|
| 35 |
+
while not stop_event.is_set():
|
| 36 |
+
start_event.set()
|
| 37 |
+
ring_buffer.put({
|
| 38 |
+
'timestamp': np.array(
|
| 39 |
+
time.time(),
|
| 40 |
+
dtype=np.float64)
|
| 41 |
+
})
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_mp():
|
| 45 |
+
shm_manager = SharedMemoryManager()
|
| 46 |
+
shm_manager.start()
|
| 47 |
+
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
|
| 48 |
+
shm_manager,
|
| 49 |
+
{'timestamp': np.array(0, dtype=np.float64)},
|
| 50 |
+
get_max_k=1,
|
| 51 |
+
get_time_budget=0.01,
|
| 52 |
+
put_desired_frequency=1000
|
| 53 |
+
)
|
| 54 |
+
start_event = mp.Event()
|
| 55 |
+
stop_event = mp.Event()
|
| 56 |
+
worker = mp.Process(target=_timestamp_worker, args=(
|
| 57 |
+
ring_buffer, start_event, stop_event))
|
| 58 |
+
worker.start()
|
| 59 |
+
start_event.wait()
|
| 60 |
+
for i in range(1000):
|
| 61 |
+
t = float(ring_buffer.get()['timestamp'])
|
| 62 |
+
curr_t = time.time()
|
| 63 |
+
print('latency', curr_t - t)
|
| 64 |
+
stop_event.set()
|
| 65 |
+
worker.join()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_get_last_k():
|
| 69 |
+
shm_manager = SharedMemoryManager()
|
| 70 |
+
shm_manager.start()
|
| 71 |
+
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
|
| 72 |
+
shm_manager,
|
| 73 |
+
{'counter': np.array(0, dtype=np.int64)},
|
| 74 |
+
buffer_size=8
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
from collections import deque
|
| 78 |
+
k = 4
|
| 79 |
+
last_k = deque(maxlen=k)
|
| 80 |
+
for i in range(100):
|
| 81 |
+
ring_buffer.put({
|
| 82 |
+
'counter': np.array(i, dtype=np.int64)
|
| 83 |
+
})
|
| 84 |
+
last_k.append(i)
|
| 85 |
+
if i > k:
|
| 86 |
+
result = ring_buffer.get_last_k(k)['counter']
|
| 87 |
+
assert np.allclose(result, last_k)
|
| 88 |
+
|
| 89 |
+
print(ring_buffer.shared_arrays['counter'].get())
|
| 90 |
+
result = ring_buffer.get_last_k(4)
|
| 91 |
+
print(result)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_timing():
|
| 95 |
+
shm_manager = SharedMemoryManager()
|
| 96 |
+
shm_manager.start()
|
| 97 |
+
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
|
| 98 |
+
shm_manager,
|
| 99 |
+
{'counter': np.array(0, dtype=np.int64)},
|
| 100 |
+
get_max_k=8,
|
| 101 |
+
get_time_budget=0.1,
|
| 102 |
+
put_desired_frequency=100
|
| 103 |
+
)
|
| 104 |
+
# print(ring_buffer.timestamp_array.get())
|
| 105 |
+
print('buffer_size', ring_buffer.buffer_size)
|
| 106 |
+
|
| 107 |
+
dt = 1 / 150
|
| 108 |
+
t_init = time.monotonic()
|
| 109 |
+
for i in range(1000):
|
| 110 |
+
t_start = time.monotonic()
|
| 111 |
+
ring_buffer.put({
|
| 112 |
+
'counter': np.array(i, dtype=np.int64)
|
| 113 |
+
}, wait=False)
|
| 114 |
+
if (i % 10 == 0) and (i > 0):
|
| 115 |
+
result = ring_buffer.get_last_k(8)
|
| 116 |
+
|
| 117 |
+
t_end =time.monotonic()
|
| 118 |
+
desired_t = (i+1) * dt + t_init
|
| 119 |
+
if desired_t > t_end:
|
| 120 |
+
time.sleep(desired_t - t_end)
|
| 121 |
+
hz = 1 / (time.monotonic() - t_start)
|
| 122 |
+
print(f'{hz}Hz')
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _timestamp_image_worker(ring_buffer, img_shape, dt, start_event, stop_event):
|
| 126 |
+
i = 0
|
| 127 |
+
t_init = time.monotonic()
|
| 128 |
+
image = np.ones(img_shape, dtype=np.uint8)
|
| 129 |
+
while not stop_event.is_set():
|
| 130 |
+
t_start = time.monotonic()
|
| 131 |
+
start_event.set()
|
| 132 |
+
ring_buffer.put({
|
| 133 |
+
'img': image,
|
| 134 |
+
'timestamp': time.time(),
|
| 135 |
+
'counter': i
|
| 136 |
+
})
|
| 137 |
+
t_end = time.monotonic()
|
| 138 |
+
desired_t = (i+1) * dt + t_init
|
| 139 |
+
# print('alive')
|
| 140 |
+
if desired_t > t_end:
|
| 141 |
+
time.sleep(desired_t - t_end)
|
| 142 |
+
# hz = 1 / (time.monotonic() - t_start)
|
| 143 |
+
i += 1
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_timing_mp():
|
| 147 |
+
shm_manager = SharedMemoryManager()
|
| 148 |
+
shm_manager.start()
|
| 149 |
+
|
| 150 |
+
hz = 200
|
| 151 |
+
img_shape = (1920,1080,3)
|
| 152 |
+
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
|
| 153 |
+
shm_manager,
|
| 154 |
+
examples={
|
| 155 |
+
'img': np.zeros(img_shape, dtype=np.uint8),
|
| 156 |
+
'timestamp': time.time(),
|
| 157 |
+
'counter': 0
|
| 158 |
+
},
|
| 159 |
+
get_max_k=60,
|
| 160 |
+
get_time_budget=0.02,
|
| 161 |
+
put_desired_frequency=hz
|
| 162 |
+
)
|
| 163 |
+
start_event = mp.Event()
|
| 164 |
+
stop_event = mp.Event()
|
| 165 |
+
worker = mp.Process(target=_timestamp_image_worker, args=(
|
| 166 |
+
ring_buffer, img_shape, 1/hz, start_event, stop_event))
|
| 167 |
+
worker.start()
|
| 168 |
+
start_event.wait()
|
| 169 |
+
out = None
|
| 170 |
+
t_start = time.monotonic()
|
| 171 |
+
k = 1
|
| 172 |
+
for i in range(1000):
|
| 173 |
+
if ring_buffer.count < k:
|
| 174 |
+
time.sleep(0)
|
| 175 |
+
continue
|
| 176 |
+
out = ring_buffer.get_last_k(k=k, out=out)
|
| 177 |
+
t = float(out['timestamp'][-1])
|
| 178 |
+
curr_t = time.time()
|
| 179 |
+
print('latency', curr_t - t)
|
| 180 |
+
t_end = time.monotonic()
|
| 181 |
+
print('Get Hz', 1/(t_end-t_start)*1000)
|
| 182 |
+
stop_event.set()
|
| 183 |
+
worker.join()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
# test_mp()
|
| 188 |
+
test_timing_mp()
|
third_party/diffusion_policy/tests/test_robomimic_image_runner.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
from diffusion_policy.env_runner.robomimic_image_runner import RobomimicImageRunner
|
| 9 |
+
|
| 10 |
+
def test():
|
| 11 |
+
import os
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
|
| 14 |
+
cfg = OmegaConf.load(cfg_path)
|
| 15 |
+
cfg['n_obs_steps'] = 1
|
| 16 |
+
cfg['n_action_steps'] = 1
|
| 17 |
+
cfg['past_action_visible'] = False
|
| 18 |
+
runner_cfg = cfg['env_runner']
|
| 19 |
+
runner_cfg['n_train'] = 1
|
| 20 |
+
runner_cfg['n_test'] = 1
|
| 21 |
+
del runner_cfg['_target_']
|
| 22 |
+
runner = RobomimicImageRunner(
|
| 23 |
+
**runner_cfg,
|
| 24 |
+
output_dir='/tmp/test')
|
| 25 |
+
|
| 26 |
+
# import pdb; pdb.set_trace()
|
| 27 |
+
|
| 28 |
+
self = runner
|
| 29 |
+
env = self.env
|
| 30 |
+
env.seed(seeds=self.env_seeds)
|
| 31 |
+
obs = env.reset()
|
| 32 |
+
for i in range(10):
|
| 33 |
+
_ = env.step(env.action_space.sample())
|
| 34 |
+
|
| 35 |
+
imgs = env.render()
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
test()
|
third_party/diffusion_policy/tests/test_robomimic_lowdim_runner.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
from diffusion_policy.env_runner.robomimic_lowdim_runner import RobomimicLowdimRunner
|
| 9 |
+
|
| 10 |
+
def test():
|
| 11 |
+
import os
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_lowdim.yaml')
|
| 14 |
+
cfg = OmegaConf.load(cfg_path)
|
| 15 |
+
cfg['n_obs_steps'] = 1
|
| 16 |
+
cfg['n_action_steps'] = 1
|
| 17 |
+
cfg['past_action_visible'] = False
|
| 18 |
+
runner_cfg = cfg['env_runner']
|
| 19 |
+
runner_cfg['n_train'] = 1
|
| 20 |
+
runner_cfg['n_test'] = 0
|
| 21 |
+
del runner_cfg['_target_']
|
| 22 |
+
runner = RobomimicLowdimRunner(
|
| 23 |
+
**runner_cfg,
|
| 24 |
+
output_dir='/tmp/test')
|
| 25 |
+
|
| 26 |
+
# import pdb; pdb.set_trace()
|
| 27 |
+
|
| 28 |
+
self = runner
|
| 29 |
+
env = self.env
|
| 30 |
+
env.seed(seeds=self.env_seeds)
|
| 31 |
+
obs = env.reset()
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
test()
|
third_party/diffusion_policy/tests/test_shared_queue.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from multiprocessing.managers import SharedMemoryManager
|
| 10 |
+
from diffusion_policy.shared_memory.shared_memory_queue import SharedMemoryQueue, Full, Empty
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test():
|
| 14 |
+
shm_manager = SharedMemoryManager()
|
| 15 |
+
shm_manager.start()
|
| 16 |
+
example = {
|
| 17 |
+
'cmd': 0,
|
| 18 |
+
'pose': np.zeros((6,))
|
| 19 |
+
}
|
| 20 |
+
queue = SharedMemoryQueue.create_from_examples(
|
| 21 |
+
shm_manager=shm_manager,
|
| 22 |
+
examples=example,
|
| 23 |
+
buffer_size=3
|
| 24 |
+
)
|
| 25 |
+
raised = False
|
| 26 |
+
try:
|
| 27 |
+
queue.get()
|
| 28 |
+
except Empty:
|
| 29 |
+
raised = True
|
| 30 |
+
assert raised
|
| 31 |
+
|
| 32 |
+
data = {
|
| 33 |
+
'cmd': 1,
|
| 34 |
+
'pose': np.ones((6,))
|
| 35 |
+
}
|
| 36 |
+
queue.put(data)
|
| 37 |
+
result = queue.get()
|
| 38 |
+
assert result['cmd'] == data['cmd']
|
| 39 |
+
assert np.allclose(result['pose'], data['pose'])
|
| 40 |
+
|
| 41 |
+
queue.put(data)
|
| 42 |
+
queue.put(data)
|
| 43 |
+
queue.put(data)
|
| 44 |
+
assert queue.qsize() == 3
|
| 45 |
+
raised = False
|
| 46 |
+
try:
|
| 47 |
+
queue.put(data)
|
| 48 |
+
except Full:
|
| 49 |
+
raised = True
|
| 50 |
+
assert raised
|
| 51 |
+
|
| 52 |
+
result = queue.get_all()
|
| 53 |
+
assert np.allclose(result['cmd'], [1,1,1])
|
| 54 |
+
|
| 55 |
+
queue.put({'cmd': 0})
|
| 56 |
+
queue.put({'cmd': 1})
|
| 57 |
+
queue.put({'cmd': 2})
|
| 58 |
+
queue.get()
|
| 59 |
+
queue.put({'cmd': 3})
|
| 60 |
+
|
| 61 |
+
result = queue.get_k(3)
|
| 62 |
+
assert np.allclose(result['cmd'], [1,2,3])
|
| 63 |
+
|
| 64 |
+
queue.clear()
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
test()
|
third_party/diffusion_policy/tests/test_single_realsense.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
from multiprocessing.managers import SharedMemoryManager
|
| 12 |
+
from diffusion_policy.real_world.single_realsense import SingleRealsense
|
| 13 |
+
|
| 14 |
+
def test():
|
| 15 |
+
|
| 16 |
+
serials = SingleRealsense.get_connected_devices_serial()
|
| 17 |
+
# import pdb; pdb.set_trace()
|
| 18 |
+
serial = serials[0]
|
| 19 |
+
config = json.load(open('/home/cchi/dev/diffusion_policy/diffusion_policy/real_world/realsense_config/415_high_accuracy_mode.json', 'r'))
|
| 20 |
+
|
| 21 |
+
def transform(data):
|
| 22 |
+
color = data['color']
|
| 23 |
+
h,w,_ = color.shape
|
| 24 |
+
factor = 2
|
| 25 |
+
color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA)
|
| 26 |
+
# color = color[:,140:500]
|
| 27 |
+
data['color'] = color
|
| 28 |
+
return data
|
| 29 |
+
|
| 30 |
+
# at 960x540 with //3, 60fps and 30fps are indistinguishable
|
| 31 |
+
|
| 32 |
+
with SharedMemoryManager() as shm_manager:
|
| 33 |
+
with SingleRealsense(
|
| 34 |
+
shm_manager=shm_manager,
|
| 35 |
+
serial_number=serial,
|
| 36 |
+
resolution=(1280,720),
|
| 37 |
+
# resolution=(960,540),
|
| 38 |
+
# resolution=(640,480),
|
| 39 |
+
capture_fps=30,
|
| 40 |
+
enable_color=True,
|
| 41 |
+
# enable_depth=True,
|
| 42 |
+
# enable_infrared=True,
|
| 43 |
+
# advanced_mode_config=config,
|
| 44 |
+
# transform=transform,
|
| 45 |
+
# recording_transform=transform
|
| 46 |
+
# verbose=True
|
| 47 |
+
) as realsense:
|
| 48 |
+
cv2.setNumThreads(1)
|
| 49 |
+
realsense.set_exposure(exposure=150, gain=5)
|
| 50 |
+
intr = realsense.get_intrinsics()
|
| 51 |
+
print(intr)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
video_path = 'data_local/test.mp4'
|
| 55 |
+
rec_start_time = time.time() + 2
|
| 56 |
+
realsense.start_recording(video_path, start_time=rec_start_time)
|
| 57 |
+
|
| 58 |
+
data = None
|
| 59 |
+
while True:
|
| 60 |
+
data = realsense.get(out=data)
|
| 61 |
+
t = time.time()
|
| 62 |
+
# print('capture_latency', data['receive_timestamp']-data['capture_timestamp'], 'receive_latency', t - data['receive_timestamp'])
|
| 63 |
+
# print('receive', t - data['receive_timestamp'])
|
| 64 |
+
|
| 65 |
+
# dt = time.time() - data['timestamp']
|
| 66 |
+
# print(dt)
|
| 67 |
+
# print(data['capture_timestamp'] - rec_start_time)
|
| 68 |
+
|
| 69 |
+
bgr = data['color']
|
| 70 |
+
# print(bgr.shape)
|
| 71 |
+
cv2.imshow('default', bgr)
|
| 72 |
+
key = cv2.pollKey()
|
| 73 |
+
# if key == ord('q'):
|
| 74 |
+
# break
|
| 75 |
+
# elif key == ord('r'):
|
| 76 |
+
# video_path = 'data_local/test.mp4'
|
| 77 |
+
# realsense.start_recording(video_path)
|
| 78 |
+
# elif key == ord('s'):
|
| 79 |
+
# realsense.stop_recording()
|
| 80 |
+
|
| 81 |
+
time.sleep(1/60)
|
| 82 |
+
if time.time() > (rec_start_time + 20.0):
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
test()
|
third_party/diffusion_policy/tests/test_timestamp_accumulator.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
os.chdir(ROOT_DIR)
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import time
|
| 10 |
+
from diffusion_policy.common.timestamp_accumulator import (
|
| 11 |
+
get_accumulate_timestamp_idxs,
|
| 12 |
+
TimestampObsAccumulator,
|
| 13 |
+
TimestampActionAccumulator
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_index():
|
| 18 |
+
buffer = np.zeros(16)
|
| 19 |
+
start_time = 0.0
|
| 20 |
+
dt = 1/10
|
| 21 |
+
|
| 22 |
+
timestamps = np.linspace(0,1,100)
|
| 23 |
+
gi = list()
|
| 24 |
+
next_global_idx = 0
|
| 25 |
+
|
| 26 |
+
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
|
| 27 |
+
start_time=start_time, dt=dt, next_global_idx=next_global_idx)
|
| 28 |
+
assert local_idxs[0] == 0
|
| 29 |
+
assert global_idxs[0] == 0
|
| 30 |
+
# print(local_idxs)
|
| 31 |
+
# print(global_idxs)
|
| 32 |
+
# print(timestamps[local_idxs])
|
| 33 |
+
buffer[global_idxs] = timestamps[local_idxs]
|
| 34 |
+
gi.extend(global_idxs)
|
| 35 |
+
|
| 36 |
+
timestamps = np.linspace(0.5,1.5,100)
|
| 37 |
+
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
|
| 38 |
+
start_time=start_time, dt=dt, next_global_idx = next_global_idx)
|
| 39 |
+
# print(local_idxs)
|
| 40 |
+
# print(global_idxs)
|
| 41 |
+
# print(timestamps[local_idxs])
|
| 42 |
+
# import pdb; pdb.set_trace()
|
| 43 |
+
buffer[global_idxs] = timestamps[local_idxs]
|
| 44 |
+
gi.extend(global_idxs)
|
| 45 |
+
|
| 46 |
+
assert np.all(buffer[1:] > buffer[:-1])
|
| 47 |
+
assert np.all(np.array(gi) == np.array(list(range(len(gi)))))
|
| 48 |
+
# print(buffer)
|
| 49 |
+
|
| 50 |
+
# start over
|
| 51 |
+
next_global_idx = 0
|
| 52 |
+
timestamps = np.linspace(0,1,3)
|
| 53 |
+
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
|
| 54 |
+
start_time=start_time, dt=dt, next_global_idx = next_global_idx)
|
| 55 |
+
assert local_idxs[0] == 0
|
| 56 |
+
assert local_idxs[-1] == 2
|
| 57 |
+
# print(local_idxs)
|
| 58 |
+
# print(global_idxs)
|
| 59 |
+
# print(timestamps[local_idxs])
|
| 60 |
+
|
| 61 |
+
# test numerical error issue
|
| 62 |
+
# this becomes a problem when eps <= 1e-7
|
| 63 |
+
start_time = time.time()
|
| 64 |
+
next_global_idx = 0
|
| 65 |
+
timestamps = np.arange(100000) * dt + start_time
|
| 66 |
+
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
|
| 67 |
+
start_time=start_time, dt=dt, next_global_idx = next_global_idx)
|
| 68 |
+
assert local_idxs == global_idxs
|
| 69 |
+
# print(local_idxs)
|
| 70 |
+
# print(global_idxs)
|
| 71 |
+
# print(timestamps[local_idxs])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_obs_accumulator():
|
| 75 |
+
dt = 1/10
|
| 76 |
+
ddt = 1/100
|
| 77 |
+
n = 100
|
| 78 |
+
d = 6
|
| 79 |
+
start_time = time.time()
|
| 80 |
+
toa = TimestampObsAccumulator(start_time, dt)
|
| 81 |
+
poses = np.arange(n).reshape((n,1))
|
| 82 |
+
poses = np.repeat(poses, d, axis=1)
|
| 83 |
+
timestamps = np.arange(n) * ddt + start_time
|
| 84 |
+
|
| 85 |
+
toa.put({
|
| 86 |
+
'pose': poses,
|
| 87 |
+
'timestamp': timestamps
|
| 88 |
+
}, timestamps)
|
| 89 |
+
assert np.all(toa.data['pose'][:,0] == np.arange(10)*10)
|
| 90 |
+
assert len(toa) == 10
|
| 91 |
+
|
| 92 |
+
# add the same thing, result shouldn't change
|
| 93 |
+
toa.put({
|
| 94 |
+
'pose': poses,
|
| 95 |
+
'timestamp': timestamps
|
| 96 |
+
}, timestamps)
|
| 97 |
+
assert np.all(toa.data['pose'][:,0] == np.arange(10)*10)
|
| 98 |
+
assert len(toa) == 10
|
| 99 |
+
|
| 100 |
+
# add lower than desired freuquency to test fill_in
|
| 101 |
+
dt = 1/10
|
| 102 |
+
ddt = 1/5
|
| 103 |
+
n = 10
|
| 104 |
+
d = 6
|
| 105 |
+
start_time = time.time()
|
| 106 |
+
toa = TimestampObsAccumulator(start_time, dt)
|
| 107 |
+
poses = np.arange(n).reshape((n,1))
|
| 108 |
+
poses = np.repeat(poses, d, axis=1)
|
| 109 |
+
timestamps = np.arange(n) * ddt + start_time
|
| 110 |
+
|
| 111 |
+
toa.put({
|
| 112 |
+
'pose': poses,
|
| 113 |
+
'timestamp': timestamps
|
| 114 |
+
}, timestamps)
|
| 115 |
+
assert len(toa) == 1 + (n-1) * 2
|
| 116 |
+
|
| 117 |
+
timestamps = (np.arange(n) + 2) * ddt + start_time
|
| 118 |
+
toa.put({
|
| 119 |
+
'pose': poses,
|
| 120 |
+
'timestamp': timestamps
|
| 121 |
+
}, timestamps)
|
| 122 |
+
assert len(toa) == 1 + (n-1) * 2 + 4
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def test_action_accumulator():
|
| 126 |
+
dt = 1/10
|
| 127 |
+
n = 10
|
| 128 |
+
d = 6
|
| 129 |
+
start_time = time.time()
|
| 130 |
+
taa = TimestampActionAccumulator(start_time, dt)
|
| 131 |
+
actions = np.arange(n).reshape((n,1))
|
| 132 |
+
actions = np.repeat(actions, d, axis=1)
|
| 133 |
+
|
| 134 |
+
timestamps = np.arange(n) * dt + start_time
|
| 135 |
+
taa.put(actions, timestamps)
|
| 136 |
+
assert np.all(taa.actions == actions)
|
| 137 |
+
assert np.all(taa.timestamps == timestamps)
|
| 138 |
+
|
| 139 |
+
# add another round
|
| 140 |
+
taa.put(actions-5, timestamps-0.5)
|
| 141 |
+
assert np.allclose(taa.timestamps, timestamps)
|
| 142 |
+
|
| 143 |
+
# add another round
|
| 144 |
+
taa.put(actions+5, timestamps+0.5)
|
| 145 |
+
assert len(taa) == 15
|
| 146 |
+
assert np.all(taa.actions[:,0] == np.arange(15))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == '__main__':
|
| 151 |
+
test_action_accumulator()
|
third_party/diffusion_policy/train.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
Training:
|
| 4 |
+
python train.py --config-name=train_diffusion_lowdim_workspace
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
# use line-buffering for both stdout and stderr
|
| 9 |
+
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
| 10 |
+
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
| 11 |
+
|
| 12 |
+
import hydra
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
import pathlib
|
| 15 |
+
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
| 16 |
+
|
| 17 |
+
# allows arbitrary python code execution in configs using the ${eval:''} resolver
|
| 18 |
+
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
| 19 |
+
|
| 20 |
+
@hydra.main(
|
| 21 |
+
version_base=None,
|
| 22 |
+
config_path=str(pathlib.Path(__file__).parent.joinpath(
|
| 23 |
+
'diffusion_policy','config'))
|
| 24 |
+
)
|
| 25 |
+
def main(cfg: OmegaConf):
|
| 26 |
+
# resolve immediately so all the ${now:} resolvers
|
| 27 |
+
# will use the same time.
|
| 28 |
+
OmegaConf.resolve(cfg)
|
| 29 |
+
|
| 30 |
+
cls = hydra.utils.get_class(cfg._target_)
|
| 31 |
+
workspace: BaseWorkspace = cls(cfg)
|
| 32 |
+
workspace.run()
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|