lsnu commited on
Commit
a1fc554
·
verified ·
1 Parent(s): 5ce8761

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +11 -0
  2. docs/ENVIRONMENT_NOTES.md +42 -0
  3. models/pointflowmatch_take_shoes_out_of_box/1717447341-indigo-quokka/1717447341-indigo-quokka/config.yaml +80 -0
  4. reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/run.log +129 -0
  5. reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json +16 -0
  6. scripts/run_pointflowmatch_take_shoes_out_of_box.sh +25 -0
  7. third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/config.yaml +14 -0
  8. third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/hydra.yaml +159 -0
  9. third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/overrides.yaml +5 -0
  10. third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/config.yaml +14 -0
  11. third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/hydra.yaml +160 -0
  12. third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/overrides.yaml +6 -0
  13. third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/evaluate.log +3 -0
  14. third_party/diffusion_policy/.gitignore +140 -0
  15. third_party/diffusion_policy/LICENSE +21 -0
  16. third_party/diffusion_policy/README.md +437 -0
  17. third_party/diffusion_policy/conda_environment.yaml +65 -0
  18. third_party/diffusion_policy/conda_environment_macos.yaml +55 -0
  19. third_party/diffusion_policy/conda_environment_real.yaml +73 -0
  20. third_party/diffusion_policy/demo_pusht.py +120 -0
  21. third_party/diffusion_policy/demo_real_robot.py +160 -0
  22. third_party/diffusion_policy/diffusion_policy.egg-info/PKG-INFO +5 -0
  23. third_party/diffusion_policy/diffusion_policy.egg-info/SOURCES.txt +19 -0
  24. third_party/diffusion_policy/diffusion_policy.egg-info/dependency_links.txt +1 -0
  25. third_party/diffusion_policy/diffusion_policy.egg-info/top_level.txt +1 -0
  26. third_party/diffusion_policy/diffusion_policy/workspace/train_diffusion_unet_lowdim_workspace.py +306 -0
  27. third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_hybrid_workspace.py +283 -0
  28. third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_lowdim_workspace.py +282 -0
  29. third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_image_workspace.py +254 -0
  30. third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_lowdim_workspace.py +221 -0
  31. third_party/diffusion_policy/eval.py +64 -0
  32. third_party/diffusion_policy/eval_real_robot.py +418 -0
  33. third_party/diffusion_policy/multirun_metrics.py +267 -0
  34. third_party/diffusion_policy/pyrightconfig.json +7 -0
  35. third_party/diffusion_policy/ray_exec.py +121 -0
  36. third_party/diffusion_policy/ray_train_multirun.py +271 -0
  37. third_party/diffusion_policy/setup.py +6 -0
  38. third_party/diffusion_policy/tests/test_block_pushing.py +44 -0
  39. third_party/diffusion_policy/tests/test_cv2_util.py +22 -0
  40. third_party/diffusion_policy/tests/test_multi_realsense.py +82 -0
  41. third_party/diffusion_policy/tests/test_pose_trajectory_interpolator.py +126 -0
  42. third_party/diffusion_policy/tests/test_precise_sleep.py +56 -0
  43. third_party/diffusion_policy/tests/test_replay_buffer.py +62 -0
  44. third_party/diffusion_policy/tests/test_ring_buffer.py +188 -0
  45. third_party/diffusion_policy/tests/test_robomimic_image_runner.py +38 -0
  46. third_party/diffusion_policy/tests/test_robomimic_lowdim_runner.py +34 -0
  47. third_party/diffusion_policy/tests/test_shared_queue.py +67 -0
  48. third_party/diffusion_policy/tests/test_single_realsense.py +87 -0
  49. third_party/diffusion_policy/tests/test_timestamp_accumulator.py +151 -0
  50. third_party/diffusion_policy/train.py +35 -0
README.md CHANGED
@@ -7,11 +7,14 @@ iterations, with emphasis on RLBench2 `take tray out of oven`.
7
 
8
  - VLAarchTests benchmark code and generated public benchmark manifest
9
  - Patched AnyBimanual RLBench runtime used to execute the public oven benchmark
 
10
  - Official `katefgroup/3d_flowmatch_actor` PerAct2 checkpoint and public test data
 
11
  - Public AnyBimanual LF baseline weights and comparison logs
12
  - Verified benchmark reports:
13
  - oven subset run: `9/10`
14
  - oven full official run: `95/100 = 0.95`
 
15
  - hybrid public benchmark smoke outputs
16
  - DexGarmentLab benchmark-related validation scripts and validation logs
17
 
@@ -26,12 +29,20 @@ Official oven result artifacts:
26
  - `reports/3dfa_peract2_take_tray_out_of_oven_subset10/eval_after_official_ttm.json`
27
  - `reports/3dfa_peract2_take_tray_out_of_oven_full100/eval.json`
28
 
 
 
 
 
 
29
  ## Important Code Paths
30
 
31
  - `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py`
32
  - `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_rlbench_hybrid_smoke.py`
33
  - `third_party/AnyBimanual/third_party/RLBench/rlbench/bimanual_tasks/bimanual_take_tray_out_of_oven.py`
34
  - `third_party/AnyBimanual/third_party/RLBench/rlbench/task_ttms/bimanual_take_tray_out_of_oven.ttm`
 
 
 
35
 
36
  ## External Dependencies Not Mirrored Here
37
 
 
7
 
8
  - VLAarchTests benchmark code and generated public benchmark manifest
9
  - Patched AnyBimanual RLBench runtime used to execute the public oven benchmark
10
+ - Patched PointFlowMatch + `diffusion_policy` source snapshot used to execute `take shoes out of box`
11
  - Official `katefgroup/3d_flowmatch_actor` PerAct2 checkpoint and public test data
12
+ - Official PointFlowMatch `1717447341-indigo-quokka` checkpoint for `take_shoes_out_of_box`
13
  - Public AnyBimanual LF baseline weights and comparison logs
14
  - Verified benchmark reports:
15
  - oven subset run: `9/10`
16
  - oven full official run: `95/100 = 0.95`
17
+ - shoes GPU search: non-zero success verified before later simulator crash
18
  - hybrid public benchmark smoke outputs
19
  - DexGarmentLab benchmark-related validation scripts and validation logs
20
 
 
29
  - `reports/3dfa_peract2_take_tray_out_of_oven_subset10/eval_after_official_ttm.json`
30
  - `reports/3dfa_peract2_take_tray_out_of_oven_full100/eval.json`
31
 
32
+ Shoes result artifacts:
33
+
34
+ - `reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json`
35
+ - `reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/run.log`
36
+
37
  ## Important Code Paths
38
 
39
  - `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py`
40
  - `code/VLAarchtests4/code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_rlbench_hybrid_smoke.py`
41
  - `third_party/AnyBimanual/third_party/RLBench/rlbench/bimanual_tasks/bimanual_take_tray_out_of_oven.py`
42
  - `third_party/AnyBimanual/third_party/RLBench/rlbench/task_ttms/bimanual_take_tray_out_of_oven.ttm`
43
+ - `third_party/PointFlowMatch/pfp/envs/rlbench_env.py`
44
+ - `third_party/PointFlowMatch/pfp/policy/fm_policy.py`
45
+ - `scripts/run_pointflowmatch_take_shoes_out_of_box.sh`
46
 
47
  ## External Dependencies Not Mirrored Here
48
 
docs/ENVIRONMENT_NOTES.md CHANGED
@@ -45,3 +45,45 @@ xvfb-run -a -s "-screen 0 1400x900x24" python \
45
  --denoise_timesteps 5 \
46
  --denoise_model rectified_flow
47
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  --denoise_timesteps 5 \
46
  --denoise_model rectified_flow
47
  ```
48
+
49
+ The `take shoes out of box` validation path was run from the bundled
50
+ PointFlowMatch source on a Blackwell GPU machine after upgrading the RLBench
51
+ environment to Torch `2.11.0+cu128` / torchvision `0.26.0+cu128` /
52
+ torchaudio `2.11.0+cu128`. The bundled PointFlowMatch tree also contains two
53
+ local compatibility fixes required for this workspace:
54
+
55
+ - `pfp/envs/rlbench_env.py`
56
+ - adapts PointFlowMatch to the local RLBench camera naming and observation API
57
+ - broadens motion-planning failure recovery to handle simulator-side runtime failures
58
+ - `pfp/policy/fm_policy.py`
59
+ - adds an inference-only fallback when legacy `composer` cannot import on modern Torch
60
+ - loads checkpoints with `weights_only=False` for PyTorch 2.6+
61
+
62
+ The shoes evaluation command used here was:
63
+
64
+ ```bash
65
+ scripts/run_pointflowmatch_take_shoes_out_of_box.sh 10 50
66
+ ```
67
+
68
+ That wrapper expands to the equivalent raw command:
69
+
70
+ ```bash
71
+ export PYTHONPATH=third_party/diffusion_policy:third_party/PointFlowMatch:${PYTHONPATH:-}
72
+ export COPPELIASIM_ROOT=/path/to/CoppeliaSim
73
+ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:$COPPELIASIM_ROOT
74
+ export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
75
+ xvfb-run -a -s "-screen 0 1400x900x24" python \
76
+ third_party/PointFlowMatch/scripts/evaluate.py \
77
+ log_wandb=False \
78
+ env_runner.env_config.vis=False \
79
+ env_runner.num_episodes=10 \
80
+ env_runner.max_episode_length=200 \
81
+ policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka \
82
+ policy.num_k_infer=50
83
+ ```
84
+
85
+ Result note for shoes:
86
+
87
+ - `reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json`
88
+ records a verified non-zero result before a later RLBench/PyRep crash in the
89
+ same longer rollout.
models/pointflowmatch_take_shoes_out_of_box/1717447341-indigo-quokka/1717447341-indigo-quokka/config.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ _target_: pfp.policy.fm_policy.FMPolicy
3
+ x_dim: 266
4
+ y_dim: 10
5
+ n_obs_steps: 2
6
+ n_pred_steps: 32
7
+ num_k_infer: 10
8
+ time_conditioning: true
9
+ norm_pcd_center:
10
+ - 0.4
11
+ - 0.0
12
+ - 1.4
13
+ augment_data: false
14
+ noise_type: gaussian
15
+ noise_scale: 1.0
16
+ loss_type: l2
17
+ flow_schedule: exp
18
+ exp_scale: 4.0
19
+ snr_sampler: uniform
20
+ obs_encoder:
21
+ _target_: pfp.backbones.pointnet.PointNetBackbone
22
+ embed_dim: 256
23
+ input_channels: 3
24
+ input_transform: false
25
+ use_group_norm: false
26
+ diffusion_net:
27
+ _target_: diffusion_policy.model.diffusion.conditional_unet1d.ConditionalUnet1D
28
+ input_dim: 10
29
+ global_cond_dim: 532
30
+ diffusion_step_embed_dim: 256
31
+ down_dims:
32
+ - 256
33
+ - 512
34
+ - 1024
35
+ kernel_size: 5
36
+ n_groups: 8
37
+ cond_predict_scale: true
38
+ loss_weights:
39
+ xyz: 10.0
40
+ rot6d: 10.0
41
+ grip: 1.0
42
+ backbone:
43
+ _target_: pfp.backbones.pointnet.PointNetBackbone
44
+ embed_dim: 256
45
+ input_channels: 3
46
+ input_transform: false
47
+ use_group_norm: false
48
+ seed: 1234
49
+ epochs: 1500
50
+ log_wandb: true
51
+ task_name: take_shoes_out_of_box
52
+ obs_features_dim: 256
53
+ y_dim: 10
54
+ x_dim: 266
55
+ n_obs_steps: 2
56
+ n_pred_steps: 32
57
+ use_ema: true
58
+ save_each_n_epochs: 500
59
+ obs_mode: pcd
60
+ run_name: null
61
+ dataset:
62
+ n_obs_steps: 2
63
+ n_pred_steps: 32
64
+ subs_factor: 3
65
+ use_pc_color: false
66
+ n_points: 4096
67
+ dataloader:
68
+ batch_size: 128
69
+ num_workers: 8
70
+ optimizer:
71
+ _target_: torch.optim.AdamW
72
+ lr: 3.0e-05
73
+ betas:
74
+ - 0.95
75
+ - 0.999
76
+ eps: 1.0e-08
77
+ weight_decay: 1.0e-06
78
+ lr_scheduler:
79
+ name: cosine
80
+ num_warmup_steps: 5000
reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/run.log ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/10 [00:00<?, ?it/s][2026-04-03 00:44:29,033][root][WARNING] - single robot
 
1
  10%|█ | 1/10 [02:31<22:41, 151.33s/it]'NoneType' object has no attribute 'step'
 
 
 
 
 
 
 
 
 
2
  10%|█ | 1/10 [03:32<31:48, 212.03s/it]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /workspace/envs/rlbench/lib/python3.10/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.13) or chardet (7.4.0.post2)/charset_normalizer (2.1.1) doesn't match a supported version!
2
+ warnings.warn(
3
+ pytorch3d not installed
4
+ WARNING: Rerun not installed. Visualization will not work.
5
+ WARNING: Rerun not installed. Visualization will not work.
6
+ seed: 5678
7
+ log_wandb: false
8
+ env_runner:
9
+ num_episodes: 10
10
+ max_episode_length: 200
11
+ verbose: true
12
+ env_config:
13
+ voxel_size: 0.01
14
+ headless: true
15
+ vis: false
16
+ policy:
17
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
18
+ ckpt_episode: ep1500
19
+ num_k_infer: 50
20
+
21
+ seed: 5678
22
+ log_wandb: false
23
+ env_runner:
24
+ num_episodes: 10
25
+ max_episode_length: 200
26
+ verbose: true
27
+ env_config:
28
+ voxel_size: 0.01
29
+ headless: true
30
+ vis: false
31
+ task_name: take_shoes_out_of_box
32
+ obs_mode: pcd
33
+ use_pc_color: false
34
+ n_points: 4096
35
+ policy:
36
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
37
+ ckpt_episode: ep1500
38
+ num_k_infer: 50
39
+ _target_: pfp.policy.fm_policy.FMPolicy.load_from_checkpoint
40
+ model:
41
+ _target_: pfp.policy.fm_policy.FMPolicy
42
+ x_dim: 266
43
+ y_dim: 10
44
+ n_obs_steps: 2
45
+ n_pred_steps: 32
46
+ num_k_infer: 10
47
+ time_conditioning: true
48
+ norm_pcd_center:
49
+ - 0.4
50
+ - 0.0
51
+ - 1.4
52
+ augment_data: false
53
+ noise_type: gaussian
54
+ noise_scale: 1.0
55
+ loss_type: l2
56
+ flow_schedule: exp
57
+ exp_scale: 4.0
58
+ snr_sampler: uniform
59
+ obs_encoder:
60
+ _target_: pfp.backbones.pointnet.PointNetBackbone
61
+ embed_dim: 256
62
+ input_channels: 3
63
+ input_transform: false
64
+ use_group_norm: false
65
+ diffusion_net:
66
+ _target_: diffusion_policy.model.diffusion.conditional_unet1d.ConditionalUnet1D
67
+ input_dim: 10
68
+ global_cond_dim: 532
69
+ diffusion_step_embed_dim: 256
70
+ down_dims:
71
+ - 256
72
+ - 512
73
+ - 1024
74
+ kernel_size: 5
75
+ n_groups: 8
76
+ cond_predict_scale: true
77
+ loss_weights:
78
+ xyz: 10.0
79
+ rot6d: 10.0
80
+ grip: 1.0
81
+
82
+ output_dim 10
83
+ [2026-04-03 00:44:27,184][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
84
+ QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
85
+
86
  0%| | 0/10 [00:00<?, ?it/s][2026-04-03 00:44:29,033][root][WARNING] - single robot
87
+
88
  10%|█ | 1/10 [02:31<22:41, 151.33s/it]'NoneType' object has no attribute 'step'
89
+ 'NoneType' object has no attribute 'step'
90
+ 'NoneType' object has no attribute 'step'
91
+ 'NoneType' object has no attribute 'step'
92
+ 'NoneType' object has no attribute 'step'
93
+ 'NoneType' object has no attribute 'step'
94
+ Steps: 182
95
+ Success: True
96
+ [2026-04-03 00:47:00,210][root][WARNING] - single robot
97
+
98
  10%|█ | 1/10 [03:32<31:48, 212.03s/it]
99
+ 'NoneType' object has no attribute 'step'
100
+ 'NoneType' object has no attribute 'step'
101
+ Error executing job with overrides: ['log_wandb=False', 'env_runner.env_config.vis=False', 'env_runner.num_episodes=10', 'env_runner.max_episode_length=200', 'policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka', 'policy.num_k_infer=50']
102
+ Traceback (most recent call last):
103
+ File "/workspace/third_party/PointFlowMatch/scripts/evaluate.py", line 52, in main
104
+ _ = env_runner.run(policy)
105
+ File "/workspace/third_party/PointFlowMatch/pfp/envs/rlbench_runner.py", line 35, in run
106
+ reward, terminate = self.env.step(next_robot_state)
107
+ File "/workspace/third_party/PointFlowMatch/pfp/envs/rlbench_env.py", line 105, in step
108
+ reward, terminate = self._step_safe(action)
109
+ File "/workspace/third_party/PointFlowMatch/pfp/envs/rlbench_env.py", line 113, in _step_safe
110
+ _, reward, terminate = self.task.step(action)
111
+ File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/task_environment.py", line 109, in step
112
+ self._action_mode.action(self._scene, action)
113
+ File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/action_modes/action_mode.py", line 47, in action
114
+ self.arm_action_mode.action(scene, arm_action, ignore_collisions)
115
+ File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/action_modes/arm_action_modes.py", line 304, in action
116
+ success, terminate = scene.task.success()
117
+ File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/backend/task.py", line 303, in success
118
+ [cond.condition_met()[0] for cond in self._success_conditions])
119
+ File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/backend/task.py", line 303, in <listcomp>
120
+ [cond.condition_met()[0] for cond in self._success_conditions])
121
+ File "/workspace/envs/rlbench/lib/python3.10/site-packages/rlbench/backend/conditions.py", line 51, in condition_met
122
+ met = self._detector.is_detected(self._obj)
123
+ File "/workspace/third_party/PyRep/pyrep/objects/proximity_sensor.py", line 36, in is_detected
124
+ state, point = sim.simCheckProximitySensor(
125
+ File "/workspace/third_party/PyRep/pyrep/backend/sim.py", line 345, in simCheckProximitySensor
126
+ _check_return(state)
127
+ File "/workspace/third_party/PyRep/pyrep/backend/sim.py", line 27, in _check_return
128
+ raise RuntimeError(
129
+ RuntimeError: The call failed on the V-REP side. Return value: -1
130
+
131
+ Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
132
+ QMutex: destroying locked mutex
reports/pointflowmatch_take_shoes_out_of_box_ep10_k50_gpu/summary.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_name": "take_shoes_out_of_box",
3
+ "model_family": "PointFlowMatch",
4
+ "checkpoint_name": "1717447341-indigo-quokka",
5
+ "checkpoint_source": "http://pointflowmatch.cs.uni-freiburg.de/download/1717447341-indigo-quokka.zip",
6
+ "num_k_infer": 50,
7
+ "requested_episodes": 10,
8
+ "completed_episodes_before_crash": 1,
9
+ "success_count_before_crash": 1,
10
+ "first_success_episode_index": 0,
11
+ "first_success_steps": 182,
12
+ "verified_non_zero_performance": true,
13
+ "run_status": "partial_success_then_crash",
14
+ "crash_reason": "RLBench/PyRep RuntimeError after first successful episode during continued rollout",
15
+ "report_log": "run.log"
16
+ }
scripts/run_pointflowmatch_take_shoes_out_of_box.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ EPISODES="${1:-10}"
5
+ NUM_K_INFER="${2:-50}"
6
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
7
+ PYTHON_BIN="${PYTHON_BIN:-python}"
8
+ COPPELIASIM_ROOT="${COPPELIASIM_ROOT:?set COPPELIASIM_ROOT to your CoppeliaSim root}"
9
+ REPORT_DIR="${REPORT_DIR:-$ROOT/reports/pointflowmatch_take_shoes_out_of_box_ep${EPISODES}_k${NUM_K_INFER}_gpu}"
10
+
11
+ export PYTHONPATH="$ROOT/third_party/diffusion_policy:$ROOT/third_party/PointFlowMatch:${PYTHONPATH:-}"
12
+ export LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}:$COPPELIASIM_ROOT"
13
+ export QT_QPA_PLATFORM_PLUGIN_PATH="$COPPELIASIM_ROOT"
14
+
15
+ mkdir -p "$REPORT_DIR"
16
+ cd "$ROOT/third_party/PointFlowMatch"
17
+
18
+ xvfb-run -a -s "-screen 0 1400x900x24" "$PYTHON_BIN" scripts/evaluate.py \
19
+ log_wandb=False \
20
+ env_runner.env_config.vis=False \
21
+ env_runner.num_episodes="$EPISODES" \
22
+ env_runner.max_episode_length=200 \
23
+ policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka \
24
+ policy.num_k_infer="$NUM_K_INFER" \
25
+ 2>&1 | tee "$REPORT_DIR/run.log"
third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 5678
2
+ log_wandb: false
3
+ env_runner:
4
+ num_episodes: 1
5
+ max_episode_length: 200
6
+ verbose: true
7
+ env_config:
8
+ voxel_size: 0.01
9
+ headless: true
10
+ vis: false
11
+ policy:
12
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
13
+ ckpt_episode: ep1500
14
+ num_k_infer: 50
third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/hydra.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - log_wandb=False
116
+ - env_runner.env_config.vis=False
117
+ - env_runner.num_episodes=1
118
+ - env_runner.max_episode_length=200
119
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
120
+ job:
121
+ name: evaluate
122
+ chdir: null
123
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=1,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
124
+ id: ???
125
+ num: ???
126
+ config_name: eval
127
+ env_set: {}
128
+ env_copy: []
129
+ config:
130
+ override_dirname:
131
+ kv_sep: '='
132
+ item_sep: ','
133
+ exclude_keys: []
134
+ runtime:
135
+ version: 1.3.2
136
+ version_base: '1.3'
137
+ cwd: /workspace/third_party/PointFlowMatch
138
+ config_sources:
139
+ - path: hydra.conf
140
+ schema: pkg
141
+ provider: hydra
142
+ - path: /workspace/third_party/PointFlowMatch/conf
143
+ schema: file
144
+ provider: main
145
+ - path: ''
146
+ schema: structured
147
+ provider: schema
148
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-25-21
149
+ choices:
150
+ hydra/env: default
151
+ hydra/callbacks: null
152
+ hydra/job_logging: default
153
+ hydra/hydra_logging: default
154
+ hydra/hydra_help: default
155
+ hydra/help: default
156
+ hydra/sweeper: basic
157
+ hydra/launcher: basic
158
+ hydra/output: default
159
+ verbose: false
third_party/PointFlowMatch/outputs/2026-04-03/00-25-21/.hydra/overrides.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ - log_wandb=False
2
+ - env_runner.env_config.vis=False
3
+ - env_runner.num_episodes=1
4
+ - env_runner.max_episode_length=200
5
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 5678
2
+ log_wandb: false
3
+ env_runner:
4
+ num_episodes: 10
5
+ max_episode_length: 200
6
+ verbose: true
7
+ env_config:
8
+ voxel_size: 0.01
9
+ headless: true
10
+ vis: false
11
+ policy:
12
+ ckpt_name: 1717447341-indigo-quokka/1717447341-indigo-quokka
13
+ ckpt_episode: ep1500
14
+ num_k_infer: 50
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/hydra.yaml ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - log_wandb=False
116
+ - env_runner.env_config.vis=False
117
+ - env_runner.num_episodes=10
118
+ - env_runner.max_episode_length=200
119
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
120
+ - policy.num_k_infer=50
121
+ job:
122
+ name: evaluate
123
+ chdir: null
124
+ override_dirname: env_runner.env_config.vis=False,env_runner.max_episode_length=200,env_runner.num_episodes=10,log_wandb=False,policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka,policy.num_k_infer=50
125
+ id: ???
126
+ num: ???
127
+ config_name: eval
128
+ env_set: {}
129
+ env_copy: []
130
+ config:
131
+ override_dirname:
132
+ kv_sep: '='
133
+ item_sep: ','
134
+ exclude_keys: []
135
+ runtime:
136
+ version: 1.3.2
137
+ version_base: '1.3'
138
+ cwd: /workspace/third_party/PointFlowMatch
139
+ config_sources:
140
+ - path: hydra.conf
141
+ schema: pkg
142
+ provider: hydra
143
+ - path: /workspace/third_party/PointFlowMatch/conf
144
+ schema: file
145
+ provider: main
146
+ - path: ''
147
+ schema: structured
148
+ provider: schema
149
+ output_dir: /workspace/third_party/PointFlowMatch/outputs/2026-04-03/00-44-21
150
+ choices:
151
+ hydra/env: default
152
+ hydra/callbacks: null
153
+ hydra/job_logging: default
154
+ hydra/hydra_logging: default
155
+ hydra/hydra_help: default
156
+ hydra/help: default
157
+ hydra/sweeper: basic
158
+ hydra/launcher: basic
159
+ hydra/output: default
160
+ verbose: false
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/.hydra/overrides.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ - log_wandb=False
2
+ - env_runner.env_config.vis=False
3
+ - env_runner.num_episodes=10
4
+ - env_runner.max_episode_length=200
5
+ - policy.ckpt_name=1717447341-indigo-quokka/1717447341-indigo-quokka
6
+ - policy.num_k_infer=50
third_party/PointFlowMatch/outputs/2026-04-03/00-44-21/evaluate.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [2026-04-03 00:44:27,184][diffusion_policy.model.diffusion.conditional_unet1d][INFO] - number of parameters: 7.285095e+07
2
+ [2026-04-03 00:44:29,033][root][WARNING] - single robot
3
+ [2026-04-03 00:47:00,210][root][WARNING] - single robot
third_party/diffusion_policy/.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bin
2
+ logs
3
+ wandb
4
+ outputs
5
+ data
6
+ data_local
7
+ .vscode
8
+ _wandb
9
+
10
+ **/.DS_Store
11
+
12
+ fuse.cfg
13
+
14
+ *.ai
15
+
16
+ # Generation results
17
+ results/
18
+
19
+ ray/auth.json
20
+
21
+ # Byte-compiled / optimized / DLL files
22
+ __pycache__/
23
+ *.py[cod]
24
+ *$py.class
25
+
26
+ # C extensions
27
+ *.so
28
+
29
+ # Distribution / packaging
30
+ .Python
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ pip-wheel-metadata/
44
+ share/python-wheels/
45
+ *.egg-info/
46
+ .installed.cfg
47
+ *.egg
48
+ MANIFEST
49
+
50
+ # PyInstaller
51
+ # Usually these files are written by a python script from a template
52
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
53
+ *.manifest
54
+ *.spec
55
+
56
+ # Installer logs
57
+ pip-log.txt
58
+ pip-delete-this-directory.txt
59
+
60
+ # Unit test / coverage reports
61
+ htmlcov/
62
+ .tox/
63
+ .nox/
64
+ .coverage
65
+ .coverage.*
66
+ .cache
67
+ nosetests.xml
68
+ coverage.xml
69
+ *.cover
70
+ *.py,cover
71
+ .hypothesis/
72
+ .pytest_cache/
73
+
74
+ # Translations
75
+ *.mo
76
+ *.pot
77
+
78
+ # Django stuff:
79
+ *.log
80
+ local_settings.py
81
+ db.sqlite3
82
+ db.sqlite3-journal
83
+
84
+ # Flask stuff:
85
+ instance/
86
+ .webassets-cache
87
+
88
+ # Scrapy stuff:
89
+ .scrapy
90
+
91
+ # Sphinx documentation
92
+ docs/_build/
93
+
94
+ # PyBuilder
95
+ target/
96
+
97
+ # Jupyter Notebook
98
+ .ipynb_checkpoints
99
+
100
+ # IPython
101
+ profile_default/
102
+ ipython_config.py
103
+
104
+ # pyenv
105
+ .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
third_party/diffusion_policy/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Columbia Artificial Intelligence and Robotics Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/diffusion_policy/README.md ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusion Policy
2
+
3
+ [[Project page]](https://diffusion-policy.cs.columbia.edu/)
4
+ [[Paper]](https://diffusion-policy.cs.columbia.edu/#paper)
5
+ [[Data]](https://diffusion-policy.cs.columbia.edu/data/)
6
+ [[Colab (state)]](https://colab.research.google.com/drive/1gxdkgRVfM55zihY9TFLja97cSVZOZq2B?usp=sharing)
7
+ [[Colab (vision)]](https://colab.research.google.com/drive/18GIHeOQ5DyjMN8iIRZL2EKZ0745NLIpg?usp=sharing)
8
+
9
+
10
+ [Cheng Chi](http://cheng-chi.github.io/)<sup>1</sup>,
11
+ [Siyuan Feng](https://www.cs.cmu.edu/~sfeng/)<sup>2</sup>,
12
+ [Yilun Du](https://yilundu.github.io/)<sup>3</sup>,
13
+ [Zhenjia Xu](https://www.zhenjiaxu.com/)<sup>1</sup>,
14
+ [Eric Cousineau](https://www.eacousineau.com/)<sup>2</sup>,
15
+ [Benjamin Burchfiel](http://www.benburchfiel.com/)<sup>2</sup>,
16
+ [Shuran Song](https://www.cs.columbia.edu/~shurans/)<sup>1</sup>
17
+
18
+ <sup>1</sup>Columbia University,
19
+ <sup>2</sup>Toyota Research Institute,
20
+ <sup>3</sup>MIT
21
+
22
+ <img src="media/teaser.png" alt="drawing" width="100%"/>
23
+ <img src="media/multimodal_sim.png" alt="drawing" width="100%"/>
24
+
25
+ ## 🛝 Try it out!
26
+ Our self-contained Google Colab notebooks is the easiest way to play with Diffusion Policy. We provide separate notebooks for [state-based environment](https://colab.research.google.com/drive/1gxdkgRVfM55zihY9TFLja97cSVZOZq2B?usp=sharing) and [vision-based environment](https://colab.research.google.com/drive/18GIHeOQ5DyjMN8iIRZL2EKZ0745NLIpg?usp=sharing).
27
+
28
+ ## 🧾 Checkout our experiment logs!
29
+ For each experiment used to generate Table I,II and IV in the [paper](https://diffusion-policy.cs.columbia.edu/#paper), we provide:
30
+ 1. A `config.yaml` that contains all parameters needed to reproduce the experiment.
31
+ 2. Detailed training/eval `logs.json.txt` for every training step.
32
+ 3. Checkpoints for the best `epoch=*-test_mean_score=*.ckpt` and last `latest.ckpt` epoch of each run.
33
+
34
+ Experiment logs are hosted on our website as nested directories in format:
35
+ `https://diffusion-policy.cs.columbia.edu/data/experiments/<image|low_dim>/<task>/<method>/`
36
+
37
+ Within each experiment directory you may find:
38
+ ```
39
+ .
40
+ ├── config.yaml
41
+ ├── metrics
42
+ │   └── logs.json.txt
43
+ ├── train_0
44
+ │   ├── checkpoints
45
+ │   │   ├── epoch=0300-test_mean_score=1.000.ckpt
46
+ │   │   └── latest.ckpt
47
+ │   └── logs.json.txt
48
+ ├── train_1
49
+ │   ├── checkpoints
50
+ │   │   ├── epoch=0250-test_mean_score=1.000.ckpt
51
+ │   │   └── latest.ckpt
52
+ │   └── logs.json.txt
53
+ └── train_2
54
+ ├── checkpoints
55
+ │   ├── epoch=0250-test_mean_score=1.000.ckpt
56
+ │   └── latest.ckpt
57
+ └── logs.json.txt
58
+ ```
59
+ The `metrics/logs.json.txt` file aggregates evaluation metrics from all 3 training runs every 50 epochs using `multirun_metrics.py`. The numbers reported in the paper correspond to `max` and `k_min_train_loss` aggregation keys.
60
+
61
+ To download all files in a subdirectory, use:
62
+
63
+ ```console
64
+ $ wget --recursive --no-parent --no-host-directories --relative --reject="index.html*" https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/square_ph/diffusion_policy_cnn/
65
+ ```
66
+
67
+ ## 🛠️ Installation
68
+ ### 🖥️ Simulation
69
+ To reproduce our simulation benchmark results, install our conda environment on a Linux machine with Nvidia GPU. On Ubuntu 20.04 you need to install the following apt packages for mujoco:
70
+ ```console
71
+ $ sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
72
+ ```
73
+
74
+ We recommend [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) instead of the standard anaconda distribution for faster installation:
75
+ ```console
76
+ $ mamba env create -f conda_environment.yaml
77
+ ```
78
+
79
+ but you can use conda as well:
80
+ ```console
81
+ $ conda env create -f conda_environment.yaml
82
+ ```
83
+
84
+ The `conda_environment_macos.yaml` file is only for development on MacOS and does not have full support for benchmarks.
85
+
86
+ ### 🦾 Real Robot
87
+ Hardware (for Push-T):
88
+ * 1x [UR5-CB3](https://www.universal-robots.com/cb3) or [UR5e](https://www.universal-robots.com/products/ur5-robot/) ([RTDE Interface](https://www.universal-robots.com/articles/ur/interface-communication/real-time-data-exchange-rtde-guide/) is required)
89
+ * 2x [RealSense D415](https://www.intelrealsense.com/depth-camera-d415/)
90
+ * 1x [3Dconnexion SpaceMouse](https://3dconnexion.com/us/product/spacemouse-wireless/) (for teleop)
91
+ * 1x [Millibar Robotics Manual Tool Changer](https://www.millibar.com/manual-tool-changer/) (only need robot side)
92
+ * 1x 3D printed [End effector](https://cad.onshape.com/documents/a818888644a15afa6cc68ee5/w/2885b48b018cda84f425beca/e/3e8771c2124cee024edd2fed?renderMode=0&uiState=63ffcba6631ca919895e64e5)
93
+ * 1x 3D printed [T-block](https://cad.onshape.com/documents/f1140134e38f6ed6902648d5/w/a78cf81827600e4ff4058d03/e/f35f57fb7589f72e05c76caf?renderMode=0&uiState=63ffcbc9af4a881b344898ee)
94
+ * USB-C cables and screws for RealSense
95
+
96
+ Software:
97
+ * Ubuntu 20.04.3 (tested)
98
+ * Mujoco dependencies:
99
+ `sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf`
100
+ * [RealSense SDK](https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md)
101
+ * Spacemouse dependencies:
102
+ `sudo apt install libspnav-dev spacenavd; sudo systemctl start spacenavd`
103
+ * Conda environment `mamba env create -f conda_environment_real.yaml`
104
+
105
+ ## 🖥️ Reproducing Simulation Benchmark Results
106
+ ### Download Training Data
107
+ Under the repo root, create data subdirectory:
108
+ ```console
109
+ [diffusion_policy]$ mkdir data && cd data
110
+ ```
111
+
112
+ Download the corresponding zip file from [https://diffusion-policy.cs.columbia.edu/data/training/](https://diffusion-policy.cs.columbia.edu/data/training/)
113
+ ```console
114
+ [data]$ wget https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip
115
+ ```
116
+
117
+ Extract training data:
118
+ ```console
119
+ [data]$ unzip pusht.zip && rm -f pusht.zip && cd ..
120
+ ```
121
+
122
+ Grab config file for the corresponding experiment:
123
+ ```console
124
+ [diffusion_policy]$ wget -O image_pusht_diffusion_policy_cnn.yaml https://diffusion-policy.cs.columbia.edu/data/experiments/image/pusht/diffusion_policy_cnn/config.yaml
125
+ ```
126
+
127
+ ### Running for a single seed
128
+ Activate conda environment and login to [wandb](https://wandb.ai) (if you haven't already).
129
+ ```console
130
+ [diffusion_policy]$ conda activate robodiff
131
+ (robodiff)[diffusion_policy]$ wandb login
132
+ ```
133
+
134
+ Launch training with seed 42 on GPU 0.
135
+ ```console
136
+ (robodiff)[diffusion_policy]$ python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml training.seed=42 training.device=cuda:0 hydra.run.dir='data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}'
137
+ ```
138
+
139
+ This will create a directory in format `data/outputs/yyyy.mm.dd/hh.mm.ss_<method_name>_<task_name>` where configs, logs and checkpoints are written to. The policy will be evaluated every 50 epochs with the success rate logged as `test/mean_score` on wandb, as well as videos for some rollouts.
140
+ ```console
141
+ (robodiff)[diffusion_policy]$ tree data/outputs/2023.03.01/20.02.03_train_diffusion_unet_hybrid_pusht_image -I wandb
142
+ data/outputs/2023.03.01/20.02.03_train_diffusion_unet_hybrid_pusht_image
143
+ ├── checkpoints
144
+ │ ├── epoch=0000-test_mean_score=0.134.ckpt
145
+ │ └── latest.ckpt
146
+ ├── .hydra
147
+ │ ├── config.yaml
148
+ │ ├── hydra.yaml
149
+ │ └── overrides.yaml
150
+ ├── logs.json.txt
151
+ ├── media
152
+ │ ├── 2k5u6wli.mp4
153
+ │ ├── 2kvovxms.mp4
154
+ │ ├── 2pxd9f6b.mp4
155
+ │ ├── 2q5gjt5f.mp4
156
+ │ ├── 2sawbf6m.mp4
157
+ │ └── 538ubl79.mp4
158
+ └── train.log
159
+
160
+ 3 directories, 13 files
161
+ ```
162
+
163
+ ### Running for multiple seeds
164
+ Launch local ray cluster. For large scale experiments, you might want to setup an [AWS cluster with autoscaling](https://docs.ray.io/en/master/cluster/vms/user-guides/launching-clusters/aws.html). All other commands remain the same.
165
+ ```console
166
+ (robodiff)[diffusion_policy]$ export CUDA_VISIBLE_DEVICES=0,1,2 # select GPUs to be managed by the ray cluster
167
+ (robodiff)[diffusion_policy]$ ray start --head --num-gpus=3
168
+ ```
169
+
170
+ Launch a ray client which will start 3 training workers (3 seeds) and 1 metrics monitor worker.
171
+ ```console
172
+ (robodiff)[diffusion_policy]$ python ray_train_multirun.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml --seeds=42,43,44 --monitor_key=test/mean_score -- multi_run.run_dir='data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' multi_run.wandb_name_base='${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}'
173
+ ```
174
+
175
+ In addition to the wandb log written by each training worker individually, the metrics monitor worker will log to wandb project `diffusion_policy_metrics` for the metrics aggregated from all 3 training runs. Local config, logs and checkpoints will be written to `data/outputs/yyyy.mm.dd/hh.mm.ss_<method_name>_<task_name>` in a directory structure identical to our [training logs](https://diffusion-policy.cs.columbia.edu/data/experiments/):
176
+ ```console
177
+ (robodiff)[diffusion_policy]$ tree data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image -I 'wandb|media'
178
+ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
179
+ ├── config.yaml
180
+ ├── metrics
181
+ │ ├── logs.json.txt
182
+ │ ├── metrics.json
183
+ │ └── metrics.log
184
+ ├── train_0
185
+ │ ├── checkpoints
186
+ │ │ ├── epoch=0000-test_mean_score=0.174.ckpt
187
+ │ │ └── latest.ckpt
188
+ │ ├── logs.json.txt
189
+ │ └── train.log
190
+ ├── train_1
191
+ │ ├── checkpoints
192
+ │ │ ├── epoch=0000-test_mean_score=0.131.ckpt
193
+ │ │ └── latest.ckpt
194
+ │ ├── logs.json.txt
195
+ │ └── train.log
196
+ └── train_2
197
+ ├── checkpoints
198
+ │ ├── epoch=0000-test_mean_score=0.105.ckpt
199
+ │ └── latest.ckpt
200
+ ├── logs.json.txt
201
+ └── train.log
202
+
203
+ 7 directories, 16 files
204
+ ```
205
+ ### 🆕 Evaluate Pre-trained Checkpoints
206
+ Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
207
+
208
+ Run the evaluation script:
209
+ ```console
210
+ (robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
211
+ ```
212
+
213
+ This will generate the following directory structure:
214
+ ```console
215
+ (robodiff)[diffusion_policy]$ tree data/pusht_eval_output
216
+ data/pusht_eval_output
217
+ ├── eval_log.json
218
+ └── media
219
+ ├── 1fxtno84.mp4
220
+ ├── 224l7jqd.mp4
221
+ ├── 2fo4btlf.mp4
222
+ ├── 2in4cn7a.mp4
223
+ ├── 34b3o2qq.mp4
224
+ └── 3p7jqn32.mp4
225
+
226
+ 1 directory, 7 files
227
+ ```
228
+
229
+ `eval_log.json` contains metrics that is logged to wandb during training:
230
+ ```console
231
+ (robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
232
+ {
233
+ "test/mean_score": 0.9150393806777066,
234
+ "test/sim_max_reward_4300000": 1.0,
235
+ "test/sim_max_reward_4300001": 0.9872969750774386,
236
+ ...
237
+ "train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
238
+ }
239
+ ```
240
+
241
+ ## 🦾 Demo, Training and Eval on a Real Robot
242
+ Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
243
+
244
+ Start the demonstration collection script. Press "C" to start recording. Use SpaceMouse to move the robot. Press "S" to stop recording.
245
+ ```console
246
+ (robodiff)[diffusion_policy]$ python demo_real_robot.py -o data/demo_pusht_real --robot_ip 192.168.0.204
247
+ ```
248
+
249
+ This should result in a demonstration dataset in `data/demo_pusht_real` with in the same structure as our example [real Push-T training dataset](https://diffusion-policy.cs.columbia.edu/data/training/pusht_real.zip).
250
+
251
+ To train a Diffusion Policy, launch training with config:
252
+ ```console
253
+ (robodiff)[diffusion_policy]$ python train.py --config-name=train_diffusion_unet_real_image_workspace task.dataset_path=data/demo_pusht_real
254
+ ```
255
+ Edit [`diffusion_policy/config/task/real_pusht_image.yaml`](./diffusion_policy/config/task/real_pusht_image.yaml) if your camera setup is different.
256
+
257
+ Assuming the training has finished and you have a checkpoint at `data/outputs/blah/checkpoints/latest.ckpt`, launch the evaluation script with:
258
+ ```console
259
+ python eval_real_robot.py -i data/outputs/blah/checkpoints/latest.ckpt -o data/eval_pusht_real --robot_ip 192.168.0.204
260
+ ```
261
+ Press "C" to start evaluation (handing control over to the policy). Press "S" to stop the current episode.
262
+
263
+ ## 🗺️ Codebase Tutorial
264
+ This codebase is structured under the requirement that:
265
+ 1. implementing `N` tasks and `M` methods will only require `O(N+M)` amount of code instead of `O(N*M)`
266
+ 2. while retaining maximum flexibility.
267
+
268
+ To achieve this requirement, we
269
+ 1. maintained a simple unified interface between tasks and methods and
270
+ 2. made the implementation of the tasks and the methods independent of each other.
271
+
272
+ These design decisions come at the cost of code repetition between the tasks and the methods. However, we believe that the benefit of being able to add/modify task/methods without affecting the remainder and being able understand a task/method by reading the code linearly outweighs the cost of copying and pasting 😊.
273
+
274
+ ### The Split
275
+ On the task side, we have:
276
+ * `Dataset`: adapts a (third-party) dataset to the interface.
277
+ * `EnvRunner`: executes a `Policy` that accepts the interface and produce logs and metrics.
278
+ * `config/task/<task_name>.yaml`: contains all information needed to construct `Dataset` and `EnvRunner`.
279
+ * (optional) `Env`: an `gym==0.21.0` compatible class that encapsulates the task environment.
280
+
281
+ On the policy side, we have:
282
+ * `Policy`: implements inference according to the interface and part of the training process.
283
+ * `Workspace`: manages the life-cycle of training and evaluation (interleaved) of a method.
284
+ * `config/<workspace_name>.yaml`: contains all information needed to construct `Policy` and `Workspace`.
285
+
286
+ ### The Interface
287
+ #### Low Dim
288
+ A [`LowdimPolicy`](./diffusion_policy/policy/base_lowdim_policy.py) takes observation dictionary:
289
+ - `"obs":` Tensor of shape `(B,To,Do)`
290
+
291
+ and predicts action dictionary:
292
+ - `"action": ` Tensor of shape `(B,Ta,Da)`
293
+
294
+ A [`LowdimDataset`](./diffusion_policy/dataset/base_dataset.py) returns a sample of dictionary:
295
+ - `"obs":` Tensor of shape `(To, Do)`
296
+ - `"action":` Tensor of shape `(Ta, Da)`
297
+
298
+ Its `get_normalizer` method returns a [`LinearNormalizer`](./diffusion_policy/model/common/normalizer.py) with keys `"obs","action"`.
299
+
300
+ The `Policy` handles normalization on GPU with its copy of the `LinearNormalizer`. The parameters of the `LinearNormalizer` is saved as part of the `Policy`'s weights checkpoint.
301
+
302
+ #### Image
303
+ A [`ImagePolicy`](./diffusion_policy/policy/base_image_policy.py) takes observation dictionary:
304
+ - `"key0":` Tensor of shape `(B,To,*)`
305
+ - `"key1":` Tensor of shape e.g. `(B,To,H,W,3)` ([0,1] float32)
306
+
307
+ and predicts action dictionary:
308
+ - `"action": ` Tensor of shape `(B,Ta,Da)`
309
+
310
+ A [`ImageDataset`](./diffusion_policy/dataset/base_dataset.py) returns a sample of dictionary:
311
+ - `"obs":` Dict of
312
+ - `"key0":` Tensor of shape `(To, *)`
313
+ - `"key1":` Tensor fo shape `(To,H,W,3)`
314
+ - `"action":` Tensor of shape `(Ta, Da)`
315
+
316
+ Its `get_normalizer` method returns a [`LinearNormalizer`](./diffusion_policy/model/common/normalizer.py) with keys `"key0","key1","action"`.
317
+
318
+ #### Example
319
+ ```
320
+ To = 3
321
+ Ta = 4
322
+ T = 6
323
+ |o|o|o|
324
+ | | |a|a|a|a|
325
+ |o|o|
326
+ | |a|a|a|a|a|
327
+ | | | | |a|a|
328
+ ```
329
+ Terminology in the paper: `varname` in the codebase
330
+ - Observation Horizon: `To|n_obs_steps`
331
+ - Action Horizon: `Ta|n_action_steps`
332
+ - Prediction Horizon: `T|horizon`
333
+
334
+ The classical (e.g. MDP) single step observation/action formulation is included as a special case where `To=1` and `Ta=1`.
335
+
336
+ ## 🔩 Key Components
337
+ ### `Workspace`
338
+ A `Workspace` object encapsulates all states and code needed to run an experiment.
339
+ * Inherits from [`BaseWorkspace`](./diffusion_policy/workspace/base_workspace.py).
340
+ * A single `OmegaConf` config object generated by `hydra` should contain all information needed to construct the Workspace object and running experiments. This config correspond to `config/<workspace_name>.yaml` + hydra overrides.
341
+ * The `run` method contains the entire pipeline for the experiment.
342
+ * Checkpoints happen at the `Workspace` level. All training states implemented as object attributes are automatically saved by the `save_checkpoint` method.
343
+ * All other states for the experiment should be implemented as local variables in the `run` method.
344
+
345
+ The entrypoint for training is `train.py` which uses `@hydra.main` decorator. Read [hydra](https://hydra.cc/)'s official documentation for command line arguments and config overrides. For example, the argument `task=<task_name>` will replace the `task` subtree of the config with the content of `config/task/<task_name>.yaml`, thereby selecting the task to run for this experiment.
346
+
347
+ ### `Dataset`
348
+ A `Dataset` object:
349
+ * Inherits from `torch.utils.data.Dataset`.
350
+ * Returns a sample conforming to [the interface](#the-interface) depending on whether the task has Low Dim or Image observations.
351
+ * Has a method `get_normalizer` that returns a `LinearNormalizer` conforming to [the interface](#the-interface).
352
+
353
+ Normalization is a very common source of bugs during project development. It is sometimes helpful to print out the specific `scale` and `bias` vectors used for each key in the `LinearNormalizer`.
354
+
355
+ Most of our implementations of `Dataset` uses a combination of [`ReplayBuffer`](#replaybuffer) and [`SequenceSampler`](./diffusion_policy/common/sampler.py) to generate samples. Correctly handling padding at the beginning and the end of each demonstration episode according to `To` and `Ta` is important for good performance. Please read our [`SequenceSampler`](./diffusion_policy/common/sampler.py) before implementing your own sampling method.
356
+
357
+ ### `Policy`
358
+ A `Policy` object:
359
+ * Inherits from `BaseLowdimPolicy` or `BaseImagePolicy`.
360
+ * Has a method `predict_action` that given observation dict, predicts actions conforming to [the interface](#the-interface).
361
+ * Has a method `set_normalizer` that takes in a `LinearNormalizer` and handles observation/action normalization internally in the policy.
362
+ * (optional) Might has a method `compute_loss` that takes in a batch and returns the loss to be optimized.
363
+ * (optional) Usually each `Policy` class correspond to a `Workspace` class due to the differences of training and evaluation process between methods.
364
+
365
+ ### `EnvRunner`
366
+ A `EnvRunner` object abstracts away the subtle differences between different task environments.
367
+ * Has a method `run` that takes a `Policy` object for evaluation, and returns a dict of logs and metrics. Each value should be compatible with `wandb.log`.
368
+
369
+ To maximize evaluation speed, we usually vectorize environments using our modification of [`gym.vector.AsyncVectorEnv`](./diffusion_policy/gym_util/async_vector_env.py) which runs each individual environment in a separate process (workaround python GIL).
370
+
371
+ ⚠️ Since subprocesses are launched using `fork` on linux, you need to be specially careful for environments that creates its OpenGL context during initialization (e.g. robosuite) which, once inherited by the child process memory space, often causes obscure bugs like segmentation fault. As a workaround, you can provide a `dummy_env_fn` that constructs an environment without initializing OpenGL.
372
+
373
+ ### `ReplayBuffer`
374
+ The [`ReplayBuffer`](./diffusion_policy/common/replay_buffer.py) is a key data structure for storing a demonstration dataset both in-memory and on-disk with chunking and compression. It makes heavy use of the [`zarr`](https://zarr.readthedocs.io/en/stable/index.html) format but also has a `numpy` backend for lower access overhead.
375
+
376
+ On disk, it can be stored as a nested directory (e.g. `data/pusht_cchi_v7_replay.zarr`) or a zip file (e.g. `data/robomimic/datasets/square/mh/image_abs.hdf5.zarr.zip`).
377
+
378
+ Due to the relative small size of our datasets, it's often possible to store the entire image-based dataset in RAM with [`Jpeg2000` compression](./diffusion_policy/codecs/imagecodecs_numcodecs.py) which eliminates disk IO during training at the expense increasing of CPU workload.
379
+
380
+ Example:
381
+ ```
382
+ data/pusht_cchi_v7_replay.zarr
383
+ ├── data
384
+ │ ├── action (25650, 2) float32
385
+ │ ├── img (25650, 96, 96, 3) float32
386
+ │ ├── keypoint (25650, 9, 2) float32
387
+ │ ├── n_contacts (25650, 1) float32
388
+ │ └── state (25650, 5) float32
389
+ └── meta
390
+ └── episode_ends (206,) int64
391
+ ```
392
+
393
+ Each array in `data` stores one data field from all episodes concatenated along the first dimension (time). The `meta/episode_ends` array stores the end index for each episode along the fist dimension.
394
+
395
+ ### `SharedMemoryRingBuffer`
396
+ The [`SharedMemoryRingBuffer`](./diffusion_policy/shared_memory/shared_memory_ring_buffer.py) is a lock-free FILO data structure used extensively in our [real robot implementation](./diffusion_policy/real_world) to utilize multiple CPU cores while avoiding pickle serialization and locking overhead for `multiprocessing.Queue`.
397
+
398
+ As an example, we would like to get the most recent `To` frames from 5 RealSense cameras. We launch 1 realsense SDK/pipeline per process using [`SingleRealsense`](./diffusion_policy/real_world/single_realsense.py), each continuously writes the captured images into a `SharedMemoryRingBuffer` shared with the main process. We can very quickly get the last `To` frames in the main process due to the FILO nature of `SharedMemoryRingBuffer`.
399
+
400
+ We also implemented [`SharedMemoryQueue`](./diffusion_policy/shared_memory/shared_memory_queue.py) for FIFO, which is used in [`RTDEInterpolationController`](./diffusion_policy/real_world/rtde_interpolation_controller.py).
401
+
402
+ ### `RealEnv`
403
+ In contrast to [OpenAI Gym](https://gymnasium.farama.org/), our polices interact with the environment asynchronously. In [`RealEnv`](./diffusion_policy/real_world/real_env.py), the `step` method in `gym` is split into two methods: `get_obs` and `exec_actions`.
404
+
405
+ The `get_obs` method returns the latest observation from `SharedMemoryRingBuffer` as well as their corresponding timestamps. This method can be call at any time during an evaluation episode.
406
+
407
+ The `exec_actions` method accepts a sequence of actions and timestamps for the expected time of execution for each step. Once called, the actions are simply enqueued to the `RTDEInterpolationController`, and the method returns without blocking for execution.
408
+
409
+ ## 🩹 Adding a Task
410
+ Read and imitate:
411
+ * `diffusion_policy/dataset/pusht_image_dataset.py`
412
+ * `diffusion_policy/env_runner/pusht_image_runner.py`
413
+ * `diffusion_policy/config/task/pusht_image.yaml`
414
+
415
+ Make sure that `shape_meta` correspond to input and output shapes for your task. Make sure `env_runner._target_` and `dataset._target_` point to the new classes you have added. When training, add `task=<your_task_name>` to `train.py`'s arguments.
416
+
417
+ ## 🩹 Adding a Method
418
+ Read and imitate:
419
+ * `diffusion_policy/workspace/train_diffusion_unet_image_workspace.py`
420
+ * `diffusion_policy/policy/diffusion_unet_image_policy.py`
421
+ * `diffusion_policy/config/train_diffusion_unet_image_workspace.yaml`
422
+
423
+ Make sure your workspace yaml's `_target_` points to the new workspace class you created.
424
+
425
+ ## 🏷️ License
426
+ This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
427
+
428
+ ## 🙏 Acknowledgement
429
+ * Our [`ConditionalUnet1D`](./diffusion_policy/model/diffusion/conditional_unet1d.py) implementation is adapted from [Planning with Diffusion](https://github.com/jannerm/diffuser).
430
+ * Our [`TransformerForDiffusion`](./diffusion_policy/model/diffusion/transformer_for_diffusion.py) implementation is adapted from [MinGPT](https://github.com/karpathy/minGPT).
431
+ * The [BET](./diffusion_policy/model/bet) baseline is adapted from [its original repo](https://github.com/notmahi/bet).
432
+ * The [IBC](./diffusion_policy/policy/ibc_dfo_lowdim_policy.py) baseline is adapted from [Kevin Zakka's reimplementation](https://github.com/kevinzakka/ibc).
433
+ * The [Robomimic](https://github.com/ARISE-Initiative/robomimic) tasks and [`ObservationEncoder`](https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/models/obs_nets.py) are used extensively in this project.
434
+ * The [Push-T](./diffusion_policy/env/pusht) task is adapted from [IBC](https://github.com/google-research/ibc).
435
+ * The [Block Pushing](./diffusion_policy/env/block_pushing) task is adapted from [BET](https://github.com/notmahi/bet) and [IBC](https://github.com/google-research/ibc).
436
+ * The [Kitchen](./diffusion_policy/env/kitchen) task is adapted from [BET](https://github.com/notmahi/bet) and [Relay Policy Learning](https://github.com/google-research/relay-policy-learning).
437
+ * Our [shared_memory](./diffusion_policy/shared_memory) data structures are heavily inspired by [shared-ndarray2](https://gitlab.com/osu-nrsg/shared-ndarray2).
third_party/diffusion_policy/conda_environment.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: robodiff
2
+ channels:
3
+ - pytorch
4
+ - pytorch3d
5
+ - nvidia
6
+ - conda-forge
7
+ dependencies:
8
+ - python=3.9
9
+ - pip=22.2.2
10
+ - cudatoolkit=11.6
11
+ - pytorch=1.12.1
12
+ - torchvision=0.13.1
13
+ - pytorch3d=0.7.0
14
+ - numpy=1.23.3
15
+ - numba==0.56.4
16
+ - scipy==1.9.1
17
+ - py-opencv=4.6.0
18
+ - cffi=1.15.1
19
+ - ipykernel=6.16
20
+ - matplotlib=3.6.1
21
+ - zarr=2.12.0
22
+ - numcodecs=0.10.2
23
+ - h5py=3.7.0
24
+ - hydra-core=1.2.0
25
+ - einops=0.4.1
26
+ - tqdm=4.64.1
27
+ - dill=0.3.5.1
28
+ - scikit-video=1.1.11
29
+ - scikit-image=0.19.3
30
+ - gym=0.21.0
31
+ - pymunk=6.2.1
32
+ - wandb=0.13.3
33
+ - threadpoolctl=3.1.0
34
+ - shapely=1.8.4
35
+ - cython=0.29.32
36
+ - imageio=2.22.0
37
+ - imageio-ffmpeg=0.4.7
38
+ - termcolor=2.0.1
39
+ - tensorboard=2.10.1
40
+ - tensorboardx=2.5.1
41
+ - psutil=5.9.2
42
+ - click=8.0.4
43
+ - boto3=1.24.96
44
+ - accelerate=0.13.2
45
+ - datasets=2.6.1
46
+ - diffusers=0.11.1
47
+ - av=10.0.0
48
+ - cmake=3.24.3
49
+ # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
50
+ - llvm-openmp=14
51
+ # trick to force reinstall imagecodecs via pip
52
+ - imagecodecs==2022.8.8
53
+ - pip:
54
+ - ray[default,tune]==2.2.0
55
+ # requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
56
+ - free-mujoco-py==2.1.6
57
+ - pygame==2.1.2
58
+ - pybullet-svl==3.1.6.4
59
+ - robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz
60
+ - robomimic==0.2.0
61
+ - pytorchvideo==0.1.5
62
+ # pip package required for jpeg-xl
63
+ - imagecodecs==2022.9.26
64
+ - r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
65
+ - dm-control==1.0.9
third_party/diffusion_policy/conda_environment_macos.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: robodiff
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ dependencies:
6
+ - python=3.9
7
+ - pip=22.2.2
8
+ - pytorch=1.12.1
9
+ - torchvision=0.13.1
10
+ - numpy=1.23.3
11
+ - numba==0.56.4
12
+ - scipy==1.9.1
13
+ - py-opencv=4.6.0
14
+ - cffi=1.15.1
15
+ - ipykernel=6.16
16
+ - matplotlib=3.6.1
17
+ - zarr=2.12.0
18
+ - numcodecs=0.10.2
19
+ - h5py=3.7.0
20
+ - hydra-core=1.2.0
21
+ - einops=0.4.1
22
+ - tqdm=4.64.1
23
+ - dill=0.3.5.1
24
+ - scikit-video=1.1.11
25
+ - scikit-image=0.19.3
26
+ - gym=0.21.0
27
+ - pymunk=6.2.1
28
+ - wandb=0.13.3
29
+ - threadpoolctl=3.1.0
30
+ - shapely=1.8.4
31
+ - cython=0.29.32
32
+ - imageio=2.22.0
33
+ - imageio-ffmpeg=0.4.7
34
+ - termcolor=2.0.1
35
+ - tensorboard=2.10.1
36
+ - tensorboardx=2.5.1
37
+ - psutil=5.9.2
38
+ - click=8.0.4
39
+ - boto3=1.24.96
40
+ - accelerate=0.13.2
41
+ - datasets=2.6.1
42
+ - diffusers=0.11.1
43
+ - av=10.0.0
44
+ - cmake=3.24.3
45
+ # trick to force reinstall imagecodecs via pip
46
+ - imagecodecs==2022.8.8
47
+ - pip:
48
+ - ray[default,tune]==2.2.0
49
+ - pygame==2.1.2
50
+ - robomimic==0.2.0
51
+ - pytorchvideo==0.1.5
52
+ - atomics==1.0.2
53
+ # No support for jpeg-xl for MacOS
54
+ - imagecodecs==2022.9.26
55
+ - r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
third_party/diffusion_policy/conda_environment_real.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: robodiff
2
+ channels:
3
+ - pytorch
4
+ - pytorch3d
5
+ - nvidia
6
+ - conda-forge
7
+ dependencies:
8
+ - python=3.9
9
+ - pip=22.2.2
10
+ - cudatoolkit=11.6
11
+ - pytorch=1.12.1
12
+ - torchvision=0.13.1
13
+ - pytorch3d=0.7.0
14
+ - numpy=1.23.3
15
+ - numba==0.56.4
16
+ - scipy==1.9.1
17
+ - py-opencv=4.6.0
18
+ - cffi=1.15.1
19
+ - ipykernel=6.16
20
+ - matplotlib=3.6.1
21
+ - zarr=2.12.0
22
+ - numcodecs=0.10.2
23
+ - h5py=3.7.0
24
+ - hydra-core=1.2.0
25
+ - einops=0.4.1
26
+ - tqdm=4.64.1
27
+ - dill=0.3.5.1
28
+ - scikit-video=1.1.11
29
+ - scikit-image=0.19.3
30
+ - gym=0.21.0
31
+ - pymunk=6.2.1
32
+ - wandb=0.13.3
33
+ - threadpoolctl=3.1.0
34
+ - shapely=1.8.4
35
+ - cython=0.29.32
36
+ - imageio=2.22.0
37
+ - imageio-ffmpeg=0.4.7
38
+ - termcolor=2.0.1
39
+ - tensorboard=2.10.1
40
+ - tensorboardx=2.5.1
41
+ - psutil=5.9.2
42
+ - click=8.0.4
43
+ - boto3=1.24.96
44
+ - accelerate=0.13.2
45
+ - datasets=2.6.1
46
+ - diffusers=0.11.1
47
+ - av=10.0.0
48
+ - cmake=3.24.3
49
+ # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
50
+ - llvm-openmp=14
51
+ # trick to force reinstall imagecodecs via pip
52
+ - imagecodecs==2022.8.8
53
+ - pip:
54
+ - ray[default,tune]==2.2.0
55
+ # requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
56
+ - free-mujoco-py==2.1.6
57
+ - pygame==2.1.2
58
+ - pybullet-svl==3.1.6.4
59
+ - robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz
60
+ - robomimic==0.2.0
61
+ - pytorchvideo==0.1.5
62
+ # requires librealsense https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md
63
+ - pyrealsense2==2.51.1.4348
64
+ # requires apt install libspnav-dev spacenavd; systemctl start spacenavd
65
+ - spnav @ https://github.com/cheng-chi/spnav/archive/c1c938ebe3cc542db4685e0d13850ff1abfdb943.tar.gz
66
+ - ur-rtde==1.5.5
67
+ - atomics==1.0.2
68
+ # pip package required for jpeg-xl
69
+ - imagecodecs==2022.9.26
70
+ - r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
71
+ - dm-control==1.0.9
72
+ - pynput==1.7.6
73
+
third_party/diffusion_policy/demo_pusht.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import click
3
+ from diffusion_policy.common.replay_buffer import ReplayBuffer
4
+ from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
5
+ import pygame
6
+
7
+ @click.command()
8
+ @click.option('-o', '--output', required=True)
9
+ @click.option('-rs', '--render_size', default=96, type=int)
10
+ @click.option('-hz', '--control_hz', default=10, type=int)
11
+ def main(output, render_size, control_hz):
12
+ """
13
+ Collect demonstration for the Push-T task.
14
+
15
+ Usage: python demo_pusht.py -o data/pusht_demo.zarr
16
+
17
+ This script is compatible with both Linux and MacOS.
18
+ Hover mouse close to the blue circle to start.
19
+ Push the T block into the green area.
20
+ The episode will automatically terminate if the task is succeeded.
21
+ Press "Q" to exit.
22
+ Press "R" to retry.
23
+ Hold "Space" to pause.
24
+ """
25
+
26
+ # create replay buffer in read-write mode
27
+ replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
28
+
29
+ # create PushT env with keypoints
30
+ kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
31
+ env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs)
32
+ agent = env.teleop_agent()
33
+ clock = pygame.time.Clock()
34
+
35
+ # episode-level while loop
36
+ while True:
37
+ episode = list()
38
+ # record in seed order, starting with 0
39
+ seed = replay_buffer.n_episodes
40
+ print(f'starting seed {seed}')
41
+
42
+ # set seed for env
43
+ env.seed(seed)
44
+
45
+ # reset env and get observations (including info and render for recording)
46
+ obs = env.reset()
47
+ info = env._get_info()
48
+ img = env.render(mode='human')
49
+
50
+ # loop state
51
+ retry = False
52
+ pause = False
53
+ done = False
54
+ plan_idx = 0
55
+ pygame.display.set_caption(f'plan_idx:{plan_idx}')
56
+ # step-level while loop
57
+ while not done:
58
+ # process keypress events
59
+ for event in pygame.event.get():
60
+ if event.type == pygame.KEYDOWN:
61
+ if event.key == pygame.K_SPACE:
62
+ # hold Space to pause
63
+ plan_idx += 1
64
+ pygame.display.set_caption(f'plan_idx:{plan_idx}')
65
+ pause = True
66
+ elif event.key == pygame.K_r:
67
+ # press "R" to retry
68
+ retry=True
69
+ elif event.key == pygame.K_q:
70
+ # press "Q" to exit
71
+ exit(0)
72
+ if event.type == pygame.KEYUP:
73
+ if event.key == pygame.K_SPACE:
74
+ pause = False
75
+
76
+ # handle control flow
77
+ if retry:
78
+ break
79
+ if pause:
80
+ continue
81
+
82
+ # get action from mouse
83
+ # None if mouse is not close to the agent
84
+ act = agent.act(obs)
85
+ if not act is None:
86
+ # teleop started
87
+ # state dim 2+3
88
+ state = np.concatenate([info['pos_agent'], info['block_pose']])
89
+ # discard unused information such as visibility mask and agent pos
90
+ # for compatibility
91
+ keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9]
92
+ data = {
93
+ 'img': img,
94
+ 'state': np.float32(state),
95
+ 'keypoint': np.float32(keypoint),
96
+ 'action': np.float32(act),
97
+ 'n_contacts': np.float32([info['n_contacts']])
98
+ }
99
+ episode.append(data)
100
+
101
+ # step env and render
102
+ obs, reward, done, info = env.step(act)
103
+ img = env.render(mode='human')
104
+
105
+ # regulate control frequency
106
+ clock.tick(control_hz)
107
+ if not retry:
108
+ # save episode buffer to replay buffer (on disk)
109
+ data_dict = dict()
110
+ for key in episode[0].keys():
111
+ data_dict[key] = np.stack(
112
+ [x[key] for x in episode])
113
+ replay_buffer.add_episode(data_dict, compressors='disk')
114
+ print(f'saved seed {seed}')
115
+ else:
116
+ print(f'retry seed {seed}')
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
third_party/diffusion_policy/demo_real_robot.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ (robodiff)$ python demo_real_robot.py -o <demo_save_dir> --robot_ip <ip_of_ur5>
4
+
5
+ Robot movement:
6
+ Move your SpaceMouse to move the robot EEF (locked in xy plane).
7
+ Press SpaceMouse right button to unlock z axis.
8
+ Press SpaceMouse left button to enable rotation axes.
9
+
10
+ Recording control:
11
+ Click the opencv window (make sure it's in focus).
12
+ Press "C" to start recording.
13
+ Press "S" to stop recording.
14
+ Press "Q" to exit program.
15
+ Press "Backspace" to delete the previously recorded episode.
16
+ """
17
+
18
+ # %%
19
+ import time
20
+ from multiprocessing.managers import SharedMemoryManager
21
+ import click
22
+ import cv2
23
+ import numpy as np
24
+ import scipy.spatial.transform as st
25
+ from diffusion_policy.real_world.real_env import RealEnv
26
+ from diffusion_policy.real_world.spacemouse_shared_memory import Spacemouse
27
+ from diffusion_policy.common.precise_sleep import precise_wait
28
+ from diffusion_policy.real_world.keystroke_counter import (
29
+ KeystrokeCounter, Key, KeyCode
30
+ )
31
+
32
+ @click.command()
33
+ @click.option('--output', '-o', required=True, help="Directory to save demonstration dataset.")
34
+ @click.option('--robot_ip', '-ri', required=True, help="UR5's IP address e.g. 192.168.0.204")
35
+ @click.option('--vis_camera_idx', default=0, type=int, help="Which RealSense camera to visualize.")
36
+ @click.option('--init_joints', '-j', is_flag=True, default=False, help="Whether to initialize robot joint configuration in the beginning.")
37
+ @click.option('--frequency', '-f', default=10, type=float, help="Control frequency in Hz.")
38
+ @click.option('--command_latency', '-cl', default=0.01, type=float, help="Latency between receiving SapceMouse command to executing on Robot in Sec.")
39
+ def main(output, robot_ip, vis_camera_idx, init_joints, frequency, command_latency):
40
+ dt = 1/frequency
41
+ with SharedMemoryManager() as shm_manager:
42
+ with KeystrokeCounter() as key_counter, \
43
+ Spacemouse(shm_manager=shm_manager) as sm, \
44
+ RealEnv(
45
+ output_dir=output,
46
+ robot_ip=robot_ip,
47
+ # recording resolution
48
+ obs_image_resolution=(1280,720),
49
+ frequency=frequency,
50
+ init_joints=init_joints,
51
+ enable_multi_cam_vis=True,
52
+ record_raw_video=True,
53
+ # number of threads per camera view for video recording (H.264)
54
+ thread_per_video=3,
55
+ # video recording quality, lower is better (but slower).
56
+ video_crf=21,
57
+ shm_manager=shm_manager
58
+ ) as env:
59
+ cv2.setNumThreads(1)
60
+
61
+ # realsense exposure
62
+ env.realsense.set_exposure(exposure=120, gain=0)
63
+ # realsense white balance
64
+ env.realsense.set_white_balance(white_balance=5900)
65
+
66
+ time.sleep(1.0)
67
+ print('Ready!')
68
+ state = env.get_robot_state()
69
+ target_pose = state['TargetTCPPose']
70
+ t_start = time.monotonic()
71
+ iter_idx = 0
72
+ stop = False
73
+ is_recording = False
74
+ while not stop:
75
+ # calculate timing
76
+ t_cycle_end = t_start + (iter_idx + 1) * dt
77
+ t_sample = t_cycle_end - command_latency
78
+ t_command_target = t_cycle_end + dt
79
+
80
+ # pump obs
81
+ obs = env.get_obs()
82
+
83
+ # handle key presses
84
+ press_events = key_counter.get_press_events()
85
+ for key_stroke in press_events:
86
+ if key_stroke == KeyCode(char='q'):
87
+ # Exit program
88
+ stop = True
89
+ elif key_stroke == KeyCode(char='c'):
90
+ # Start recording
91
+ env.start_episode(t_start + (iter_idx + 2) * dt - time.monotonic() + time.time())
92
+ key_counter.clear()
93
+ is_recording = True
94
+ print('Recording!')
95
+ elif key_stroke == KeyCode(char='s'):
96
+ # Stop recording
97
+ env.end_episode()
98
+ key_counter.clear()
99
+ is_recording = False
100
+ print('Stopped.')
101
+ elif key_stroke == Key.backspace:
102
+ # Delete the most recent recorded episode
103
+ if click.confirm('Are you sure to drop an episode?'):
104
+ env.drop_episode()
105
+ key_counter.clear()
106
+ is_recording = False
107
+ # delete
108
+ stage = key_counter[Key.space]
109
+
110
+ # visualize
111
+ vis_img = obs[f'camera_{vis_camera_idx}'][-1,:,:,::-1].copy()
112
+ episode_id = env.replay_buffer.n_episodes
113
+ text = f'Episode: {episode_id}, Stage: {stage}'
114
+ if is_recording:
115
+ text += ', Recording!'
116
+ cv2.putText(
117
+ vis_img,
118
+ text,
119
+ (10,30),
120
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
121
+ fontScale=1,
122
+ thickness=2,
123
+ color=(255,255,255)
124
+ )
125
+
126
+ cv2.imshow('default', vis_img)
127
+ cv2.pollKey()
128
+
129
+ precise_wait(t_sample)
130
+ # get teleop command
131
+ sm_state = sm.get_motion_state_transformed()
132
+ # print(sm_state)
133
+ dpos = sm_state[:3] * (env.max_pos_speed / frequency)
134
+ drot_xyz = sm_state[3:] * (env.max_rot_speed / frequency)
135
+
136
+ if not sm.is_button_pressed(0):
137
+ # translation mode
138
+ drot_xyz[:] = 0
139
+ else:
140
+ dpos[:] = 0
141
+ if not sm.is_button_pressed(1):
142
+ # 2D translation mode
143
+ dpos[2] = 0
144
+
145
+ drot = st.Rotation.from_euler('xyz', drot_xyz)
146
+ target_pose[:3] += dpos
147
+ target_pose[3:] = (drot * st.Rotation.from_rotvec(
148
+ target_pose[3:])).as_rotvec()
149
+
150
+ # execute teleop command
151
+ env.exec_actions(
152
+ actions=[target_pose],
153
+ timestamps=[t_command_target-time.monotonic()+time.time()],
154
+ stages=[stage])
155
+ precise_wait(t_cycle_end)
156
+ iter_idx += 1
157
+
158
+ # %%
159
+ if __name__ == '__main__':
160
+ main()
third_party/diffusion_policy/diffusion_policy.egg-info/PKG-INFO ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: diffusion_policy
3
+ Version: 0.0.0
4
+ License-File: LICENSE
5
+ Dynamic: license-file
third_party/diffusion_policy/diffusion_policy.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ diffusion_policy.egg-info/PKG-INFO
5
+ diffusion_policy.egg-info/SOURCES.txt
6
+ diffusion_policy.egg-info/dependency_links.txt
7
+ diffusion_policy.egg-info/top_level.txt
8
+ tests/test_block_pushing.py
9
+ tests/test_cv2_util.py
10
+ tests/test_multi_realsense.py
11
+ tests/test_pose_trajectory_interpolator.py
12
+ tests/test_precise_sleep.py
13
+ tests/test_replay_buffer.py
14
+ tests/test_ring_buffer.py
15
+ tests/test_robomimic_image_runner.py
16
+ tests/test_robomimic_lowdim_runner.py
17
+ tests/test_shared_queue.py
18
+ tests/test_single_realsense.py
19
+ tests/test_timestamp_accumulator.py
third_party/diffusion_policy/diffusion_policy.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
third_party/diffusion_policy/diffusion_policy.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
third_party/diffusion_policy/diffusion_policy/workspace/train_diffusion_unet_lowdim_workspace.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+ os.chdir(ROOT_DIR)
9
+
10
+ import os
11
+ import hydra
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ import pathlib
15
+ from torch.utils.data import DataLoader
16
+ import copy
17
+ import numpy as np
18
+ import random
19
+ import wandb
20
+ import tqdm
21
+ import shutil
22
+
23
+ from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
24
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
25
+ from diffusion_policy.policy.diffusion_unet_lowdim_policy import DiffusionUnetLowdimPolicy
26
+ from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
27
+ from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
28
+ from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
29
+ from diffusion_policy.common.json_logger import JsonLogger
30
+ from diffusion_policy.model.common.lr_scheduler import get_scheduler
31
+ from diffusers.training_utils import EMAModel
32
+
33
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
34
+
35
+ # %%
36
+ class TrainDiffusionUnetLowdimWorkspace(BaseWorkspace):
37
+ include_keys = ['global_step', 'epoch']
38
+
39
+ def __init__(self, cfg: OmegaConf, output_dir=None):
40
+ super().__init__(cfg, output_dir=output_dir)
41
+
42
+ # set seed
43
+ seed = cfg.training.seed
44
+ torch.manual_seed(seed)
45
+ np.random.seed(seed)
46
+ random.seed(seed)
47
+
48
+ # configure model
49
+ self.model: DiffusionUnetLowdimPolicy
50
+ self.model = hydra.utils.instantiate(cfg.policy)
51
+
52
+ self.ema_model: DiffusionUnetLowdimPolicy = None
53
+ if cfg.training.use_ema:
54
+ self.ema_model = copy.deepcopy(self.model)
55
+
56
+ # configure training state
57
+ self.optimizer = hydra.utils.instantiate(
58
+ cfg.optimizer, params=self.model.parameters())
59
+
60
+ self.global_step = 0
61
+ self.epoch = 0
62
+
63
+ def run(self):
64
+ cfg = copy.deepcopy(self.cfg)
65
+
66
+ # resume training
67
+ if cfg.training.resume:
68
+ lastest_ckpt_path = self.get_checkpoint_path()
69
+ if lastest_ckpt_path.is_file():
70
+ print(f"Resuming from checkpoint {lastest_ckpt_path}")
71
+ self.load_checkpoint(path=lastest_ckpt_path)
72
+
73
+ # configure dataset
74
+ dataset: BaseLowdimDataset
75
+ dataset = hydra.utils.instantiate(cfg.task.dataset)
76
+ assert isinstance(dataset, BaseLowdimDataset)
77
+ train_dataloader = DataLoader(dataset, **cfg.dataloader)
78
+ normalizer = dataset.get_normalizer()
79
+
80
+ # configure validation dataset
81
+ val_dataset = dataset.get_validation_dataset()
82
+ val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
83
+
84
+ self.model.set_normalizer(normalizer)
85
+ if cfg.training.use_ema:
86
+ self.ema_model.set_normalizer(normalizer)
87
+
88
+ # configure lr scheduler
89
+ lr_scheduler = get_scheduler(
90
+ cfg.training.lr_scheduler,
91
+ optimizer=self.optimizer,
92
+ num_warmup_steps=cfg.training.lr_warmup_steps,
93
+ num_training_steps=(
94
+ len(train_dataloader) * cfg.training.num_epochs) \
95
+ // cfg.training.gradient_accumulate_every,
96
+ # pytorch assumes stepping LRScheduler every epoch
97
+ # however huggingface diffusers steps it every batch
98
+ last_epoch=self.global_step-1
99
+ )
100
+
101
+ # configure ema
102
+ ema: EMAModel = None
103
+ if cfg.training.use_ema:
104
+ ema = hydra.utils.instantiate(
105
+ cfg.ema,
106
+ model=self.ema_model)
107
+
108
+ # configure env runner
109
+ env_runner: BaseLowdimRunner
110
+ env_runner = hydra.utils.instantiate(
111
+ cfg.task.env_runner,
112
+ output_dir=self.output_dir)
113
+ assert isinstance(env_runner, BaseLowdimRunner)
114
+
115
+ # configure logging
116
+ wandb_run = wandb.init(
117
+ dir=str(self.output_dir),
118
+ config=OmegaConf.to_container(cfg, resolve=True),
119
+ **cfg.logging
120
+ )
121
+ wandb.config.update(
122
+ {
123
+ "output_dir": self.output_dir,
124
+ }
125
+ )
126
+
127
+ # configure checkpoint
128
+ topk_manager = TopKCheckpointManager(
129
+ save_dir=os.path.join(self.output_dir, 'checkpoints'),
130
+ **cfg.checkpoint.topk
131
+ )
132
+
133
+ # device transfer
134
+ device = torch.device(cfg.training.device)
135
+ self.model.to(device)
136
+ if self.ema_model is not None:
137
+ self.ema_model.to(device)
138
+ optimizer_to(self.optimizer, device)
139
+
140
+ # save batch for sampling
141
+ train_sampling_batch = None
142
+
143
+ if cfg.training.debug:
144
+ cfg.training.num_epochs = 2
145
+ cfg.training.max_train_steps = 3
146
+ cfg.training.max_val_steps = 3
147
+ cfg.training.rollout_every = 1
148
+ cfg.training.checkpoint_every = 1
149
+ cfg.training.val_every = 1
150
+ cfg.training.sample_every = 1
151
+
152
+ # training loop
153
+ log_path = os.path.join(self.output_dir, 'logs.json.txt')
154
+ with JsonLogger(log_path) as json_logger:
155
+ for local_epoch_idx in range(cfg.training.num_epochs):
156
+ step_log = dict()
157
+ # ========= train for this epoch ==========
158
+ train_losses = list()
159
+ with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
160
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
161
+ for batch_idx, batch in enumerate(tepoch):
162
+ # device transfer
163
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
164
+ if train_sampling_batch is None:
165
+ train_sampling_batch = batch
166
+
167
+ # compute loss
168
+ raw_loss = self.model.compute_loss(batch)
169
+ loss = raw_loss / cfg.training.gradient_accumulate_every
170
+ loss.backward()
171
+
172
+ # step optimizer
173
+ if self.global_step % cfg.training.gradient_accumulate_every == 0:
174
+ self.optimizer.step()
175
+ self.optimizer.zero_grad()
176
+ lr_scheduler.step()
177
+
178
+ # update ema
179
+ if cfg.training.use_ema:
180
+ ema.step(self.model)
181
+
182
+ # logging
183
+ raw_loss_cpu = raw_loss.item()
184
+ tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
185
+ train_losses.append(raw_loss_cpu)
186
+ step_log = {
187
+ 'train_loss': raw_loss_cpu,
188
+ 'global_step': self.global_step,
189
+ 'epoch': self.epoch,
190
+ 'lr': lr_scheduler.get_last_lr()[0]
191
+ }
192
+
193
+ is_last_batch = (batch_idx == (len(train_dataloader)-1))
194
+ if not is_last_batch:
195
+ # log of last step is combined with validation and rollout
196
+ wandb_run.log(step_log, step=self.global_step)
197
+ json_logger.log(step_log)
198
+ self.global_step += 1
199
+
200
+ if (cfg.training.max_train_steps is not None) \
201
+ and batch_idx >= (cfg.training.max_train_steps-1):
202
+ break
203
+
204
+ # at the end of each epoch
205
+ # replace train_loss with epoch average
206
+ train_loss = np.mean(train_losses)
207
+ step_log['train_loss'] = train_loss
208
+
209
+ # ========= eval for this epoch ==========
210
+ policy = self.model
211
+ if cfg.training.use_ema:
212
+ policy = self.ema_model
213
+ policy.eval()
214
+
215
+ # run rollout
216
+ if (self.epoch % cfg.training.rollout_every) == 0:
217
+ runner_log = env_runner.run(policy)
218
+ # log all
219
+ step_log.update(runner_log)
220
+
221
+ # run validation
222
+ if (self.epoch % cfg.training.val_every) == 0:
223
+ with torch.no_grad():
224
+ val_losses = list()
225
+ with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
226
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
227
+ for batch_idx, batch in enumerate(tepoch):
228
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
229
+ loss = self.model.compute_loss(batch)
230
+ val_losses.append(loss)
231
+ if (cfg.training.max_val_steps is not None) \
232
+ and batch_idx >= (cfg.training.max_val_steps-1):
233
+ break
234
+ if len(val_losses) > 0:
235
+ val_loss = torch.mean(torch.tensor(val_losses)).item()
236
+ # log epoch average validation loss
237
+ step_log['val_loss'] = val_loss
238
+
239
+ # run diffusion sampling on a training batch
240
+ if (self.epoch % cfg.training.sample_every) == 0:
241
+ with torch.no_grad():
242
+ # sample trajectory from training set, and evaluate difference
243
+ batch = train_sampling_batch
244
+ obs_dict = {'obs': batch['obs']}
245
+ gt_action = batch['action']
246
+
247
+ result = policy.predict_action(obs_dict)
248
+ if cfg.pred_action_steps_only:
249
+ pred_action = result['action']
250
+ start = cfg.n_obs_steps - 1
251
+ end = start + cfg.n_action_steps
252
+ gt_action = gt_action[:,start:end]
253
+ else:
254
+ pred_action = result['action_pred']
255
+ mse = torch.nn.functional.mse_loss(pred_action, gt_action)
256
+ # log
257
+ step_log['train_action_mse_error'] = mse.item()
258
+ # release RAM
259
+ del batch
260
+ del obs_dict
261
+ del gt_action
262
+ del result
263
+ del pred_action
264
+ del mse
265
+
266
+ # checkpoint
267
+ if (self.epoch % cfg.training.checkpoint_every) == 0:
268
+ # checkpointing
269
+ if cfg.checkpoint.save_last_ckpt:
270
+ self.save_checkpoint()
271
+ if cfg.checkpoint.save_last_snapshot:
272
+ self.save_snapshot()
273
+
274
+ # sanitize metric names
275
+ metric_dict = dict()
276
+ for key, value in step_log.items():
277
+ new_key = key.replace('/', '_')
278
+ metric_dict[new_key] = value
279
+
280
+ # We can't copy the last checkpoint here
281
+ # since save_checkpoint uses threads.
282
+ # therefore at this point the file might have been empty!
283
+ topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
284
+
285
+ if topk_ckpt_path is not None:
286
+ self.save_checkpoint(path=topk_ckpt_path)
287
+ # ========= eval end for this epoch ==========
288
+ policy.train()
289
+
290
+ # end of epoch
291
+ # log of last step is combined with validation and rollout
292
+ wandb_run.log(step_log, step=self.global_step)
293
+ json_logger.log(step_log)
294
+ self.global_step += 1
295
+ self.epoch += 1
296
+
297
+ @hydra.main(
298
+ version_base=None,
299
+ config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
300
+ config_name=pathlib.Path(__file__).stem)
301
+ def main(cfg):
302
+ workspace = TrainDiffusionUnetLowdimWorkspace(cfg)
303
+ workspace.run()
304
+
305
+ if __name__ == "__main__":
306
+ main()
third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_hybrid_workspace.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+ os.chdir(ROOT_DIR)
9
+
10
+ import os
11
+ import hydra
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ import pathlib
15
+ from torch.utils.data import DataLoader
16
+ import copy
17
+ import random
18
+ import wandb
19
+ import tqdm
20
+ import numpy as np
21
+ import shutil
22
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
23
+ from diffusion_policy.policy.ibc_dfo_hybrid_image_policy import IbcDfoHybridImagePolicy
24
+ from diffusion_policy.dataset.base_dataset import BaseImageDataset
25
+ from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
26
+ from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
27
+ from diffusion_policy.common.json_logger import JsonLogger
28
+ from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
29
+ from diffusion_policy.model.diffusion.ema_model import EMAModel
30
+ from diffusion_policy.model.common.lr_scheduler import get_scheduler
31
+
32
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
33
+
34
+ class TrainIbcDfoHybridWorkspace(BaseWorkspace):
35
+ include_keys = ['global_step', 'epoch']
36
+
37
+ def __init__(self, cfg: OmegaConf, output_dir=None):
38
+ super().__init__(cfg, output_dir=output_dir)
39
+
40
+ # set seed
41
+ seed = cfg.training.seed
42
+ torch.manual_seed(seed)
43
+ np.random.seed(seed)
44
+ random.seed(seed)
45
+
46
+ # configure model
47
+ self.model: IbcDfoHybridImagePolicy= hydra.utils.instantiate(cfg.policy)
48
+
49
+ # configure training state
50
+ self.optimizer = hydra.utils.instantiate(
51
+ cfg.optimizer, params=self.model.parameters())
52
+
53
+ # configure training state
54
+ self.global_step = 0
55
+ self.epoch = 0
56
+
57
+ def run(self):
58
+ cfg = copy.deepcopy(self.cfg)
59
+
60
+ # resume training
61
+ if cfg.training.resume:
62
+ lastest_ckpt_path = self.get_checkpoint_path()
63
+ if lastest_ckpt_path.is_file():
64
+ print(f"Resuming from checkpoint {lastest_ckpt_path}")
65
+ self.load_checkpoint(path=lastest_ckpt_path)
66
+
67
+ # configure dataset
68
+ dataset: BaseImageDataset
69
+ dataset = hydra.utils.instantiate(cfg.task.dataset)
70
+ assert isinstance(dataset, BaseImageDataset)
71
+ train_dataloader = DataLoader(dataset, **cfg.dataloader)
72
+ normalizer = dataset.get_normalizer()
73
+
74
+ # configure validation dataset
75
+ val_dataset = dataset.get_validation_dataset()
76
+ val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
77
+
78
+ self.model.set_normalizer(normalizer)
79
+
80
+ # configure lr scheduler
81
+ lr_scheduler = get_scheduler(
82
+ cfg.training.lr_scheduler,
83
+ optimizer=self.optimizer,
84
+ num_warmup_steps=cfg.training.lr_warmup_steps,
85
+ num_training_steps=(
86
+ len(train_dataloader) * cfg.training.num_epochs) \
87
+ // cfg.training.gradient_accumulate_every,
88
+ # pytorch assumes stepping LRScheduler every epoch
89
+ # however huggingface diffusers steps it every batch
90
+ last_epoch=self.global_step-1
91
+ )
92
+
93
+ # configure env
94
+ env_runner: BaseImageRunner
95
+ env_runner = hydra.utils.instantiate(
96
+ cfg.task.env_runner,
97
+ output_dir=self.output_dir)
98
+ assert isinstance(env_runner, BaseImageRunner)
99
+
100
+ # configure logging
101
+ wandb_run = wandb.init(
102
+ dir=str(self.output_dir),
103
+ config=OmegaConf.to_container(cfg, resolve=True),
104
+ **cfg.logging
105
+ )
106
+ wandb.config.update(
107
+ {
108
+ "output_dir": self.output_dir,
109
+ }
110
+ )
111
+
112
+ # configure checkpoint
113
+ topk_manager = TopKCheckpointManager(
114
+ save_dir=os.path.join(self.output_dir, 'checkpoints'),
115
+ **cfg.checkpoint.topk
116
+ )
117
+
118
+ # device transfer
119
+ device = torch.device(cfg.training.device)
120
+ self.model.to(device)
121
+ optimizer_to(self.optimizer, device)
122
+
123
+ # save batch for sampling
124
+ train_sampling_batch = None
125
+
126
+ if cfg.training.debug:
127
+ cfg.training.num_epochs = 2
128
+ cfg.training.max_train_steps = 3
129
+ cfg.training.max_val_steps = 3
130
+ cfg.training.rollout_every = 1
131
+ cfg.training.checkpoint_every = 1
132
+ cfg.training.val_every = 1
133
+ cfg.training.sample_every = 1
134
+
135
+ # training loop
136
+ log_path = os.path.join(self.output_dir, 'logs.json.txt')
137
+ with JsonLogger(log_path) as json_logger:
138
+ for local_epoch_idx in range(cfg.training.num_epochs):
139
+ step_log = dict()
140
+ # ========= train for this epoch ==========
141
+ train_losses = list()
142
+ with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
143
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
144
+ for batch_idx, batch in enumerate(tepoch):
145
+ # device transfer
146
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
147
+ if train_sampling_batch is None:
148
+ train_sampling_batch = batch
149
+
150
+ # compute loss
151
+ raw_loss = self.model.compute_loss(batch)
152
+ loss = raw_loss / cfg.training.gradient_accumulate_every
153
+ loss.backward()
154
+
155
+ # step optimizer
156
+ if self.global_step % cfg.training.gradient_accumulate_every == 0:
157
+ self.optimizer.step()
158
+ self.optimizer.zero_grad()
159
+ lr_scheduler.step()
160
+
161
+ # logging
162
+ raw_loss_cpu = raw_loss.item()
163
+ tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
164
+ train_losses.append(raw_loss_cpu)
165
+ step_log = {
166
+ 'train_loss': raw_loss_cpu,
167
+ 'global_step': self.global_step,
168
+ 'epoch': self.epoch,
169
+ 'lr': lr_scheduler.get_last_lr()[0]
170
+ }
171
+
172
+ is_last_batch = (batch_idx == (len(train_dataloader)-1))
173
+ if not is_last_batch:
174
+ # log of last step is combined with validation and rollout
175
+ wandb_run.log(step_log, step=self.global_step)
176
+ json_logger.log(step_log)
177
+ self.global_step += 1
178
+
179
+ if (cfg.training.max_train_steps is not None) \
180
+ and batch_idx >= (cfg.training.max_train_steps-1):
181
+ break
182
+
183
+ # at the end of each epoch
184
+ # replace train_loss with epoch average
185
+ train_loss = np.mean(train_losses)
186
+ step_log['train_loss'] = train_loss
187
+
188
+ # ========= eval for this epoch ==========
189
+ policy = self.model
190
+ policy.eval()
191
+
192
+ # run rollout
193
+ if (self.epoch % cfg.training.rollout_every) == 0:
194
+ runner_log = env_runner.run(policy)
195
+ # log all
196
+ step_log.update(runner_log)
197
+
198
+ # run validation
199
+ if (self.epoch % cfg.training.val_every) == 0:
200
+ with torch.no_grad():
201
+ val_losses = list()
202
+ with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
203
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
204
+ for batch_idx, batch in enumerate(tepoch):
205
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
206
+ loss = self.model.compute_loss(batch)
207
+ val_losses.append(loss)
208
+ if (cfg.training.max_val_steps is not None) \
209
+ and batch_idx >= (cfg.training.max_val_steps-1):
210
+ break
211
+ if len(val_losses) > 0:
212
+ val_loss = torch.mean(torch.tensor(val_losses)).item()
213
+ # log epoch average validation loss
214
+ step_log['val_loss'] = val_loss
215
+
216
+ # run diffusion sampling on a training batch
217
+ if (self.epoch % cfg.training.sample_every) == 0:
218
+ with torch.no_grad():
219
+ # sample trajectory from training set, and evaluate difference
220
+ batch = train_sampling_batch
221
+ n_samples = cfg.training.sample_max_batch
222
+ batch = dict_apply(train_sampling_batch,
223
+ lambda x: x.to(device, non_blocking=True))
224
+ obs_dict = dict_apply(batch['obs'], lambda x: x[:n_samples])
225
+ gt_action = batch['action']
226
+
227
+ result = policy.predict_action(obs_dict)
228
+ pred_action = result['action']
229
+ start = cfg.n_obs_steps - 1
230
+ end = start + cfg.n_action_steps
231
+ gt_action = gt_action[:,start:end]
232
+ mse = torch.nn.functional.mse_loss(pred_action, gt_action)
233
+ # log
234
+ step_log['train_action_mse_error'] = mse.item()
235
+ # release RAM
236
+ del batch
237
+ del obs_dict
238
+ del gt_action
239
+ del result
240
+ del pred_action
241
+ del mse
242
+
243
+ # checkpoint
244
+ if (self.epoch % cfg.training.checkpoint_every) == 0:
245
+ # checkpointing
246
+ if cfg.checkpoint.save_last_ckpt:
247
+ self.save_checkpoint()
248
+ if cfg.checkpoint.save_last_snapshot:
249
+ self.save_snapshot()
250
+
251
+ # sanitize metric names
252
+ metric_dict = dict()
253
+ for key, value in step_log.items():
254
+ new_key = key.replace('/', '_')
255
+ metric_dict[new_key] = value
256
+
257
+ # We can't copy the last checkpoint here
258
+ # since save_checkpoint uses threads.
259
+ # therefore at this point the file might have been empty!
260
+ topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
261
+
262
+ if topk_ckpt_path is not None:
263
+ self.save_checkpoint(path=topk_ckpt_path)
264
+ # ========= eval end for this epoch ==========
265
+ policy.train()
266
+
267
+ # end of epoch
268
+ # log of last step is combined with validation and rollout
269
+ wandb_run.log(step_log, step=self.global_step)
270
+ json_logger.log(step_log)
271
+ self.global_step += 1
272
+ self.epoch += 1
273
+
274
+ @hydra.main(
275
+ version_base=None,
276
+ config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
277
+ config_name=pathlib.Path(__file__).stem)
278
+ def main(cfg):
279
+ workspace = TrainIbcDfoHybridWorkspace(cfg)
280
+ workspace.run()
281
+
282
+ if __name__ == "__main__":
283
+ main()
third_party/diffusion_policy/diffusion_policy/workspace/train_ibc_dfo_lowdim_workspace.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+ os.chdir(ROOT_DIR)
9
+
10
+ import os
11
+ import hydra
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ import pathlib
15
+ from torch.utils.data import DataLoader
16
+ import copy
17
+ import numpy as np
18
+ import random
19
+ import wandb
20
+ import tqdm
21
+ import shutil
22
+
23
+ from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
24
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
25
+ from diffusion_policy.policy.ibc_dfo_lowdim_policy import IbcDfoLowdimPolicy
26
+ from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
27
+ from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
28
+ from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
29
+ from diffusion_policy.common.json_logger import JsonLogger
30
+ from diffusion_policy.model.common.lr_scheduler import get_scheduler
31
+
32
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
33
+
34
+ # %%
35
+ class TrainIbcDfoLowdimWorkspace(BaseWorkspace):
36
+ include_keys = ['global_step', 'epoch']
37
+
38
+ def __init__(self, cfg: OmegaConf, output_dir=None):
39
+ super().__init__(cfg, output_dir=output_dir)
40
+
41
+ # set seed
42
+ seed = cfg.training.seed
43
+ torch.manual_seed(seed)
44
+ np.random.seed(seed)
45
+ random.seed(seed)
46
+
47
+ # configure model
48
+ self.model: IbcDfoLowdimPolicy
49
+ self.model = hydra.utils.instantiate(cfg.policy)
50
+
51
+ # configure training state
52
+ self.optimizer = hydra.utils.instantiate(
53
+ cfg.optimizer, params=self.model.parameters())
54
+
55
+ self.global_step = 0
56
+ self.epoch = 0
57
+
58
+ def run(self):
59
+ cfg = copy.deepcopy(self.cfg)
60
+
61
+ # resume training
62
+ if cfg.training.resume:
63
+ lastest_ckpt_path = self.get_checkpoint_path()
64
+ if lastest_ckpt_path.is_file():
65
+ print(f"Resuming from checkpoint {lastest_ckpt_path}")
66
+ self.load_checkpoint(path=lastest_ckpt_path)
67
+
68
+ # configure dataset
69
+ dataset: BaseLowdimDataset
70
+ dataset = hydra.utils.instantiate(cfg.task.dataset)
71
+ assert isinstance(dataset, BaseLowdimDataset)
72
+ train_dataloader = DataLoader(dataset, **cfg.dataloader)
73
+ normalizer = dataset.get_normalizer()
74
+
75
+ # configure validation dataset
76
+ val_dataset = dataset.get_validation_dataset()
77
+ val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
78
+
79
+ self.model.set_normalizer(normalizer)
80
+
81
+ # configure lr scheduler
82
+ lr_scheduler = get_scheduler(
83
+ cfg.training.lr_scheduler,
84
+ optimizer=self.optimizer,
85
+ num_warmup_steps=cfg.training.lr_warmup_steps,
86
+ num_training_steps=(
87
+ len(train_dataloader) * cfg.training.num_epochs) \
88
+ // cfg.training.gradient_accumulate_every,
89
+ # pytorch assumes stepping LRScheduler every epoch
90
+ # however huggingface diffusers steps it every batch
91
+ last_epoch=self.global_step-1
92
+ )
93
+
94
+ # configure env runner
95
+ env_runner: BaseLowdimRunner
96
+ env_runner = hydra.utils.instantiate(
97
+ cfg.task.env_runner,
98
+ output_dir=self.output_dir)
99
+ assert isinstance(env_runner, BaseLowdimRunner)
100
+
101
+ # configure logging
102
+ wandb_run = wandb.init(
103
+ dir=str(self.output_dir),
104
+ config=OmegaConf.to_container(cfg, resolve=True),
105
+ **cfg.logging
106
+ )
107
+ wandb.config.update(
108
+ {
109
+ "output_dir": self.output_dir,
110
+ }
111
+ )
112
+
113
+ # configure checkpoint
114
+ topk_manager = TopKCheckpointManager(
115
+ save_dir=os.path.join(self.output_dir, 'checkpoints'),
116
+ **cfg.checkpoint.topk
117
+ )
118
+
119
+ # device transfer
120
+ device = torch.device(cfg.training.device)
121
+ self.model.to(device)
122
+ optimizer_to(self.optimizer, device)
123
+
124
+ # save batch for sampling
125
+ train_sampling_batch = None
126
+
127
+ if cfg.training.debug:
128
+ cfg.training.num_epochs = 2
129
+ cfg.training.max_train_steps = 3
130
+ cfg.training.max_val_steps = 3
131
+ cfg.training.rollout_every = 1
132
+ cfg.training.checkpoint_every = 1
133
+ cfg.training.val_every = 1
134
+ cfg.training.sample_every = 1
135
+
136
+ # training loop
137
+ log_path = os.path.join(self.output_dir, 'logs.json.txt')
138
+ with JsonLogger(log_path) as json_logger:
139
+ for local_epoch_idx in range(cfg.training.num_epochs):
140
+ step_log = dict()
141
+ # ========= train for this epoch ==========
142
+ train_losses = list()
143
+ with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
144
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
145
+ for batch_idx, batch in enumerate(tepoch):
146
+ # device transfer
147
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
148
+ if train_sampling_batch is None:
149
+ train_sampling_batch = batch
150
+
151
+ # compute loss
152
+ raw_loss = self.model.compute_loss(batch)
153
+ loss = raw_loss / cfg.training.gradient_accumulate_every
154
+ loss.backward()
155
+
156
+ # step optimizer
157
+ if self.global_step % cfg.training.gradient_accumulate_every == 0:
158
+ self.optimizer.step()
159
+ self.optimizer.zero_grad()
160
+ lr_scheduler.step()
161
+
162
+ # logging
163
+ raw_loss_cpu = raw_loss.item()
164
+ tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
165
+ train_losses.append(raw_loss_cpu)
166
+ step_log = {
167
+ 'train_loss': raw_loss_cpu,
168
+ 'global_step': self.global_step,
169
+ 'epoch': self.epoch,
170
+ 'lr': lr_scheduler.get_last_lr()[0]
171
+ }
172
+
173
+ is_last_batch = (batch_idx == (len(train_dataloader)-1))
174
+ if not is_last_batch:
175
+ # log of last step is combined with validation and rollout
176
+ wandb_run.log(step_log, step=self.global_step)
177
+ json_logger.log(step_log)
178
+ self.global_step += 1
179
+
180
+ if (cfg.training.max_train_steps is not None) \
181
+ and batch_idx >= (cfg.training.max_train_steps-1):
182
+ break
183
+
184
+ # at the end of each epoch
185
+ # replace train_loss with epoch average
186
+ train_loss = np.mean(train_losses)
187
+ step_log['train_loss'] = train_loss
188
+
189
+ # ========= eval for this epoch ==========
190
+ policy = self.model
191
+ policy.eval()
192
+
193
+ # run rollout
194
+ if (self.epoch % cfg.training.rollout_every) == 0:
195
+ runner_log = env_runner.run(policy)
196
+ # log all
197
+ step_log.update(runner_log)
198
+
199
+ # run validation
200
+ if (self.epoch % cfg.training.val_every) == 0:
201
+ with torch.no_grad():
202
+ val_losses = list()
203
+ with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
204
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
205
+ for batch_idx, batch in enumerate(tepoch):
206
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
207
+ loss = self.model.compute_loss(batch)
208
+ val_losses.append(loss)
209
+ if (cfg.training.max_val_steps is not None) \
210
+ and batch_idx >= (cfg.training.max_val_steps-1):
211
+ break
212
+ if len(val_losses) > 0:
213
+ val_loss = torch.mean(torch.tensor(val_losses)).item()
214
+ # log epoch average validation loss
215
+ step_log['val_loss'] = val_loss
216
+
217
+ # run diffusion sampling on a training batch
218
+ if (self.epoch % cfg.training.sample_every) == 0:
219
+ with torch.no_grad():
220
+ # sample trajectory from training set, and evaluate difference
221
+ batch = train_sampling_batch
222
+ n_samples = cfg.training.sample_max_batch
223
+ obs_dict = {'obs': batch['obs'][:n_samples]}
224
+ gt_action = batch['action'][:n_samples]
225
+
226
+ result = policy.predict_action(obs_dict)
227
+ pred_action = result['action']
228
+ start = cfg.n_obs_steps - 1
229
+ end = start + cfg.n_action_steps
230
+ gt_action = gt_action[:,start:end]
231
+ mse = torch.nn.functional.mse_loss(pred_action, gt_action)
232
+ # log
233
+ step_log['train_action_mse_error'] = mse.item()
234
+ # release RAM
235
+ del batch
236
+ del obs_dict
237
+ del gt_action
238
+ del result
239
+ del pred_action
240
+ del mse
241
+
242
+ # checkpoint
243
+ if (self.epoch % cfg.training.checkpoint_every) == 0:
244
+ # checkpointing
245
+ if cfg.checkpoint.save_last_ckpt:
246
+ self.save_checkpoint()
247
+ if cfg.checkpoint.save_last_snapshot:
248
+ self.save_snapshot()
249
+
250
+ # sanitize metric names
251
+ metric_dict = dict()
252
+ for key, value in step_log.items():
253
+ new_key = key.replace('/', '_')
254
+ metric_dict[new_key] = value
255
+
256
+ # We can't copy the last checkpoint here
257
+ # since save_checkpoint uses threads.
258
+ # therefore at this point the file might have been empty!
259
+ topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
260
+
261
+ if topk_ckpt_path is not None:
262
+ self.save_checkpoint(path=topk_ckpt_path)
263
+ # ========= eval end for this epoch ==========
264
+ policy.train()
265
+
266
+ # end of epoch
267
+ # log of last step is combined with validation and rollout
268
+ wandb_run.log(step_log, step=self.global_step)
269
+ json_logger.log(step_log)
270
+ self.global_step += 1
271
+ self.epoch += 1
272
+
273
+ @hydra.main(
274
+ version_base=None,
275
+ config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
276
+ config_name=pathlib.Path(__file__).stem)
277
+ def main(cfg):
278
+ workspace = TrainIbcDfoLowdimWorkspace(cfg)
279
+ workspace.run()
280
+
281
+ if __name__ == "__main__":
282
+ main()
third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_image_workspace.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+ os.chdir(ROOT_DIR)
9
+
10
+ import os
11
+ import hydra
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ import pathlib
15
+ from torch.utils.data import DataLoader
16
+ import copy
17
+ import random
18
+ import wandb
19
+ import tqdm
20
+ import numpy as np
21
+ import shutil
22
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
23
+ from diffusion_policy.policy.robomimic_image_policy import RobomimicImagePolicy
24
+ from diffusion_policy.dataset.base_dataset import BaseImageDataset
25
+ from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
26
+ from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
27
+ from diffusion_policy.common.json_logger import JsonLogger
28
+ from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
29
+
30
+
31
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
32
+
33
+ class TrainRobomimicImageWorkspace(BaseWorkspace):
34
+ include_keys = ['global_step', 'epoch']
35
+
36
+ def __init__(self, cfg: OmegaConf, output_dir=None):
37
+ super().__init__(cfg, output_dir=output_dir)
38
+
39
+ # set seed
40
+ seed = cfg.training.seed
41
+ torch.manual_seed(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+ # configure model
46
+ self.model: RobomimicImagePolicy = hydra.utils.instantiate(cfg.policy)
47
+
48
+ # configure training state
49
+ self.global_step = 0
50
+ self.epoch = 0
51
+
52
+ def run(self):
53
+ cfg = copy.deepcopy(self.cfg)
54
+
55
+ # resume training
56
+ if cfg.training.resume:
57
+ lastest_ckpt_path = self.get_checkpoint_path()
58
+ if lastest_ckpt_path.is_file():
59
+ print(f"Resuming from checkpoint {lastest_ckpt_path}")
60
+ self.load_checkpoint(path=lastest_ckpt_path)
61
+
62
+ # configure dataset
63
+ dataset: BaseImageDataset
64
+ dataset = hydra.utils.instantiate(cfg.task.dataset)
65
+ assert isinstance(dataset, BaseImageDataset)
66
+ train_dataloader = DataLoader(dataset, **cfg.dataloader)
67
+ normalizer = dataset.get_normalizer()
68
+
69
+ # configure validation dataset
70
+ val_dataset = dataset.get_validation_dataset()
71
+ val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
72
+
73
+ self.model.set_normalizer(normalizer)
74
+
75
+ # configure env
76
+ env_runner: BaseImageRunner
77
+ env_runner = hydra.utils.instantiate(
78
+ cfg.task.env_runner,
79
+ output_dir=self.output_dir)
80
+ assert isinstance(env_runner, BaseImageRunner)
81
+
82
+ # configure logging
83
+ wandb_run = wandb.init(
84
+ dir=str(self.output_dir),
85
+ config=OmegaConf.to_container(cfg, resolve=True),
86
+ **cfg.logging
87
+ )
88
+ wandb.config.update(
89
+ {
90
+ "output_dir": self.output_dir,
91
+ }
92
+ )
93
+
94
+ # configure checkpoint
95
+ topk_manager = TopKCheckpointManager(
96
+ save_dir=os.path.join(self.output_dir, 'checkpoints'),
97
+ **cfg.checkpoint.topk
98
+ )
99
+
100
+ # device transfer
101
+ device = torch.device(cfg.training.device)
102
+ self.model.to(device)
103
+
104
+ # save batch for sampling
105
+ train_sampling_batch = None
106
+
107
+ if cfg.training.debug:
108
+ cfg.training.num_epochs = 2
109
+ cfg.training.max_train_steps = 3
110
+ cfg.training.max_val_steps = 3
111
+ cfg.training.rollout_every = 1
112
+ cfg.training.checkpoint_every = 1
113
+ cfg.training.val_every = 1
114
+ cfg.training.sample_every = 1
115
+
116
+ # training loop
117
+ log_path = os.path.join(self.output_dir, 'logs.json.txt')
118
+ with JsonLogger(log_path) as json_logger:
119
+ for local_epoch_idx in range(cfg.training.num_epochs):
120
+ step_log = dict()
121
+ # ========= train for this epoch ==========
122
+ train_losses = list()
123
+ with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
124
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
125
+ for batch_idx, batch in enumerate(tepoch):
126
+ # device transfer
127
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
128
+ if train_sampling_batch is None:
129
+ train_sampling_batch = batch
130
+
131
+ info = self.model.train_on_batch(batch, epoch=self.epoch)
132
+
133
+ # logging
134
+ loss_cpu = info['losses']['action_loss'].item()
135
+ tepoch.set_postfix(loss=loss_cpu, refresh=False)
136
+ train_losses.append(loss_cpu)
137
+ step_log = {
138
+ 'train_loss': loss_cpu,
139
+ 'global_step': self.global_step,
140
+ 'epoch': self.epoch
141
+ }
142
+
143
+ is_last_batch = (batch_idx == (len(train_dataloader)-1))
144
+ if not is_last_batch:
145
+ # log of last step is combined with validation and rollout
146
+ wandb_run.log(step_log, step=self.global_step)
147
+ json_logger.log(step_log)
148
+ self.global_step += 1
149
+
150
+ if (cfg.training.max_train_steps is not None) \
151
+ and batch_idx >= (cfg.training.max_train_steps-1):
152
+ break
153
+
154
+ # at the end of each epoch
155
+ # replace train_loss with epoch average
156
+ train_loss = np.mean(train_losses)
157
+ step_log['train_loss'] = train_loss
158
+
159
+ # ========= eval for this epoch ==========
160
+ self.model.eval()
161
+
162
+ # run rollout
163
+ if (self.epoch % cfg.training.rollout_every) == 0:
164
+ runner_log = env_runner.run(self.model)
165
+ # log all
166
+ step_log.update(runner_log)
167
+
168
+ # run validation
169
+ if (self.epoch % cfg.training.val_every) == 0:
170
+ with torch.no_grad():
171
+ val_losses = list()
172
+ with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
173
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
174
+ for batch_idx, batch in enumerate(tepoch):
175
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
176
+ info = self.model.train_on_batch(batch, epoch=self.epoch, validate=True)
177
+ loss = info['losses']['action_loss']
178
+ val_losses.append(loss)
179
+ if (cfg.training.max_val_steps is not None) \
180
+ and batch_idx >= (cfg.training.max_val_steps-1):
181
+ break
182
+ if len(val_losses) > 0:
183
+ val_loss = torch.mean(torch.tensor(val_losses)).item()
184
+ # log epoch average validation loss
185
+ step_log['val_loss'] = val_loss
186
+
187
+ # run diffusion sampling on a training batch
188
+ if (self.epoch % cfg.training.sample_every) == 0:
189
+ with torch.no_grad():
190
+ # sample trajectory from training set, and evaluate difference
191
+ batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
192
+ obs_dict = batch['obs']
193
+ gt_action = batch['action']
194
+ T = gt_action.shape[1]
195
+
196
+ pred_actions = list()
197
+ self.model.reset()
198
+ for i in range(T):
199
+ result = self.model.predict_action(
200
+ dict_apply(obs_dict, lambda x: x[:,[i]])
201
+ )
202
+ pred_actions.append(result['action'])
203
+ pred_actions = torch.cat(pred_actions, dim=1)
204
+ mse = torch.nn.functional.mse_loss(pred_actions, gt_action)
205
+ step_log['train_action_mse_error'] = mse.item()
206
+ del batch
207
+ del obs_dict
208
+ del gt_action
209
+ del result
210
+ del pred_actions
211
+ del mse
212
+
213
+ # checkpoint
214
+ if (self.epoch % cfg.training.checkpoint_every) == 0:
215
+ # checkpointing
216
+ if cfg.checkpoint.save_last_ckpt:
217
+ self.save_checkpoint()
218
+ if cfg.checkpoint.save_last_snapshot:
219
+ self.save_snapshot()
220
+
221
+ # sanitize metric names
222
+ metric_dict = dict()
223
+ for key, value in step_log.items():
224
+ new_key = key.replace('/', '_')
225
+ metric_dict[new_key] = value
226
+
227
+ # We can't copy the last checkpoint here
228
+ # since save_checkpoint uses threads.
229
+ # therefore at this point the file might have been empty!
230
+ topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
231
+
232
+ if topk_ckpt_path is not None:
233
+ self.save_checkpoint(path=topk_ckpt_path)
234
+ # ========= eval end for this epoch ==========
235
+ self.model.train()
236
+
237
+ # end of epoch
238
+ # log of last step is combined with validation and rollout
239
+ wandb_run.log(step_log, step=self.global_step)
240
+ json_logger.log(step_log)
241
+ self.global_step += 1
242
+ self.epoch += 1
243
+
244
+
245
+ @hydra.main(
246
+ version_base=None,
247
+ config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
248
+ config_name=pathlib.Path(__file__).stem)
249
+ def main(cfg):
250
+ workspace = TrainRobomimicImageWorkspace(cfg)
251
+ workspace.run()
252
+
253
+ if __name__ == "__main__":
254
+ main()
third_party/diffusion_policy/diffusion_policy/workspace/train_robomimic_lowdim_workspace.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+ os.chdir(ROOT_DIR)
9
+
10
+ import os
11
+ import hydra
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ import pathlib
15
+ from torch.utils.data import DataLoader
16
+ import copy
17
+ import random
18
+ import wandb
19
+ import tqdm
20
+ import numpy as np
21
+ import shutil
22
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
23
+ from diffusion_policy.policy.robomimic_lowdim_policy import RobomimicLowdimPolicy
24
+ from diffusion_policy.dataset.base_dataset import BaseLowdimDataset
25
+ from diffusion_policy.env_runner.base_lowdim_runner import BaseLowdimRunner
26
+ from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
27
+ from diffusion_policy.common.json_logger import JsonLogger
28
+ from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
29
+
30
+
31
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
32
+
33
+ class TrainRobomimicLowdimWorkspace(BaseWorkspace):
34
+ include_keys = ['global_step', 'epoch']
35
+
36
+ def __init__(self, cfg: OmegaConf):
37
+ super().__init__(cfg)
38
+
39
+ # set seed
40
+ seed = cfg.training.seed
41
+ torch.manual_seed(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+ # configure model
46
+ self.model: RobomimicLowdimPolicy = hydra.utils.instantiate(cfg.policy)
47
+
48
+ # configure training state
49
+ self.global_step = 0
50
+ self.epoch = 0
51
+
52
+ def run(self):
53
+ cfg = copy.deepcopy(self.cfg)
54
+
55
+ # resume training
56
+ if cfg.training.resume:
57
+ lastest_ckpt_path = self.get_checkpoint_path()
58
+ if lastest_ckpt_path.is_file():
59
+ print(f"Resuming from checkpoint {lastest_ckpt_path}")
60
+ self.load_checkpoint(path=lastest_ckpt_path)
61
+
62
+ # configure dataset
63
+ dataset: BaseLowdimDataset
64
+ dataset = hydra.utils.instantiate(cfg.task.dataset)
65
+ assert isinstance(dataset, BaseLowdimDataset)
66
+ train_dataloader = DataLoader(dataset, **cfg.dataloader)
67
+ normalizer = dataset.get_normalizer()
68
+
69
+ # configure validation dataset
70
+ val_dataset = dataset.get_validation_dataset()
71
+ val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader)
72
+
73
+ self.model.set_normalizer(normalizer)
74
+
75
+ # configure env
76
+ env_runner: BaseLowdimRunner
77
+ env_runner = hydra.utils.instantiate(
78
+ cfg.task.env_runner,
79
+ output_dir=self.output_dir)
80
+ assert isinstance(env_runner, BaseLowdimRunner)
81
+
82
+ # configure logging
83
+ wandb_run = wandb.init(
84
+ dir=str(self.output_dir),
85
+ config=OmegaConf.to_container(cfg, resolve=True),
86
+ **cfg.logging
87
+ )
88
+ wandb.config.update(
89
+ {
90
+ "output_dir": self.output_dir,
91
+ }
92
+ )
93
+
94
+ # configure checkpoint
95
+ topk_manager = TopKCheckpointManager(
96
+ save_dir=os.path.join(self.output_dir, 'checkpoints'),
97
+ **cfg.checkpoint.topk
98
+ )
99
+
100
+ # device transfer
101
+ device = torch.device(cfg.training.device)
102
+ self.model.to(device)
103
+
104
+ if cfg.training.debug:
105
+ cfg.training.num_epochs = 2
106
+ cfg.training.max_train_steps = 3
107
+ cfg.training.max_val_steps = 3
108
+ cfg.training.rollout_every = 1
109
+ cfg.training.checkpoint_every = 1
110
+ cfg.training.val_every = 1
111
+
112
+ # training loop
113
+ log_path = os.path.join(self.output_dir, 'logs.json.txt')
114
+ with JsonLogger(log_path) as json_logger:
115
+ for local_epoch_idx in range(cfg.training.num_epochs):
116
+ step_log = dict()
117
+ # ========= train for this epoch ==========
118
+ train_losses = list()
119
+ with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
120
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
121
+ for batch_idx, batch in enumerate(tepoch):
122
+ # device transfer
123
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
124
+ info = self.model.train_on_batch(batch, epoch=self.epoch)
125
+
126
+ # logging
127
+ loss_cpu = info['losses']['action_loss'].item()
128
+ tepoch.set_postfix(loss=loss_cpu, refresh=False)
129
+ train_losses.append(loss_cpu)
130
+ step_log = {
131
+ 'train_loss': loss_cpu,
132
+ 'global_step': self.global_step,
133
+ 'epoch': self.epoch
134
+ }
135
+
136
+ is_last_batch = (batch_idx == (len(train_dataloader)-1))
137
+ if not is_last_batch:
138
+ # log of last step is combined with validation and rollout
139
+ wandb_run.log(step_log, step=self.global_step)
140
+ json_logger.log(step_log)
141
+ self.global_step += 1
142
+
143
+ if (cfg.training.max_train_steps is not None) \
144
+ and batch_idx >= (cfg.training.max_train_steps-1):
145
+ break
146
+
147
+ # at the end of each epoch
148
+ # replace train_loss with epoch average
149
+ train_loss = np.mean(train_losses)
150
+ step_log['train_loss'] = train_loss
151
+
152
+ # ========= eval for this epoch ==========
153
+ self.model.eval()
154
+
155
+ # run rollout
156
+ if (self.epoch % cfg.training.rollout_every) == 0:
157
+ runner_log = env_runner.run(self.model)
158
+ # log all
159
+ step_log.update(runner_log)
160
+
161
+ # run validation
162
+ if (self.epoch % cfg.training.val_every) == 0:
163
+ with torch.no_grad():
164
+ val_losses = list()
165
+ with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
166
+ leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
167
+ for batch_idx, batch in enumerate(tepoch):
168
+ batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
169
+ info = self.model.train_on_batch(batch, epoch=self.epoch, validate=True)
170
+ loss = info['losses']['action_loss']
171
+ val_losses.append(loss)
172
+ if (cfg.training.max_val_steps is not None) \
173
+ and batch_idx >= (cfg.training.max_val_steps-1):
174
+ break
175
+ if len(val_losses) > 0:
176
+ val_loss = torch.mean(torch.tensor(val_losses)).item()
177
+ # log epoch average validation loss
178
+ step_log['val_loss'] = val_loss
179
+
180
+ # checkpoint
181
+ if (self.epoch % cfg.training.checkpoint_every) == 0:
182
+ # checkpointing
183
+ if cfg.checkpoint.save_last_ckpt:
184
+ self.save_checkpoint()
185
+ if cfg.checkpoint.save_last_snapshot:
186
+ self.save_snapshot()
187
+
188
+ # sanitize metric names
189
+ metric_dict = dict()
190
+ for key, value in step_log.items():
191
+ new_key = key.replace('/', '_')
192
+ metric_dict[new_key] = value
193
+
194
+ # We can't copy the last checkpoint here
195
+ # since save_checkpoint uses threads.
196
+ # therefore at this point the file might have been empty!
197
+ topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
198
+
199
+ if topk_ckpt_path is not None:
200
+ self.save_checkpoint(path=topk_ckpt_path)
201
+ # ========= eval end for this epoch ==========
202
+ self.model.train()
203
+
204
+ # end of epoch
205
+ # log of last step is combined with validation and rollout
206
+ wandb_run.log(step_log, step=self.global_step)
207
+ json_logger.log(step_log)
208
+ self.global_step += 1
209
+ self.epoch += 1
210
+
211
+
212
+ @hydra.main(
213
+ version_base=None,
214
+ config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
215
+ config_name=pathlib.Path(__file__).stem)
216
+ def main(cfg):
217
+ workspace = TrainRobomimicLowdimWorkspace(cfg)
218
+ workspace.run()
219
+
220
+ if __name__ == "__main__":
221
+ main()
third_party/diffusion_policy/eval.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
4
+ """
5
+
6
+ import sys
7
+ # use line-buffering for both stdout and stderr
8
+ sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
9
+ sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
10
+
11
+ import os
12
+ import pathlib
13
+ import click
14
+ import hydra
15
+ import torch
16
+ import dill
17
+ import wandb
18
+ import json
19
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
20
+
21
+ @click.command()
22
+ @click.option('-c', '--checkpoint', required=True)
23
+ @click.option('-o', '--output_dir', required=True)
24
+ @click.option('-d', '--device', default='cuda:0')
25
+ def main(checkpoint, output_dir, device):
26
+ if os.path.exists(output_dir):
27
+ click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
28
+ pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
29
+
30
+ # load checkpoint
31
+ payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
32
+ cfg = payload['cfg']
33
+ cls = hydra.utils.get_class(cfg._target_)
34
+ workspace = cls(cfg, output_dir=output_dir)
35
+ workspace: BaseWorkspace
36
+ workspace.load_payload(payload, exclude_keys=None, include_keys=None)
37
+
38
+ # get policy from workspace
39
+ policy = workspace.model
40
+ if cfg.training.use_ema:
41
+ policy = workspace.ema_model
42
+
43
+ device = torch.device(device)
44
+ policy.to(device)
45
+ policy.eval()
46
+
47
+ # run eval
48
+ env_runner = hydra.utils.instantiate(
49
+ cfg.task.env_runner,
50
+ output_dir=output_dir)
51
+ runner_log = env_runner.run(policy)
52
+
53
+ # dump log to json
54
+ json_log = dict()
55
+ for key, value in runner_log.items():
56
+ if isinstance(value, wandb.sdk.data_types.video.Video):
57
+ json_log[key] = value._path
58
+ else:
59
+ json_log[key] = value
60
+ out_path = os.path.join(output_dir, 'eval_log.json')
61
+ json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
62
+
63
+ if __name__ == '__main__':
64
+ main()
third_party/diffusion_policy/eval_real_robot.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ (robodiff)$ python eval_real_robot.py -i <ckpt_path> -o <save_dir> --robot_ip <ip_of_ur5>
4
+
5
+ ================ Human in control ==============
6
+ Robot movement:
7
+ Move your SpaceMouse to move the robot EEF (locked in xy plane).
8
+ Press SpaceMouse right button to unlock z axis.
9
+ Press SpaceMouse left button to enable rotation axes.
10
+
11
+ Recording control:
12
+ Click the opencv window (make sure it's in focus).
13
+ Press "C" to start evaluation (hand control over to policy).
14
+ Press "Q" to exit program.
15
+
16
+ ================ Policy in control ==============
17
+ Make sure you can hit the robot hardware emergency-stop button quickly!
18
+
19
+ Recording control:
20
+ Press "S" to stop evaluation and gain control back.
21
+ """
22
+
23
+ # %%
24
+ import time
25
+ from multiprocessing.managers import SharedMemoryManager
26
+ import click
27
+ import cv2
28
+ import numpy as np
29
+ import torch
30
+ import dill
31
+ import hydra
32
+ import pathlib
33
+ import skvideo.io
34
+ from omegaconf import OmegaConf
35
+ import scipy.spatial.transform as st
36
+ from diffusion_policy.real_world.real_env import RealEnv
37
+ from diffusion_policy.real_world.spacemouse_shared_memory import Spacemouse
38
+ from diffusion_policy.common.precise_sleep import precise_wait
39
+ from diffusion_policy.real_world.real_inference_util import (
40
+ get_real_obs_resolution,
41
+ get_real_obs_dict)
42
+ from diffusion_policy.common.pytorch_util import dict_apply
43
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
44
+ from diffusion_policy.policy.base_image_policy import BaseImagePolicy
45
+ from diffusion_policy.common.cv2_util import get_image_transform
46
+
47
+
48
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
49
+
50
+ @click.command()
51
+ @click.option('--input', '-i', required=True, help='Path to checkpoint')
52
+ @click.option('--output', '-o', required=True, help='Directory to save recording')
53
+ @click.option('--robot_ip', '-ri', required=True, help="UR5's IP address e.g. 192.168.0.204")
54
+ @click.option('--match_dataset', '-m', default=None, help='Dataset used to overlay and adjust initial condition')
55
+ @click.option('--match_episode', '-me', default=None, type=int, help='Match specific episode from the match dataset')
56
+ @click.option('--vis_camera_idx', default=0, type=int, help="Which RealSense camera to visualize.")
57
+ @click.option('--init_joints', '-j', is_flag=True, default=False, help="Whether to initialize robot joint configuration in the beginning.")
58
+ @click.option('--steps_per_inference', '-si', default=6, type=int, help="Action horizon for inference.")
59
+ @click.option('--max_duration', '-md', default=60, help='Max duration for each epoch in seconds.')
60
+ @click.option('--frequency', '-f', default=10, type=float, help="Control frequency in Hz.")
61
+ @click.option('--command_latency', '-cl', default=0.01, type=float, help="Latency between receiving SapceMouse command to executing on Robot in Sec.")
62
+ def main(input, output, robot_ip, match_dataset, match_episode,
63
+ vis_camera_idx, init_joints,
64
+ steps_per_inference, max_duration,
65
+ frequency, command_latency):
66
+ # load match_dataset
67
+ match_camera_idx = 0
68
+ episode_first_frame_map = dict()
69
+ if match_dataset is not None:
70
+ match_dir = pathlib.Path(match_dataset)
71
+ match_video_dir = match_dir.joinpath('videos')
72
+ for vid_dir in match_video_dir.glob("*/"):
73
+ episode_idx = int(vid_dir.stem)
74
+ match_video_path = vid_dir.joinpath(f'{match_camera_idx}.mp4')
75
+ if match_video_path.exists():
76
+ frames = skvideo.io.vread(
77
+ str(match_video_path), num_frames=1)
78
+ episode_first_frame_map[episode_idx] = frames[0]
79
+ print(f"Loaded initial frame for {len(episode_first_frame_map)} episodes")
80
+
81
+ # load checkpoint
82
+ ckpt_path = input
83
+ payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill)
84
+ cfg = payload['cfg']
85
+ cls = hydra.utils.get_class(cfg._target_)
86
+ workspace = cls(cfg)
87
+ workspace: BaseWorkspace
88
+ workspace.load_payload(payload, exclude_keys=None, include_keys=None)
89
+
90
+ # hacks for method-specific setup.
91
+ action_offset = 0
92
+ delta_action = False
93
+ if 'diffusion' in cfg.name:
94
+ # diffusion model
95
+ policy: BaseImagePolicy
96
+ policy = workspace.model
97
+ if cfg.training.use_ema:
98
+ policy = workspace.ema_model
99
+
100
+ device = torch.device('cuda')
101
+ policy.eval().to(device)
102
+
103
+ # set inference params
104
+ policy.num_inference_steps = 16 # DDIM inference iterations
105
+ policy.n_action_steps = policy.horizon - policy.n_obs_steps + 1
106
+
107
+ elif 'robomimic' in cfg.name:
108
+ # BCRNN model
109
+ policy: BaseImagePolicy
110
+ policy = workspace.model
111
+
112
+ device = torch.device('cuda')
113
+ policy.eval().to(device)
114
+
115
+ # BCRNN always has action horizon of 1
116
+ steps_per_inference = 1
117
+ action_offset = cfg.n_latency_steps
118
+ delta_action = cfg.task.dataset.get('delta_action', False)
119
+
120
+ elif 'ibc' in cfg.name:
121
+ policy: BaseImagePolicy
122
+ policy = workspace.model
123
+ policy.pred_n_iter = 5
124
+ policy.pred_n_samples = 4096
125
+
126
+ device = torch.device('cuda')
127
+ policy.eval().to(device)
128
+ steps_per_inference = 1
129
+ action_offset = 1
130
+ delta_action = cfg.task.dataset.get('delta_action', False)
131
+ else:
132
+ raise RuntimeError("Unsupported policy type: ", cfg.name)
133
+
134
+ # setup experiment
135
+ dt = 1/frequency
136
+
137
+ obs_res = get_real_obs_resolution(cfg.task.shape_meta)
138
+ n_obs_steps = cfg.n_obs_steps
139
+ print("n_obs_steps: ", n_obs_steps)
140
+ print("steps_per_inference:", steps_per_inference)
141
+ print("action_offset:", action_offset)
142
+
143
+ with SharedMemoryManager() as shm_manager:
144
+ with Spacemouse(shm_manager=shm_manager) as sm, RealEnv(
145
+ output_dir=output,
146
+ robot_ip=robot_ip,
147
+ frequency=frequency,
148
+ n_obs_steps=n_obs_steps,
149
+ obs_image_resolution=obs_res,
150
+ obs_float32=True,
151
+ init_joints=init_joints,
152
+ enable_multi_cam_vis=True,
153
+ record_raw_video=True,
154
+ # number of threads per camera view for video recording (H.264)
155
+ thread_per_video=3,
156
+ # video recording quality, lower is better (but slower).
157
+ video_crf=21,
158
+ shm_manager=shm_manager) as env:
159
+ cv2.setNumThreads(1)
160
+
161
+ # Should be the same as demo
162
+ # realsense exposure
163
+ env.realsense.set_exposure(exposure=120, gain=0)
164
+ # realsense white balance
165
+ env.realsense.set_white_balance(white_balance=5900)
166
+
167
+ print("Waiting for realsense")
168
+ time.sleep(1.0)
169
+
170
+ print("Warming up policy inference")
171
+ obs = env.get_obs()
172
+ with torch.no_grad():
173
+ policy.reset()
174
+ obs_dict_np = get_real_obs_dict(
175
+ env_obs=obs, shape_meta=cfg.task.shape_meta)
176
+ obs_dict = dict_apply(obs_dict_np,
177
+ lambda x: torch.from_numpy(x).unsqueeze(0).to(device))
178
+ result = policy.predict_action(obs_dict)
179
+ action = result['action'][0].detach().to('cpu').numpy()
180
+ assert action.shape[-1] == 2
181
+ del result
182
+
183
+ print('Ready!')
184
+ while True:
185
+ # ========= human control loop ==========
186
+ print("Human in control!")
187
+ state = env.get_robot_state()
188
+ target_pose = state['TargetTCPPose']
189
+ t_start = time.monotonic()
190
+ iter_idx = 0
191
+ while True:
192
+ # calculate timing
193
+ t_cycle_end = t_start + (iter_idx + 1) * dt
194
+ t_sample = t_cycle_end - command_latency
195
+ t_command_target = t_cycle_end + dt
196
+
197
+ # pump obs
198
+ obs = env.get_obs()
199
+
200
+ # visualize
201
+ episode_id = env.replay_buffer.n_episodes
202
+ vis_img = obs[f'camera_{vis_camera_idx}'][-1]
203
+ match_episode_id = episode_id
204
+ if match_episode is not None:
205
+ match_episode_id = match_episode
206
+ if match_episode_id in episode_first_frame_map:
207
+ match_img = episode_first_frame_map[match_episode_id]
208
+ ih, iw, _ = match_img.shape
209
+ oh, ow, _ = vis_img.shape
210
+ tf = get_image_transform(
211
+ input_res=(iw, ih),
212
+ output_res=(ow, oh),
213
+ bgr_to_rgb=False)
214
+ match_img = tf(match_img).astype(np.float32) / 255
215
+ vis_img = np.minimum(vis_img, match_img)
216
+
217
+ text = f'Episode: {episode_id}'
218
+ cv2.putText(
219
+ vis_img,
220
+ text,
221
+ (10,20),
222
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
223
+ fontScale=0.5,
224
+ thickness=1,
225
+ color=(255,255,255)
226
+ )
227
+ cv2.imshow('default', vis_img[...,::-1])
228
+ key_stroke = cv2.pollKey()
229
+ if key_stroke == ord('q'):
230
+ # Exit program
231
+ env.end_episode()
232
+ exit(0)
233
+ elif key_stroke == ord('c'):
234
+ # Exit human control loop
235
+ # hand control over to the policy
236
+ break
237
+
238
+ precise_wait(t_sample)
239
+ # get teleop command
240
+ sm_state = sm.get_motion_state_transformed()
241
+ # print(sm_state)
242
+ dpos = sm_state[:3] * (env.max_pos_speed / frequency)
243
+ drot_xyz = sm_state[3:] * (env.max_rot_speed / frequency)
244
+
245
+ if not sm.is_button_pressed(0):
246
+ # translation mode
247
+ drot_xyz[:] = 0
248
+ else:
249
+ dpos[:] = 0
250
+ if not sm.is_button_pressed(1):
251
+ # 2D translation mode
252
+ dpos[2] = 0
253
+
254
+ drot = st.Rotation.from_euler('xyz', drot_xyz)
255
+ target_pose[:3] += dpos
256
+ target_pose[3:] = (drot * st.Rotation.from_rotvec(
257
+ target_pose[3:])).as_rotvec()
258
+ # clip target pose
259
+ target_pose[:2] = np.clip(target_pose[:2], [0.25, -0.45], [0.77, 0.40])
260
+
261
+ # execute teleop command
262
+ env.exec_actions(
263
+ actions=[target_pose],
264
+ timestamps=[t_command_target-time.monotonic()+time.time()])
265
+ precise_wait(t_cycle_end)
266
+ iter_idx += 1
267
+
268
+ # ========== policy control loop ==============
269
+ try:
270
+ # start episode
271
+ policy.reset()
272
+ start_delay = 1.0
273
+ eval_t_start = time.time() + start_delay
274
+ t_start = time.monotonic() + start_delay
275
+ env.start_episode(eval_t_start)
276
+ # wait for 1/30 sec to get the closest frame actually
277
+ # reduces overall latency
278
+ frame_latency = 1/30
279
+ precise_wait(eval_t_start - frame_latency, time_func=time.time)
280
+ print("Started!")
281
+ iter_idx = 0
282
+ term_area_start_timestamp = float('inf')
283
+ perv_target_pose = None
284
+ while True:
285
+ # calculate timing
286
+ t_cycle_end = t_start + (iter_idx + steps_per_inference) * dt
287
+
288
+ # get obs
289
+ print('get_obs')
290
+ obs = env.get_obs()
291
+ obs_timestamps = obs['timestamp']
292
+ print(f'Obs latency {time.time() - obs_timestamps[-1]}')
293
+
294
+ # run inference
295
+ with torch.no_grad():
296
+ s = time.time()
297
+ obs_dict_np = get_real_obs_dict(
298
+ env_obs=obs, shape_meta=cfg.task.shape_meta)
299
+ obs_dict = dict_apply(obs_dict_np,
300
+ lambda x: torch.from_numpy(x).unsqueeze(0).to(device))
301
+ result = policy.predict_action(obs_dict)
302
+ # this action starts from the first obs step
303
+ action = result['action'][0].detach().to('cpu').numpy()
304
+ print('Inference latency:', time.time() - s)
305
+
306
+ # convert policy action to env actions
307
+ if delta_action:
308
+ assert len(action) == 1
309
+ if perv_target_pose is None:
310
+ perv_target_pose = obs['robot_eef_pose'][-1]
311
+ this_target_pose = perv_target_pose.copy()
312
+ this_target_pose[[0,1]] += action[-1]
313
+ perv_target_pose = this_target_pose
314
+ this_target_poses = np.expand_dims(this_target_pose, axis=0)
315
+ else:
316
+ this_target_poses = np.zeros((len(action), len(target_pose)), dtype=np.float64)
317
+ this_target_poses[:] = target_pose
318
+ this_target_poses[:,[0,1]] = action
319
+
320
+ # deal with timing
321
+ # the same step actions are always the target for
322
+ action_timestamps = (np.arange(len(action), dtype=np.float64) + action_offset
323
+ ) * dt + obs_timestamps[-1]
324
+ action_exec_latency = 0.01
325
+ curr_time = time.time()
326
+ is_new = action_timestamps > (curr_time + action_exec_latency)
327
+ if np.sum(is_new) == 0:
328
+ # exceeded time budget, still do something
329
+ this_target_poses = this_target_poses[[-1]]
330
+ # schedule on next available step
331
+ next_step_idx = int(np.ceil((curr_time - eval_t_start) / dt))
332
+ action_timestamp = eval_t_start + (next_step_idx) * dt
333
+ print('Over budget', action_timestamp - curr_time)
334
+ action_timestamps = np.array([action_timestamp])
335
+ else:
336
+ this_target_poses = this_target_poses[is_new]
337
+ action_timestamps = action_timestamps[is_new]
338
+
339
+ # clip actions
340
+ this_target_poses[:,:2] = np.clip(
341
+ this_target_poses[:,:2], [0.25, -0.45], [0.77, 0.40])
342
+
343
+ # execute actions
344
+ env.exec_actions(
345
+ actions=this_target_poses,
346
+ timestamps=action_timestamps
347
+ )
348
+ print(f"Submitted {len(this_target_poses)} steps of actions.")
349
+
350
+ # visualize
351
+ episode_id = env.replay_buffer.n_episodes
352
+ vis_img = obs[f'camera_{vis_camera_idx}'][-1]
353
+ text = 'Episode: {}, Time: {:.1f}'.format(
354
+ episode_id, time.monotonic() - t_start
355
+ )
356
+ cv2.putText(
357
+ vis_img,
358
+ text,
359
+ (10,20),
360
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
361
+ fontScale=0.5,
362
+ thickness=1,
363
+ color=(255,255,255)
364
+ )
365
+ cv2.imshow('default', vis_img[...,::-1])
366
+
367
+
368
+ key_stroke = cv2.pollKey()
369
+ if key_stroke == ord('s'):
370
+ # Stop episode
371
+ # Hand control back to human
372
+ env.end_episode()
373
+ print('Stopped.')
374
+ break
375
+
376
+ # auto termination
377
+ terminate = False
378
+ if time.monotonic() - t_start > max_duration:
379
+ terminate = True
380
+ print('Terminated by the timeout!')
381
+
382
+ term_pose = np.array([ 3.40948500e-01, 2.17721816e-01, 4.59076878e-02, 2.22014183e+00, -2.22184883e+00, -4.07186655e-04])
383
+ curr_pose = obs['robot_eef_pose'][-1]
384
+ dist = np.linalg.norm((curr_pose - term_pose)[:2], axis=-1)
385
+ if dist < 0.03:
386
+ # in termination area
387
+ curr_timestamp = obs['timestamp'][-1]
388
+ if term_area_start_timestamp > curr_timestamp:
389
+ term_area_start_timestamp = curr_timestamp
390
+ else:
391
+ term_area_time = curr_timestamp - term_area_start_timestamp
392
+ if term_area_time > 0.5:
393
+ terminate = True
394
+ print('Terminated by the policy!')
395
+ else:
396
+ # out of the area
397
+ term_area_start_timestamp = float('inf')
398
+
399
+ if terminate:
400
+ env.end_episode()
401
+ break
402
+
403
+ # wait for execution
404
+ precise_wait(t_cycle_end - frame_latency)
405
+ iter_idx += steps_per_inference
406
+
407
+ except KeyboardInterrupt:
408
+ print("Interrupted!")
409
+ # stop robot.
410
+ env.end_episode()
411
+
412
+ print("Stopped.")
413
+
414
+
415
+
416
+ # %%
417
+ if __name__ == '__main__':
418
+ main()
third_party/diffusion_policy/multirun_metrics.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import pathlib
3
+ import pandas as pd
4
+ import numpy as np
5
+ import numba
6
+ import click
7
+ import time
8
+ import collections
9
+ import json
10
+ import wandb
11
+ import yaml
12
+ import numbers
13
+ import scipy.ndimage as sn
14
+ from diffusion_policy.common.json_logger import read_json_log, JsonLogger
15
+ import logging
16
+
17
+ @numba.jit(nopython=True)
18
+ def get_indexed_window_average(
19
+ arr: np.ndarray, idxs: np.ndarray, window_size: int):
20
+ result = np.zeros(idxs.shape, dtype=arr.dtype)
21
+ length = arr.shape[0]
22
+ for i in range(len(idxs)):
23
+ idx = idxs[i]
24
+ start = max(idx - window_size, 0)
25
+ end = min(start + window_size, length)
26
+ result[i] = np.mean(arr[start:end])
27
+ return result
28
+
29
+
30
+ def compute_metrics(log_df: pd.DataFrame, key: str,
31
+ end_step: Optional[int]=None,
32
+ k_min_loss: int=10,
33
+ k_around_max: int=10,
34
+ max_k_window: int=10,
35
+ replace_slash: int=True,
36
+ ):
37
+ if key not in log_df:
38
+ return dict()
39
+
40
+ # prepare data
41
+ if end_step is not None:
42
+ log_df = log_df.iloc[:end_step]
43
+ is_key = ~pd.isnull(log_df[key])
44
+ is_key_idxs = is_key.index[is_key].to_numpy()
45
+ if len(is_key_idxs) == 0:
46
+ return dict()
47
+
48
+ key_data = log_df[key][is_key].to_numpy()
49
+ # after adding validation to workspace
50
+ # rollout happens at the last step of each epoch
51
+ # where the reported train_loss and val_loss
52
+ # are already the average for that epoch
53
+ train_loss = log_df['train_loss'][is_key].to_numpy()
54
+ val_loss = log_df['val_loss'][is_key].to_numpy()
55
+
56
+ result = dict()
57
+
58
+ log_key = key
59
+ if replace_slash:
60
+ log_key = key.replace('/', '_')
61
+ # max
62
+ max_value = np.max(key_data)
63
+ result['max/'+log_key] = max_value
64
+
65
+ # k_around_max
66
+ max_idx = np.argmax(key_data)
67
+ end = min(max_idx + k_around_max // 2, len(key_data))
68
+ start = max(end - k_around_max, 0)
69
+ k_around_max_value = np.mean(key_data[start:end])
70
+ result['k_around_max/'+log_key] = k_around_max_value
71
+
72
+ # max_k_window
73
+ k_window_value = sn.uniform_filter1d(key_data, size=max_k_window, axis=0, mode='nearest')
74
+ max_k_window_value = np.max(k_window_value)
75
+ result['max_k_window/'+log_key] = max_k_window_value
76
+
77
+ # min_train_loss
78
+ min_idx = np.argmin(train_loss)
79
+ min_train_loss_value = key_data[min_idx]
80
+ result['min_train_loss/'+log_key] = min_train_loss_value
81
+
82
+ # min_val_loss
83
+ min_idx = np.argmin(val_loss)
84
+ min_val_loss_value = key_data[min_idx]
85
+ result['min_val_loss/'+log_key] = min_val_loss_value
86
+
87
+ # k_min_train_loss
88
+ min_loss_idxs = np.argsort(train_loss)[:k_min_loss]
89
+ k_min_train_loss_value = np.mean(key_data[min_loss_idxs])
90
+ result['k_min_train_loss/'+log_key] = k_min_train_loss_value
91
+
92
+ # k_min_val_loss
93
+ min_loss_idxs = np.argsort(val_loss)[:k_min_loss]
94
+ k_min_val_loss_value = np.mean(key_data[min_loss_idxs])
95
+ result['k_min_val_loss/'+log_key] = k_min_val_loss_value
96
+
97
+ # last
98
+ result['last/'+log_key] = key_data[-1]
99
+
100
+ # global step for visualization
101
+ result['metric_global_step/'+log_key] = is_key_idxs[-1]
102
+ return result
103
+
104
+
105
+ def compute_metrics_agg(
106
+ log_dfs: List[pd.DataFrame],
107
+ key: str, end_step:int,
108
+ **kwargs):
109
+
110
+ # compute metrics
111
+ results = collections.defaultdict(list)
112
+ for log_df in log_dfs:
113
+ result = compute_metrics(log_df, key=key, end_step=end_step, **kwargs)
114
+ for k, v in result.items():
115
+ results[k].append(v)
116
+ # agg
117
+ agg_result = dict()
118
+ for k, v in results.items():
119
+ value = np.mean(v)
120
+ if k.startswith('metric_global_step'):
121
+ value = int(value)
122
+ agg_result[k] = value
123
+ return agg_result
124
+
125
+
126
+ @click.command()
127
+ @click.option('--input', '-i', required=True, help='Root logging dir, contains train_* dirs')
128
+ @click.option('--key', '-k', multiple=True, default=['test/mean_score'])
129
+ @click.option('--interval', default=10, type=float)
130
+ @click.option('--replace_slash', default=True, type=bool)
131
+ @click.option('--index_key', '-ik', default='global_step')
132
+ @click.option('--use_wandb', '-w', is_flag=True, default=False)
133
+ @click.option('--project', default=None)
134
+ @click.option('--name', default=None)
135
+ @click.option('--id', default=None)
136
+ @click.option('--group', default=None)
137
+ def main(
138
+ input,
139
+ key,
140
+ interval,
141
+ replace_slash,
142
+ index_key,
143
+ use_wandb,
144
+ # wandb args
145
+ project,
146
+ name,
147
+ id,
148
+ group):
149
+ root_dir = pathlib.Path(input)
150
+ assert root_dir.is_dir()
151
+ metrics_dir = root_dir.joinpath('metrics')
152
+ metrics_dir.mkdir(exist_ok=True)
153
+
154
+ logging.basicConfig(
155
+ level=logging.INFO,
156
+ format="%(asctime)s [%(levelname)s] %(message)s",
157
+ handlers=[
158
+ logging.FileHandler(str(metrics_dir.joinpath("metrics.log"))),
159
+ logging.StreamHandler()
160
+ ]
161
+ )
162
+
163
+ train_dirs = list(root_dir.glob('train_*'))
164
+ log_files = [x.joinpath('logs.json.txt') for x in train_dirs]
165
+ logging.info("Monitor waiting for log files!")
166
+ while True:
167
+ # wait for files to show up
168
+ files_exist = True
169
+ for log_file in log_files:
170
+ if not log_file.is_file():
171
+ files_exist = False
172
+ if files_exist:
173
+ break
174
+ time.sleep(1.0)
175
+ logging.info("All log files ready!")
176
+
177
+ # init path
178
+ metric_log_path = metrics_dir.joinpath('logs.json.txt')
179
+ metric_path = metrics_dir.joinpath('metrics.json')
180
+ config_path = root_dir.joinpath('config.yaml')
181
+
182
+ # load config
183
+ config = yaml.safe_load(config_path.open('r'))
184
+
185
+ # init wandb
186
+ wandb_run = None
187
+ if use_wandb:
188
+ wandb_kwargs = config['logging']
189
+ if project is not None:
190
+ wandb_kwargs['project'] = project
191
+ if id is not None:
192
+ wandb_kwargs['id'] = id
193
+ if name is not None:
194
+ wandb_kwargs['name'] = name
195
+ if group is not None:
196
+ wandb_kwargs['group'] = group
197
+ wandb_kwargs['resume'] = True
198
+ wandb_run = wandb.init(
199
+ dir=str(metrics_dir),
200
+ config=config,
201
+ # auto-resume run, automatically load id
202
+ # as long as using the same dir.
203
+ # https://docs.wandb.ai/guides/track/advanced/resuming#resuming-guidance
204
+ **wandb_kwargs
205
+ )
206
+ wandb.config.update(
207
+ {
208
+ "output_dir": str(root_dir),
209
+ }
210
+ )
211
+
212
+ with JsonLogger(metric_log_path) as json_logger:
213
+ last_log = json_logger.get_last_log()
214
+ while True:
215
+ # read json files
216
+ log_dfs = [read_json_log(str(x), required_keys=key) for x in log_files]
217
+
218
+ # previously logged data point
219
+ last_log_idx = -1
220
+ if last_log is not None:
221
+ last_log_idx = log_dfs[0].index[log_dfs[0][index_key] <= last_log[index_key]][-1]
222
+
223
+ start_idx = last_log_idx + 1
224
+ # last idx where we have a data point from all logs
225
+ end_idx = min(*[len(x) for x in log_dfs])
226
+
227
+ # log every position
228
+ for this_idx in range(start_idx, end_idx):
229
+ # compute metrics
230
+ all_metrics = dict()
231
+ global_step = log_dfs[0]['global_step'][this_idx]
232
+ epoch = log_dfs[0]['epoch'][this_idx]
233
+ all_metrics['global_step'] = global_step
234
+ all_metrics['epoch'] = epoch
235
+ for k in key:
236
+ metrics = compute_metrics_agg(
237
+ log_dfs=log_dfs, key=k, end_step=this_idx+1,
238
+ replace_slash=replace_slash)
239
+ all_metrics.update(metrics)
240
+
241
+ # sanitize metrics
242
+ old_metrics = all_metrics
243
+ all_metrics = dict()
244
+ for k, v in old_metrics.items():
245
+ if isinstance(v, numbers.Integral):
246
+ all_metrics[k] = int(v)
247
+ elif isinstance(v, numbers.Number):
248
+ all_metrics[k] = float(v)
249
+
250
+ has_update = all_metrics != last_log
251
+ if has_update:
252
+ last_log = all_metrics
253
+ json_logger.log(all_metrics)
254
+
255
+ with metric_path.open('w') as f:
256
+ json.dump(all_metrics, f, sort_keys=True, indent=2)
257
+
258
+ if wandb_run is not None:
259
+ wandb_run.log(all_metrics, step=all_metrics[index_key])
260
+
261
+ logging.info(f"Metrics logged at step {all_metrics[index_key]}")
262
+
263
+ time.sleep(interval)
264
+
265
+
266
+ if __name__ == "__main__":
267
+ main()
third_party/diffusion_policy/pyrightconfig.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "exclude": [
3
+ "data/**",
4
+ "data_local/**",
5
+ "outputs/**"
6
+ ]
7
+ }
third_party/diffusion_policy/ray_exec.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ Training:
4
+ python train.py --config-name=train_diffusion_lowdim_workspace -- logger.mode=online
5
+ """
6
+ import os
7
+ import ray
8
+ import click
9
+
10
+ def worker_fn(command_args, data_src=None, unbuffer_python=False, use_shell=False):
11
+ import os
12
+ import subprocess
13
+ import signal
14
+ import time
15
+
16
+ # setup data symlink
17
+ if data_src is not None:
18
+ cwd = os.getcwd()
19
+ src = os.path.expanduser(data_src)
20
+ dst = os.path.join(cwd, 'data')
21
+ try:
22
+ os.symlink(src=src, dst=dst)
23
+ except FileExistsError:
24
+ # it's fine if it already exists
25
+ pass
26
+
27
+ # run command
28
+ process_env = os.environ.copy()
29
+ if unbuffer_python:
30
+ # disable stdout/stderr buffering for subprocess (if python)
31
+ # to remove latency between print statement and receiving printed result
32
+ process_env['PYTHONUNBUFFERED'] = 'TRUE'
33
+
34
+ # ray worker masks out Ctrl-C signal (ie SIGINT)
35
+ # here we unblock this signal for the child process
36
+ def preexec_function():
37
+ import signal
38
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGINT})
39
+
40
+ if use_shell:
41
+ command_args = ' '.join(command_args)
42
+
43
+ # stdout passthrough to ray worker, which is then passed to ray driver
44
+ process = subprocess.Popen(
45
+ args=command_args,
46
+ env=process_env,
47
+ preexec_fn=preexec_function,
48
+ shell=use_shell)
49
+
50
+ while process.poll() is None:
51
+ try:
52
+ # sleep to ensure that monitor thread can acquire gil
53
+ # and raise KeyboardInterrupt here.
54
+ time.sleep(0.01)
55
+ except KeyboardInterrupt:
56
+ process.send_signal(signal.SIGINT)
57
+ print('SIGINT sent to subprocess')
58
+ except Exception as e:
59
+ process.terminate()
60
+ raise e
61
+
62
+ if process.returncode not in (0, -2):
63
+ print("Failed execution!")
64
+ raise RuntimeError("Failed execution.")
65
+ return process.returncode
66
+
67
+
68
+ @click.command()
69
+ @click.option('--ray_address', '-ra', default='auto')
70
+ @click.option('--num_cpus', '-nc', default=7, type=float)
71
+ @click.option('--num_gpus', '-ng', default=1, type=float)
72
+ @click.option('--max_retries', '-mr', default=0, type=int)
73
+ @click.option('--data_src', '-d', default='./data', type=str)
74
+ @click.option('--unbuffer_python', '-u', is_flag=True, default=False)
75
+ @click.argument('command_args', nargs=-1, type=str)
76
+ def main(ray_address,
77
+ num_cpus, num_gpus, max_retries,
78
+ data_src, unbuffer_python,
79
+ command_args):
80
+ # expand path
81
+ if data_src is not None:
82
+ data_src = os.path.abspath(os.path.expanduser(data_src))
83
+
84
+ # init ray
85
+ root_dir = os.path.dirname(__file__)
86
+ runtime_env = {
87
+ 'working_dir': root_dir,
88
+ 'excludes': ['.git']
89
+ }
90
+ ray.init(
91
+ address=ray_address,
92
+ runtime_env=runtime_env
93
+ )
94
+ # remote worker func
95
+ worker_ray = ray.remote(worker_fn).options(
96
+ num_cpus=num_cpus,
97
+ num_gpus=num_gpus,
98
+ max_retries=max_retries,
99
+ # resources=resources,
100
+ retry_exceptions=True
101
+ )
102
+ # run
103
+ task_ref = worker_ray.remote(command_args, data_src, unbuffer_python)
104
+
105
+ try:
106
+ # normal case
107
+ result = ray.get(task_ref)
108
+ print('Return code: ', result)
109
+ except KeyboardInterrupt:
110
+ # a KeyboardInterrupt will be raised in worker
111
+ ray.cancel(task_ref, force=False)
112
+ result = ray.get(task_ref)
113
+ print('Return code: ', result)
114
+ except Exception as e:
115
+ # worker will be terminated
116
+ ray.cancel(task_ref, force=True)
117
+ raise e
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main()
third_party/diffusion_policy/ray_train_multirun.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Start local ray cluster
3
+ (robodiff)$ export CUDA_VISIBLE_DEVICES=0,1,2 # select GPUs to be managed by the ray cluster
4
+ (robodiff)$ ray start --head --num-gpus=3
5
+
6
+ Training:
7
+ python ray_train_multirun.py --config-name=train_diffusion_unet_lowdim_workspace --seeds=42,43,44 --monitor_key=test/mean_score -- logger.mode=online training.eval_first=True
8
+ """
9
+ import os
10
+ import ray
11
+ import click
12
+ import hydra
13
+ import yaml
14
+ import wandb
15
+ import pathlib
16
+ import collections
17
+ from pprint import pprint
18
+ from omegaconf import OmegaConf
19
+ from ray_exec import worker_fn
20
+ from ray.util.placement_group import (
21
+ placement_group,
22
+ )
23
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
24
+
25
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
26
+
27
+ @click.command()
28
+ @click.option('--config-name', '-cn', required=True, type=str)
29
+ @click.option('--config-dir', '-cd', default=None, type=str)
30
+ @click.option('--seeds', '-s', default='42,43,44', type=str)
31
+ @click.option('--monitor_key', '-k', multiple=True, default=['test/mean_score'])
32
+ @click.option('--ray_address', '-ra', default='auto')
33
+ @click.option('--num_cpus', '-nc', default=7, type=float)
34
+ @click.option('--num_gpus', '-ng', default=1, type=float)
35
+ @click.option('--max_retries', '-mr', default=0, type=int)
36
+ @click.option('--monitor_max_retires', default=3, type=int)
37
+ @click.option('--data_src', '-d', default='./data', type=str)
38
+ @click.option('--unbuffer_python', '-u', is_flag=True, default=False)
39
+ @click.option('--single_node', '-sn', is_flag=True, default=False, help='run all experiments on a single machine')
40
+ @click.argument('command_args', nargs=-1, type=str)
41
+ def main(config_name, config_dir, seeds, monitor_key, ray_address,
42
+ num_cpus, num_gpus, max_retries, monitor_max_retires,
43
+ data_src, unbuffer_python,
44
+ single_node, command_args):
45
+ # parse args
46
+ seeds = [int(x) for x in seeds.split(',')]
47
+ # expand path
48
+ if data_src is not None:
49
+ data_src = os.path.abspath(os.path.expanduser(data_src))
50
+
51
+ # initialize hydra
52
+ if config_dir is None:
53
+ config_path_abs = pathlib.Path(__file__).parent.joinpath(
54
+ 'diffusion_policy','config')
55
+ config_path_rel = str(config_path_abs.relative_to(pathlib.Path.cwd()))
56
+ else:
57
+ config_path_rel = config_dir
58
+
59
+ run_command_args = list()
60
+ monitor_command_args = list()
61
+ with hydra.initialize(
62
+ version_base=None,
63
+ config_path=config_path_rel):
64
+
65
+ # generate raw config
66
+ cfg = hydra.compose(
67
+ config_name=config_name,
68
+ overrides=command_args)
69
+ OmegaConf.resolve(cfg)
70
+
71
+ # manually create output dir
72
+ output_dir = pathlib.Path(cfg.multi_run.run_dir)
73
+ output_dir.mkdir(parents=True, exist_ok=False)
74
+ config_path = output_dir.joinpath('config.yaml')
75
+ print(output_dir)
76
+
77
+ # save current config
78
+ yaml.dump(OmegaConf.to_container(cfg, resolve=True),
79
+ config_path.open('w'), default_flow_style=False)
80
+
81
+ # wandb
82
+ wandb_group_id = wandb.util.generate_id()
83
+ name_base = cfg.multi_run.wandb_name_base
84
+
85
+ # create monitor command args
86
+ monitor_command_args = [
87
+ 'python',
88
+ 'multirun_metrics.py',
89
+ '--input', str(output_dir),
90
+ '--use_wandb',
91
+ '--project', 'diffusion_policy_metrics',
92
+ '--group', wandb_group_id
93
+ ]
94
+ for k in monitor_key:
95
+ monitor_command_args.extend([
96
+ '--key', k
97
+ ])
98
+
99
+ # generate command args
100
+ run_command_args = list()
101
+ for i, seed in enumerate(seeds):
102
+ test_start_seed = (seed + 1) * 100000
103
+ this_output_dir = output_dir.joinpath(f'train_{i}')
104
+ this_output_dir.mkdir()
105
+ wandb_name = name_base + f'_train_{i}'
106
+ wandb_run_id = wandb_group_id + f'_train_{i}'
107
+
108
+ this_command_args = [
109
+ 'python',
110
+ 'train.py',
111
+ '--config-name='+config_name,
112
+ '--config-dir='+config_path_rel
113
+ ]
114
+
115
+ this_command_args.extend(command_args)
116
+ this_command_args.extend([
117
+ f'training.seed={seed}',
118
+ f'task.env_runner.test_start_seed={test_start_seed}',
119
+ f'logging.name={wandb_name}',
120
+ f'logging.id={wandb_run_id}',
121
+ f'logging.group={wandb_group_id}',
122
+ f'hydra.run.dir={this_output_dir}'
123
+ ])
124
+ run_command_args.append(this_command_args)
125
+
126
+ # init ray
127
+ root_dir = os.path.dirname(__file__)
128
+ runtime_env = {
129
+ 'working_dir': root_dir,
130
+ 'excludes': ['.git'],
131
+ 'pip': ['dm-control==1.0.9']
132
+ }
133
+ ray.init(
134
+ address=ray_address,
135
+ runtime_env=runtime_env
136
+ )
137
+
138
+ # create resources for train
139
+ train_resources = dict()
140
+
141
+ train_bundle = dict(train_resources)
142
+ train_bundle['CPU'] = num_cpus
143
+ train_bundle['GPU'] = num_gpus
144
+
145
+ # create resources for monitor
146
+ monitor_resources = dict()
147
+ monitor_resources['CPU'] = 1
148
+
149
+ monitor_bundle = dict(monitor_resources)
150
+
151
+ # aggregate bundle
152
+ bundle = collections.defaultdict(lambda:0)
153
+ n_train_bundles = 1
154
+ if single_node:
155
+ n_train_bundles = len(seeds)
156
+ for _ in range(n_train_bundles):
157
+ for k, v in train_bundle.items():
158
+ bundle[k] += v
159
+ for k, v in monitor_bundle.items():
160
+ bundle[k] += v
161
+ bundle = dict(bundle)
162
+
163
+ # create placement group
164
+ print("Creating placement group with resources:")
165
+ pprint(bundle)
166
+ pg = placement_group([bundle])
167
+
168
+ # run
169
+ task_name_map = dict()
170
+ task_refs = list()
171
+ for i, this_command_args in enumerate(run_command_args):
172
+ if single_node or i == (len(run_command_args) - 1):
173
+ print(f'Training worker {i} with placement group.')
174
+ ray.get(pg.ready())
175
+ print("Placement Group created!")
176
+ worker_ray = ray.remote(worker_fn).options(
177
+ num_cpus=num_cpus,
178
+ num_gpus=num_gpus,
179
+ max_retries=max_retries,
180
+ resources=train_resources,
181
+ retry_exceptions=True,
182
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
183
+ placement_group=pg)
184
+ )
185
+ else:
186
+ print(f'Training worker {i} without placement group.')
187
+ worker_ray = ray.remote(worker_fn).options(
188
+ num_cpus=num_cpus,
189
+ num_gpus=num_gpus,
190
+ max_retries=max_retries,
191
+ resources=train_resources,
192
+ retry_exceptions=True,
193
+ )
194
+ task_ref = worker_ray.remote(
195
+ this_command_args, data_src, unbuffer_python)
196
+ task_refs.append(task_ref)
197
+ task_name_map[task_ref] = f'train_{i}'
198
+
199
+ # monitor worker is always packed on the same node
200
+ # as training worker 0
201
+ ray.get(pg.ready())
202
+ monitor_worker_ray = ray.remote(worker_fn).options(
203
+ num_cpus=1,
204
+ num_gpus=0,
205
+ max_retries=monitor_max_retires,
206
+ # resources=monitor_resources,
207
+ retry_exceptions=True,
208
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
209
+ placement_group=pg)
210
+ )
211
+ monitor_ref = monitor_worker_ray.remote(
212
+ monitor_command_args, data_src, unbuffer_python)
213
+ task_name_map[monitor_ref] = 'metrics'
214
+
215
+ try:
216
+ # normal case
217
+ ready_refs = list()
218
+ rest_refs = task_refs
219
+ while len(ready_refs) < len(task_refs):
220
+ this_ready_refs, rest_refs = ray.wait(rest_refs,
221
+ num_returns=1, timeout=None, fetch_local=True)
222
+ cancel_other_tasks = False
223
+ for ref in this_ready_refs:
224
+ task_name = task_name_map[ref]
225
+ try:
226
+ result = ray.get(ref)
227
+ print(f"Task {task_name} finished with result: {result}")
228
+ except KeyboardInterrupt as e:
229
+ # skip to outer try catch
230
+ raise KeyboardInterrupt
231
+ except Exception as e:
232
+ print(f"Task {task_name} raised exception: {e}")
233
+ this_cancel_other_tasks = True
234
+ if isinstance(e, ray.exceptions.RayTaskError):
235
+ if isinstance(e.cause, ray.exceptions.TaskCancelledError):
236
+ this_cancel_other_tasks = False
237
+ cancel_other_tasks = cancel_other_tasks or this_cancel_other_tasks
238
+ ready_refs.append(ref)
239
+ if cancel_other_tasks:
240
+ print('Exception! Cancelling all other tasks.')
241
+ # cancel all other refs
242
+ for _ref in rest_refs:
243
+ ray.cancel(_ref, force=False)
244
+ print("Training tasks done.")
245
+ ray.cancel(monitor_ref, force=False)
246
+ except KeyboardInterrupt:
247
+ print('KeyboardInterrupt received in the driver.')
248
+ # a KeyboardInterrupt will be raised in worker
249
+ _ = [ray.cancel(x, force=False) for x in task_refs + [monitor_ref]]
250
+ print('KeyboardInterrupt sent to workers.')
251
+ except Exception as e:
252
+ # worker will be terminated
253
+ _ = [ray.cancel(x, force=True) for x in task_refs + [monitor_ref]]
254
+ raise e
255
+
256
+ for ref in task_refs + [monitor_ref]:
257
+ task_name = task_name_map[ref]
258
+ try:
259
+ result = ray.get(ref)
260
+ print(f"Task {task_name} finished with result: {result}")
261
+ except KeyboardInterrupt as e:
262
+ # force kill everything.
263
+ print("Force killing all workers")
264
+ _ = [ray.cancel(x, force=True) for x in task_refs]
265
+ ray.cancel(monitor_ref, force=True)
266
+ except Exception as e:
267
+ print(f"Task {task_name} raised exception: {e}")
268
+
269
+
270
+ if __name__ == "__main__":
271
+ main()
third_party/diffusion_policy/setup.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name = 'diffusion_policy',
5
+ packages = find_packages(),
6
+ )
third_party/diffusion_policy/tests/test_block_pushing.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ from diffusion_policy.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
9
+ from gym.wrappers import FlattenObservation
10
+ from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
11
+ from diffusion_policy.gym_util.video_wrapper import VideoWrapper
12
+
13
+ def test():
14
+ env = MultiStepWrapper(
15
+ VideoWrapper(
16
+ FlattenObservation(
17
+ BlockPushMultimodal()
18
+ ),
19
+ enabled=True,
20
+ steps_per_render=2
21
+ ),
22
+ n_obs_steps=2,
23
+ n_action_steps=8,
24
+ max_episode_steps=16
25
+ )
26
+ env = BlockPushMultimodal()
27
+ obs = env.reset()
28
+ import pdb; pdb.set_trace()
29
+
30
+ env = FlattenObservation(BlockPushMultimodal())
31
+ obs = env.reset()
32
+ action = env.action_space.sample()
33
+ next_obs, reward, done, info = env.step(action)
34
+ print(obs[8:10] + action - next_obs[8:10])
35
+ import pdb; pdb.set_trace()
36
+
37
+ for i in range(3):
38
+ obs, reward, done, info = env.step(env.action_space.sample())
39
+ img = env.render()
40
+ import pdb; pdb.set_trace()
41
+ print("Done!", done)
42
+
43
+ if __name__ == '__main__':
44
+ test()
third_party/diffusion_policy/tests/test_cv2_util.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import numpy as np
9
+ from diffusion_policy.common.cv2_util import get_image_transform
10
+
11
+
12
+ def test():
13
+ tf = get_image_transform((1280,720), (640,480), bgr_to_rgb=False)
14
+ in_img = np.zeros((720,1280,3), dtype=np.uint8)
15
+ out_img = tf(in_img)
16
+ # print(out_img.shape)
17
+ assert out_img.shape == (480,640,3)
18
+
19
+ # import pdb; pdb.set_trace()
20
+
21
+ if __name__ == '__main__':
22
+ test()
third_party/diffusion_policy/tests/test_multi_realsense.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import cv2
9
+ import json
10
+ import time
11
+ import numpy as np
12
+ from diffusion_policy.real_world.multi_realsense import MultiRealsense
13
+ from diffusion_policy.real_world.video_recorder import VideoRecorder
14
+
15
+ def test():
16
+ config = json.load(open('/home/cchi/dev/diffusion_policy/diffusion_policy/real_world/realsense_config/415_high_accuracy_mode.json', 'r'))
17
+
18
+ def transform(data):
19
+ color = data['color']
20
+ h,w,_ = color.shape
21
+ factor = 4
22
+ color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA)
23
+ # color = color[:,140:500]
24
+ data['color'] = color
25
+ return data
26
+
27
+ from diffusion_policy.common.cv2_util import get_image_transform
28
+ color_transform = get_image_transform(
29
+ input_res=(1280,720),
30
+ output_res=(640,480),
31
+ bgr_to_rgb=False)
32
+ def transform(data):
33
+ data['color'] = color_transform(data['color'])
34
+ return data
35
+
36
+ # one thread per camera
37
+ video_recorder = VideoRecorder.create_h264(
38
+ fps=30,
39
+ codec='h264',
40
+ thread_type='FRAME'
41
+ )
42
+
43
+ with MultiRealsense(
44
+ resolution=(1280,720),
45
+ capture_fps=30,
46
+ record_fps=15,
47
+ enable_color=True,
48
+ # advanced_mode_config=config,
49
+ transform=transform,
50
+ # recording_transform=transform,
51
+ # video_recorder=video_recorder,
52
+ verbose=True
53
+ ) as realsense:
54
+ realsense.set_exposure(exposure=150, gain=5)
55
+ intr = realsense.get_intrinsics()
56
+ print(intr)
57
+
58
+ video_path = 'data_local/test'
59
+ rec_start_time = time.time() + 1
60
+ realsense.start_recording(video_path, start_time=rec_start_time)
61
+ realsense.restart_put(rec_start_time)
62
+
63
+ out = None
64
+ vis_img = None
65
+ while True:
66
+ out = realsense.get(out=out)
67
+
68
+ # bgr = out['color']
69
+ # print(bgr.shape)
70
+ # vis_img = np.concatenate(list(bgr), axis=0, out=vis_img)
71
+ # cv2.imshow('default', vis_img)
72
+ # key = cv2.pollKey()
73
+ # if key == ord('q'):
74
+ # break
75
+
76
+ time.sleep(1/60)
77
+ if time.time() > (rec_start_time + 20.0):
78
+ break
79
+
80
+
81
+ if __name__ == "__main__":
82
+ test()
third_party/diffusion_policy/tests/test_pose_trajectory_interpolator.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import scipy.interpolate as si
11
+ import scipy.spatial.transform as st
12
+ from diffusion_policy.common.pose_trajectory_interpolator import (
13
+ rotation_distance,
14
+ pose_distance,
15
+ PoseTrajectoryInterpolator)
16
+
17
+
18
+ def test_rotation_distance():
19
+ def rotation_distance_align(a: st.Rotation, b: st.Rotation) -> float:
20
+ return st.Rotation.align_vectors(b.as_matrix().T, a.as_matrix().T)[0].magnitude()
21
+
22
+ for i in range(10000):
23
+ a = st.Rotation.from_euler('xyz', np.random.uniform(-7,7,size=3))
24
+ b = st.Rotation.from_euler('xyz', np.random.uniform(-7,7,size=3))
25
+ x = rotation_distance(a, b)
26
+ y = rotation_distance_align(a, b)
27
+ assert abs(x-y) < 1e-7
28
+
29
+ def test_pose_trajectory_interpolator():
30
+ t = np.linspace(-1,5,100)
31
+ interp = PoseTrajectoryInterpolator(
32
+ [0,1,3],
33
+ np.zeros((3,6))
34
+ )
35
+ times = interp.times
36
+ poses = interp.poses
37
+
38
+ trimmed_interp = interp.trim(-1,4)
39
+ assert len(trimmed_interp.times) == 5
40
+ trimmed_interp(t)
41
+
42
+ trimmed_interp = interp.trim(-1,4)
43
+ assert len(trimmed_interp.times) == 5
44
+ trimmed_interp(t)
45
+
46
+ trimmed_interp = interp.trim(0.5, 3.5)
47
+ assert len(trimmed_interp.times) == 4
48
+ trimmed_interp(t)
49
+
50
+ trimmed_interp = interp.trim(0.5, 2.5)
51
+ assert len(trimmed_interp.times) == 3
52
+ trimmed_interp(t)
53
+
54
+ trimmed_interp = interp.trim(0.5, 1.5)
55
+ assert len(trimmed_interp.times) == 3
56
+ trimmed_interp(t)
57
+
58
+ trimmed_interp = interp.trim(1.2, 1.5)
59
+ assert len(trimmed_interp.times) == 2
60
+ trimmed_interp(t)
61
+
62
+ trimmed_interp = interp.trim(1.3, 1.3)
63
+ assert len(trimmed_interp.times) == 1
64
+ trimmed_interp(t)
65
+
66
+ # import pdb; pdb.set_trace()
67
+
68
+ def test_add_waypoint():
69
+ # fuzz testing
70
+ for i in tqdm(range(10000)):
71
+ rng = np.random.default_rng(i)
72
+ n_waypoints = rng.integers(1, 5)
73
+ waypoint_times = np.sort(rng.uniform(0, 1, size=n_waypoints))
74
+ last_waypoint_time = waypoint_times[-1]
75
+ insert_time = rng.uniform(-0.1, 1.1)
76
+ curr_time = rng.uniform(-0.1, 1.1)
77
+ max_pos_speed = rng.poisson(3) + 1e-3
78
+ max_rot_speed = rng.poisson(3) + 1e-3
79
+ waypoint_poses = rng.normal(0, 3, size=(n_waypoints, 6))
80
+ new_pose = rng.normal(0, 3, size=6)
81
+
82
+ if rng.random() < 0.1:
83
+ last_waypoint_time = None
84
+ if rng.random() < 0.1:
85
+ curr_time = None
86
+
87
+ interp = PoseTrajectoryInterpolator(
88
+ times=waypoint_times,
89
+ poses=waypoint_poses)
90
+ new_interp = interp.add_waypoint(
91
+ pose=new_pose,
92
+ time=insert_time,
93
+ max_pos_speed=max_pos_speed,
94
+ max_rot_speed=max_rot_speed,
95
+ curr_time=curr_time,
96
+ last_waypoint_time=last_waypoint_time
97
+ )
98
+
99
+ def test_drive_to_waypoint():
100
+ # fuzz testing
101
+ for i in tqdm(range(10000)):
102
+ rng = np.random.default_rng(i)
103
+ n_waypoints = rng.integers(1, 5)
104
+ waypoint_times = np.sort(rng.uniform(0, 1, size=n_waypoints))
105
+ insert_time = rng.uniform(-0.1, 1.1)
106
+ curr_time = rng.uniform(-0.1, 1.1)
107
+ max_pos_speed = rng.poisson(3) + 1e-3
108
+ max_rot_speed = rng.poisson(3) + 1e-3
109
+ waypoint_poses = rng.normal(0, 3, size=(n_waypoints, 6))
110
+ new_pose = rng.normal(0, 3, size=6)
111
+
112
+ interp = PoseTrajectoryInterpolator(
113
+ times=waypoint_times,
114
+ poses=waypoint_poses)
115
+ new_interp = interp.drive_to_waypoint(
116
+ pose=new_pose,
117
+ time=insert_time,
118
+ curr_time=curr_time,
119
+ max_pos_speed=max_pos_speed,
120
+ max_rot_speed=max_rot_speed
121
+ )
122
+
123
+
124
+
125
+ if __name__ == '__main__':
126
+ test_drive_to_waypoint()
third_party/diffusion_policy/tests/test_precise_sleep.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import time
9
+ import numpy as np
10
+ from diffusion_policy.common.precise_sleep import precise_sleep, precise_wait
11
+
12
+
13
+ def test_sleep():
14
+ dt = 0.1
15
+ tol = 1e-3
16
+ time_samples = list()
17
+ for i in range(100):
18
+ precise_sleep(dt)
19
+ # time.sleep(dt)
20
+ time_samples.append(time.monotonic())
21
+ time_deltas = np.diff(time_samples)
22
+
23
+ from matplotlib import pyplot as plt
24
+ plt.plot(time_deltas)
25
+ plt.ylim((dt-tol,dt+tol))
26
+
27
+
28
+ def test_wait():
29
+ dt = 0.1
30
+ tol = 1e-3
31
+ errors = list()
32
+ t_start = time.monotonic()
33
+ for i in range(1,100):
34
+ t_end_desired = t_start + i * dt
35
+ time.sleep(t_end_desired - time.monotonic())
36
+ t_end = time.monotonic()
37
+ errors.append(t_end - t_end_desired)
38
+
39
+ new_errors = list()
40
+ t_start = time.monotonic()
41
+ for i in range(1,100):
42
+ t_end_desired = t_start + i * dt
43
+ precise_wait(t_end_desired)
44
+ t_end = time.monotonic()
45
+ new_errors.append(t_end - t_end_desired)
46
+
47
+ from matplotlib import pyplot as plt
48
+ plt.plot(errors, label='time.sleep')
49
+ plt.plot(new_errors, label='sleep/spin hybrid')
50
+ plt.ylim((-tol,+tol))
51
+ plt.title('0.1 sec sleep error')
52
+ plt.legend()
53
+
54
+
55
+ if __name__ == '__main__':
56
+ test_sleep()
third_party/diffusion_policy/tests/test_replay_buffer.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import zarr
9
+ from diffusion_policy.common.replay_buffer import ReplayBuffer
10
+
11
+ def test():
12
+ import numpy as np
13
+ buff = ReplayBuffer.create_empty_numpy()
14
+ buff.add_episode({
15
+ 'obs': np.zeros((100,10), dtype=np.float16)
16
+ })
17
+ buff.add_episode({
18
+ 'obs': np.ones((50,10)),
19
+ 'action': np.ones((50,2))
20
+ })
21
+ # buff.rechunk(256)
22
+ obs = buff.get_episode(0)
23
+
24
+ import numpy as np
25
+ buff = ReplayBuffer.create_empty_zarr()
26
+ buff.add_episode({
27
+ 'obs': np.zeros((100,10), dtype=np.float16)
28
+ })
29
+ buff.add_episode({
30
+ 'obs': np.ones((50,10)),
31
+ 'action': np.ones((50,2))
32
+ })
33
+ obs = buff.get_episode(0)
34
+ buff.set_chunks({
35
+ 'obs': (100,10),
36
+ 'action': (100,2)
37
+ })
38
+
39
+
40
+ def test_real():
41
+ import os
42
+ dist_group = zarr.open(
43
+ os.path.expanduser('~/dev/diffusion_policy/data/pusht/pusht_cchi_v2.zarr'), 'r')
44
+
45
+ buff = ReplayBuffer.create_empty_numpy()
46
+ key, group = next(iter(dist_group.items()))
47
+ for key, group in dist_group.items():
48
+ buff.add_episode(group)
49
+
50
+ # out_path = os.path.expanduser('~/dev/diffusion_policy/data/pusht_cchi2_v2_replay.zarr')
51
+ out_path = os.path.expanduser('~/dev/diffusion_policy/data/test.zarr')
52
+ out_store = zarr.DirectoryStore(out_path)
53
+ buff.save_to_store(out_store)
54
+
55
+ buff = ReplayBuffer.copy_from_path(out_path, store=zarr.MemoryStore())
56
+ buff.pop_episode()
57
+
58
+
59
+ def test_pop():
60
+ buff = ReplayBuffer.create_from_path(
61
+ '/home/chengchi/dev/diffusion_policy/data/pusht_cchi_v3_replay.zarr',
62
+ mode='rw')
third_party/diffusion_policy/tests/test_ring_buffer.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import time
9
+ import numpy as np
10
+ import multiprocessing as mp
11
+ from multiprocessing.managers import SharedMemoryManager
12
+ from diffusion_policy.shared_memory.shared_memory_ring_buffer import (
13
+ SharedMemoryRingBuffer,
14
+ SharedAtomicCounter)
15
+
16
+
17
+ def test():
18
+ shm_manager = SharedMemoryManager()
19
+ shm_manager.start()
20
+ ring_buffer = SharedMemoryRingBuffer.create_from_examples(
21
+ shm_manager,
22
+ {'timestamp': np.array(0, dtype=np.float64)},
23
+ buffer_size=128
24
+ )
25
+ for i in range(30):
26
+ ring_buffer.put({
27
+ 'timestamp': np.array(
28
+ time.perf_counter(),
29
+ dtype=np.float64)
30
+ })
31
+ print(ring_buffer.get())
32
+
33
+
34
+ def _timestamp_worker(ring_buffer, start_event, stop_event):
35
+ while not stop_event.is_set():
36
+ start_event.set()
37
+ ring_buffer.put({
38
+ 'timestamp': np.array(
39
+ time.time(),
40
+ dtype=np.float64)
41
+ })
42
+
43
+
44
+ def test_mp():
45
+ shm_manager = SharedMemoryManager()
46
+ shm_manager.start()
47
+ ring_buffer = SharedMemoryRingBuffer.create_from_examples(
48
+ shm_manager,
49
+ {'timestamp': np.array(0, dtype=np.float64)},
50
+ get_max_k=1,
51
+ get_time_budget=0.01,
52
+ put_desired_frequency=1000
53
+ )
54
+ start_event = mp.Event()
55
+ stop_event = mp.Event()
56
+ worker = mp.Process(target=_timestamp_worker, args=(
57
+ ring_buffer, start_event, stop_event))
58
+ worker.start()
59
+ start_event.wait()
60
+ for i in range(1000):
61
+ t = float(ring_buffer.get()['timestamp'])
62
+ curr_t = time.time()
63
+ print('latency', curr_t - t)
64
+ stop_event.set()
65
+ worker.join()
66
+
67
+
68
+ def test_get_last_k():
69
+ shm_manager = SharedMemoryManager()
70
+ shm_manager.start()
71
+ ring_buffer = SharedMemoryRingBuffer.create_from_examples(
72
+ shm_manager,
73
+ {'counter': np.array(0, dtype=np.int64)},
74
+ buffer_size=8
75
+ )
76
+
77
+ from collections import deque
78
+ k = 4
79
+ last_k = deque(maxlen=k)
80
+ for i in range(100):
81
+ ring_buffer.put({
82
+ 'counter': np.array(i, dtype=np.int64)
83
+ })
84
+ last_k.append(i)
85
+ if i > k:
86
+ result = ring_buffer.get_last_k(k)['counter']
87
+ assert np.allclose(result, last_k)
88
+
89
+ print(ring_buffer.shared_arrays['counter'].get())
90
+ result = ring_buffer.get_last_k(4)
91
+ print(result)
92
+
93
+
94
+ def test_timing():
95
+ shm_manager = SharedMemoryManager()
96
+ shm_manager.start()
97
+ ring_buffer = SharedMemoryRingBuffer.create_from_examples(
98
+ shm_manager,
99
+ {'counter': np.array(0, dtype=np.int64)},
100
+ get_max_k=8,
101
+ get_time_budget=0.1,
102
+ put_desired_frequency=100
103
+ )
104
+ # print(ring_buffer.timestamp_array.get())
105
+ print('buffer_size', ring_buffer.buffer_size)
106
+
107
+ dt = 1 / 150
108
+ t_init = time.monotonic()
109
+ for i in range(1000):
110
+ t_start = time.monotonic()
111
+ ring_buffer.put({
112
+ 'counter': np.array(i, dtype=np.int64)
113
+ }, wait=False)
114
+ if (i % 10 == 0) and (i > 0):
115
+ result = ring_buffer.get_last_k(8)
116
+
117
+ t_end =time.monotonic()
118
+ desired_t = (i+1) * dt + t_init
119
+ if desired_t > t_end:
120
+ time.sleep(desired_t - t_end)
121
+ hz = 1 / (time.monotonic() - t_start)
122
+ print(f'{hz}Hz')
123
+
124
+
125
+ def _timestamp_image_worker(ring_buffer, img_shape, dt, start_event, stop_event):
126
+ i = 0
127
+ t_init = time.monotonic()
128
+ image = np.ones(img_shape, dtype=np.uint8)
129
+ while not stop_event.is_set():
130
+ t_start = time.monotonic()
131
+ start_event.set()
132
+ ring_buffer.put({
133
+ 'img': image,
134
+ 'timestamp': time.time(),
135
+ 'counter': i
136
+ })
137
+ t_end = time.monotonic()
138
+ desired_t = (i+1) * dt + t_init
139
+ # print('alive')
140
+ if desired_t > t_end:
141
+ time.sleep(desired_t - t_end)
142
+ # hz = 1 / (time.monotonic() - t_start)
143
+ i += 1
144
+
145
+
146
+ def test_timing_mp():
147
+ shm_manager = SharedMemoryManager()
148
+ shm_manager.start()
149
+
150
+ hz = 200
151
+ img_shape = (1920,1080,3)
152
+ ring_buffer = SharedMemoryRingBuffer.create_from_examples(
153
+ shm_manager,
154
+ examples={
155
+ 'img': np.zeros(img_shape, dtype=np.uint8),
156
+ 'timestamp': time.time(),
157
+ 'counter': 0
158
+ },
159
+ get_max_k=60,
160
+ get_time_budget=0.02,
161
+ put_desired_frequency=hz
162
+ )
163
+ start_event = mp.Event()
164
+ stop_event = mp.Event()
165
+ worker = mp.Process(target=_timestamp_image_worker, args=(
166
+ ring_buffer, img_shape, 1/hz, start_event, stop_event))
167
+ worker.start()
168
+ start_event.wait()
169
+ out = None
170
+ t_start = time.monotonic()
171
+ k = 1
172
+ for i in range(1000):
173
+ if ring_buffer.count < k:
174
+ time.sleep(0)
175
+ continue
176
+ out = ring_buffer.get_last_k(k=k, out=out)
177
+ t = float(out['timestamp'][-1])
178
+ curr_t = time.time()
179
+ print('latency', curr_t - t)
180
+ t_end = time.monotonic()
181
+ print('Get Hz', 1/(t_end-t_start)*1000)
182
+ stop_event.set()
183
+ worker.join()
184
+
185
+
186
+ if __name__ == '__main__':
187
+ # test_mp()
188
+ test_timing_mp()
third_party/diffusion_policy/tests/test_robomimic_image_runner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ from diffusion_policy.env_runner.robomimic_image_runner import RobomimicImageRunner
9
+
10
+ def test():
11
+ import os
12
+ from omegaconf import OmegaConf
13
+ cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
14
+ cfg = OmegaConf.load(cfg_path)
15
+ cfg['n_obs_steps'] = 1
16
+ cfg['n_action_steps'] = 1
17
+ cfg['past_action_visible'] = False
18
+ runner_cfg = cfg['env_runner']
19
+ runner_cfg['n_train'] = 1
20
+ runner_cfg['n_test'] = 1
21
+ del runner_cfg['_target_']
22
+ runner = RobomimicImageRunner(
23
+ **runner_cfg,
24
+ output_dir='/tmp/test')
25
+
26
+ # import pdb; pdb.set_trace()
27
+
28
+ self = runner
29
+ env = self.env
30
+ env.seed(seeds=self.env_seeds)
31
+ obs = env.reset()
32
+ for i in range(10):
33
+ _ = env.step(env.action_space.sample())
34
+
35
+ imgs = env.render()
36
+
37
+ if __name__ == '__main__':
38
+ test()
third_party/diffusion_policy/tests/test_robomimic_lowdim_runner.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ from diffusion_policy.env_runner.robomimic_lowdim_runner import RobomimicLowdimRunner
9
+
10
+ def test():
11
+ import os
12
+ from omegaconf import OmegaConf
13
+ cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_lowdim.yaml')
14
+ cfg = OmegaConf.load(cfg_path)
15
+ cfg['n_obs_steps'] = 1
16
+ cfg['n_action_steps'] = 1
17
+ cfg['past_action_visible'] = False
18
+ runner_cfg = cfg['env_runner']
19
+ runner_cfg['n_train'] = 1
20
+ runner_cfg['n_test'] = 0
21
+ del runner_cfg['_target_']
22
+ runner = RobomimicLowdimRunner(
23
+ **runner_cfg,
24
+ output_dir='/tmp/test')
25
+
26
+ # import pdb; pdb.set_trace()
27
+
28
+ self = runner
29
+ env = self.env
30
+ env.seed(seeds=self.env_seeds)
31
+ obs = env.reset()
32
+
33
+ if __name__ == '__main__':
34
+ test()
third_party/diffusion_policy/tests/test_shared_queue.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import numpy as np
9
+ from multiprocessing.managers import SharedMemoryManager
10
+ from diffusion_policy.shared_memory.shared_memory_queue import SharedMemoryQueue, Full, Empty
11
+
12
+
13
+ def test():
14
+ shm_manager = SharedMemoryManager()
15
+ shm_manager.start()
16
+ example = {
17
+ 'cmd': 0,
18
+ 'pose': np.zeros((6,))
19
+ }
20
+ queue = SharedMemoryQueue.create_from_examples(
21
+ shm_manager=shm_manager,
22
+ examples=example,
23
+ buffer_size=3
24
+ )
25
+ raised = False
26
+ try:
27
+ queue.get()
28
+ except Empty:
29
+ raised = True
30
+ assert raised
31
+
32
+ data = {
33
+ 'cmd': 1,
34
+ 'pose': np.ones((6,))
35
+ }
36
+ queue.put(data)
37
+ result = queue.get()
38
+ assert result['cmd'] == data['cmd']
39
+ assert np.allclose(result['pose'], data['pose'])
40
+
41
+ queue.put(data)
42
+ queue.put(data)
43
+ queue.put(data)
44
+ assert queue.qsize() == 3
45
+ raised = False
46
+ try:
47
+ queue.put(data)
48
+ except Full:
49
+ raised = True
50
+ assert raised
51
+
52
+ result = queue.get_all()
53
+ assert np.allclose(result['cmd'], [1,1,1])
54
+
55
+ queue.put({'cmd': 0})
56
+ queue.put({'cmd': 1})
57
+ queue.put({'cmd': 2})
58
+ queue.get()
59
+ queue.put({'cmd': 3})
60
+
61
+ result = queue.get_k(3)
62
+ assert np.allclose(result['cmd'], [1,2,3])
63
+
64
+ queue.clear()
65
+
66
+ if __name__ == "__main__":
67
+ test()
third_party/diffusion_policy/tests/test_single_realsense.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import cv2
9
+ import json
10
+ import time
11
+ from multiprocessing.managers import SharedMemoryManager
12
+ from diffusion_policy.real_world.single_realsense import SingleRealsense
13
+
14
+ def test():
15
+
16
+ serials = SingleRealsense.get_connected_devices_serial()
17
+ # import pdb; pdb.set_trace()
18
+ serial = serials[0]
19
+ config = json.load(open('/home/cchi/dev/diffusion_policy/diffusion_policy/real_world/realsense_config/415_high_accuracy_mode.json', 'r'))
20
+
21
+ def transform(data):
22
+ color = data['color']
23
+ h,w,_ = color.shape
24
+ factor = 2
25
+ color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA)
26
+ # color = color[:,140:500]
27
+ data['color'] = color
28
+ return data
29
+
30
+ # at 960x540 with //3, 60fps and 30fps are indistinguishable
31
+
32
+ with SharedMemoryManager() as shm_manager:
33
+ with SingleRealsense(
34
+ shm_manager=shm_manager,
35
+ serial_number=serial,
36
+ resolution=(1280,720),
37
+ # resolution=(960,540),
38
+ # resolution=(640,480),
39
+ capture_fps=30,
40
+ enable_color=True,
41
+ # enable_depth=True,
42
+ # enable_infrared=True,
43
+ # advanced_mode_config=config,
44
+ # transform=transform,
45
+ # recording_transform=transform
46
+ # verbose=True
47
+ ) as realsense:
48
+ cv2.setNumThreads(1)
49
+ realsense.set_exposure(exposure=150, gain=5)
50
+ intr = realsense.get_intrinsics()
51
+ print(intr)
52
+
53
+
54
+ video_path = 'data_local/test.mp4'
55
+ rec_start_time = time.time() + 2
56
+ realsense.start_recording(video_path, start_time=rec_start_time)
57
+
58
+ data = None
59
+ while True:
60
+ data = realsense.get(out=data)
61
+ t = time.time()
62
+ # print('capture_latency', data['receive_timestamp']-data['capture_timestamp'], 'receive_latency', t - data['receive_timestamp'])
63
+ # print('receive', t - data['receive_timestamp'])
64
+
65
+ # dt = time.time() - data['timestamp']
66
+ # print(dt)
67
+ # print(data['capture_timestamp'] - rec_start_time)
68
+
69
+ bgr = data['color']
70
+ # print(bgr.shape)
71
+ cv2.imshow('default', bgr)
72
+ key = cv2.pollKey()
73
+ # if key == ord('q'):
74
+ # break
75
+ # elif key == ord('r'):
76
+ # video_path = 'data_local/test.mp4'
77
+ # realsense.start_recording(video_path)
78
+ # elif key == ord('s'):
79
+ # realsense.stop_recording()
80
+
81
+ time.sleep(1/60)
82
+ if time.time() > (rec_start_time + 20.0):
83
+ break
84
+
85
+
86
+ if __name__ == "__main__":
87
+ test()
third_party/diffusion_policy/tests/test_timestamp_accumulator.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5
+ sys.path.append(ROOT_DIR)
6
+ os.chdir(ROOT_DIR)
7
+
8
+ import numpy as np
9
+ import time
10
+ from diffusion_policy.common.timestamp_accumulator import (
11
+ get_accumulate_timestamp_idxs,
12
+ TimestampObsAccumulator,
13
+ TimestampActionAccumulator
14
+ )
15
+
16
+
17
+ def test_index():
18
+ buffer = np.zeros(16)
19
+ start_time = 0.0
20
+ dt = 1/10
21
+
22
+ timestamps = np.linspace(0,1,100)
23
+ gi = list()
24
+ next_global_idx = 0
25
+
26
+ local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
27
+ start_time=start_time, dt=dt, next_global_idx=next_global_idx)
28
+ assert local_idxs[0] == 0
29
+ assert global_idxs[0] == 0
30
+ # print(local_idxs)
31
+ # print(global_idxs)
32
+ # print(timestamps[local_idxs])
33
+ buffer[global_idxs] = timestamps[local_idxs]
34
+ gi.extend(global_idxs)
35
+
36
+ timestamps = np.linspace(0.5,1.5,100)
37
+ local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
38
+ start_time=start_time, dt=dt, next_global_idx = next_global_idx)
39
+ # print(local_idxs)
40
+ # print(global_idxs)
41
+ # print(timestamps[local_idxs])
42
+ # import pdb; pdb.set_trace()
43
+ buffer[global_idxs] = timestamps[local_idxs]
44
+ gi.extend(global_idxs)
45
+
46
+ assert np.all(buffer[1:] > buffer[:-1])
47
+ assert np.all(np.array(gi) == np.array(list(range(len(gi)))))
48
+ # print(buffer)
49
+
50
+ # start over
51
+ next_global_idx = 0
52
+ timestamps = np.linspace(0,1,3)
53
+ local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
54
+ start_time=start_time, dt=dt, next_global_idx = next_global_idx)
55
+ assert local_idxs[0] == 0
56
+ assert local_idxs[-1] == 2
57
+ # print(local_idxs)
58
+ # print(global_idxs)
59
+ # print(timestamps[local_idxs])
60
+
61
+ # test numerical error issue
62
+ # this becomes a problem when eps <= 1e-7
63
+ start_time = time.time()
64
+ next_global_idx = 0
65
+ timestamps = np.arange(100000) * dt + start_time
66
+ local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
67
+ start_time=start_time, dt=dt, next_global_idx = next_global_idx)
68
+ assert local_idxs == global_idxs
69
+ # print(local_idxs)
70
+ # print(global_idxs)
71
+ # print(timestamps[local_idxs])
72
+
73
+
74
+ def test_obs_accumulator():
75
+ dt = 1/10
76
+ ddt = 1/100
77
+ n = 100
78
+ d = 6
79
+ start_time = time.time()
80
+ toa = TimestampObsAccumulator(start_time, dt)
81
+ poses = np.arange(n).reshape((n,1))
82
+ poses = np.repeat(poses, d, axis=1)
83
+ timestamps = np.arange(n) * ddt + start_time
84
+
85
+ toa.put({
86
+ 'pose': poses,
87
+ 'timestamp': timestamps
88
+ }, timestamps)
89
+ assert np.all(toa.data['pose'][:,0] == np.arange(10)*10)
90
+ assert len(toa) == 10
91
+
92
+ # add the same thing, result shouldn't change
93
+ toa.put({
94
+ 'pose': poses,
95
+ 'timestamp': timestamps
96
+ }, timestamps)
97
+ assert np.all(toa.data['pose'][:,0] == np.arange(10)*10)
98
+ assert len(toa) == 10
99
+
100
+ # add lower than desired freuquency to test fill_in
101
+ dt = 1/10
102
+ ddt = 1/5
103
+ n = 10
104
+ d = 6
105
+ start_time = time.time()
106
+ toa = TimestampObsAccumulator(start_time, dt)
107
+ poses = np.arange(n).reshape((n,1))
108
+ poses = np.repeat(poses, d, axis=1)
109
+ timestamps = np.arange(n) * ddt + start_time
110
+
111
+ toa.put({
112
+ 'pose': poses,
113
+ 'timestamp': timestamps
114
+ }, timestamps)
115
+ assert len(toa) == 1 + (n-1) * 2
116
+
117
+ timestamps = (np.arange(n) + 2) * ddt + start_time
118
+ toa.put({
119
+ 'pose': poses,
120
+ 'timestamp': timestamps
121
+ }, timestamps)
122
+ assert len(toa) == 1 + (n-1) * 2 + 4
123
+
124
+
125
+ def test_action_accumulator():
126
+ dt = 1/10
127
+ n = 10
128
+ d = 6
129
+ start_time = time.time()
130
+ taa = TimestampActionAccumulator(start_time, dt)
131
+ actions = np.arange(n).reshape((n,1))
132
+ actions = np.repeat(actions, d, axis=1)
133
+
134
+ timestamps = np.arange(n) * dt + start_time
135
+ taa.put(actions, timestamps)
136
+ assert np.all(taa.actions == actions)
137
+ assert np.all(taa.timestamps == timestamps)
138
+
139
+ # add another round
140
+ taa.put(actions-5, timestamps-0.5)
141
+ assert np.allclose(taa.timestamps, timestamps)
142
+
143
+ # add another round
144
+ taa.put(actions+5, timestamps+0.5)
145
+ assert len(taa) == 15
146
+ assert np.all(taa.actions[:,0] == np.arange(15))
147
+
148
+
149
+
150
+ if __name__ == '__main__':
151
+ test_action_accumulator()
third_party/diffusion_policy/train.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ Training:
4
+ python train.py --config-name=train_diffusion_lowdim_workspace
5
+ """
6
+
7
+ import sys
8
+ # use line-buffering for both stdout and stderr
9
+ sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
10
+ sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
11
+
12
+ import hydra
13
+ from omegaconf import OmegaConf
14
+ import pathlib
15
+ from diffusion_policy.workspace.base_workspace import BaseWorkspace
16
+
17
+ # allows arbitrary python code execution in configs using the ${eval:''} resolver
18
+ OmegaConf.register_new_resolver("eval", eval, replace=True)
19
+
20
+ @hydra.main(
21
+ version_base=None,
22
+ config_path=str(pathlib.Path(__file__).parent.joinpath(
23
+ 'diffusion_policy','config'))
24
+ )
25
+ def main(cfg: OmegaConf):
26
+ # resolve immediately so all the ${now:} resolvers
27
+ # will use the same time.
28
+ OmegaConf.resolve(cfg)
29
+
30
+ cls = hydra.utils.get_class(cfg._target_)
31
+ workspace: BaseWorkspace = cls(cfg)
32
+ workspace.run()
33
+
34
+ if __name__ == "__main__":
35
+ main()