felixw commited on
Commit
47cacba
·
verified ·
1 Parent(s): 0f2c545

Upload ITPS Maze2D pretrained checkpoint

Browse files
Files changed (4) hide show
  1. README.md +56 -0
  2. config.json +41 -0
  3. config.yaml +173 -0
  4. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: lerobot
3
+ license: mit
4
+ tags:
5
+ - lerobot
6
+ - act
7
+ - robotics
8
+ - maze2d
9
+ - itps
10
+ - pytorch_model_hub_mixin
11
+ pipeline_tag: robotics
12
+ ---
13
+
14
+ # ITPS Maze2D — Action Chunking Transformer (ACT)
15
+
16
+ Pre-trained Action Chunking Transformer checkpoint used in
17
+ **Inference-Time Policy Steering through Human Interactions**
18
+ ([paper](https://huggingface.co/papers/2411.16627), [project page](https://yanweiw.github.io/itps/), [code](https://github.com/yanweiw/itps)).
19
+
20
+ The model was trained on the [D4RL Maze2D](https://github.com/Farama-Foundation/D4RL)
21
+ dataset and is intended to be loaded with the
22
+ [LeRobot](https://github.com/huggingface/lerobot) policy classes.
23
+
24
+ ## Usage
25
+
26
+ Clone the inference repo, then load this checkpoint directly from the Hub:
27
+
28
+ ```bash
29
+ git clone https://github.com/yanweiw/itps.git && cd itps
30
+ pip install -e .
31
+ python interact_maze2d.py -p act --hf
32
+ ```
33
+
34
+ Or load it programmatically:
35
+
36
+ ```python
37
+ from itps.common.policies.act.modeling_act import ACTPolicy
38
+
39
+ policy = ACTPolicy.from_pretrained("felixw/itps-act")
40
+ policy.eval()
41
+ ```
42
+
43
+ ## Citation
44
+
45
+ ```bibtex
46
+ @article{wang2024itps,
47
+ title={Inference-Time Policy Steering through Human Interactions},
48
+ author={Wang, Yanwei and others},
49
+ journal={arXiv preprint arXiv:2411.16627},
50
+ year={2024}
51
+ }
52
+ ```
53
+
54
+ ## License
55
+
56
+ MIT — see [LICENSE](https://github.com/yanweiw/itps/blob/main/LICENSE).
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_size": 64,
3
+ "dim_feedforward": 3200,
4
+ "dim_model": 512,
5
+ "dropout": 0.1,
6
+ "feedforward_activation": "relu",
7
+ "input_normalization_modes": {
8
+ "observation.environment_state": "mean_std",
9
+ "observation.state": "mean_std"
10
+ },
11
+ "input_shapes": {
12
+ "observation.environment_state": [
13
+ 2
14
+ ],
15
+ "observation.state": [
16
+ 2
17
+ ]
18
+ },
19
+ "kl_weight": 10.0,
20
+ "latent_dim": 32,
21
+ "n_action_steps": 64,
22
+ "n_decoder_layers": 1,
23
+ "n_encoder_layers": 4,
24
+ "n_heads": 8,
25
+ "n_obs_steps": 1,
26
+ "n_vae_encoder_layers": 4,
27
+ "output_normalization_modes": {
28
+ "action": "mean_std"
29
+ },
30
+ "output_shapes": {
31
+ "action": [
32
+ 2
33
+ ]
34
+ },
35
+ "pre_norm": false,
36
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
37
+ "replace_final_stride_with_dilation": false,
38
+ "temporal_ensemble_coeff": null,
39
+ "use_vae": true,
40
+ "vision_backbone": "resnet18"
41
+ }
config.yaml ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resume: false
2
+ device: cuda:0
3
+ use_amp: false
4
+ seed: 100000
5
+ dataset_repo_id: maze2d
6
+ dataset_root: /afs/csail.mit.edu/u/f/felixw/lerobot/data/maze2d-large-sparse-v1.hdf5
7
+ video_backend: pyav
8
+ training:
9
+ offline_steps: 100000
10
+ num_workers: 4
11
+ batch_size: 256
12
+ eval_freq: 0
13
+ log_freq: 250
14
+ save_checkpoint: true
15
+ save_freq: 5000
16
+ online_steps: 0
17
+ online_rollout_n_episodes: 1
18
+ online_rollout_batch_size: 1
19
+ online_steps_between_rollouts: 1
20
+ online_sampling_ratio: 0.5
21
+ online_env_seed: null
22
+ online_buffer_capacity: null
23
+ online_buffer_seed_size: 0
24
+ do_online_rollout_async: false
25
+ image_transforms:
26
+ enable: false
27
+ max_num_transforms: 3
28
+ random_order: false
29
+ brightness:
30
+ weight: 1
31
+ min_max:
32
+ - 0.8
33
+ - 1.2
34
+ contrast:
35
+ weight: 1
36
+ min_max:
37
+ - 0.8
38
+ - 1.2
39
+ saturation:
40
+ weight: 1
41
+ min_max:
42
+ - 0.5
43
+ - 1.5
44
+ hue:
45
+ weight: 1
46
+ min_max:
47
+ - -0.05
48
+ - 0.05
49
+ sharpness:
50
+ weight: 1
51
+ min_max:
52
+ - 0.8
53
+ - 1.2
54
+ lr: 1.0e-05
55
+ lr_backbone: 1.0e-05
56
+ weight_decay: 0.0001
57
+ grad_clip_norm: 10
58
+ delta_timestamps:
59
+ action:
60
+ - 0.0
61
+ - 0.1
62
+ - 0.2
63
+ - 0.3
64
+ - 0.4
65
+ - 0.5
66
+ - 0.6
67
+ - 0.7
68
+ - 0.8
69
+ - 0.9
70
+ - 1.0
71
+ - 1.1
72
+ - 1.2
73
+ - 1.3
74
+ - 1.4
75
+ - 1.5
76
+ - 1.6
77
+ - 1.7
78
+ - 1.8
79
+ - 1.9
80
+ - 2.0
81
+ - 2.1
82
+ - 2.2
83
+ - 2.3
84
+ - 2.4
85
+ - 2.5
86
+ - 2.6
87
+ - 2.7
88
+ - 2.8
89
+ - 2.9
90
+ - 3.0
91
+ - 3.1
92
+ - 3.2
93
+ - 3.3
94
+ - 3.4
95
+ - 3.5
96
+ - 3.6
97
+ - 3.7
98
+ - 3.8
99
+ - 3.9
100
+ - 4.0
101
+ - 4.1
102
+ - 4.2
103
+ - 4.3
104
+ - 4.4
105
+ - 4.5
106
+ - 4.6
107
+ - 4.7
108
+ - 4.8
109
+ - 4.9
110
+ - 5.0
111
+ - 5.1
112
+ - 5.2
113
+ - 5.3
114
+ - 5.4
115
+ - 5.5
116
+ - 5.6
117
+ - 5.7
118
+ - 5.8
119
+ - 5.9
120
+ - 6.0
121
+ - 6.1
122
+ - 6.2
123
+ - 6.3
124
+ eval:
125
+ n_episodes: 50
126
+ batch_size: 50
127
+ use_async_envs: false
128
+ wandb:
129
+ enable: true
130
+ disable_artifact: false
131
+ project: lerobot
132
+ notes: ''
133
+ fps: 10
134
+ env:
135
+ name: maze2d
136
+ task: null
137
+ state_dim: 2
138
+ action_dim: 2
139
+ fps: ${fps}
140
+ policy:
141
+ name: act
142
+ n_obs_steps: 1
143
+ chunk_size: 64
144
+ n_action_steps: 64
145
+ input_shapes:
146
+ observation.environment_state:
147
+ - 2
148
+ observation.state:
149
+ - ${env.state_dim}
150
+ output_shapes:
151
+ action:
152
+ - ${env.action_dim}
153
+ input_normalization_modes:
154
+ observation.environment_state: mean_std
155
+ observation.state: mean_std
156
+ output_normalization_modes:
157
+ action: mean_std
158
+ vision_backbone: resnet18
159
+ pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
160
+ replace_final_stride_with_dilation: false
161
+ pre_norm: false
162
+ dim_model: 512
163
+ n_heads: 8
164
+ dim_feedforward: 3200
165
+ feedforward_activation: relu
166
+ n_encoder_layers: 4
167
+ n_decoder_layers: 1
168
+ use_vae: true
169
+ latent_dim: 32
170
+ n_vae_encoder_layers: 4
171
+ temporal_ensemble_coeff: null
172
+ dropout: 0.1
173
+ kl_weight: 10.0
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1463f404e77cb9716fe1157180408925de84d13ad2e51552f8f1ae9b796d6c05
3
+ size 160723088