Robotics
LeRobot
Safetensors
act
Basr88 commited on
Commit
db135fd
·
verified ·
1 Parent(s): d236303

Upload policy weights, train config and readme

Browse files
Files changed (4) hide show
  1. README.md +62 -0
  2. config.json +63 -0
  3. model.safetensors +3 -0
  4. train_config.json +596 -0
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets: Basr88/Merged_800_2
3
+ library_name: lerobot
4
+ license: apache-2.0
5
+ model_name: act
6
+ pipeline_tag: robotics
7
+ tags:
8
+ - act
9
+ - robotics
10
+ - lerobot
11
+ ---
12
+
13
+ # Model Card for act
14
+
15
+ <!-- Provide a quick summary of what the model is/does. -->
16
+
17
+
18
+ [Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates.
19
+
20
+
21
+ This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
22
+ See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
23
+
24
+ ---
25
+
26
+ ## How to Get Started with the Model
27
+
28
+ For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy).
29
+ Below is the short version on how to train and run inference/eval:
30
+
31
+ ### Train from scratch
32
+
33
+ ```bash
34
+ lerobot-train \
35
+ --dataset.repo_id=${HF_USER}/<dataset> \
36
+ --policy.type=act \
37
+ --output_dir=outputs/train/<desired_policy_repo_id> \
38
+ --job_name=lerobot_training \
39
+ --policy.device=cuda \
40
+ --policy.repo_id=${HF_USER}/<desired_policy_repo_id>
41
+ --wandb.enable=true
42
+ ```
43
+
44
+ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
45
+
46
+ ### Evaluate the policy/run inference
47
+
48
+ ```bash
49
+ lerobot-record \
50
+ --robot.type=so100_follower \
51
+ --dataset.repo_id=<hf_user>/eval_<dataset> \
52
+ --policy.path=<hf_user>/<desired_policy_repo_id> \
53
+ --episodes=10
54
+ ```
55
+
56
+ Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
57
+
58
+ ---
59
+
60
+ ## Model Details
61
+
62
+ - **License:** apache-2.0
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "act",
3
+ "n_obs_steps": 1,
4
+ "input_features": {
5
+ "observation.state": {
6
+ "type": "STATE",
7
+ "shape": [
8
+ 6
9
+ ]
10
+ },
11
+ "observation.images.front": {
12
+ "type": "VISUAL",
13
+ "shape": [
14
+ 3,
15
+ 480,
16
+ 640
17
+ ]
18
+ }
19
+ },
20
+ "output_features": {
21
+ "action": {
22
+ "type": "ACTION",
23
+ "shape": [
24
+ 6
25
+ ]
26
+ }
27
+ },
28
+ "device": "cuda",
29
+ "use_amp": false,
30
+ "use_peft": false,
31
+ "push_to_hub": true,
32
+ "repo_id": "Basr88/Merged_dataset_8",
33
+ "private": null,
34
+ "tags": null,
35
+ "license": null,
36
+ "pretrained_path": null,
37
+ "chunk_size": 100,
38
+ "n_action_steps": 100,
39
+ "normalization_mapping": {
40
+ "VISUAL": "MEAN_STD",
41
+ "STATE": "MEAN_STD",
42
+ "ACTION": "MEAN_STD"
43
+ },
44
+ "vision_backbone": "resnet18",
45
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
46
+ "replace_final_stride_with_dilation": false,
47
+ "pre_norm": false,
48
+ "dim_model": 512,
49
+ "n_heads": 8,
50
+ "dim_feedforward": 3200,
51
+ "feedforward_activation": "relu",
52
+ "n_encoder_layers": 4,
53
+ "n_decoder_layers": 1,
54
+ "use_vae": true,
55
+ "latent_dim": 32,
56
+ "n_vae_encoder_layers": 4,
57
+ "temporal_ensemble_coeff": null,
58
+ "dropout": 0.1,
59
+ "kl_weight": 10.0,
60
+ "optimizer_lr": 1e-05,
61
+ "optimizer_weight_decay": 0.0001,
62
+ "optimizer_lr_backbone": 1e-05
63
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331e09852c2bf6292f66c5b7d2f98e2e25b7edbb7b3a2cfb34ad5fd9d54cef77
3
+ size 206699736
train_config.json ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": {
3
+ "repo_id": "Basr88/Merged_800_2",
4
+ "root": null,
5
+ "episodes": [
6
+ 294,
7
+ 18,
8
+ 651,
9
+ 130,
10
+ 293,
11
+ 647,
12
+ 291,
13
+ 185,
14
+ 694,
15
+ 466,
16
+ 190,
17
+ 146,
18
+ 303,
19
+ 596,
20
+ 285,
21
+ 202,
22
+ 248,
23
+ 691,
24
+ 564,
25
+ 589,
26
+ 492,
27
+ 10,
28
+ 186,
29
+ 743,
30
+ 422,
31
+ 455,
32
+ 221,
33
+ 773,
34
+ 548,
35
+ 195,
36
+ 259,
37
+ 565,
38
+ 545,
39
+ 442,
40
+ 749,
41
+ 114,
42
+ 542,
43
+ 101,
44
+ 61,
45
+ 500,
46
+ 556,
47
+ 271,
48
+ 732,
49
+ 251,
50
+ 412,
51
+ 458,
52
+ 482,
53
+ 634,
54
+ 727,
55
+ 196,
56
+ 700,
57
+ 613,
58
+ 757,
59
+ 42,
60
+ 332,
61
+ 338,
62
+ 409,
63
+ 258,
64
+ 561,
65
+ 41,
66
+ 665,
67
+ 677,
68
+ 348,
69
+ 770,
70
+ 121,
71
+ 430,
72
+ 328,
73
+ 381,
74
+ 66,
75
+ 527,
76
+ 446,
77
+ 126,
78
+ 206,
79
+ 85,
80
+ 388,
81
+ 626,
82
+ 641,
83
+ 687,
84
+ 9,
85
+ 169,
86
+ 600,
87
+ 33,
88
+ 284,
89
+ 278,
90
+ 111,
91
+ 153,
92
+ 636,
93
+ 539,
94
+ 8,
95
+ 423,
96
+ 679,
97
+ 142,
98
+ 39,
99
+ 70,
100
+ 650,
101
+ 48,
102
+ 756,
103
+ 241,
104
+ 23,
105
+ 484,
106
+ 502,
107
+ 580,
108
+ 403,
109
+ 230,
110
+ 566,
111
+ 189,
112
+ 530,
113
+ 427,
114
+ 334,
115
+ 171,
116
+ 380,
117
+ 668,
118
+ 587,
119
+ 87,
120
+ 467,
121
+ 22,
122
+ 617,
123
+ 748,
124
+ 180,
125
+ 397,
126
+ 65,
127
+ 345,
128
+ 299,
129
+ 611,
130
+ 551,
131
+ 290,
132
+ 78,
133
+ 383,
134
+ 287,
135
+ 740,
136
+ 624,
137
+ 220,
138
+ 209,
139
+ 374,
140
+ 417,
141
+ 312,
142
+ 481,
143
+ 553,
144
+ 407,
145
+ 473,
146
+ 438,
147
+ 778,
148
+ 385,
149
+ 663,
150
+ 281,
151
+ 791,
152
+ 567,
153
+ 168,
154
+ 493,
155
+ 597,
156
+ 30,
157
+ 103,
158
+ 582,
159
+ 27,
160
+ 761,
161
+ 755,
162
+ 779,
163
+ 274,
164
+ 5,
165
+ 346,
166
+ 262,
167
+ 655,
168
+ 264,
169
+ 225,
170
+ 292,
171
+ 421,
172
+ 62,
173
+ 166,
174
+ 769,
175
+ 648,
176
+ 84,
177
+ 245,
178
+ 243,
179
+ 35,
180
+ 419,
181
+ 642,
182
+ 776,
183
+ 0,
184
+ 184,
185
+ 376,
186
+ 224,
187
+ 763,
188
+ 675,
189
+ 558,
190
+ 550,
191
+ 640,
192
+ 440,
193
+ 486,
194
+ 237,
195
+ 199,
196
+ 750,
197
+ 413,
198
+ 754,
199
+ 733,
200
+ 676,
201
+ 216,
202
+ 621,
203
+ 112,
204
+ 115,
205
+ 239,
206
+ 148,
207
+ 505,
208
+ 158,
209
+ 786,
210
+ 777,
211
+ 129,
212
+ 639,
213
+ 362,
214
+ 360,
215
+ 339,
216
+ 705,
217
+ 249,
218
+ 83,
219
+ 167,
220
+ 147,
221
+ 375,
222
+ 116,
223
+ 14,
224
+ 678,
225
+ 231,
226
+ 140,
227
+ 170,
228
+ 433,
229
+ 570,
230
+ 483,
231
+ 7,
232
+ 498,
233
+ 477,
234
+ 628,
235
+ 363,
236
+ 98,
237
+ 586,
238
+ 767,
239
+ 656,
240
+ 522,
241
+ 207,
242
+ 32,
243
+ 214,
244
+ 144,
245
+ 315,
246
+ 797,
247
+ 357,
248
+ 784,
249
+ 2,
250
+ 75,
251
+ 415,
252
+ 135,
253
+ 19,
254
+ 51,
255
+ 329,
256
+ 437,
257
+ 289,
258
+ 418,
259
+ 758,
260
+ 183,
261
+ 526,
262
+ 429,
263
+ 208,
264
+ 560,
265
+ 664,
266
+ 276,
267
+ 627,
268
+ 529,
269
+ 674,
270
+ 480,
271
+ 305,
272
+ 444,
273
+ 173,
274
+ 313,
275
+ 331,
276
+ 145,
277
+ 325,
278
+ 635,
279
+ 514,
280
+ 704,
281
+ 390,
282
+ 540,
283
+ 396,
284
+ 682,
285
+ 428,
286
+ 425,
287
+ 269,
288
+ 298,
289
+ 222,
290
+ 535,
291
+ 603,
292
+ 246,
293
+ 775,
294
+ 164,
295
+ 752,
296
+ 798,
297
+ 592,
298
+ 400,
299
+ 358,
300
+ 267,
301
+ 250,
302
+ 771,
303
+ 254,
304
+ 373,
305
+ 468,
306
+ 163,
307
+ 464,
308
+ 584,
309
+ 102,
310
+ 3,
311
+ 93,
312
+ 300,
313
+ 160,
314
+ 471,
315
+ 118,
316
+ 336,
317
+ 585,
318
+ 534,
319
+ 175,
320
+ 669,
321
+ 697,
322
+ 354,
323
+ 552,
324
+ 571,
325
+ 172,
326
+ 426,
327
+ 17,
328
+ 310,
329
+ 314,
330
+ 598,
331
+ 105,
332
+ 701,
333
+ 124,
334
+ 319,
335
+ 470,
336
+ 389,
337
+ 452,
338
+ 681,
339
+ 43,
340
+ 309,
341
+ 394,
342
+ 99,
343
+ 53,
344
+ 741,
345
+ 462,
346
+ 670,
347
+ 781,
348
+ 36,
349
+ 524,
350
+ 745,
351
+ 787,
352
+ 295,
353
+ 660,
354
+ 316,
355
+ 731,
356
+ 123,
357
+ 633,
358
+ 459,
359
+ 335,
360
+ 521,
361
+ 232,
362
+ 12,
363
+ 192,
364
+ 368,
365
+ 68,
366
+ 49,
367
+ 141,
368
+ 401,
369
+ 538,
370
+ 495,
371
+ 594,
372
+ 67,
373
+ 277,
374
+ 131,
375
+ 461,
376
+ 79,
377
+ 785,
378
+ 6,
379
+ 395,
380
+ 450,
381
+ 795,
382
+ 15,
383
+ 24,
384
+ 789,
385
+ 593,
386
+ 725,
387
+ 509,
388
+ 513,
389
+ 179,
390
+ 557,
391
+ 52,
392
+ 94,
393
+ 410,
394
+ 653,
395
+ 317,
396
+ 579,
397
+ 583,
398
+ 408,
399
+ 620,
400
+ 1,
401
+ 503,
402
+ 434,
403
+ 457,
404
+ 44,
405
+ 178
406
+ ],
407
+ "image_transforms": {
408
+ "enable": false,
409
+ "max_num_transforms": 3,
410
+ "random_order": false,
411
+ "tfs": {
412
+ "brightness": {
413
+ "weight": 1.0,
414
+ "type": "ColorJitter",
415
+ "kwargs": {
416
+ "brightness": [
417
+ 0.8,
418
+ 1.2
419
+ ]
420
+ }
421
+ },
422
+ "contrast": {
423
+ "weight": 1.0,
424
+ "type": "ColorJitter",
425
+ "kwargs": {
426
+ "contrast": [
427
+ 0.8,
428
+ 1.2
429
+ ]
430
+ }
431
+ },
432
+ "saturation": {
433
+ "weight": 1.0,
434
+ "type": "ColorJitter",
435
+ "kwargs": {
436
+ "saturation": [
437
+ 0.5,
438
+ 1.5
439
+ ]
440
+ }
441
+ },
442
+ "hue": {
443
+ "weight": 1.0,
444
+ "type": "ColorJitter",
445
+ "kwargs": {
446
+ "hue": [
447
+ -0.05,
448
+ 0.05
449
+ ]
450
+ }
451
+ },
452
+ "sharpness": {
453
+ "weight": 1.0,
454
+ "type": "SharpnessJitter",
455
+ "kwargs": {
456
+ "sharpness": [
457
+ 0.5,
458
+ 1.5
459
+ ]
460
+ }
461
+ },
462
+ "affine": {
463
+ "weight": 1.0,
464
+ "type": "RandomAffine",
465
+ "kwargs": {
466
+ "degrees": [
467
+ -5.0,
468
+ 5.0
469
+ ],
470
+ "translate": [
471
+ 0.05,
472
+ 0.05
473
+ ]
474
+ }
475
+ }
476
+ }
477
+ },
478
+ "revision": null,
479
+ "use_imagenet_stats": true,
480
+ "video_backend": "pyav",
481
+ "streaming": false
482
+ },
483
+ "env": null,
484
+ "policy": {
485
+ "type": "act",
486
+ "n_obs_steps": 1,
487
+ "input_features": {
488
+ "observation.state": {
489
+ "type": "STATE",
490
+ "shape": [
491
+ 6
492
+ ]
493
+ },
494
+ "observation.images.front": {
495
+ "type": "VISUAL",
496
+ "shape": [
497
+ 3,
498
+ 480,
499
+ 640
500
+ ]
501
+ }
502
+ },
503
+ "output_features": {
504
+ "action": {
505
+ "type": "ACTION",
506
+ "shape": [
507
+ 6
508
+ ]
509
+ }
510
+ },
511
+ "device": "cuda",
512
+ "use_amp": false,
513
+ "use_peft": false,
514
+ "push_to_hub": true,
515
+ "repo_id": "Basr88/Merged_dataset_8",
516
+ "private": null,
517
+ "tags": null,
518
+ "license": null,
519
+ "pretrained_path": null,
520
+ "chunk_size": 100,
521
+ "n_action_steps": 100,
522
+ "normalization_mapping": {
523
+ "VISUAL": "MEAN_STD",
524
+ "STATE": "MEAN_STD",
525
+ "ACTION": "MEAN_STD"
526
+ },
527
+ "vision_backbone": "resnet18",
528
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
529
+ "replace_final_stride_with_dilation": false,
530
+ "pre_norm": false,
531
+ "dim_model": 512,
532
+ "n_heads": 8,
533
+ "dim_feedforward": 3200,
534
+ "feedforward_activation": "relu",
535
+ "n_encoder_layers": 4,
536
+ "n_decoder_layers": 1,
537
+ "use_vae": true,
538
+ "latent_dim": 32,
539
+ "n_vae_encoder_layers": 4,
540
+ "temporal_ensemble_coeff": null,
541
+ "dropout": 0.1,
542
+ "kl_weight": 10.0,
543
+ "optimizer_lr": 1e-05,
544
+ "optimizer_weight_decay": 0.0001,
545
+ "optimizer_lr_backbone": 1e-05
546
+ },
547
+ "output_dir": "/home/basr88/data/outputs/train/act_gluestick_merged_400",
548
+ "job_name": "act_gluestick",
549
+ "resume": false,
550
+ "seed": 1000,
551
+ "cudnn_deterministic": false,
552
+ "num_workers": 4,
553
+ "batch_size": 8,
554
+ "steps": 150000,
555
+ "eval_freq": 20000,
556
+ "log_freq": 200,
557
+ "tolerance_s": 0.0001,
558
+ "save_checkpoint": true,
559
+ "save_freq": 20000,
560
+ "use_policy_training_preset": true,
561
+ "optimizer": {
562
+ "type": "adamw",
563
+ "lr": 1e-05,
564
+ "weight_decay": 0.0001,
565
+ "grad_clip_norm": 10.0,
566
+ "betas": [
567
+ 0.9,
568
+ 0.999
569
+ ],
570
+ "eps": 1e-08
571
+ },
572
+ "scheduler": null,
573
+ "eval": {
574
+ "n_episodes": 50,
575
+ "batch_size": 50,
576
+ "use_async_envs": false
577
+ },
578
+ "wandb": {
579
+ "enable": false,
580
+ "disable_artifact": false,
581
+ "project": "lerobot",
582
+ "entity": null,
583
+ "notes": null,
584
+ "run_id": null,
585
+ "mode": null,
586
+ "add_tags": true
587
+ },
588
+ "peft": null,
589
+ "use_rabc": false,
590
+ "rabc_progress_path": null,
591
+ "rabc_kappa": 0.01,
592
+ "rabc_epsilon": 1e-06,
593
+ "rabc_head_mode": "sparse",
594
+ "rename_map": {},
595
+ "checkpoint_path": null
596
+ }