Luna2099 commited on
Commit
5a798e0
·
verified ·
1 Parent(s): 3aad2b4

ACT grab_block model: 30k steps, loss 0.15, trained on RTX 5090

Browse files
ckpt_10000/config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "act",
3
+ "n_obs_steps": 1,
4
+ "input_features": {
5
+ "observation.state": {
6
+ "type": "STATE",
7
+ "shape": [
8
+ 6
9
+ ]
10
+ },
11
+ "observation.images.wrist": {
12
+ "type": "VISUAL",
13
+ "shape": [
14
+ 3,
15
+ 480,
16
+ 640
17
+ ]
18
+ }
19
+ },
20
+ "output_features": {
21
+ "action": {
22
+ "type": "ACTION",
23
+ "shape": [
24
+ 6
25
+ ]
26
+ }
27
+ },
28
+ "device": "cuda",
29
+ "use_amp": false,
30
+ "use_peft": false,
31
+ "push_to_hub": true,
32
+ "repo_id": null,
33
+ "private": null,
34
+ "tags": null,
35
+ "license": null,
36
+ "pretrained_path": null,
37
+ "chunk_size": 100,
38
+ "n_action_steps": 100,
39
+ "normalization_mapping": {
40
+ "VISUAL": "MEAN_STD",
41
+ "STATE": "MEAN_STD",
42
+ "ACTION": "MEAN_STD"
43
+ },
44
+ "vision_backbone": "resnet18",
45
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
46
+ "replace_final_stride_with_dilation": false,
47
+ "pre_norm": false,
48
+ "dim_model": 512,
49
+ "n_heads": 8,
50
+ "dim_feedforward": 3200,
51
+ "feedforward_activation": "relu",
52
+ "n_encoder_layers": 4,
53
+ "n_decoder_layers": 1,
54
+ "use_vae": true,
55
+ "latent_dim": 32,
56
+ "n_vae_encoder_layers": 4,
57
+ "temporal_ensemble_coeff": null,
58
+ "dropout": 0.1,
59
+ "kl_weight": 10.0,
60
+ "optimizer_lr": 1e-05,
61
+ "optimizer_weight_decay": 0.0001,
62
+ "optimizer_lr_backbone": 1e-05
63
+ }
ckpt_10000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b54971333625f1b96f70506d157116f2531fa9dd8322d95eb8eaa9d38c2cd77
3
+ size 206699736
ckpt_20000/config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "act",
3
+ "n_obs_steps": 1,
4
+ "input_features": {
5
+ "observation.state": {
6
+ "type": "STATE",
7
+ "shape": [
8
+ 6
9
+ ]
10
+ },
11
+ "observation.images.wrist": {
12
+ "type": "VISUAL",
13
+ "shape": [
14
+ 3,
15
+ 480,
16
+ 640
17
+ ]
18
+ }
19
+ },
20
+ "output_features": {
21
+ "action": {
22
+ "type": "ACTION",
23
+ "shape": [
24
+ 6
25
+ ]
26
+ }
27
+ },
28
+ "device": "cuda",
29
+ "use_amp": false,
30
+ "use_peft": false,
31
+ "push_to_hub": true,
32
+ "repo_id": null,
33
+ "private": null,
34
+ "tags": null,
35
+ "license": null,
36
+ "pretrained_path": null,
37
+ "chunk_size": 100,
38
+ "n_action_steps": 100,
39
+ "normalization_mapping": {
40
+ "VISUAL": "MEAN_STD",
41
+ "STATE": "MEAN_STD",
42
+ "ACTION": "MEAN_STD"
43
+ },
44
+ "vision_backbone": "resnet18",
45
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
46
+ "replace_final_stride_with_dilation": false,
47
+ "pre_norm": false,
48
+ "dim_model": 512,
49
+ "n_heads": 8,
50
+ "dim_feedforward": 3200,
51
+ "feedforward_activation": "relu",
52
+ "n_encoder_layers": 4,
53
+ "n_decoder_layers": 1,
54
+ "use_vae": true,
55
+ "latent_dim": 32,
56
+ "n_vae_encoder_layers": 4,
57
+ "temporal_ensemble_coeff": null,
58
+ "dropout": 0.1,
59
+ "kl_weight": 10.0,
60
+ "optimizer_lr": 1e-05,
61
+ "optimizer_weight_decay": 0.0001,
62
+ "optimizer_lr_backbone": 1e-05
63
+ }
ckpt_20000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41b0699e6634b51d47c60c640ae1c03b2407bd1aed9679421e083fe555572a8d
3
+ size 206699736
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "act",
3
+ "n_obs_steps": 1,
4
+ "input_features": {
5
+ "observation.state": {
6
+ "type": "STATE",
7
+ "shape": [
8
+ 6
9
+ ]
10
+ },
11
+ "observation.images.wrist": {
12
+ "type": "VISUAL",
13
+ "shape": [
14
+ 3,
15
+ 480,
16
+ 640
17
+ ]
18
+ }
19
+ },
20
+ "output_features": {
21
+ "action": {
22
+ "type": "ACTION",
23
+ "shape": [
24
+ 6
25
+ ]
26
+ }
27
+ },
28
+ "device": "cuda",
29
+ "use_amp": false,
30
+ "use_peft": false,
31
+ "push_to_hub": true,
32
+ "repo_id": null,
33
+ "private": null,
34
+ "tags": null,
35
+ "license": null,
36
+ "pretrained_path": null,
37
+ "chunk_size": 100,
38
+ "n_action_steps": 100,
39
+ "normalization_mapping": {
40
+ "VISUAL": "MEAN_STD",
41
+ "STATE": "MEAN_STD",
42
+ "ACTION": "MEAN_STD"
43
+ },
44
+ "vision_backbone": "resnet18",
45
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
46
+ "replace_final_stride_with_dilation": false,
47
+ "pre_norm": false,
48
+ "dim_model": 512,
49
+ "n_heads": 8,
50
+ "dim_feedforward": 3200,
51
+ "feedforward_activation": "relu",
52
+ "n_encoder_layers": 4,
53
+ "n_decoder_layers": 1,
54
+ "use_vae": true,
55
+ "latent_dim": 32,
56
+ "n_vae_encoder_layers": 4,
57
+ "temporal_ensemble_coeff": null,
58
+ "dropout": 0.1,
59
+ "kl_weight": 10.0,
60
+ "optimizer_lr": 1e-05,
61
+ "optimizer_weight_decay": 0.0001,
62
+ "optimizer_lr_backbone": 1e-05
63
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed48d7b8e7885e09eaef4622fadf0197dc12391b0e42b04302df667850f004f
3
+ size 206699736
policy_postprocessor.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "policy_postprocessor",
3
+ "steps": [
4
+ {
5
+ "registry_name": "unnormalizer_processor",
6
+ "config": {
7
+ "eps": 1e-08,
8
+ "features": {
9
+ "action": {
10
+ "type": "ACTION",
11
+ "shape": [
12
+ 6
13
+ ]
14
+ }
15
+ },
16
+ "norm_map": {
17
+ "VISUAL": "MEAN_STD",
18
+ "STATE": "MEAN_STD",
19
+ "ACTION": "MEAN_STD"
20
+ }
21
+ },
22
+ "state_file": "policy_postprocessor_step_0_unnormalizer_processor.safetensors"
23
+ },
24
+ {
25
+ "registry_name": "device_processor",
26
+ "config": {
27
+ "device": "cpu",
28
+ "float_dtype": null
29
+ }
30
+ }
31
+ ]
32
+ }
policy_postprocessor_step_0_unnormalizer_processor.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bbdcc06422b4628f823172bd128458150a4d47a332f59ff28e9d89e7281ccdc
3
+ size 6560
policy_preprocessor.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "policy_preprocessor",
3
+ "steps": [
4
+ {
5
+ "registry_name": "rename_observations_processor",
6
+ "config": {
7
+ "rename_map": {}
8
+ }
9
+ },
10
+ {
11
+ "registry_name": "to_batch_processor",
12
+ "config": {}
13
+ },
14
+ {
15
+ "registry_name": "device_processor",
16
+ "config": {
17
+ "device": "cuda",
18
+ "float_dtype": null
19
+ }
20
+ },
21
+ {
22
+ "registry_name": "normalizer_processor",
23
+ "config": {
24
+ "eps": 1e-08,
25
+ "features": {
26
+ "observation.state": {
27
+ "type": "STATE",
28
+ "shape": [
29
+ 6
30
+ ]
31
+ },
32
+ "observation.images.wrist": {
33
+ "type": "VISUAL",
34
+ "shape": [
35
+ 3,
36
+ 480,
37
+ 640
38
+ ]
39
+ },
40
+ "action": {
41
+ "type": "ACTION",
42
+ "shape": [
43
+ 6
44
+ ]
45
+ }
46
+ },
47
+ "norm_map": {
48
+ "VISUAL": "MEAN_STD",
49
+ "STATE": "MEAN_STD",
50
+ "ACTION": "MEAN_STD"
51
+ }
52
+ },
53
+ "state_file": "policy_preprocessor_step_3_normalizer_processor.safetensors"
54
+ }
55
+ ]
56
+ }
policy_preprocessor_step_3_normalizer_processor.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bbdcc06422b4628f823172bd128458150a4d47a332f59ff28e9d89e7281ccdc
3
+ size 6560
training_log.json ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "step": 0,
4
+ "loss": 95.05252838134766,
5
+ "elapsed": 5.026012420654297
6
+ },
7
+ {
8
+ "step": 500,
9
+ "loss": 2.0986263751983643,
10
+ "elapsed": 48.24358034133911
11
+ },
12
+ {
13
+ "step": 1000,
14
+ "loss": 1.416904330253601,
15
+ "elapsed": 90.54185724258423
16
+ },
17
+ {
18
+ "step": 1500,
19
+ "loss": 0.9374834299087524,
20
+ "elapsed": 133.1820592880249
21
+ },
22
+ {
23
+ "step": 2000,
24
+ "loss": 0.6080644726753235,
25
+ "elapsed": 175.66594648361206
26
+ },
27
+ {
28
+ "step": 2500,
29
+ "loss": 0.46668586134910583,
30
+ "elapsed": 218.11662578582764
31
+ },
32
+ {
33
+ "step": 3000,
34
+ "loss": 0.39408573508262634,
35
+ "elapsed": 260.9173834323883
36
+ },
37
+ {
38
+ "step": 3500,
39
+ "loss": 0.3152383863925934,
40
+ "elapsed": 303.28292202949524
41
+ },
42
+ {
43
+ "step": 4000,
44
+ "loss": 0.29642900824546814,
45
+ "elapsed": 345.74440598487854
46
+ },
47
+ {
48
+ "step": 4500,
49
+ "loss": 0.2556461989879608,
50
+ "elapsed": 388.4183192253113
51
+ },
52
+ {
53
+ "step": 5000,
54
+ "loss": 0.277762234210968,
55
+ "elapsed": 430.8937876224518
56
+ },
57
+ {
58
+ "step": 5500,
59
+ "loss": 0.25442755222320557,
60
+ "elapsed": 473.4626610279083
61
+ },
62
+ {
63
+ "step": 6000,
64
+ "loss": 0.21272782981395721,
65
+ "elapsed": 516.0056443214417
66
+ },
67
+ {
68
+ "step": 6500,
69
+ "loss": 0.221424862742424,
70
+ "elapsed": 558.4167232513428
71
+ },
72
+ {
73
+ "step": 7000,
74
+ "loss": 0.21034026145935059,
75
+ "elapsed": 601.0059714317322
76
+ },
77
+ {
78
+ "step": 7500,
79
+ "loss": 0.20263026654720306,
80
+ "elapsed": 644.2706835269928
81
+ },
82
+ {
83
+ "step": 8000,
84
+ "loss": 0.18980352580547333,
85
+ "elapsed": 686.9732842445374
86
+ },
87
+ {
88
+ "step": 8500,
89
+ "loss": 0.17610831558704376,
90
+ "elapsed": 729.3965904712677
91
+ },
92
+ {
93
+ "step": 9000,
94
+ "loss": 0.1795623004436493,
95
+ "elapsed": 772.0203955173492
96
+ },
97
+ {
98
+ "step": 9500,
99
+ "loss": 0.19825756549835205,
100
+ "elapsed": 814.214008808136
101
+ },
102
+ {
103
+ "step": 10000,
104
+ "loss": 0.17051126062870026,
105
+ "elapsed": 856.9056174755096
106
+ },
107
+ {
108
+ "step": 10500,
109
+ "loss": 0.1672261506319046,
110
+ "elapsed": 900.7812814712524
111
+ },
112
+ {
113
+ "step": 11000,
114
+ "loss": 0.17105333507061005,
115
+ "elapsed": 944.1276385784149
116
+ },
117
+ {
118
+ "step": 11500,
119
+ "loss": 0.15053829550743103,
120
+ "elapsed": 987.4672486782074
121
+ },
122
+ {
123
+ "step": 12000,
124
+ "loss": 0.16913338005542755,
125
+ "elapsed": 1030.4265339374542
126
+ },
127
+ {
128
+ "step": 12500,
129
+ "loss": 0.17490649223327637,
130
+ "elapsed": 1073.70126414299
131
+ },
132
+ {
133
+ "step": 13000,
134
+ "loss": 0.15259376168251038,
135
+ "elapsed": 1116.8105387687683
136
+ },
137
+ {
138
+ "step": 13500,
139
+ "loss": 0.14562517404556274,
140
+ "elapsed": 1160.817851305008
141
+ },
142
+ {
143
+ "step": 14000,
144
+ "loss": 0.1578102558851242,
145
+ "elapsed": 1204.749274969101
146
+ },
147
+ {
148
+ "step": 14500,
149
+ "loss": 0.15456590056419373,
150
+ "elapsed": 1248.6135349273682
151
+ },
152
+ {
153
+ "step": 15000,
154
+ "loss": 0.17116491496562958,
155
+ "elapsed": 1291.4715003967285
156
+ },
157
+ {
158
+ "step": 15500,
159
+ "loss": 0.137327641248703,
160
+ "elapsed": 1334.6077785491943
161
+ },
162
+ {
163
+ "step": 16000,
164
+ "loss": 0.13515342772006989,
165
+ "elapsed": 1378.242861032486
166
+ },
167
+ {
168
+ "step": 16500,
169
+ "loss": 0.143902987241745,
170
+ "elapsed": 1421.2837686538696
171
+ },
172
+ {
173
+ "step": 17000,
174
+ "loss": 0.14174163341522217,
175
+ "elapsed": 1465.0907399654388
176
+ },
177
+ {
178
+ "step": 17500,
179
+ "loss": 0.14444664120674133,
180
+ "elapsed": 1509.468195438385
181
+ },
182
+ {
183
+ "step": 18000,
184
+ "loss": 0.1439598649740219,
185
+ "elapsed": 1553.0655224323273
186
+ },
187
+ {
188
+ "step": 18500,
189
+ "loss": 0.153713196516037,
190
+ "elapsed": 1597.3860943317413
191
+ },
192
+ {
193
+ "step": 19000,
194
+ "loss": 0.14797145128250122,
195
+ "elapsed": 1641.5034580230713
196
+ },
197
+ {
198
+ "step": 19500,
199
+ "loss": 0.1756235659122467,
200
+ "elapsed": 1684.5155091285706
201
+ },
202
+ {
203
+ "step": 20000,
204
+ "loss": 0.16604095697402954,
205
+ "elapsed": 1727.5247118473053
206
+ },
207
+ {
208
+ "step": 20500,
209
+ "loss": 0.1382945328950882,
210
+ "elapsed": 1771.6523303985596
211
+ },
212
+ {
213
+ "step": 21000,
214
+ "loss": 0.16012033820152283,
215
+ "elapsed": 1814.806452035904
216
+ },
217
+ {
218
+ "step": 21500,
219
+ "loss": 0.14267100393772125,
220
+ "elapsed": 1857.8734240531921
221
+ },
222
+ {
223
+ "step": 22000,
224
+ "loss": 0.1692996323108673,
225
+ "elapsed": 1902.5696635246277
226
+ },
227
+ {
228
+ "step": 22500,
229
+ "loss": 0.15417799353599548,
230
+ "elapsed": 1946.2307980060577
231
+ },
232
+ {
233
+ "step": 23000,
234
+ "loss": 0.16970615088939667,
235
+ "elapsed": 1989.7159111499786
236
+ },
237
+ {
238
+ "step": 23500,
239
+ "loss": 0.14763763546943665,
240
+ "elapsed": 2033.4999907016754
241
+ },
242
+ {
243
+ "step": 24000,
244
+ "loss": 0.168531596660614,
245
+ "elapsed": 2077.4635944366455
246
+ },
247
+ {
248
+ "step": 24500,
249
+ "loss": 0.16383926570415497,
250
+ "elapsed": 2121.584620475769
251
+ },
252
+ {
253
+ "step": 25000,
254
+ "loss": 0.13870900869369507,
255
+ "elapsed": 2165.6783912181854
256
+ },
257
+ {
258
+ "step": 25500,
259
+ "loss": 0.1494102030992508,
260
+ "elapsed": 2209.5933837890625
261
+ },
262
+ {
263
+ "step": 26000,
264
+ "loss": 0.15371178090572357,
265
+ "elapsed": 2253.022980928421
266
+ },
267
+ {
268
+ "step": 26500,
269
+ "loss": 0.15978066623210907,
270
+ "elapsed": 2297.043392896652
271
+ },
272
+ {
273
+ "step": 27000,
274
+ "loss": 0.16112983226776123,
275
+ "elapsed": 2340.2668483257294
276
+ },
277
+ {
278
+ "step": 27500,
279
+ "loss": 0.16778525710105896,
280
+ "elapsed": 2383.6698849201202
281
+ },
282
+ {
283
+ "step": 28000,
284
+ "loss": 0.13977855443954468,
285
+ "elapsed": 2427.16020488739
286
+ },
287
+ {
288
+ "step": 28500,
289
+ "loss": 0.16271978616714478,
290
+ "elapsed": 2471.589874982834
291
+ },
292
+ {
293
+ "step": 29000,
294
+ "loss": 0.14824886620044708,
295
+ "elapsed": 2514.4410626888275
296
+ },
297
+ {
298
+ "step": 29500,
299
+ "loss": 0.1602301150560379,
300
+ "elapsed": 2557.0406301021576
301
+ }
302
+ ]