Robotics
LeRobot
Safetensors
act
Basr88 commited on
Commit
4a50b57
·
verified ·
1 Parent(s): ded3cd8

Upload policy weights, train config and readme

Browse files
Files changed (4) hide show
  1. README.md +62 -0
  2. config.json +63 -0
  3. model.safetensors +3 -0
  4. train_config.json +696 -0
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets: Basr88/Merged_800_2
3
+ library_name: lerobot
4
+ license: apache-2.0
5
+ model_name: act
6
+ pipeline_tag: robotics
7
+ tags:
8
+ - lerobot
9
+ - robotics
10
+ - act
11
+ ---
12
+
13
+ # Model Card for act
14
+
15
+ <!-- Provide a quick summary of what the model is/does. -->
16
+
17
+
18
+ [Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates.
19
+
20
+
21
+ This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
22
+ See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
23
+
24
+ ---
25
+
26
+ ## How to Get Started with the Model
27
+
28
+ For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy).
29
+ Below is the short version on how to train and run inference/eval:
30
+
31
+ ### Train from scratch
32
+
33
+ ```bash
34
+ lerobot-train \
35
+ --dataset.repo_id=${HF_USER}/<dataset> \
36
+ --policy.type=act \
37
+ --output_dir=outputs/train/<desired_policy_repo_id> \
38
+ --job_name=lerobot_training \
39
+ --policy.device=cuda \
40
+ --policy.repo_id=${HF_USER}/<desired_policy_repo_id>
41
+ --wandb.enable=true
42
+ ```
43
+
44
+ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
45
+
46
+ ### Evaluate the policy/run inference
47
+
48
+ ```bash
49
+ lerobot-record \
50
+ --robot.type=so100_follower \
51
+ --dataset.repo_id=<hf_user>/eval_<dataset> \
52
+ --policy.path=<hf_user>/<desired_policy_repo_id> \
53
+ --episodes=10
54
+ ```
55
+
56
+ Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
57
+
58
+ ---
59
+
60
+ ## Model Details
61
+
62
+ - **License:** apache-2.0
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "act",
3
+ "n_obs_steps": 1,
4
+ "input_features": {
5
+ "observation.state": {
6
+ "type": "STATE",
7
+ "shape": [
8
+ 6
9
+ ]
10
+ },
11
+ "observation.images.front": {
12
+ "type": "VISUAL",
13
+ "shape": [
14
+ 3,
15
+ 480,
16
+ 640
17
+ ]
18
+ }
19
+ },
20
+ "output_features": {
21
+ "action": {
22
+ "type": "ACTION",
23
+ "shape": [
24
+ 6
25
+ ]
26
+ }
27
+ },
28
+ "device": "cuda",
29
+ "use_amp": false,
30
+ "use_peft": false,
31
+ "push_to_hub": true,
32
+ "repo_id": "Basr88/Merged_dataset_4",
33
+ "private": null,
34
+ "tags": null,
35
+ "license": null,
36
+ "pretrained_path": null,
37
+ "chunk_size": 100,
38
+ "n_action_steps": 100,
39
+ "normalization_mapping": {
40
+ "VISUAL": "MEAN_STD",
41
+ "STATE": "MEAN_STD",
42
+ "ACTION": "MEAN_STD"
43
+ },
44
+ "vision_backbone": "resnet18",
45
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
46
+ "replace_final_stride_with_dilation": false,
47
+ "pre_norm": false,
48
+ "dim_model": 512,
49
+ "n_heads": 8,
50
+ "dim_feedforward": 3200,
51
+ "feedforward_activation": "relu",
52
+ "n_encoder_layers": 4,
53
+ "n_decoder_layers": 1,
54
+ "use_vae": true,
55
+ "latent_dim": 32,
56
+ "n_vae_encoder_layers": 4,
57
+ "temporal_ensemble_coeff": null,
58
+ "dropout": 0.1,
59
+ "kl_weight": 10.0,
60
+ "optimizer_lr": 1e-05,
61
+ "optimizer_weight_decay": 0.0001,
62
+ "optimizer_lr_backbone": 1e-05
63
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9578e1754f657f192e19c5730112abd2ef935f58380929dabde6fe9463868e49
3
+ size 206699736
train_config.json ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": {
3
+ "repo_id": "Basr88/Merged_800_2",
4
+ "root": null,
5
+ "episodes": [
6
+ 444,
7
+ 731,
8
+ 6,
9
+ 601,
10
+ 586,
11
+ 725,
12
+ 416,
13
+ 67,
14
+ 496,
15
+ 392,
16
+ 342,
17
+ 300,
18
+ 785,
19
+ 434,
20
+ 710,
21
+ 108,
22
+ 177,
23
+ 415,
24
+ 619,
25
+ 171,
26
+ 658,
27
+ 560,
28
+ 763,
29
+ 764,
30
+ 635,
31
+ 758,
32
+ 548,
33
+ 448,
34
+ 137,
35
+ 59,
36
+ 344,
37
+ 335,
38
+ 253,
39
+ 657,
40
+ 593,
41
+ 71,
42
+ 21,
43
+ 669,
44
+ 231,
45
+ 293,
46
+ 433,
47
+ 277,
48
+ 131,
49
+ 303,
50
+ 715,
51
+ 523,
52
+ 732,
53
+ 81,
54
+ 693,
55
+ 738,
56
+ 32,
57
+ 99,
58
+ 324,
59
+ 720,
60
+ 698,
61
+ 643,
62
+ 609,
63
+ 594,
64
+ 76,
65
+ 84,
66
+ 296,
67
+ 632,
68
+ 306,
69
+ 75,
70
+ 311,
71
+ 753,
72
+ 114,
73
+ 289,
74
+ 430,
75
+ 672,
76
+ 435,
77
+ 535,
78
+ 298,
79
+ 280,
80
+ 723,
81
+ 552,
82
+ 459,
83
+ 272,
84
+ 29,
85
+ 684,
86
+ 572,
87
+ 227,
88
+ 441,
89
+ 760,
90
+ 189,
91
+ 573,
92
+ 405,
93
+ 489,
94
+ 425,
95
+ 15,
96
+ 349,
97
+ 506,
98
+ 183,
99
+ 125,
100
+ 510,
101
+ 610,
102
+ 505,
103
+ 39,
104
+ 508,
105
+ 642,
106
+ 624,
107
+ 403,
108
+ 676,
109
+ 395,
110
+ 9,
111
+ 238,
112
+ 126,
113
+ 668,
114
+ 597,
115
+ 770,
116
+ 302,
117
+ 418,
118
+ 682,
119
+ 718,
120
+ 616,
121
+ 475,
122
+ 608,
123
+ 603,
124
+ 143,
125
+ 185,
126
+ 182,
127
+ 389,
128
+ 413,
129
+ 336,
130
+ 359,
131
+ 366,
132
+ 480,
133
+ 540,
134
+ 122,
135
+ 730,
136
+ 796,
137
+ 134,
138
+ 757,
139
+ 148,
140
+ 542,
141
+ 12,
142
+ 559,
143
+ 60,
144
+ 229,
145
+ 90,
146
+ 176,
147
+ 74,
148
+ 320,
149
+ 761,
150
+ 279,
151
+ 13,
152
+ 467,
153
+ 788,
154
+ 239,
155
+ 562,
156
+ 736,
157
+ 541,
158
+ 144,
159
+ 391,
160
+ 631,
161
+ 769,
162
+ 246,
163
+ 561,
164
+ 504,
165
+ 742,
166
+ 53,
167
+ 719,
168
+ 3,
169
+ 604,
170
+ 717,
171
+ 46,
172
+ 707,
173
+ 22,
174
+ 347,
175
+ 283,
176
+ 452,
177
+ 153,
178
+ 368,
179
+ 97,
180
+ 644,
181
+ 49,
182
+ 524,
183
+ 690,
184
+ 666,
185
+ 269,
186
+ 274,
187
+ 107,
188
+ 312,
189
+ 149,
190
+ 607,
191
+ 95,
192
+ 514,
193
+ 266,
194
+ 522,
195
+ 756,
196
+ 410,
197
+ 61,
198
+ 388,
199
+ 353,
200
+ 590,
201
+ 217,
202
+ 193,
203
+ 94,
204
+ 294,
205
+ 529,
206
+ 611,
207
+ 281,
208
+ 352,
209
+ 721,
210
+ 665,
211
+ 722,
212
+ 778,
213
+ 356,
214
+ 646,
215
+ 711,
216
+ 579,
217
+ 50,
218
+ 640,
219
+ 321,
220
+ 614,
221
+ 583,
222
+ 408,
223
+ 629,
224
+ 376,
225
+ 432,
226
+ 204,
227
+ 650,
228
+ 453,
229
+ 194,
230
+ 119,
231
+ 737,
232
+ 549,
233
+ 196,
234
+ 315,
235
+ 599,
236
+ 132,
237
+ 145,
238
+ 591,
239
+ 54,
240
+ 292,
241
+ 175,
242
+ 7,
243
+ 243,
244
+ 337,
245
+ 772,
246
+ 654,
247
+ 129,
248
+ 438,
249
+ 613,
250
+ 265,
251
+ 139,
252
+ 458,
253
+ 304,
254
+ 398,
255
+ 553,
256
+ 313,
257
+ 261,
258
+ 257,
259
+ 254,
260
+ 663,
261
+ 301,
262
+ 38,
263
+ 558,
264
+ 35,
265
+ 89,
266
+ 792,
267
+ 568,
268
+ 206,
269
+ 248,
270
+ 495,
271
+ 533,
272
+ 521,
273
+ 240,
274
+ 534,
275
+ 709,
276
+ 498,
277
+ 492,
278
+ 714,
279
+ 621,
280
+ 188,
281
+ 531,
282
+ 497,
283
+ 526,
284
+ 437,
285
+ 235,
286
+ 519,
287
+ 746,
288
+ 214,
289
+ 641,
290
+ 158,
291
+ 166,
292
+ 701,
293
+ 207,
294
+ 726,
295
+ 645,
296
+ 197,
297
+ 617,
298
+ 637,
299
+ 93,
300
+ 318,
301
+ 400,
302
+ 449,
303
+ 402,
304
+ 515,
305
+ 502,
306
+ 140,
307
+ 488,
308
+ 361,
309
+ 96,
310
+ 685,
311
+ 520,
312
+ 588,
313
+ 428,
314
+ 799,
315
+ 110,
316
+ 308,
317
+ 716,
318
+ 696,
319
+ 705,
320
+ 345,
321
+ 439,
322
+ 307,
323
+ 695,
324
+ 73,
325
+ 636,
326
+ 86,
327
+ 466,
328
+ 369,
329
+ 462,
330
+ 782,
331
+ 85,
332
+ 278,
333
+ 426,
334
+ 771,
335
+ 530,
336
+ 592,
337
+ 270,
338
+ 331,
339
+ 759,
340
+ 674,
341
+ 712,
342
+ 370,
343
+ 355,
344
+ 332,
345
+ 170,
346
+ 199,
347
+ 547,
348
+ 589,
349
+ 169,
350
+ 781,
351
+ 151,
352
+ 4,
353
+ 380,
354
+ 406,
355
+ 536,
356
+ 454,
357
+ 252,
358
+ 208,
359
+ 679,
360
+ 659,
361
+ 503,
362
+ 576,
363
+ 124,
364
+ 156,
365
+ 442,
366
+ 186,
367
+ 681,
368
+ 440,
369
+ 223,
370
+ 457,
371
+ 363,
372
+ 136,
373
+ 484,
374
+ 485,
375
+ 299,
376
+ 443,
377
+ 284,
378
+ 688,
379
+ 27,
380
+ 138,
381
+ 202,
382
+ 386,
383
+ 546,
384
+ 686,
385
+ 135,
386
+ 655,
387
+ 638,
388
+ 739,
389
+ 436,
390
+ 451,
391
+ 555,
392
+ 83,
393
+ 351,
394
+ 456,
395
+ 314,
396
+ 622,
397
+ 282,
398
+ 20,
399
+ 66,
400
+ 211,
401
+ 768,
402
+ 264,
403
+ 791,
404
+ 178,
405
+ 8,
406
+ 776,
407
+ 255,
408
+ 330,
409
+ 752,
410
+ 774,
411
+ 678,
412
+ 578,
413
+ 116,
414
+ 1,
415
+ 476,
416
+ 371,
417
+ 326,
418
+ 790,
419
+ 639,
420
+ 163,
421
+ 664,
422
+ 507,
423
+ 581,
424
+ 411,
425
+ 630,
426
+ 319,
427
+ 226,
428
+ 794,
429
+ 338,
430
+ 130,
431
+ 218,
432
+ 346,
433
+ 662,
434
+ 241,
435
+ 165,
436
+ 48,
437
+ 626,
438
+ 728,
439
+ 286,
440
+ 113,
441
+ 215,
442
+ 747,
443
+ 128,
444
+ 111,
445
+ 787,
446
+ 587,
447
+ 627,
448
+ 551,
449
+ 112,
450
+ 699,
451
+ 56,
452
+ 564,
453
+ 494,
454
+ 98,
455
+ 734,
456
+ 571,
457
+ 667,
458
+ 779,
459
+ 343,
460
+ 585,
461
+ 234,
462
+ 651,
463
+ 472,
464
+ 31,
465
+ 192,
466
+ 103,
467
+ 478,
468
+ 58,
469
+ 735,
470
+ 23,
471
+ 754,
472
+ 382,
473
+ 80,
474
+ 618,
475
+ 57,
476
+ 364,
477
+ 584,
478
+ 180,
479
+ 26,
480
+ 201,
481
+ 276,
482
+ 713,
483
+ 350,
484
+ 288,
485
+ 647,
486
+ 419,
487
+ 538,
488
+ 154,
489
+ 620,
490
+ 367,
491
+ 577,
492
+ 450,
493
+ 748,
494
+ 40,
495
+ 147,
496
+ 16,
497
+ 675,
498
+ 471,
499
+ 729,
500
+ 291,
501
+ 677,
502
+ 393,
503
+ 115,
504
+ 205,
505
+ 82
506
+ ],
507
+ "image_transforms": {
508
+ "enable": false,
509
+ "max_num_transforms": 3,
510
+ "random_order": false,
511
+ "tfs": {
512
+ "brightness": {
513
+ "weight": 1.0,
514
+ "type": "ColorJitter",
515
+ "kwargs": {
516
+ "brightness": [
517
+ 0.8,
518
+ 1.2
519
+ ]
520
+ }
521
+ },
522
+ "contrast": {
523
+ "weight": 1.0,
524
+ "type": "ColorJitter",
525
+ "kwargs": {
526
+ "contrast": [
527
+ 0.8,
528
+ 1.2
529
+ ]
530
+ }
531
+ },
532
+ "saturation": {
533
+ "weight": 1.0,
534
+ "type": "ColorJitter",
535
+ "kwargs": {
536
+ "saturation": [
537
+ 0.5,
538
+ 1.5
539
+ ]
540
+ }
541
+ },
542
+ "hue": {
543
+ "weight": 1.0,
544
+ "type": "ColorJitter",
545
+ "kwargs": {
546
+ "hue": [
547
+ -0.05,
548
+ 0.05
549
+ ]
550
+ }
551
+ },
552
+ "sharpness": {
553
+ "weight": 1.0,
554
+ "type": "SharpnessJitter",
555
+ "kwargs": {
556
+ "sharpness": [
557
+ 0.5,
558
+ 1.5
559
+ ]
560
+ }
561
+ },
562
+ "affine": {
563
+ "weight": 1.0,
564
+ "type": "RandomAffine",
565
+ "kwargs": {
566
+ "degrees": [
567
+ -5.0,
568
+ 5.0
569
+ ],
570
+ "translate": [
571
+ 0.05,
572
+ 0.05
573
+ ]
574
+ }
575
+ }
576
+ }
577
+ },
578
+ "revision": null,
579
+ "use_imagenet_stats": true,
580
+ "video_backend": "pyav",
581
+ "streaming": false
582
+ },
583
+ "env": null,
584
+ "policy": {
585
+ "type": "act",
586
+ "n_obs_steps": 1,
587
+ "input_features": {
588
+ "observation.state": {
589
+ "type": "STATE",
590
+ "shape": [
591
+ 6
592
+ ]
593
+ },
594
+ "observation.images.front": {
595
+ "type": "VISUAL",
596
+ "shape": [
597
+ 3,
598
+ 480,
599
+ 640
600
+ ]
601
+ }
602
+ },
603
+ "output_features": {
604
+ "action": {
605
+ "type": "ACTION",
606
+ "shape": [
607
+ 6
608
+ ]
609
+ }
610
+ },
611
+ "device": "cuda",
612
+ "use_amp": false,
613
+ "use_peft": false,
614
+ "push_to_hub": true,
615
+ "repo_id": "Basr88/Merged_dataset_4",
616
+ "private": null,
617
+ "tags": null,
618
+ "license": null,
619
+ "pretrained_path": null,
620
+ "chunk_size": 100,
621
+ "n_action_steps": 100,
622
+ "normalization_mapping": {
623
+ "VISUAL": "MEAN_STD",
624
+ "STATE": "MEAN_STD",
625
+ "ACTION": "MEAN_STD"
626
+ },
627
+ "vision_backbone": "resnet18",
628
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
629
+ "replace_final_stride_with_dilation": false,
630
+ "pre_norm": false,
631
+ "dim_model": 512,
632
+ "n_heads": 8,
633
+ "dim_feedforward": 3200,
634
+ "feedforward_activation": "relu",
635
+ "n_encoder_layers": 4,
636
+ "n_decoder_layers": 1,
637
+ "use_vae": true,
638
+ "latent_dim": 32,
639
+ "n_vae_encoder_layers": 4,
640
+ "temporal_ensemble_coeff": null,
641
+ "dropout": 0.1,
642
+ "kl_weight": 10.0,
643
+ "optimizer_lr": 1e-05,
644
+ "optimizer_weight_decay": 0.0001,
645
+ "optimizer_lr_backbone": 1e-05
646
+ },
647
+ "output_dir": "/home/basr88/data/outputs/train/act_gluestick_merged_500",
648
+ "job_name": "act_gluestick",
649
+ "resume": false,
650
+ "seed": 1000,
651
+ "cudnn_deterministic": false,
652
+ "num_workers": 4,
653
+ "batch_size": 8,
654
+ "steps": 150000,
655
+ "eval_freq": 20000,
656
+ "log_freq": 200,
657
+ "tolerance_s": 0.0001,
658
+ "save_checkpoint": true,
659
+ "save_freq": 20000,
660
+ "use_policy_training_preset": true,
661
+ "optimizer": {
662
+ "type": "adamw",
663
+ "lr": 1e-05,
664
+ "weight_decay": 0.0001,
665
+ "grad_clip_norm": 10.0,
666
+ "betas": [
667
+ 0.9,
668
+ 0.999
669
+ ],
670
+ "eps": 1e-08
671
+ },
672
+ "scheduler": null,
673
+ "eval": {
674
+ "n_episodes": 50,
675
+ "batch_size": 50,
676
+ "use_async_envs": false
677
+ },
678
+ "wandb": {
679
+ "enable": false,
680
+ "disable_artifact": false,
681
+ "project": "lerobot",
682
+ "entity": null,
683
+ "notes": null,
684
+ "run_id": null,
685
+ "mode": null,
686
+ "add_tags": true
687
+ },
688
+ "peft": null,
689
+ "use_rabc": false,
690
+ "rabc_progress_path": null,
691
+ "rabc_kappa": 0.01,
692
+ "rabc_epsilon": 1e-06,
693
+ "rabc_head_mode": "sparse",
694
+ "rename_map": {},
695
+ "checkpoint_path": null
696
+ }