felixw commited on
Commit
e2f7d64
·
verified ·
1 Parent(s): fb5d440

Upload ITPS Maze2D pretrained checkpoint

Browse files
Files changed (4) hide show
  1. README.md +56 -0
  2. config.json +53 -0
  3. config.yaml +195 -0
  4. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: lerobot
3
+ license: mit
4
+ tags:
5
+ - lerobot
6
+ - dp
7
+ - robotics
8
+ - maze2d
9
+ - itps
10
+ - pytorch_model_hub_mixin
11
+ pipeline_tag: robotics
12
+ ---
13
+
14
+ # ITPS Maze2D — Diffusion Policy (DP)
15
+
16
+ Pre-trained Diffusion Policy checkpoint used in
17
+ **Inference-Time Policy Steering through Human Interactions**
18
+ ([paper](https://huggingface.co/papers/2411.16627), [project page](https://yanweiw.github.io/itps/), [code](https://github.com/yanweiw/itps)).
19
+
20
+ The model was trained on the [D4RL Maze2D](https://github.com/Farama-Foundation/D4RL)
21
+ dataset and is intended to be loaded with the
22
+ [LeRobot](https://github.com/huggingface/lerobot) policy classes.
23
+
24
+ ## Usage
25
+
26
+ Clone the inference repo, then load this checkpoint directly from the Hub:
27
+
28
+ ```bash
29
+ git clone https://github.com/yanweiw/itps.git && cd itps
30
+ pip install -e .
31
+ python interact_maze2d.py -p dp --hf
32
+ ```
33
+
34
+ Or load it programmatically:
35
+
36
+ ```python
37
+ from itps.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
38
+
39
+ policy = DiffusionPolicy.from_pretrained("felixw/itps-dp")
40
+ policy.eval()
41
+ ```
42
+
43
+ ## Citation
44
+
45
+ ```bibtex
46
+ @article{wang2024itps,
47
+ title={Inference-Time Policy Steering through Human Interactions},
48
+ author={Wang, Yanwei and others},
49
+ journal={arXiv preprint arXiv:2411.16627},
50
+ year={2024}
51
+ }
52
+ ```
53
+
54
+ ## License
55
+
56
+ MIT — see [LICENSE](https://github.com/yanweiw/itps/blob/main/LICENSE).
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "beta_end": 0.02,
3
+ "beta_schedule": "squaredcos_cap_v2",
4
+ "beta_start": 0.0001,
5
+ "clip_sample": true,
6
+ "clip_sample_range": 1.0,
7
+ "crop_is_random": true,
8
+ "crop_shape": [
9
+ 84,
10
+ 84
11
+ ],
12
+ "diffusion_step_embed_dim": 128,
13
+ "do_mask_loss_for_padding": false,
14
+ "down_dims": [
15
+ 128,
16
+ 256,
17
+ 512
18
+ ],
19
+ "horizon": 64,
20
+ "input_normalization_modes": {
21
+ "observation.environment_state": "min_max",
22
+ "observation.state": "min_max"
23
+ },
24
+ "input_shapes": {
25
+ "observation.environment_state": [
26
+ 2
27
+ ],
28
+ "observation.state": [
29
+ 2
30
+ ]
31
+ },
32
+ "kernel_size": 5,
33
+ "n_action_steps": 8,
34
+ "n_groups": 8,
35
+ "n_obs_steps": 2,
36
+ "noise_scheduler_type": "DDIM",
37
+ "num_inference_steps": 10,
38
+ "num_train_timesteps": 100,
39
+ "output_normalization_modes": {
40
+ "action": "min_max"
41
+ },
42
+ "output_shapes": {
43
+ "action": [
44
+ 2
45
+ ]
46
+ },
47
+ "prediction_type": "epsilon",
48
+ "pretrained_backbone_weights": null,
49
+ "spatial_softmax_num_keypoints": 32,
50
+ "use_film_scale_modulation": true,
51
+ "use_group_norm": true,
52
+ "vision_backbone": "resnet18"
53
+ }
config.yaml ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resume: false
2
+ device: cuda:1
3
+ use_amp: false
4
+ seed: 100000
5
+ dataset_repo_id: maze2d
6
+ dataset_root: /afs/csail.mit.edu/u/f/felixw/lerobot/data/maze2d-large-sparse-v1.hdf5
7
+ video_backend: pyav
8
+ training:
9
+ offline_steps: 100000
10
+ num_workers: 4
11
+ batch_size: 256
12
+ eval_freq: 0
13
+ log_freq: 250
14
+ save_checkpoint: true
15
+ save_freq: 5000
16
+ online_steps: 0
17
+ online_rollout_n_episodes: 1
18
+ online_rollout_batch_size: 1
19
+ online_steps_between_rollouts: 1
20
+ online_sampling_ratio: 0.5
21
+ online_env_seed: null
22
+ online_buffer_capacity: null
23
+ online_buffer_seed_size: 0
24
+ do_online_rollout_async: false
25
+ image_transforms:
26
+ enable: false
27
+ max_num_transforms: 3
28
+ random_order: false
29
+ brightness:
30
+ weight: 1
31
+ min_max:
32
+ - 0.8
33
+ - 1.2
34
+ contrast:
35
+ weight: 1
36
+ min_max:
37
+ - 0.8
38
+ - 1.2
39
+ saturation:
40
+ weight: 1
41
+ min_max:
42
+ - 0.5
43
+ - 1.5
44
+ hue:
45
+ weight: 1
46
+ min_max:
47
+ - -0.05
48
+ - 0.05
49
+ sharpness:
50
+ weight: 1
51
+ min_max:
52
+ - 0.8
53
+ - 1.2
54
+ grad_clip_norm: 10
55
+ lr: 0.0001
56
+ lr_scheduler: cosine
57
+ lr_warmup_steps: 500
58
+ adam_betas:
59
+ - 0.95
60
+ - 0.999
61
+ adam_eps: 1.0e-08
62
+ adam_weight_decay: 1.0e-06
63
+ delta_timestamps:
64
+ observation.environment_state:
65
+ - -0.1
66
+ - 0.0
67
+ observation.state:
68
+ - -0.1
69
+ - 0.0
70
+ action:
71
+ - -0.1
72
+ - 0.0
73
+ - 0.1
74
+ - 0.2
75
+ - 0.3
76
+ - 0.4
77
+ - 0.5
78
+ - 0.6
79
+ - 0.7
80
+ - 0.8
81
+ - 0.9
82
+ - 1.0
83
+ - 1.1
84
+ - 1.2
85
+ - 1.3
86
+ - 1.4
87
+ - 1.5
88
+ - 1.6
89
+ - 1.7
90
+ - 1.8
91
+ - 1.9
92
+ - 2.0
93
+ - 2.1
94
+ - 2.2
95
+ - 2.3
96
+ - 2.4
97
+ - 2.5
98
+ - 2.6
99
+ - 2.7
100
+ - 2.8
101
+ - 2.9
102
+ - 3.0
103
+ - 3.1
104
+ - 3.2
105
+ - 3.3
106
+ - 3.4
107
+ - 3.5
108
+ - 3.6
109
+ - 3.7
110
+ - 3.8
111
+ - 3.9
112
+ - 4.0
113
+ - 4.1
114
+ - 4.2
115
+ - 4.3
116
+ - 4.4
117
+ - 4.5
118
+ - 4.6
119
+ - 4.7
120
+ - 4.8
121
+ - 4.9
122
+ - 5.0
123
+ - 5.1
124
+ - 5.2
125
+ - 5.3
126
+ - 5.4
127
+ - 5.5
128
+ - 5.6
129
+ - 5.7
130
+ - 5.8
131
+ - 5.9
132
+ - 6.0
133
+ - 6.1
134
+ - 6.2
135
+ drop_n_last_frames: 7
136
+ eval:
137
+ n_episodes: 50
138
+ batch_size: 50
139
+ use_async_envs: false
140
+ wandb:
141
+ enable: true
142
+ disable_artifact: false
143
+ project: lerobot
144
+ notes: ''
145
+ fps: 10
146
+ env:
147
+ name: maze2d
148
+ task: null
149
+ state_dim: 2
150
+ action_dim: 2
151
+ fps: ${fps}
152
+ policy:
153
+ name: diffusion
154
+ n_obs_steps: 2
155
+ horizon: 64
156
+ n_action_steps: 8
157
+ input_shapes:
158
+ observation.environment_state:
159
+ - 2
160
+ observation.state:
161
+ - ${env.state_dim}
162
+ output_shapes:
163
+ action:
164
+ - ${env.action_dim}
165
+ input_normalization_modes:
166
+ observation.environment_state: min_max
167
+ observation.state: min_max
168
+ output_normalization_modes:
169
+ action: min_max
170
+ vision_backbone: resnet18
171
+ crop_shape:
172
+ - 84
173
+ - 84
174
+ crop_is_random: true
175
+ pretrained_backbone_weights: null
176
+ use_group_norm: true
177
+ spatial_softmax_num_keypoints: 32
178
+ down_dims:
179
+ - 128
180
+ - 256
181
+ - 512
182
+ kernel_size: 5
183
+ n_groups: 8
184
+ diffusion_step_embed_dim: 128
185
+ use_film_scale_modulation: true
186
+ noise_scheduler_type: DDIM
187
+ num_train_timesteps: 100
188
+ beta_schedule: squaredcos_cap_v2
189
+ beta_start: 0.0001
190
+ beta_end: 0.02
191
+ prediction_type: epsilon
192
+ clip_sample: true
193
+ clip_sample_range: 1.0
194
+ num_inference_steps: 10
195
+ do_mask_loss_for_padding: false
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae14bf229318d2bf3f652086ec96336233acd4a575fdae273d1c3c80e64a3d80
3
+ size 65524656