e-zorzi commited on
Commit
4d172a3
·
verified ·
1 Parent(s): a777ddb

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. checkpoint-10000/config.json +70 -0
  2. checkpoint-10000/embodiment_id.json +11 -0
  3. checkpoint-10000/latest +1 -0
  4. checkpoint-10000/model.safetensors.index.json +0 -0
  5. checkpoint-10000/processor_config.json +526 -0
  6. checkpoint-10000/statistics.json +0 -0
  7. checkpoint-10000/trainer_state.json +0 -0
  8. checkpoint-10000/wandb_config.json +1 -0
  9. checkpoint-10000/zero_to_fp32.py +760 -0
  10. checkpoint-15000/config.json +70 -0
  11. checkpoint-15000/embodiment_id.json +11 -0
  12. checkpoint-15000/experiment_cfg/conf.yaml +270 -0
  13. checkpoint-15000/experiment_cfg/config.yaml +308 -0
  14. checkpoint-15000/experiment_cfg/dataset_statistics.json +573 -0
  15. checkpoint-15000/experiment_cfg/final_model_config.json +53 -0
  16. checkpoint-15000/latest +1 -0
  17. checkpoint-15000/model.safetensors.index.json +0 -0
  18. checkpoint-15000/processor_config.json +526 -0
  19. checkpoint-15000/statistics.json +0 -0
  20. checkpoint-15000/trainer_state.json +0 -0
  21. checkpoint-15000/wandb_config.json +1 -0
  22. checkpoint-20000/config.json +70 -0
  23. checkpoint-20000/experiment_cfg/conf.yaml +270 -0
  24. checkpoint-20000/experiment_cfg/config.yaml +308 -0
  25. checkpoint-20000/experiment_cfg/dataset_statistics.json +573 -0
  26. checkpoint-20000/experiment_cfg/final_model_config.json +53 -0
  27. checkpoint-20000/experiment_cfg/final_processor_config.json +0 -0
  28. checkpoint-20000/latest +1 -0
  29. checkpoint-20000/model.safetensors.index.json +0 -0
  30. checkpoint-20000/zero_to_fp32.py +760 -0
  31. checkpoint-5000/config.json +70 -0
  32. checkpoint-5000/embodiment_id.json +11 -0
  33. checkpoint-5000/latest +1 -0
  34. checkpoint-5000/model.safetensors.index.json +0 -0
  35. checkpoint-5000/processor_config.json +526 -0
  36. checkpoint-5000/statistics.json +0 -0
  37. checkpoint-5000/trainer_state.json +3034 -0
  38. checkpoint-5000/wandb_config.json +1 -0
  39. checkpoint-5000/zero_to_fp32.py +760 -0
  40. config.json +70 -0
  41. experiment_cfg/conf.yaml +270 -0
  42. experiment_cfg/config.yaml +308 -0
  43. experiment_cfg/dataset_statistics.json +573 -0
  44. experiment_cfg/final_model_config.json +53 -0
  45. experiment_cfg/final_processor_config.json +0 -0
  46. model.safetensors.index.json +0 -0
  47. processor/embodiment_id.json +11 -0
  48. processor/processor_config.json +526 -0
  49. processor/statistics.json +0 -0
  50. wandb_config.json +1 -0
checkpoint-10000/config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_horizon": 50,
3
+ "add_pos_embed": true,
4
+ "apply_sincos_state_encoding": true,
5
+ "architectures": [
6
+ "Gr00tN1d6"
7
+ ],
8
+ "attn_dropout": 0.2,
9
+ "attn_implementation": null,
10
+ "backbone_embedding_dim": 2048,
11
+ "backbone_model_type": "eagle",
12
+ "backbone_trainable_params_fp32": true,
13
+ "collator_overwrite_image_inputs": false,
14
+ "color_jitter_params": {
15
+ "brightness": 0.1,
16
+ "contrast": 0.1,
17
+ "hue": 0.1,
18
+ "saturation": 0.1
19
+ },
20
+ "crop_fraction": 0.95,
21
+ "diffusion_model_cfg": {
22
+ "attention_head_dim": 48,
23
+ "dropout": 0.2,
24
+ "final_dropout": true,
25
+ "interleave_self_attention": true,
26
+ "norm_type": "ada_norm",
27
+ "num_attention_heads": 32,
28
+ "num_layers": 32,
29
+ "output_dim": 1024,
30
+ "positional_embeddings": null
31
+ },
32
+ "eagle_collator": true,
33
+ "formalize_language": true,
34
+ "gemma_collator": false,
35
+ "hidden_size": 1024,
36
+ "image_crop_size": null,
37
+ "image_target_size": null,
38
+ "input_embedding_dim": 1536,
39
+ "load_bf16": true,
40
+ "max_action_dim": 128,
41
+ "max_num_embodiments": 32,
42
+ "max_seq_len": 1024,
43
+ "max_state_dim": 128,
44
+ "model_dtype": "bfloat16",
45
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
46
+ "model_type": "Gr00tN1d6",
47
+ "noise_beta_alpha": 1.5,
48
+ "noise_beta_beta": 1.0,
49
+ "noise_s": 0.999,
50
+ "num_inference_timesteps": 4,
51
+ "num_timestep_buckets": 1000,
52
+ "random_rotation_angle": null,
53
+ "reproject_vision": false,
54
+ "select_layer": 16,
55
+ "shortest_image_edge": 256,
56
+ "state_dropout_prob": 0.0,
57
+ "torch_dtype": "bfloat16",
58
+ "transformers_version": "4.51.3",
59
+ "tune_diffusion_model": true,
60
+ "tune_llm": false,
61
+ "tune_projector": true,
62
+ "tune_top_llm_layers": 4,
63
+ "tune_visual": true,
64
+ "tune_vlln": true,
65
+ "use_albumentations_transforms": true,
66
+ "use_alternate_vl_dit": true,
67
+ "use_flash_attention": true,
68
+ "use_relative_action": true,
69
+ "use_vlln": true
70
+ }
checkpoint-10000/embodiment_id.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "robocasa_panda_omron": 13,
3
+ "gr1": 20,
4
+ "behavior_r1_pro": 24,
5
+ "unitree_g1": 8,
6
+ "oxe_google": 0,
7
+ "oxe_widowx": 1,
8
+ "libero_panda": 2,
9
+ "oxe_droid": 16,
10
+ "new_embodiment": 10
11
+ }
checkpoint-10000/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step10000
checkpoint-10000/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-10000/processor_config.json ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "Gr00tN1d6Processor",
3
+ "processor_kwargs": {
4
+ "modality_configs": {
5
+ "behavior_r1_pro": {
6
+ "video": {
7
+ "delta_indices": [
8
+ 0
9
+ ],
10
+ "modality_keys": [
11
+ "observation.images.rgb.head_256_256",
12
+ "observation.images.rgb.left_wrist_256_256",
13
+ "observation.images.rgb.right_wrist_256_256"
14
+ ],
15
+ "sin_cos_embedding_keys": null,
16
+ "mean_std_embedding_keys": null,
17
+ "action_configs": null
18
+ },
19
+ "state": {
20
+ "delta_indices": [
21
+ 0
22
+ ],
23
+ "modality_keys": [
24
+ "robot_pos",
25
+ "robot_ori_cos",
26
+ "robot_ori_sin",
27
+ "robot_2d_ori",
28
+ "robot_2d_ori_cos",
29
+ "robot_2d_ori_sin",
30
+ "robot_lin_vel",
31
+ "robot_ang_vel",
32
+ "arm_left_qpos",
33
+ "arm_left_qpos_sin",
34
+ "arm_left_qpos_cos",
35
+ "eef_left_pos",
36
+ "eef_left_quat",
37
+ "gripper_left_qpos",
38
+ "arm_right_qpos",
39
+ "arm_right_qpos_sin",
40
+ "arm_right_qpos_cos",
41
+ "eef_right_pos",
42
+ "eef_right_quat",
43
+ "gripper_right_qpos",
44
+ "trunk_qpos"
45
+ ],
46
+ "sin_cos_embedding_keys": null,
47
+ "mean_std_embedding_keys": null,
48
+ "action_configs": null
49
+ },
50
+ "action": {
51
+ "delta_indices": [
52
+ 0,
53
+ 1,
54
+ 2,
55
+ 3,
56
+ 4,
57
+ 5,
58
+ 6,
59
+ 7,
60
+ 8,
61
+ 9,
62
+ 10,
63
+ 11,
64
+ 12,
65
+ 13,
66
+ 14,
67
+ 15,
68
+ 16,
69
+ 17,
70
+ 18,
71
+ 19,
72
+ 20,
73
+ 21,
74
+ 22,
75
+ 23,
76
+ 24,
77
+ 25,
78
+ 26,
79
+ 27,
80
+ 28,
81
+ 29,
82
+ 30,
83
+ 31
84
+ ],
85
+ "modality_keys": [
86
+ "base",
87
+ "torso",
88
+ "left_arm",
89
+ "left_gripper",
90
+ "right_arm",
91
+ "right_gripper"
92
+ ],
93
+ "sin_cos_embedding_keys": null,
94
+ "mean_std_embedding_keys": null,
95
+ "action_configs": [
96
+ {
97
+ "rep": "ABSOLUTE",
98
+ "type": "NON_EEF",
99
+ "format": "DEFAULT",
100
+ "state_key": null
101
+ },
102
+ {
103
+ "rep": "RELATIVE",
104
+ "type": "NON_EEF",
105
+ "format": "DEFAULT",
106
+ "state_key": "trunk_qpos"
107
+ },
108
+ {
109
+ "rep": "RELATIVE",
110
+ "type": "NON_EEF",
111
+ "format": "DEFAULT",
112
+ "state_key": "arm_left_qpos"
113
+ },
114
+ {
115
+ "rep": "ABSOLUTE",
116
+ "type": "NON_EEF",
117
+ "format": "DEFAULT",
118
+ "state_key": null
119
+ },
120
+ {
121
+ "rep": "RELATIVE",
122
+ "type": "NON_EEF",
123
+ "format": "DEFAULT",
124
+ "state_key": "arm_right_qpos"
125
+ },
126
+ {
127
+ "rep": "ABSOLUTE",
128
+ "type": "NON_EEF",
129
+ "format": "DEFAULT",
130
+ "state_key": null
131
+ }
132
+ ]
133
+ },
134
+ "language": {
135
+ "delta_indices": [
136
+ 0
137
+ ],
138
+ "modality_keys": [
139
+ "annotation.human.coarse_action"
140
+ ],
141
+ "sin_cos_embedding_keys": null,
142
+ "mean_std_embedding_keys": null,
143
+ "action_configs": null
144
+ }
145
+ },
146
+ "gr1": {
147
+ "video": {
148
+ "delta_indices": [
149
+ 0
150
+ ],
151
+ "modality_keys": [
152
+ "ego_view_bg_crop_pad_res256_freq20"
153
+ ],
154
+ "sin_cos_embedding_keys": null,
155
+ "mean_std_embedding_keys": null,
156
+ "action_configs": null
157
+ },
158
+ "state": {
159
+ "delta_indices": [
160
+ 0
161
+ ],
162
+ "modality_keys": [
163
+ "left_arm",
164
+ "right_arm",
165
+ "left_hand",
166
+ "right_hand",
167
+ "waist"
168
+ ],
169
+ "sin_cos_embedding_keys": [
170
+ "left_arm",
171
+ "right_arm",
172
+ "left_hand",
173
+ "right_hand",
174
+ "waist"
175
+ ],
176
+ "mean_std_embedding_keys": null,
177
+ "action_configs": null
178
+ },
179
+ "action": {
180
+ "delta_indices": [
181
+ 0,
182
+ 1,
183
+ 2,
184
+ 3,
185
+ 4,
186
+ 5,
187
+ 6,
188
+ 7,
189
+ 8,
190
+ 9,
191
+ 10,
192
+ 11,
193
+ 12,
194
+ 13,
195
+ 14,
196
+ 15
197
+ ],
198
+ "modality_keys": [
199
+ "left_arm",
200
+ "right_arm",
201
+ "left_hand",
202
+ "right_hand",
203
+ "waist"
204
+ ],
205
+ "sin_cos_embedding_keys": null,
206
+ "mean_std_embedding_keys": null,
207
+ "action_configs": [
208
+ {
209
+ "rep": "RELATIVE",
210
+ "type": "NON_EEF",
211
+ "format": "DEFAULT",
212
+ "state_key": null
213
+ },
214
+ {
215
+ "rep": "RELATIVE",
216
+ "type": "NON_EEF",
217
+ "format": "DEFAULT",
218
+ "state_key": null
219
+ },
220
+ {
221
+ "rep": "RELATIVE",
222
+ "type": "NON_EEF",
223
+ "format": "DEFAULT",
224
+ "state_key": null
225
+ },
226
+ {
227
+ "rep": "RELATIVE",
228
+ "type": "NON_EEF",
229
+ "format": "DEFAULT",
230
+ "state_key": null
231
+ },
232
+ {
233
+ "rep": "ABSOLUTE",
234
+ "type": "NON_EEF",
235
+ "format": "DEFAULT",
236
+ "state_key": null
237
+ }
238
+ ]
239
+ },
240
+ "language": {
241
+ "delta_indices": [
242
+ 0
243
+ ],
244
+ "modality_keys": [
245
+ "task"
246
+ ],
247
+ "sin_cos_embedding_keys": null,
248
+ "mean_std_embedding_keys": null,
249
+ "action_configs": null
250
+ }
251
+ },
252
+ "robocasa_panda_omron": {
253
+ "video": {
254
+ "delta_indices": [
255
+ 0
256
+ ],
257
+ "modality_keys": [
258
+ "res256_image_side_0",
259
+ "res256_image_side_1",
260
+ "res256_image_wrist_0"
261
+ ],
262
+ "sin_cos_embedding_keys": null,
263
+ "mean_std_embedding_keys": null,
264
+ "action_configs": null
265
+ },
266
+ "state": {
267
+ "delta_indices": [
268
+ 0
269
+ ],
270
+ "modality_keys": [
271
+ "end_effector_position_relative",
272
+ "end_effector_rotation_relative",
273
+ "gripper_qpos",
274
+ "base_position",
275
+ "base_rotation"
276
+ ],
277
+ "sin_cos_embedding_keys": null,
278
+ "mean_std_embedding_keys": null,
279
+ "action_configs": null
280
+ },
281
+ "action": {
282
+ "delta_indices": [
283
+ 0,
284
+ 1,
285
+ 2,
286
+ 3,
287
+ 4,
288
+ 5,
289
+ 6,
290
+ 7,
291
+ 8,
292
+ 9,
293
+ 10,
294
+ 11,
295
+ 12,
296
+ 13,
297
+ 14,
298
+ 15
299
+ ],
300
+ "modality_keys": [
301
+ "end_effector_position",
302
+ "end_effector_rotation",
303
+ "gripper_close",
304
+ "base_motion",
305
+ "control_mode"
306
+ ],
307
+ "sin_cos_embedding_keys": null,
308
+ "mean_std_embedding_keys": null,
309
+ "action_configs": [
310
+ {
311
+ "rep": "ABSOLUTE",
312
+ "type": "NON_EEF",
313
+ "format": "DEFAULT",
314
+ "state_key": null
315
+ },
316
+ {
317
+ "rep": "ABSOLUTE",
318
+ "type": "NON_EEF",
319
+ "format": "DEFAULT",
320
+ "state_key": null
321
+ },
322
+ {
323
+ "rep": "ABSOLUTE",
324
+ "type": "NON_EEF",
325
+ "format": "DEFAULT",
326
+ "state_key": null
327
+ },
328
+ {
329
+ "rep": "ABSOLUTE",
330
+ "type": "NON_EEF",
331
+ "format": "DEFAULT",
332
+ "state_key": null
333
+ },
334
+ {
335
+ "rep": "ABSOLUTE",
336
+ "type": "NON_EEF",
337
+ "format": "DEFAULT",
338
+ "state_key": null
339
+ }
340
+ ]
341
+ },
342
+ "language": {
343
+ "delta_indices": [
344
+ 0
345
+ ],
346
+ "modality_keys": [
347
+ "annotation.human.action.task_description"
348
+ ],
349
+ "sin_cos_embedding_keys": null,
350
+ "mean_std_embedding_keys": null,
351
+ "action_configs": null
352
+ }
353
+ },
354
+ "new_embodiment": {
355
+ "video": {
356
+ "delta_indices": [
357
+ 0
358
+ ],
359
+ "modality_keys": [
360
+ "ego_view"
361
+ ],
362
+ "sin_cos_embedding_keys": null,
363
+ "mean_std_embedding_keys": null,
364
+ "action_configs": null
365
+ },
366
+ "state": {
367
+ "delta_indices": [
368
+ 0
369
+ ],
370
+ "modality_keys": [
371
+ "left_arm",
372
+ "right_arm",
373
+ "left_hand",
374
+ "right_hand",
375
+ "waist"
376
+ ],
377
+ "sin_cos_embedding_keys": null,
378
+ "mean_std_embedding_keys": null,
379
+ "action_configs": null
380
+ },
381
+ "action": {
382
+ "delta_indices": [
383
+ 0,
384
+ 1,
385
+ 2,
386
+ 3,
387
+ 4,
388
+ 5,
389
+ 6,
390
+ 7,
391
+ 8,
392
+ 9,
393
+ 10,
394
+ 11,
395
+ 12,
396
+ 13,
397
+ 14,
398
+ 15,
399
+ 16,
400
+ 17,
401
+ 18,
402
+ 19,
403
+ 20,
404
+ 21,
405
+ 22,
406
+ 23,
407
+ 24,
408
+ 25,
409
+ 26,
410
+ 27,
411
+ 28,
412
+ 29,
413
+ 30,
414
+ 31,
415
+ 32,
416
+ 33,
417
+ 34,
418
+ 35,
419
+ 36,
420
+ 37,
421
+ 38,
422
+ 39,
423
+ 40,
424
+ 41,
425
+ 42,
426
+ 43,
427
+ 44,
428
+ 45,
429
+ 46,
430
+ 47,
431
+ 48,
432
+ 49
433
+ ],
434
+ "modality_keys": [
435
+ "left_arm",
436
+ "right_arm",
437
+ "left_hand",
438
+ "right_hand",
439
+ "waist",
440
+ "base_height_command",
441
+ "navigate_command"
442
+ ],
443
+ "sin_cos_embedding_keys": null,
444
+ "mean_std_embedding_keys": null,
445
+ "action_configs": [
446
+ {
447
+ "rep": "ABSOLUTE",
448
+ "type": "NON_EEF",
449
+ "format": "DEFAULT",
450
+ "state_key": null
451
+ },
452
+ {
453
+ "rep": "ABSOLUTE",
454
+ "type": "NON_EEF",
455
+ "format": "DEFAULT",
456
+ "state_key": null
457
+ },
458
+ {
459
+ "rep": "ABSOLUTE",
460
+ "type": "NON_EEF",
461
+ "format": "DEFAULT",
462
+ "state_key": null
463
+ },
464
+ {
465
+ "rep": "ABSOLUTE",
466
+ "type": "NON_EEF",
467
+ "format": "DEFAULT",
468
+ "state_key": null
469
+ },
470
+ {
471
+ "rep": "ABSOLUTE",
472
+ "type": "NON_EEF",
473
+ "format": "DEFAULT",
474
+ "state_key": null
475
+ },
476
+ {
477
+ "rep": "ABSOLUTE",
478
+ "type": "NON_EEF",
479
+ "format": "DEFAULT",
480
+ "state_key": null
481
+ },
482
+ {
483
+ "rep": "ABSOLUTE",
484
+ "type": "NON_EEF",
485
+ "format": "DEFAULT",
486
+ "state_key": null
487
+ }
488
+ ]
489
+ },
490
+ "language": {
491
+ "delta_indices": [
492
+ 0
493
+ ],
494
+ "modality_keys": [
495
+ "annotation.human.task_description"
496
+ ],
497
+ "sin_cos_embedding_keys": null,
498
+ "mean_std_embedding_keys": null,
499
+ "action_configs": null
500
+ }
501
+ }
502
+ },
503
+ "image_crop_size": null,
504
+ "image_target_size": null,
505
+ "use_albumentations": true,
506
+ "random_rotation_angle": null,
507
+ "color_jitter_params": {
508
+ "brightness": 0.3,
509
+ "contrast": 0.4,
510
+ "saturation": 0.5,
511
+ "hue": 0.08
512
+ },
513
+ "shortest_image_edge": 256,
514
+ "crop_fraction": 0.95,
515
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
516
+ "model_type": "eagle",
517
+ "formalize_language": true,
518
+ "max_state_dim": 128,
519
+ "max_action_dim": 128,
520
+ "max_action_horizon": 50,
521
+ "use_percentiles": false,
522
+ "clip_outliers": true,
523
+ "apply_sincos_state_encoding": true,
524
+ "use_relative_action": true
525
+ }
526
+ }
checkpoint-10000/statistics.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-10000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-10000/wandb_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"project": "finetune-gr00t-n1d6", "run_id": "locomanipulation_tutorial"}
checkpoint-10000/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
checkpoint-15000/config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_horizon": 50,
3
+ "add_pos_embed": true,
4
+ "apply_sincos_state_encoding": true,
5
+ "architectures": [
6
+ "Gr00tN1d6"
7
+ ],
8
+ "attn_dropout": 0.2,
9
+ "attn_implementation": null,
10
+ "backbone_embedding_dim": 2048,
11
+ "backbone_model_type": "eagle",
12
+ "backbone_trainable_params_fp32": true,
13
+ "collator_overwrite_image_inputs": false,
14
+ "color_jitter_params": {
15
+ "brightness": 0.1,
16
+ "contrast": 0.1,
17
+ "hue": 0.1,
18
+ "saturation": 0.1
19
+ },
20
+ "crop_fraction": 0.95,
21
+ "diffusion_model_cfg": {
22
+ "attention_head_dim": 48,
23
+ "dropout": 0.2,
24
+ "final_dropout": true,
25
+ "interleave_self_attention": true,
26
+ "norm_type": "ada_norm",
27
+ "num_attention_heads": 32,
28
+ "num_layers": 32,
29
+ "output_dim": 1024,
30
+ "positional_embeddings": null
31
+ },
32
+ "eagle_collator": true,
33
+ "formalize_language": true,
34
+ "gemma_collator": false,
35
+ "hidden_size": 1024,
36
+ "image_crop_size": null,
37
+ "image_target_size": null,
38
+ "input_embedding_dim": 1536,
39
+ "load_bf16": true,
40
+ "max_action_dim": 128,
41
+ "max_num_embodiments": 32,
42
+ "max_seq_len": 1024,
43
+ "max_state_dim": 128,
44
+ "model_dtype": "bfloat16",
45
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
46
+ "model_type": "Gr00tN1d6",
47
+ "noise_beta_alpha": 1.5,
48
+ "noise_beta_beta": 1.0,
49
+ "noise_s": 0.999,
50
+ "num_inference_timesteps": 4,
51
+ "num_timestep_buckets": 1000,
52
+ "random_rotation_angle": null,
53
+ "reproject_vision": false,
54
+ "select_layer": 16,
55
+ "shortest_image_edge": 256,
56
+ "state_dropout_prob": 0.0,
57
+ "torch_dtype": "bfloat16",
58
+ "transformers_version": "4.51.3",
59
+ "tune_diffusion_model": true,
60
+ "tune_llm": false,
61
+ "tune_projector": true,
62
+ "tune_top_llm_layers": 4,
63
+ "tune_visual": true,
64
+ "tune_vlln": true,
65
+ "use_albumentations_transforms": true,
66
+ "use_alternate_vl_dit": true,
67
+ "use_flash_attention": true,
68
+ "use_relative_action": true,
69
+ "use_vlln": true
70
+ }
checkpoint-15000/embodiment_id.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "robocasa_panda_omron": 13,
3
+ "gr1": 20,
4
+ "behavior_r1_pro": 24,
5
+ "unitree_g1": 8,
6
+ "oxe_google": 0,
7
+ "oxe_widowx": 1,
8
+ "libero_panda": 2,
9
+ "oxe_droid": 16,
10
+ "new_embodiment": 10
11
+ }
checkpoint-15000/experiment_cfg/conf.yaml ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ load_config_path: null
2
+ model:
3
+ model_type: Gr00tN1d6
4
+ model_dtype: bfloat16
5
+ model_name: nvidia/Eagle-Block2A-2B-v2
6
+ backbone_model_type: eagle
7
+ model_revision: null
8
+ tune_top_llm_layers: 4
9
+ backbone_embedding_dim: 2048
10
+ tune_llm: false
11
+ tune_visual: true
12
+ select_layer: 16
13
+ reproject_vision: false
14
+ use_flash_attention: true
15
+ load_bf16: false
16
+ collator_overwrite_image_inputs: false
17
+ eagle_collator: true
18
+ backbone_trainable_params_fp32: true
19
+ image_crop_size: null
20
+ image_target_size: null
21
+ shortest_image_edge: 256
22
+ crop_fraction: 0.95
23
+ random_rotation_angle: null
24
+ color_jitter_params:
25
+ brightness: 0.3
26
+ contrast: 0.4
27
+ saturation: 0.5
28
+ hue: 0.08
29
+ use_albumentations_transforms: true
30
+ formalize_language: true
31
+ apply_sincos_state_encoding: false
32
+ use_relative_action: true
33
+ max_state_dim: 29
34
+ max_action_dim: 29
35
+ action_horizon: 16
36
+ hidden_size: 1024
37
+ input_embedding_dim: 1536
38
+ add_pos_embed: true
39
+ attn_dropout: 0.2
40
+ use_vlln: true
41
+ max_seq_len: 1024
42
+ use_alternate_vl_dit: true
43
+ attend_text_every_n_blocks: 2
44
+ diffusion_model_cfg:
45
+ positional_embeddings: null
46
+ num_layers: 32
47
+ num_attention_heads: 32
48
+ attention_head_dim: 48
49
+ norm_type: ada_norm
50
+ dropout: 0.2
51
+ final_dropout: true
52
+ output_dim: 1024
53
+ interleave_self_attention: true
54
+ num_inference_timesteps: 4
55
+ noise_beta_alpha: 1.5
56
+ noise_beta_beta: 1.0
57
+ noise_s: 0.999
58
+ num_timestep_buckets: 1000
59
+ tune_projector: true
60
+ tune_diffusion_model: true
61
+ tune_vlln: true
62
+ state_dropout_prob: 0.0
63
+ state_additive_noise_scale: 0.0
64
+ max_num_embodiments: 32
65
+ data:
66
+ datasets:
67
+ - dataset_paths:
68
+ - /datasets/isaaclab_arena/locomanipulation_tutorial/arena_g1_loco_manipulation_dataset_generated/lerobot
69
+ embodiment_tag: new_embodiment
70
+ mix_ratio: 1.0
71
+ dataset_type: physical_embodiment
72
+ val_dataset_path: null
73
+ modality_configs:
74
+ new_embodiment:
75
+ video:
76
+ delta_indices:
77
+ - 0
78
+ modality_keys:
79
+ - ego_view
80
+ sin_cos_embedding_keys: null
81
+ mean_std_embedding_keys: null
82
+ action_configs: null
83
+ state:
84
+ delta_indices:
85
+ - 0
86
+ modality_keys:
87
+ - left_arm
88
+ - right_arm
89
+ - left_hand
90
+ - right_hand
91
+ - waist
92
+ sin_cos_embedding_keys: null
93
+ mean_std_embedding_keys: null
94
+ action_configs: null
95
+ action:
96
+ delta_indices:
97
+ - 0
98
+ - 1
99
+ - 2
100
+ - 3
101
+ - 4
102
+ - 5
103
+ - 6
104
+ - 7
105
+ - 8
106
+ - 9
107
+ - 10
108
+ - 11
109
+ - 12
110
+ - 13
111
+ - 14
112
+ - 15
113
+ - 16
114
+ - 17
115
+ - 18
116
+ - 19
117
+ - 20
118
+ - 21
119
+ - 22
120
+ - 23
121
+ - 24
122
+ - 25
123
+ - 26
124
+ - 27
125
+ - 28
126
+ - 29
127
+ - 30
128
+ - 31
129
+ - 32
130
+ - 33
131
+ - 34
132
+ - 35
133
+ - 36
134
+ - 37
135
+ - 38
136
+ - 39
137
+ - 40
138
+ - 41
139
+ - 42
140
+ - 43
141
+ - 44
142
+ - 45
143
+ - 46
144
+ - 47
145
+ - 48
146
+ - 49
147
+ modality_keys:
148
+ - left_arm
149
+ - right_arm
150
+ - left_hand
151
+ - right_hand
152
+ - waist
153
+ - base_height_command
154
+ - navigate_command
155
+ sin_cos_embedding_keys: null
156
+ mean_std_embedding_keys: null
157
+ action_configs:
158
+ - rep: ABSOLUTE
159
+ type: NON_EEF
160
+ format: DEFAULT
161
+ state_key: null
162
+ - rep: ABSOLUTE
163
+ type: NON_EEF
164
+ format: DEFAULT
165
+ state_key: null
166
+ - rep: ABSOLUTE
167
+ type: NON_EEF
168
+ format: DEFAULT
169
+ state_key: null
170
+ - rep: ABSOLUTE
171
+ type: NON_EEF
172
+ format: DEFAULT
173
+ state_key: null
174
+ - rep: ABSOLUTE
175
+ type: NON_EEF
176
+ format: DEFAULT
177
+ state_key: null
178
+ - rep: ABSOLUTE
179
+ type: NON_EEF
180
+ format: DEFAULT
181
+ state_key: null
182
+ - rep: ABSOLUTE
183
+ type: NON_EEF
184
+ format: DEFAULT
185
+ state_key: null
186
+ language:
187
+ delta_indices:
188
+ - 0
189
+ modality_keys:
190
+ - annotation.human.task_description
191
+ sin_cos_embedding_keys: null
192
+ mean_std_embedding_keys: null
193
+ action_configs: null
194
+ download_cache: false
195
+ shard_size: 1024
196
+ episode_sampling_rate: 0.1
197
+ num_shards_per_epoch: 100000
198
+ override_pretraining_statistics: false
199
+ mode: single_turn
200
+ random_chop: 0.0
201
+ mock_dataset_mode: false
202
+ shuffle: true
203
+ seed: 42
204
+ multiprocessing_context: fork
205
+ allow_padding: false
206
+ subsample_ratio: 1.0
207
+ image_crop_size:
208
+ - 244
209
+ - 244
210
+ image_target_size:
211
+ - 224
212
+ - 224
213
+ video_backend: torchcodec
214
+ training:
215
+ output_dir: /models/isaaclab_arena/locomanipulation_tutorial
216
+ experiment_name: null
217
+ max_steps: 20000
218
+ global_batch_size: 192
219
+ batch_size: null
220
+ gradient_accumulation_steps: 1
221
+ learning_rate: 0.0001
222
+ lr_scheduler_type: cosine
223
+ weight_decay: 1.0e-05
224
+ warmup_ratio: 0.05
225
+ warmup_steps: 0
226
+ max_grad_norm: 1.0
227
+ optim: adamw_torch
228
+ start_from_checkpoint: nvidia/GR00T-N1.6-3B
229
+ tf32: true
230
+ fp16: false
231
+ bf16: true
232
+ eval_bf16: true
233
+ logging_steps: 10
234
+ save_steps: 5000
235
+ save_total_limit: 5
236
+ save_vl_model: false
237
+ upload_checkpoints: false
238
+ upload_every: 1000
239
+ upload_last_n_checkpoints: 5
240
+ max_concurrent_uploads: 2
241
+ eval_strategy: 'no'
242
+ eval_steps: 500
243
+ eval_set_split_ratio: 0.1
244
+ eval_batch_size: 2
245
+ save_best_eval_metric_name: ''
246
+ save_best_eval_metric_greater_is_better: true
247
+ deepspeed_stage: 2
248
+ gradient_checkpointing: false
249
+ transformers_trust_remote_code: true
250
+ transformers_local_files_only: false
251
+ transformers_cache_dir: null
252
+ transformers_access_token: null
253
+ use_ddp: false
254
+ ddp_bucket_cap_mb: 100
255
+ num_gpus: 8
256
+ dataloader_num_workers: 16
257
+ remove_unused_columns: false
258
+ use_wandb: false
259
+ wandb_project: finetune-gr00t-n1d6
260
+ enable_profiling: false
261
+ max_retries: 3
262
+ assert_loss_less_than: null
263
+ add_rl_callback: false
264
+ enable_open_loop_eval: false
265
+ open_loop_eval_traj_ids:
266
+ - 0
267
+ open_loop_eval_steps_per_traj: 100
268
+ open_loop_eval_plot_indices: null
269
+ max_steps: 20000
270
+ save_steps: 5000
checkpoint-15000/experiment_cfg/config.yaml ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !!python/object:gr00t.configs.base_config.Config
2
+ data: !!python/object:gr00t.configs.data.data_config.DataConfig
3
+ allow_padding: false
4
+ datasets:
5
+ - !!python/object:gr00t.configs.data.data_config.SingleDatasetConfig
6
+ dataset_paths:
7
+ - /datasets/isaaclab_arena/locomanipulation_tutorial/arena_g1_loco_manipulation_dataset_generated/lerobot
8
+ dataset_type: physical_embodiment
9
+ embodiment_tag: new_embodiment
10
+ mix_ratio: 1.0
11
+ val_dataset_path: null
12
+ download_cache: false
13
+ episode_sampling_rate: 0.1
14
+ image_crop_size:
15
+ - 244
16
+ - 244
17
+ image_target_size:
18
+ - 224
19
+ - 224
20
+ mock_dataset_mode: false
21
+ modality_configs:
22
+ new_embodiment:
23
+ action: !!python/object:gr00t.data.types.ModalityConfig
24
+ action_configs:
25
+ - !!python/object:gr00t.data.types.ActionConfig
26
+ format: &id001 !!python/object/apply:gr00t.data.types.ActionFormat
27
+ - default
28
+ rep: &id002 !!python/object/apply:gr00t.data.types.ActionRepresentation
29
+ - absolute
30
+ state_key: null
31
+ type: &id003 !!python/object/apply:gr00t.data.types.ActionType
32
+ - non_eef
33
+ - !!python/object:gr00t.data.types.ActionConfig
34
+ format: *id001
35
+ rep: *id002
36
+ state_key: null
37
+ type: *id003
38
+ - !!python/object:gr00t.data.types.ActionConfig
39
+ format: *id001
40
+ rep: *id002
41
+ state_key: null
42
+ type: *id003
43
+ - !!python/object:gr00t.data.types.ActionConfig
44
+ format: *id001
45
+ rep: *id002
46
+ state_key: null
47
+ type: *id003
48
+ - !!python/object:gr00t.data.types.ActionConfig
49
+ format: *id001
50
+ rep: *id002
51
+ state_key: null
52
+ type: *id003
53
+ - !!python/object:gr00t.data.types.ActionConfig
54
+ format: *id001
55
+ rep: *id002
56
+ state_key: null
57
+ type: *id003
58
+ - !!python/object:gr00t.data.types.ActionConfig
59
+ format: *id001
60
+ rep: *id002
61
+ state_key: null
62
+ type: *id003
63
+ delta_indices:
64
+ - 0
65
+ - 1
66
+ - 2
67
+ - 3
68
+ - 4
69
+ - 5
70
+ - 6
71
+ - 7
72
+ - 8
73
+ - 9
74
+ - 10
75
+ - 11
76
+ - 12
77
+ - 13
78
+ - 14
79
+ - 15
80
+ - 16
81
+ - 17
82
+ - 18
83
+ - 19
84
+ - 20
85
+ - 21
86
+ - 22
87
+ - 23
88
+ - 24
89
+ - 25
90
+ - 26
91
+ - 27
92
+ - 28
93
+ - 29
94
+ - 30
95
+ - 31
96
+ - 32
97
+ - 33
98
+ - 34
99
+ - 35
100
+ - 36
101
+ - 37
102
+ - 38
103
+ - 39
104
+ - 40
105
+ - 41
106
+ - 42
107
+ - 43
108
+ - 44
109
+ - 45
110
+ - 46
111
+ - 47
112
+ - 48
113
+ - 49
114
+ mean_std_embedding_keys: null
115
+ modality_keys:
116
+ - left_arm
117
+ - right_arm
118
+ - left_hand
119
+ - right_hand
120
+ - waist
121
+ - base_height_command
122
+ - navigate_command
123
+ sin_cos_embedding_keys: null
124
+ language: !!python/object:gr00t.data.types.ModalityConfig
125
+ action_configs: null
126
+ delta_indices:
127
+ - 0
128
+ mean_std_embedding_keys: null
129
+ modality_keys:
130
+ - annotation.human.task_description
131
+ sin_cos_embedding_keys: null
132
+ state: !!python/object:gr00t.data.types.ModalityConfig
133
+ action_configs: null
134
+ delta_indices:
135
+ - 0
136
+ mean_std_embedding_keys: null
137
+ modality_keys:
138
+ - left_arm
139
+ - right_arm
140
+ - left_hand
141
+ - right_hand
142
+ - waist
143
+ sin_cos_embedding_keys: null
144
+ video: !!python/object:gr00t.data.types.ModalityConfig
145
+ action_configs: null
146
+ delta_indices:
147
+ - 0
148
+ mean_std_embedding_keys: null
149
+ modality_keys:
150
+ - ego_view
151
+ sin_cos_embedding_keys: null
152
+ mode: single_turn
153
+ multiprocessing_context: fork
154
+ num_shards_per_epoch: 100000
155
+ override_pretraining_statistics: false
156
+ random_chop: 0.0
157
+ seed: 42
158
+ shard_size: 1024
159
+ shuffle: true
160
+ subsample_ratio: 1.0
161
+ video_backend: torchcodec
162
+ load_config_path: null
163
+ model: !!python/object:gr00t.configs.model.gr00t_n1d6.Gr00tN1d6Config
164
+ _attn_implementation_autoset: false
165
+ _attn_implementation_internal: null
166
+ _commit_hash: null
167
+ _name_or_path: ''
168
+ add_cross_attention: false
169
+ architectures: null
170
+ backbone_model_type: eagle
171
+ backbone_trainable_params_fp32: true
172
+ bad_words_ids: null
173
+ begin_suppress_tokens: null
174
+ bos_token_id: null
175
+ chunk_size_feed_forward: 0
176
+ color_jitter_params:
177
+ brightness: 0.3
178
+ contrast: 0.4
179
+ hue: 0.08
180
+ saturation: 0.5
181
+ cross_attention_hidden_size: null
182
+ decoder_start_token_id: null
183
+ diffusion_model_cfg:
184
+ attention_head_dim: 48
185
+ dropout: 0.2
186
+ final_dropout: true
187
+ interleave_self_attention: true
188
+ norm_type: ada_norm
189
+ num_attention_heads: 32
190
+ num_layers: 32
191
+ output_dim: 1024
192
+ positional_embeddings: null
193
+ diversity_penalty: 0.0
194
+ do_sample: false
195
+ eagle_collator: true
196
+ early_stopping: false
197
+ encoder_no_repeat_ngram_size: 0
198
+ eos_token_id: null
199
+ exponential_decay_length_penalty: null
200
+ finetuning_task: null
201
+ forced_bos_token_id: null
202
+ forced_eos_token_id: null
203
+ id2label:
204
+ 0: LABEL_0
205
+ 1: LABEL_1
206
+ is_decoder: false
207
+ is_encoder_decoder: false
208
+ label2id:
209
+ LABEL_0: 0
210
+ LABEL_1: 1
211
+ length_penalty: 1.0
212
+ load_bf16: false
213
+ max_length: 20
214
+ min_length: 0
215
+ model_name: nvidia/Eagle-Block2A-2B-v2
216
+ no_repeat_ngram_size: 0
217
+ num_beam_groups: 1
218
+ num_beams: 1
219
+ num_return_sequences: 1
220
+ output_attentions: false
221
+ output_hidden_states: false
222
+ output_scores: false
223
+ pad_token_id: null
224
+ prefix: null
225
+ problem_type: null
226
+ pruned_heads: {}
227
+ random_rotation_angle: null
228
+ remove_invalid_values: false
229
+ repetition_penalty: 1.0
230
+ reproject_vision: false
231
+ return_dict: true
232
+ return_dict_in_generate: false
233
+ sep_token_id: null
234
+ state_dropout_prob: 0.0
235
+ suppress_tokens: null
236
+ task_specific_params: null
237
+ temperature: 1.0
238
+ tf_legacy_loss: false
239
+ tie_encoder_decoder: false
240
+ tie_word_embeddings: true
241
+ tokenizer_class: null
242
+ top_k: 50
243
+ top_p: 1.0
244
+ torch_dtype: null
245
+ torchscript: false
246
+ transformers_version: null
247
+ tune_diffusion_model: true
248
+ tune_llm: false
249
+ tune_projector: true
250
+ tune_visual: true
251
+ typical_p: 1.0
252
+ use_bfloat16: false
253
+ use_relative_action: true
254
+ training: !!python/object:gr00t.configs.training.training_config.TrainingConfig
255
+ add_rl_callback: false
256
+ assert_loss_less_than: null
257
+ batch_size: null
258
+ bf16: true
259
+ dataloader_num_workers: 16
260
+ ddp_bucket_cap_mb: 100
261
+ deepspeed_stage: 2
262
+ enable_open_loop_eval: false
263
+ enable_profiling: false
264
+ eval_batch_size: 2
265
+ eval_bf16: true
266
+ eval_set_split_ratio: 0.1
267
+ eval_steps: 500
268
+ eval_strategy: 'no'
269
+ experiment_name: null
270
+ fp16: false
271
+ global_batch_size: 192
272
+ gradient_accumulation_steps: 1
273
+ gradient_checkpointing: false
274
+ learning_rate: 0.0001
275
+ logging_steps: 10
276
+ lr_scheduler_type: cosine
277
+ max_concurrent_uploads: 2
278
+ max_grad_norm: 1.0
279
+ max_retries: 3
280
+ max_steps: 20000
281
+ num_gpus: 8
282
+ open_loop_eval_plot_indices: null
283
+ open_loop_eval_steps_per_traj: 100
284
+ open_loop_eval_traj_ids:
285
+ - 0
286
+ optim: adamw_torch
287
+ output_dir: /models/isaaclab_arena/locomanipulation_tutorial
288
+ remove_unused_columns: false
289
+ save_best_eval_metric_greater_is_better: true
290
+ save_best_eval_metric_name: ''
291
+ save_steps: 5000
292
+ save_total_limit: 5
293
+ save_vl_model: false
294
+ start_from_checkpoint: nvidia/GR00T-N1.6-3B
295
+ tf32: true
296
+ transformers_access_token: null
297
+ transformers_cache_dir: null
298
+ transformers_local_files_only: false
299
+ transformers_trust_remote_code: true
300
+ upload_checkpoints: false
301
+ upload_every: 1000
302
+ upload_last_n_checkpoints: 5
303
+ use_ddp: false
304
+ use_wandb: false
305
+ wandb_project: finetune-gr00t-n1d6
306
+ warmup_ratio: 0.05
307
+ warmup_steps: 0
308
+ weight_decay: 1.0e-05
checkpoint-15000/experiment_cfg/dataset_statistics.json ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "new_embodiment": {
3
+ "state": {
4
+ "left_arm": {
5
+ "min": [
6
+ -1.2616037130355835,
7
+ -0.29025015234947205,
8
+ -0.22703997790813446,
9
+ -0.3353549540042877,
10
+ -0.0829518586397171,
11
+ -0.8195276260375977,
12
+ -0.2688920795917511
13
+ ],
14
+ "max": [
15
+ 0.15299034118652344,
16
+ 0.4194548726081848,
17
+ 0.304278701543808,
18
+ 1.4247486591339111,
19
+ 0.751840353012085,
20
+ 0.6736590266227722,
21
+ 0.569625973701477
22
+ ],
23
+ "mean": [
24
+ -0.6218094229698181,
25
+ -0.03578367084264755,
26
+ 0.05471671372652054,
27
+ 0.3273524045944214,
28
+ 0.16905353963375092,
29
+ 0.1931331604719162,
30
+ 0.0418560616672039
31
+ ],
32
+ "std": [
33
+ 0.2542016804218292,
34
+ 0.08585234731435776,
35
+ 0.05442973971366882,
36
+ 0.3563520908355713,
37
+ 0.10547080636024475,
38
+ 0.21155740320682526,
39
+ 0.0815652459859848
40
+ ],
41
+ "q01": [
42
+ -1.0867726147174834,
43
+ -0.23316791355609895,
44
+ -0.06077688504010439,
45
+ -0.2531130000948906,
46
+ -0.025190447550266983,
47
+ -0.41234332919120786,
48
+ -0.14684838354587554
49
+ ],
50
+ "q99": [
51
+ 0.02166599538177228,
52
+ 0.16592777222394936,
53
+ 0.19437864869832985,
54
+ 1.3526465594768522,
55
+ 0.47515065073966933,
56
+ 0.6158077389001846,
57
+ 0.267849366366863
58
+ ]
59
+ },
60
+ "right_arm": {
61
+ "min": [
62
+ -0.9889344573020935,
63
+ -0.7240632772445679,
64
+ -0.4150152802467346,
65
+ -0.2197991907596588,
66
+ -0.44296473264694214,
67
+ -0.9651272296905518,
68
+ -0.4595109820365906
69
+ ],
70
+ "max": [
71
+ 0.15951132774353027,
72
+ 0.21149154007434845,
73
+ 0.13221219182014465,
74
+ 1.4304473400115967,
75
+ 0.6581774950027466,
76
+ 0.33145904541015625,
77
+ 0.42284855246543884
78
+ ],
79
+ "mean": [
80
+ -0.5138179659843445,
81
+ -0.07899317145347595,
82
+ -0.1299561709165573,
83
+ 0.40922680497169495,
84
+ 0.027388907968997955,
85
+ -0.0835803970694542,
86
+ 0.024336807429790497
87
+ ],
88
+ "std": [
89
+ 0.1910795420408249,
90
+ 0.10697221755981445,
91
+ 0.0633271336555481,
92
+ 0.2594990134239197,
93
+ 0.14704135060310364,
94
+ 0.15591612458229065,
95
+ 0.06830708682537079
96
+ ],
97
+ "q01": [
98
+ -0.83366958796978,
99
+ -0.38898577094078063,
100
+ -0.27746869176626204,
101
+ -0.12615955173969268,
102
+ -0.2731088250875473,
103
+ -0.6371771156787872,
104
+ -0.16048517003655433
105
+ ],
106
+ "q99": [
107
+ 0.019438467640429113,
108
+ 0.13264653384685496,
109
+ 0.03749443646520371,
110
+ 1.3000927805900555,
111
+ 0.3483726784586904,
112
+ 0.12948824167251569,
113
+ 0.168773318082094
114
+ ]
115
+ },
116
+ "left_hand": {
117
+ "min": [
118
+ -0.008645662106573582,
119
+ -0.0016571161104366183,
120
+ -0.008173327893018723,
121
+ -0.0033370573073625565,
122
+ -0.049815986305475235,
123
+ -0.13737092912197113,
124
+ -8.590802735852776e-09
125
+ ],
126
+ "max": [
127
+ 8.85741064848844e-06,
128
+ 1.4383874713530531e-06,
129
+ 7.31344407540746e-05,
130
+ 4.420346158440225e-05,
131
+ 0.026730380952358246,
132
+ 0.06749135255813599,
133
+ 0.004176338668912649
134
+ ],
135
+ "mean": [
136
+ -0.00045161443995311856,
137
+ -9.045441402122378e-05,
138
+ -0.0008751734858378768,
139
+ -0.00010305152682121843,
140
+ -0.0026190115604549646,
141
+ -0.0007728625205345452,
142
+ 3.4298220271011814e-05
143
+ ],
144
+ "std": [
145
+ 0.0010219421237707138,
146
+ 0.00011942393030039966,
147
+ 0.0011946671875193715,
148
+ 0.00021070965158287436,
149
+ 0.004766007885336876,
150
+ 0.008314870297908783,
151
+ 0.00020773601136170328
152
+ ],
153
+ "q01": [
154
+ -0.004614621866494417,
155
+ -0.0005385997559642419,
156
+ -0.004787646210752427,
157
+ -0.0012936698796693236,
158
+ -0.01875622048974037,
159
+ -0.03178232274949551,
160
+ -2.9993839079089924e-10
161
+ ],
162
+ "q99": [
163
+ 1.4417540605826582e-09,
164
+ -5.172329953229189e-10,
165
+ -2.493637962786175e-10,
166
+ -6.717705641756689e-10,
167
+ 0.008347299136221403,
168
+ 0.012830186681821834,
169
+ 0.0014548563922289215
170
+ ]
171
+ },
172
+ "right_hand": {
173
+ "min": [
174
+ -1.5373115047623287e-07,
175
+ -2.7022052151437492e-08,
176
+ -2.0592709915945306e-05,
177
+ -7.066118541843025e-06,
178
+ -0.03601590916514397,
179
+ -0.5857902765274048,
180
+ -0.3214021623134613
181
+ ],
182
+ "max": [
183
+ 0.006290650460869074,
184
+ 0.001731343101710081,
185
+ 0.017454728484153748,
186
+ 0.012643150985240936,
187
+ 0.09934248775243759,
188
+ 0.0994623526930809,
189
+ 3.1769886277288606e-08
190
+ ],
191
+ "mean": [
192
+ 0.00025306272436864674,
193
+ 5.4000069212634116e-05,
194
+ 0.0003351480991113931,
195
+ 0.0008108046022243798,
196
+ 0.0006079890299588442,
197
+ -0.006738435477018356,
198
+ -0.00452095502987504
199
+ ],
200
+ "std": [
201
+ 0.0006930792587809265,
202
+ 0.00016116801998578012,
203
+ 0.0007848768145777285,
204
+ 0.0014818455092608929,
205
+ 0.009566166438162327,
206
+ 0.05241963639855385,
207
+ 0.030341269448399544
208
+ ],
209
+ "q01": [
210
+ -1.1203826366656955e-09,
211
+ 5.471793157463268e-10,
212
+ -7.516792688289087e-10,
213
+ 1.7157600895600922e-10,
214
+ -0.008333299728110432,
215
+ -0.3553843080997467,
216
+ -0.20837910920381547
217
+ ],
218
+ "q99": [
219
+ 0.0038171554915606976,
220
+ 0.0008218895673053339,
221
+ 0.003914117161184549,
222
+ 0.005107918474823237,
223
+ 0.061319448240101194,
224
+ 0.009818258183076798,
225
+ 3.1323699190011206e-10
226
+ ]
227
+ },
228
+ "waist": {
229
+ "min": [
230
+ -0.04632357507944107,
231
+ -0.11110502481460571,
232
+ -0.036814406514167786
233
+ ],
234
+ "max": [
235
+ 0.0633544921875,
236
+ 0.11162503063678741,
237
+ 0.1282370686531067
238
+ ],
239
+ "mean": [
240
+ 0.002279821317642927,
241
+ -0.0016866918886080384,
242
+ 0.05629865825176239
243
+ ],
244
+ "std": [
245
+ 0.019741930067539215,
246
+ 0.04374425858259201,
247
+ 0.023172633722424507
248
+ ],
249
+ "q01": [
250
+ -0.039197818748652934,
251
+ -0.09254500381648541,
252
+ -0.020507800113409757
253
+ ],
254
+ "q99": [
255
+ 0.054476964659988844,
256
+ 0.09499521441757679,
257
+ 0.10415777899324889
258
+ ]
259
+ }
260
+ },
261
+ "action": {
262
+ "left_arm": {
263
+ "min": [
264
+ -1.348067283630371,
265
+ -0.3527751564979553,
266
+ -0.3787360191345215,
267
+ -0.625663697719574,
268
+ -0.09716995060443878,
269
+ -0.9718959331512451,
270
+ -0.41488397121429443
271
+ ],
272
+ "max": [
273
+ 0.1336316466331482,
274
+ 0.4716266393661499,
275
+ 0.30831149220466614,
276
+ 1.4016180038452148,
277
+ 0.9397326111793518,
278
+ 0.6476842761039734,
279
+ 0.8313083648681641
280
+ ],
281
+ "mean": [
282
+ -0.6952570080757141,
283
+ -0.0709061548113823,
284
+ -0.04288463667035103,
285
+ 0.2694568634033203,
286
+ 0.1649714857339859,
287
+ 0.13536368310451508,
288
+ -0.02554020844399929
289
+ ],
290
+ "std": [
291
+ 0.26363858580589294,
292
+ 0.10477105528116226,
293
+ 0.07000378519296646,
294
+ 0.3648890554904938,
295
+ 0.11654239892959595,
296
+ 0.2099701166152954,
297
+ 0.08394794911146164
298
+ ],
299
+ "q01": [
300
+ -1.1805148243904113,
301
+ -0.308816134929657,
302
+ -0.17785422429442405,
303
+ -0.3138654500246048,
304
+ -0.05110809002071619,
305
+ -0.4920081451535225,
306
+ -0.1742709159851074
307
+ ],
308
+ "q99": [
309
+ -0.008620778424665838,
310
+ 0.20248875990509888,
311
+ 0.17697372585535032,
312
+ 1.284248530864715,
313
+ 0.522044214606285,
314
+ 0.5478375405073164,
315
+ 0.24634651243686412
316
+ ]
317
+ },
318
+ "right_arm": {
319
+ "min": [
320
+ -1.0777442455291748,
321
+ -0.7950155735015869,
322
+ -0.4215357005596161,
323
+ -0.33741918206214905,
324
+ -0.5877293348312378,
325
+ -1.0788743495941162,
326
+ -0.573306679725647
327
+ ],
328
+ "max": [
329
+ 0.14458219707012177,
330
+ 0.31825390458106995,
331
+ 0.3697803318500519,
332
+ 1.4193015098571777,
333
+ 0.6486993432044983,
334
+ 0.28742435574531555,
335
+ 0.49852707982063293
336
+ ],
337
+ "mean": [
338
+ -0.604250967502594,
339
+ -0.0556945763528347,
340
+ -0.03765946254134178,
341
+ 0.30660828948020935,
342
+ 0.01742653176188469,
343
+ -0.16916987299919128,
344
+ 0.09518744796514511
345
+ ],
346
+ "std": [
347
+ 0.20923613011837006,
348
+ 0.12663093209266663,
349
+ 0.08735905587673187,
350
+ 0.2593192756175995,
351
+ 0.15945474803447723,
352
+ 0.16604292392730713,
353
+ 0.07976584881544113
354
+ ],
355
+ "q01": [
356
+ -0.9175809919834137,
357
+ -0.5007677406072617,
358
+ -0.21304122656583785,
359
+ -0.21431435346603395,
360
+ -0.2938103020191193,
361
+ -0.7407654404640198,
362
+ -0.1693093843758106
363
+ ],
364
+ "q99": [
365
+ -0.011969150230289034,
366
+ 0.1981081753969192,
367
+ 0.14730184450745581,
368
+ 1.2670192122459407,
369
+ 0.3571772933006279,
370
+ 0.07727374359965306,
371
+ 0.24925321042537663
372
+ ]
373
+ },
374
+ "left_hand": {
375
+ "min": [
376
+ 0.0,
377
+ 0.0,
378
+ 0.0,
379
+ 0.0,
380
+ 0.0,
381
+ 0.0,
382
+ 0.0
383
+ ],
384
+ "max": [
385
+ 0.0,
386
+ 0.0,
387
+ 0.0,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.0
392
+ ],
393
+ "mean": [
394
+ 0.0,
395
+ 0.0,
396
+ 0.0,
397
+ 0.0,
398
+ 0.0,
399
+ 0.0,
400
+ 0.0
401
+ ],
402
+ "std": [
403
+ 0.0,
404
+ 0.0,
405
+ 0.0,
406
+ 0.0,
407
+ 0.0,
408
+ 0.0,
409
+ 0.0
410
+ ],
411
+ "q01": [
412
+ 0.0,
413
+ 0.0,
414
+ 0.0,
415
+ 0.0,
416
+ 0.0,
417
+ 0.0,
418
+ 0.0
419
+ ],
420
+ "q99": [
421
+ 0.0,
422
+ 0.0,
423
+ 0.0,
424
+ 0.0,
425
+ 0.0,
426
+ 0.0,
427
+ 0.0
428
+ ]
429
+ },
430
+ "right_hand": {
431
+ "min": [
432
+ -0.0,
433
+ -0.0,
434
+ -0.0,
435
+ -0.0,
436
+ -0.0,
437
+ -0.0,
438
+ -0.0
439
+ ],
440
+ "max": [
441
+ -0.0,
442
+ -0.0,
443
+ -0.0,
444
+ -0.0,
445
+ -0.0,
446
+ -0.0,
447
+ -0.0
448
+ ],
449
+ "mean": [
450
+ 0.0,
451
+ 0.0,
452
+ 0.0,
453
+ 0.0,
454
+ 0.0,
455
+ 0.0,
456
+ 0.0
457
+ ],
458
+ "std": [
459
+ 0.0,
460
+ 0.0,
461
+ 0.0,
462
+ 0.0,
463
+ 0.0,
464
+ 0.0,
465
+ 0.0
466
+ ],
467
+ "q01": [
468
+ 0.0,
469
+ 0.0,
470
+ 0.0,
471
+ 0.0,
472
+ 0.0,
473
+ 0.0,
474
+ 0.0
475
+ ],
476
+ "q99": [
477
+ -0.0,
478
+ -0.0,
479
+ -0.0,
480
+ -0.0,
481
+ -0.0,
482
+ -0.0,
483
+ -0.0
484
+ ]
485
+ },
486
+ "waist": {
487
+ "min": [
488
+ -0.03817012533545494,
489
+ -0.14767035841941833,
490
+ -0.09924878180027008
491
+ ],
492
+ "max": [
493
+ 0.05044477432966232,
494
+ 0.13773855566978455,
495
+ 0.10575182735919952
496
+ ],
497
+ "mean": [
498
+ 0.0021713885944336653,
499
+ -0.006043997593224049,
500
+ -0.0009960572933778167
501
+ ],
502
+ "std": [
503
+ 0.01315564289689064,
504
+ 0.04625461995601654,
505
+ 0.0275924950838089
506
+ ],
507
+ "q01": [
508
+ -0.02857382604852319,
509
+ -0.1123543307185173,
510
+ -0.09090777784585953
511
+ ],
512
+ "q99": [
513
+ 0.04313158672302961,
514
+ 0.1042894288897514,
515
+ 0.06339201703667638
516
+ ]
517
+ },
518
+ "base_height_command": {
519
+ "min": [
520
+ 0.6000000238418579
521
+ ],
522
+ "max": [
523
+ 0.75
524
+ ],
525
+ "mean": [
526
+ 0.7374278903007507
527
+ ],
528
+ "std": [
529
+ 0.039233911782502955
530
+ ],
531
+ "q01": [
532
+ 0.6000000238418579
533
+ ],
534
+ "q99": [
535
+ 0.75
536
+ ]
537
+ },
538
+ "navigate_command": {
539
+ "min": [
540
+ 0.0,
541
+ -0.12772086262702942,
542
+ -0.4000000059604645
543
+ ],
544
+ "max": [
545
+ 0.4000000059604645,
546
+ 0.15753206610679626,
547
+ 0.10000000149011612
548
+ ],
549
+ "mean": [
550
+ 0.10862857103347778,
551
+ 0.006709238979965448,
552
+ -0.08270397037267685
553
+ ],
554
+ "std": [
555
+ 0.17079046368598938,
556
+ 0.035745956003665924,
557
+ 0.1377689093351364
558
+ ],
559
+ "q01": [
560
+ 0.0,
561
+ -0.06209215875715017,
562
+ -0.4000000059604645
563
+ ],
564
+ "q99": [
565
+ 0.4000000059604645,
566
+ 0.10000000149011612,
567
+ 0.004937881324440136
568
+ ]
569
+ }
570
+ },
571
+ "relative_action": {}
572
+ }
573
+ }
checkpoint-15000/experiment_cfg/final_model_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "Gr00tN1d6",
3
+ "model_dtype": "bfloat16",
4
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
5
+ "backbone_model_type": "eagle",
6
+ "model_revision": null,
7
+ "tune_top_llm_layers": 4,
8
+ "backbone_embedding_dim": 2048,
9
+ "tune_llm": false,
10
+ "tune_visual": true,
11
+ "select_layer": 16,
12
+ "reproject_vision": false,
13
+ "use_flash_attention": true,
14
+ "load_bf16": true,
15
+ "collator_overwrite_image_inputs": false,
16
+ "eagle_collator": true,
17
+ "backbone_trainable_params_fp32": true,
18
+ "apply_sincos_state_encoding": true,
19
+ "use_relative_action": true,
20
+ "max_state_dim": 128,
21
+ "max_action_dim": 128,
22
+ "action_horizon": 50,
23
+ "hidden_size": 1024,
24
+ "input_embedding_dim": 1536,
25
+ "add_pos_embed": true,
26
+ "attn_dropout": 0.2,
27
+ "use_vlln": true,
28
+ "max_seq_len": 1024,
29
+ "use_alternate_vl_dit": true,
30
+ "attend_text_every_n_blocks": 2,
31
+ "diffusion_model_cfg": {
32
+ "attention_head_dim": 48,
33
+ "dropout": 0.2,
34
+ "final_dropout": true,
35
+ "interleave_self_attention": true,
36
+ "norm_type": "ada_norm",
37
+ "num_attention_heads": 32,
38
+ "num_layers": 32,
39
+ "output_dim": 1024,
40
+ "positional_embeddings": null
41
+ },
42
+ "num_inference_timesteps": 4,
43
+ "noise_beta_alpha": 1.5,
44
+ "noise_beta_beta": 1.0,
45
+ "noise_s": 0.999,
46
+ "num_timestep_buckets": 1000,
47
+ "tune_projector": true,
48
+ "tune_diffusion_model": true,
49
+ "tune_vlln": true,
50
+ "state_dropout_prob": 0.0,
51
+ "state_additive_noise_scale": 0.0,
52
+ "max_num_embodiments": 32
53
+ }
checkpoint-15000/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step15000
checkpoint-15000/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-15000/processor_config.json ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "Gr00tN1d6Processor",
3
+ "processor_kwargs": {
4
+ "modality_configs": {
5
+ "behavior_r1_pro": {
6
+ "video": {
7
+ "delta_indices": [
8
+ 0
9
+ ],
10
+ "modality_keys": [
11
+ "observation.images.rgb.head_256_256",
12
+ "observation.images.rgb.left_wrist_256_256",
13
+ "observation.images.rgb.right_wrist_256_256"
14
+ ],
15
+ "sin_cos_embedding_keys": null,
16
+ "mean_std_embedding_keys": null,
17
+ "action_configs": null
18
+ },
19
+ "state": {
20
+ "delta_indices": [
21
+ 0
22
+ ],
23
+ "modality_keys": [
24
+ "robot_pos",
25
+ "robot_ori_cos",
26
+ "robot_ori_sin",
27
+ "robot_2d_ori",
28
+ "robot_2d_ori_cos",
29
+ "robot_2d_ori_sin",
30
+ "robot_lin_vel",
31
+ "robot_ang_vel",
32
+ "arm_left_qpos",
33
+ "arm_left_qpos_sin",
34
+ "arm_left_qpos_cos",
35
+ "eef_left_pos",
36
+ "eef_left_quat",
37
+ "gripper_left_qpos",
38
+ "arm_right_qpos",
39
+ "arm_right_qpos_sin",
40
+ "arm_right_qpos_cos",
41
+ "eef_right_pos",
42
+ "eef_right_quat",
43
+ "gripper_right_qpos",
44
+ "trunk_qpos"
45
+ ],
46
+ "sin_cos_embedding_keys": null,
47
+ "mean_std_embedding_keys": null,
48
+ "action_configs": null
49
+ },
50
+ "action": {
51
+ "delta_indices": [
52
+ 0,
53
+ 1,
54
+ 2,
55
+ 3,
56
+ 4,
57
+ 5,
58
+ 6,
59
+ 7,
60
+ 8,
61
+ 9,
62
+ 10,
63
+ 11,
64
+ 12,
65
+ 13,
66
+ 14,
67
+ 15,
68
+ 16,
69
+ 17,
70
+ 18,
71
+ 19,
72
+ 20,
73
+ 21,
74
+ 22,
75
+ 23,
76
+ 24,
77
+ 25,
78
+ 26,
79
+ 27,
80
+ 28,
81
+ 29,
82
+ 30,
83
+ 31
84
+ ],
85
+ "modality_keys": [
86
+ "base",
87
+ "torso",
88
+ "left_arm",
89
+ "left_gripper",
90
+ "right_arm",
91
+ "right_gripper"
92
+ ],
93
+ "sin_cos_embedding_keys": null,
94
+ "mean_std_embedding_keys": null,
95
+ "action_configs": [
96
+ {
97
+ "rep": "ABSOLUTE",
98
+ "type": "NON_EEF",
99
+ "format": "DEFAULT",
100
+ "state_key": null
101
+ },
102
+ {
103
+ "rep": "RELATIVE",
104
+ "type": "NON_EEF",
105
+ "format": "DEFAULT",
106
+ "state_key": "trunk_qpos"
107
+ },
108
+ {
109
+ "rep": "RELATIVE",
110
+ "type": "NON_EEF",
111
+ "format": "DEFAULT",
112
+ "state_key": "arm_left_qpos"
113
+ },
114
+ {
115
+ "rep": "ABSOLUTE",
116
+ "type": "NON_EEF",
117
+ "format": "DEFAULT",
118
+ "state_key": null
119
+ },
120
+ {
121
+ "rep": "RELATIVE",
122
+ "type": "NON_EEF",
123
+ "format": "DEFAULT",
124
+ "state_key": "arm_right_qpos"
125
+ },
126
+ {
127
+ "rep": "ABSOLUTE",
128
+ "type": "NON_EEF",
129
+ "format": "DEFAULT",
130
+ "state_key": null
131
+ }
132
+ ]
133
+ },
134
+ "language": {
135
+ "delta_indices": [
136
+ 0
137
+ ],
138
+ "modality_keys": [
139
+ "annotation.human.coarse_action"
140
+ ],
141
+ "sin_cos_embedding_keys": null,
142
+ "mean_std_embedding_keys": null,
143
+ "action_configs": null
144
+ }
145
+ },
146
+ "gr1": {
147
+ "video": {
148
+ "delta_indices": [
149
+ 0
150
+ ],
151
+ "modality_keys": [
152
+ "ego_view_bg_crop_pad_res256_freq20"
153
+ ],
154
+ "sin_cos_embedding_keys": null,
155
+ "mean_std_embedding_keys": null,
156
+ "action_configs": null
157
+ },
158
+ "state": {
159
+ "delta_indices": [
160
+ 0
161
+ ],
162
+ "modality_keys": [
163
+ "left_arm",
164
+ "right_arm",
165
+ "left_hand",
166
+ "right_hand",
167
+ "waist"
168
+ ],
169
+ "sin_cos_embedding_keys": [
170
+ "left_arm",
171
+ "right_arm",
172
+ "left_hand",
173
+ "right_hand",
174
+ "waist"
175
+ ],
176
+ "mean_std_embedding_keys": null,
177
+ "action_configs": null
178
+ },
179
+ "action": {
180
+ "delta_indices": [
181
+ 0,
182
+ 1,
183
+ 2,
184
+ 3,
185
+ 4,
186
+ 5,
187
+ 6,
188
+ 7,
189
+ 8,
190
+ 9,
191
+ 10,
192
+ 11,
193
+ 12,
194
+ 13,
195
+ 14,
196
+ 15
197
+ ],
198
+ "modality_keys": [
199
+ "left_arm",
200
+ "right_arm",
201
+ "left_hand",
202
+ "right_hand",
203
+ "waist"
204
+ ],
205
+ "sin_cos_embedding_keys": null,
206
+ "mean_std_embedding_keys": null,
207
+ "action_configs": [
208
+ {
209
+ "rep": "RELATIVE",
210
+ "type": "NON_EEF",
211
+ "format": "DEFAULT",
212
+ "state_key": null
213
+ },
214
+ {
215
+ "rep": "RELATIVE",
216
+ "type": "NON_EEF",
217
+ "format": "DEFAULT",
218
+ "state_key": null
219
+ },
220
+ {
221
+ "rep": "RELATIVE",
222
+ "type": "NON_EEF",
223
+ "format": "DEFAULT",
224
+ "state_key": null
225
+ },
226
+ {
227
+ "rep": "RELATIVE",
228
+ "type": "NON_EEF",
229
+ "format": "DEFAULT",
230
+ "state_key": null
231
+ },
232
+ {
233
+ "rep": "ABSOLUTE",
234
+ "type": "NON_EEF",
235
+ "format": "DEFAULT",
236
+ "state_key": null
237
+ }
238
+ ]
239
+ },
240
+ "language": {
241
+ "delta_indices": [
242
+ 0
243
+ ],
244
+ "modality_keys": [
245
+ "task"
246
+ ],
247
+ "sin_cos_embedding_keys": null,
248
+ "mean_std_embedding_keys": null,
249
+ "action_configs": null
250
+ }
251
+ },
252
+ "robocasa_panda_omron": {
253
+ "video": {
254
+ "delta_indices": [
255
+ 0
256
+ ],
257
+ "modality_keys": [
258
+ "res256_image_side_0",
259
+ "res256_image_side_1",
260
+ "res256_image_wrist_0"
261
+ ],
262
+ "sin_cos_embedding_keys": null,
263
+ "mean_std_embedding_keys": null,
264
+ "action_configs": null
265
+ },
266
+ "state": {
267
+ "delta_indices": [
268
+ 0
269
+ ],
270
+ "modality_keys": [
271
+ "end_effector_position_relative",
272
+ "end_effector_rotation_relative",
273
+ "gripper_qpos",
274
+ "base_position",
275
+ "base_rotation"
276
+ ],
277
+ "sin_cos_embedding_keys": null,
278
+ "mean_std_embedding_keys": null,
279
+ "action_configs": null
280
+ },
281
+ "action": {
282
+ "delta_indices": [
283
+ 0,
284
+ 1,
285
+ 2,
286
+ 3,
287
+ 4,
288
+ 5,
289
+ 6,
290
+ 7,
291
+ 8,
292
+ 9,
293
+ 10,
294
+ 11,
295
+ 12,
296
+ 13,
297
+ 14,
298
+ 15
299
+ ],
300
+ "modality_keys": [
301
+ "end_effector_position",
302
+ "end_effector_rotation",
303
+ "gripper_close",
304
+ "base_motion",
305
+ "control_mode"
306
+ ],
307
+ "sin_cos_embedding_keys": null,
308
+ "mean_std_embedding_keys": null,
309
+ "action_configs": [
310
+ {
311
+ "rep": "ABSOLUTE",
312
+ "type": "NON_EEF",
313
+ "format": "DEFAULT",
314
+ "state_key": null
315
+ },
316
+ {
317
+ "rep": "ABSOLUTE",
318
+ "type": "NON_EEF",
319
+ "format": "DEFAULT",
320
+ "state_key": null
321
+ },
322
+ {
323
+ "rep": "ABSOLUTE",
324
+ "type": "NON_EEF",
325
+ "format": "DEFAULT",
326
+ "state_key": null
327
+ },
328
+ {
329
+ "rep": "ABSOLUTE",
330
+ "type": "NON_EEF",
331
+ "format": "DEFAULT",
332
+ "state_key": null
333
+ },
334
+ {
335
+ "rep": "ABSOLUTE",
336
+ "type": "NON_EEF",
337
+ "format": "DEFAULT",
338
+ "state_key": null
339
+ }
340
+ ]
341
+ },
342
+ "language": {
343
+ "delta_indices": [
344
+ 0
345
+ ],
346
+ "modality_keys": [
347
+ "annotation.human.action.task_description"
348
+ ],
349
+ "sin_cos_embedding_keys": null,
350
+ "mean_std_embedding_keys": null,
351
+ "action_configs": null
352
+ }
353
+ },
354
+ "new_embodiment": {
355
+ "video": {
356
+ "delta_indices": [
357
+ 0
358
+ ],
359
+ "modality_keys": [
360
+ "ego_view"
361
+ ],
362
+ "sin_cos_embedding_keys": null,
363
+ "mean_std_embedding_keys": null,
364
+ "action_configs": null
365
+ },
366
+ "state": {
367
+ "delta_indices": [
368
+ 0
369
+ ],
370
+ "modality_keys": [
371
+ "left_arm",
372
+ "right_arm",
373
+ "left_hand",
374
+ "right_hand",
375
+ "waist"
376
+ ],
377
+ "sin_cos_embedding_keys": null,
378
+ "mean_std_embedding_keys": null,
379
+ "action_configs": null
380
+ },
381
+ "action": {
382
+ "delta_indices": [
383
+ 0,
384
+ 1,
385
+ 2,
386
+ 3,
387
+ 4,
388
+ 5,
389
+ 6,
390
+ 7,
391
+ 8,
392
+ 9,
393
+ 10,
394
+ 11,
395
+ 12,
396
+ 13,
397
+ 14,
398
+ 15,
399
+ 16,
400
+ 17,
401
+ 18,
402
+ 19,
403
+ 20,
404
+ 21,
405
+ 22,
406
+ 23,
407
+ 24,
408
+ 25,
409
+ 26,
410
+ 27,
411
+ 28,
412
+ 29,
413
+ 30,
414
+ 31,
415
+ 32,
416
+ 33,
417
+ 34,
418
+ 35,
419
+ 36,
420
+ 37,
421
+ 38,
422
+ 39,
423
+ 40,
424
+ 41,
425
+ 42,
426
+ 43,
427
+ 44,
428
+ 45,
429
+ 46,
430
+ 47,
431
+ 48,
432
+ 49
433
+ ],
434
+ "modality_keys": [
435
+ "left_arm",
436
+ "right_arm",
437
+ "left_hand",
438
+ "right_hand",
439
+ "waist",
440
+ "base_height_command",
441
+ "navigate_command"
442
+ ],
443
+ "sin_cos_embedding_keys": null,
444
+ "mean_std_embedding_keys": null,
445
+ "action_configs": [
446
+ {
447
+ "rep": "ABSOLUTE",
448
+ "type": "NON_EEF",
449
+ "format": "DEFAULT",
450
+ "state_key": null
451
+ },
452
+ {
453
+ "rep": "ABSOLUTE",
454
+ "type": "NON_EEF",
455
+ "format": "DEFAULT",
456
+ "state_key": null
457
+ },
458
+ {
459
+ "rep": "ABSOLUTE",
460
+ "type": "NON_EEF",
461
+ "format": "DEFAULT",
462
+ "state_key": null
463
+ },
464
+ {
465
+ "rep": "ABSOLUTE",
466
+ "type": "NON_EEF",
467
+ "format": "DEFAULT",
468
+ "state_key": null
469
+ },
470
+ {
471
+ "rep": "ABSOLUTE",
472
+ "type": "NON_EEF",
473
+ "format": "DEFAULT",
474
+ "state_key": null
475
+ },
476
+ {
477
+ "rep": "ABSOLUTE",
478
+ "type": "NON_EEF",
479
+ "format": "DEFAULT",
480
+ "state_key": null
481
+ },
482
+ {
483
+ "rep": "ABSOLUTE",
484
+ "type": "NON_EEF",
485
+ "format": "DEFAULT",
486
+ "state_key": null
487
+ }
488
+ ]
489
+ },
490
+ "language": {
491
+ "delta_indices": [
492
+ 0
493
+ ],
494
+ "modality_keys": [
495
+ "annotation.human.task_description"
496
+ ],
497
+ "sin_cos_embedding_keys": null,
498
+ "mean_std_embedding_keys": null,
499
+ "action_configs": null
500
+ }
501
+ }
502
+ },
503
+ "image_crop_size": null,
504
+ "image_target_size": null,
505
+ "use_albumentations": true,
506
+ "random_rotation_angle": null,
507
+ "color_jitter_params": {
508
+ "brightness": 0.3,
509
+ "contrast": 0.4,
510
+ "saturation": 0.5,
511
+ "hue": 0.08
512
+ },
513
+ "shortest_image_edge": 256,
514
+ "crop_fraction": 0.95,
515
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
516
+ "model_type": "eagle",
517
+ "formalize_language": true,
518
+ "max_state_dim": 128,
519
+ "max_action_dim": 128,
520
+ "max_action_horizon": 50,
521
+ "use_percentiles": false,
522
+ "clip_outliers": true,
523
+ "apply_sincos_state_encoding": true,
524
+ "use_relative_action": true
525
+ }
526
+ }
checkpoint-15000/statistics.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-15000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-15000/wandb_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"project": "finetune-gr00t-n1d6", "run_id": "locomanipulation_tutorial"}
checkpoint-20000/config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_horizon": 50,
3
+ "add_pos_embed": true,
4
+ "apply_sincos_state_encoding": true,
5
+ "architectures": [
6
+ "Gr00tN1d6"
7
+ ],
8
+ "attn_dropout": 0.2,
9
+ "attn_implementation": null,
10
+ "backbone_embedding_dim": 2048,
11
+ "backbone_model_type": "eagle",
12
+ "backbone_trainable_params_fp32": true,
13
+ "collator_overwrite_image_inputs": false,
14
+ "color_jitter_params": {
15
+ "brightness": 0.1,
16
+ "contrast": 0.1,
17
+ "hue": 0.1,
18
+ "saturation": 0.1
19
+ },
20
+ "crop_fraction": 0.95,
21
+ "diffusion_model_cfg": {
22
+ "attention_head_dim": 48,
23
+ "dropout": 0.2,
24
+ "final_dropout": true,
25
+ "interleave_self_attention": true,
26
+ "norm_type": "ada_norm",
27
+ "num_attention_heads": 32,
28
+ "num_layers": 32,
29
+ "output_dim": 1024,
30
+ "positional_embeddings": null
31
+ },
32
+ "eagle_collator": true,
33
+ "formalize_language": true,
34
+ "gemma_collator": false,
35
+ "hidden_size": 1024,
36
+ "image_crop_size": null,
37
+ "image_target_size": null,
38
+ "input_embedding_dim": 1536,
39
+ "load_bf16": true,
40
+ "max_action_dim": 128,
41
+ "max_num_embodiments": 32,
42
+ "max_seq_len": 1024,
43
+ "max_state_dim": 128,
44
+ "model_dtype": "bfloat16",
45
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
46
+ "model_type": "Gr00tN1d6",
47
+ "noise_beta_alpha": 1.5,
48
+ "noise_beta_beta": 1.0,
49
+ "noise_s": 0.999,
50
+ "num_inference_timesteps": 4,
51
+ "num_timestep_buckets": 1000,
52
+ "random_rotation_angle": null,
53
+ "reproject_vision": false,
54
+ "select_layer": 16,
55
+ "shortest_image_edge": 256,
56
+ "state_dropout_prob": 0.0,
57
+ "torch_dtype": "bfloat16",
58
+ "transformers_version": "4.51.3",
59
+ "tune_diffusion_model": true,
60
+ "tune_llm": false,
61
+ "tune_projector": true,
62
+ "tune_top_llm_layers": 4,
63
+ "tune_visual": true,
64
+ "tune_vlln": true,
65
+ "use_albumentations_transforms": true,
66
+ "use_alternate_vl_dit": true,
67
+ "use_flash_attention": true,
68
+ "use_relative_action": true,
69
+ "use_vlln": true
70
+ }
checkpoint-20000/experiment_cfg/conf.yaml ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ load_config_path: null
2
+ model:
3
+ model_type: Gr00tN1d6
4
+ model_dtype: bfloat16
5
+ model_name: nvidia/Eagle-Block2A-2B-v2
6
+ backbone_model_type: eagle
7
+ model_revision: null
8
+ tune_top_llm_layers: 4
9
+ backbone_embedding_dim: 2048
10
+ tune_llm: false
11
+ tune_visual: true
12
+ select_layer: 16
13
+ reproject_vision: false
14
+ use_flash_attention: true
15
+ load_bf16: false
16
+ collator_overwrite_image_inputs: false
17
+ eagle_collator: true
18
+ backbone_trainable_params_fp32: true
19
+ image_crop_size: null
20
+ image_target_size: null
21
+ shortest_image_edge: 256
22
+ crop_fraction: 0.95
23
+ random_rotation_angle: null
24
+ color_jitter_params:
25
+ brightness: 0.3
26
+ contrast: 0.4
27
+ saturation: 0.5
28
+ hue: 0.08
29
+ use_albumentations_transforms: true
30
+ formalize_language: true
31
+ apply_sincos_state_encoding: false
32
+ use_relative_action: true
33
+ max_state_dim: 29
34
+ max_action_dim: 29
35
+ action_horizon: 16
36
+ hidden_size: 1024
37
+ input_embedding_dim: 1536
38
+ add_pos_embed: true
39
+ attn_dropout: 0.2
40
+ use_vlln: true
41
+ max_seq_len: 1024
42
+ use_alternate_vl_dit: true
43
+ attend_text_every_n_blocks: 2
44
+ diffusion_model_cfg:
45
+ positional_embeddings: null
46
+ num_layers: 32
47
+ num_attention_heads: 32
48
+ attention_head_dim: 48
49
+ norm_type: ada_norm
50
+ dropout: 0.2
51
+ final_dropout: true
52
+ output_dim: 1024
53
+ interleave_self_attention: true
54
+ num_inference_timesteps: 4
55
+ noise_beta_alpha: 1.5
56
+ noise_beta_beta: 1.0
57
+ noise_s: 0.999
58
+ num_timestep_buckets: 1000
59
+ tune_projector: true
60
+ tune_diffusion_model: true
61
+ tune_vlln: true
62
+ state_dropout_prob: 0.0
63
+ state_additive_noise_scale: 0.0
64
+ max_num_embodiments: 32
65
+ data:
66
+ datasets:
67
+ - dataset_paths:
68
+ - /datasets/isaaclab_arena/locomanipulation_tutorial/arena_g1_loco_manipulation_dataset_generated/lerobot
69
+ embodiment_tag: new_embodiment
70
+ mix_ratio: 1.0
71
+ dataset_type: physical_embodiment
72
+ val_dataset_path: null
73
+ modality_configs:
74
+ new_embodiment:
75
+ video:
76
+ delta_indices:
77
+ - 0
78
+ modality_keys:
79
+ - ego_view
80
+ sin_cos_embedding_keys: null
81
+ mean_std_embedding_keys: null
82
+ action_configs: null
83
+ state:
84
+ delta_indices:
85
+ - 0
86
+ modality_keys:
87
+ - left_arm
88
+ - right_arm
89
+ - left_hand
90
+ - right_hand
91
+ - waist
92
+ sin_cos_embedding_keys: null
93
+ mean_std_embedding_keys: null
94
+ action_configs: null
95
+ action:
96
+ delta_indices:
97
+ - 0
98
+ - 1
99
+ - 2
100
+ - 3
101
+ - 4
102
+ - 5
103
+ - 6
104
+ - 7
105
+ - 8
106
+ - 9
107
+ - 10
108
+ - 11
109
+ - 12
110
+ - 13
111
+ - 14
112
+ - 15
113
+ - 16
114
+ - 17
115
+ - 18
116
+ - 19
117
+ - 20
118
+ - 21
119
+ - 22
120
+ - 23
121
+ - 24
122
+ - 25
123
+ - 26
124
+ - 27
125
+ - 28
126
+ - 29
127
+ - 30
128
+ - 31
129
+ - 32
130
+ - 33
131
+ - 34
132
+ - 35
133
+ - 36
134
+ - 37
135
+ - 38
136
+ - 39
137
+ - 40
138
+ - 41
139
+ - 42
140
+ - 43
141
+ - 44
142
+ - 45
143
+ - 46
144
+ - 47
145
+ - 48
146
+ - 49
147
+ modality_keys:
148
+ - left_arm
149
+ - right_arm
150
+ - left_hand
151
+ - right_hand
152
+ - waist
153
+ - base_height_command
154
+ - navigate_command
155
+ sin_cos_embedding_keys: null
156
+ mean_std_embedding_keys: null
157
+ action_configs:
158
+ - rep: ABSOLUTE
159
+ type: NON_EEF
160
+ format: DEFAULT
161
+ state_key: null
162
+ - rep: ABSOLUTE
163
+ type: NON_EEF
164
+ format: DEFAULT
165
+ state_key: null
166
+ - rep: ABSOLUTE
167
+ type: NON_EEF
168
+ format: DEFAULT
169
+ state_key: null
170
+ - rep: ABSOLUTE
171
+ type: NON_EEF
172
+ format: DEFAULT
173
+ state_key: null
174
+ - rep: ABSOLUTE
175
+ type: NON_EEF
176
+ format: DEFAULT
177
+ state_key: null
178
+ - rep: ABSOLUTE
179
+ type: NON_EEF
180
+ format: DEFAULT
181
+ state_key: null
182
+ - rep: ABSOLUTE
183
+ type: NON_EEF
184
+ format: DEFAULT
185
+ state_key: null
186
+ language:
187
+ delta_indices:
188
+ - 0
189
+ modality_keys:
190
+ - annotation.human.task_description
191
+ sin_cos_embedding_keys: null
192
+ mean_std_embedding_keys: null
193
+ action_configs: null
194
+ download_cache: false
195
+ shard_size: 1024
196
+ episode_sampling_rate: 0.1
197
+ num_shards_per_epoch: 100000
198
+ override_pretraining_statistics: false
199
+ mode: single_turn
200
+ random_chop: 0.0
201
+ mock_dataset_mode: false
202
+ shuffle: true
203
+ seed: 42
204
+ multiprocessing_context: fork
205
+ allow_padding: false
206
+ subsample_ratio: 1.0
207
+ image_crop_size:
208
+ - 244
209
+ - 244
210
+ image_target_size:
211
+ - 224
212
+ - 224
213
+ video_backend: torchcodec
214
+ training:
215
+ output_dir: /models/isaaclab_arena/locomanipulation_tutorial
216
+ experiment_name: null
217
+ max_steps: 20000
218
+ global_batch_size: 192
219
+ batch_size: null
220
+ gradient_accumulation_steps: 1
221
+ learning_rate: 0.0001
222
+ lr_scheduler_type: cosine
223
+ weight_decay: 1.0e-05
224
+ warmup_ratio: 0.05
225
+ warmup_steps: 0
226
+ max_grad_norm: 1.0
227
+ optim: adamw_torch
228
+ start_from_checkpoint: nvidia/GR00T-N1.6-3B
229
+ tf32: true
230
+ fp16: false
231
+ bf16: true
232
+ eval_bf16: true
233
+ logging_steps: 10
234
+ save_steps: 5000
235
+ save_total_limit: 5
236
+ save_vl_model: false
237
+ upload_checkpoints: false
238
+ upload_every: 1000
239
+ upload_last_n_checkpoints: 5
240
+ max_concurrent_uploads: 2
241
+ eval_strategy: 'no'
242
+ eval_steps: 500
243
+ eval_set_split_ratio: 0.1
244
+ eval_batch_size: 2
245
+ save_best_eval_metric_name: ''
246
+ save_best_eval_metric_greater_is_better: true
247
+ deepspeed_stage: 2
248
+ gradient_checkpointing: false
249
+ transformers_trust_remote_code: true
250
+ transformers_local_files_only: false
251
+ transformers_cache_dir: null
252
+ transformers_access_token: null
253
+ use_ddp: false
254
+ ddp_bucket_cap_mb: 100
255
+ num_gpus: 8
256
+ dataloader_num_workers: 16
257
+ remove_unused_columns: false
258
+ use_wandb: false
259
+ wandb_project: finetune-gr00t-n1d6
260
+ enable_profiling: false
261
+ max_retries: 3
262
+ assert_loss_less_than: null
263
+ add_rl_callback: false
264
+ enable_open_loop_eval: false
265
+ open_loop_eval_traj_ids:
266
+ - 0
267
+ open_loop_eval_steps_per_traj: 100
268
+ open_loop_eval_plot_indices: null
269
+ max_steps: 20000
270
+ save_steps: 5000
checkpoint-20000/experiment_cfg/config.yaml ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !!python/object:gr00t.configs.base_config.Config
2
+ data: !!python/object:gr00t.configs.data.data_config.DataConfig
3
+ allow_padding: false
4
+ datasets:
5
+ - !!python/object:gr00t.configs.data.data_config.SingleDatasetConfig
6
+ dataset_paths:
7
+ - /datasets/isaaclab_arena/locomanipulation_tutorial/arena_g1_loco_manipulation_dataset_generated/lerobot
8
+ dataset_type: physical_embodiment
9
+ embodiment_tag: new_embodiment
10
+ mix_ratio: 1.0
11
+ val_dataset_path: null
12
+ download_cache: false
13
+ episode_sampling_rate: 0.1
14
+ image_crop_size:
15
+ - 244
16
+ - 244
17
+ image_target_size:
18
+ - 224
19
+ - 224
20
+ mock_dataset_mode: false
21
+ modality_configs:
22
+ new_embodiment:
23
+ action: !!python/object:gr00t.data.types.ModalityConfig
24
+ action_configs:
25
+ - !!python/object:gr00t.data.types.ActionConfig
26
+ format: &id001 !!python/object/apply:gr00t.data.types.ActionFormat
27
+ - default
28
+ rep: &id002 !!python/object/apply:gr00t.data.types.ActionRepresentation
29
+ - absolute
30
+ state_key: null
31
+ type: &id003 !!python/object/apply:gr00t.data.types.ActionType
32
+ - non_eef
33
+ - !!python/object:gr00t.data.types.ActionConfig
34
+ format: *id001
35
+ rep: *id002
36
+ state_key: null
37
+ type: *id003
38
+ - !!python/object:gr00t.data.types.ActionConfig
39
+ format: *id001
40
+ rep: *id002
41
+ state_key: null
42
+ type: *id003
43
+ - !!python/object:gr00t.data.types.ActionConfig
44
+ format: *id001
45
+ rep: *id002
46
+ state_key: null
47
+ type: *id003
48
+ - !!python/object:gr00t.data.types.ActionConfig
49
+ format: *id001
50
+ rep: *id002
51
+ state_key: null
52
+ type: *id003
53
+ - !!python/object:gr00t.data.types.ActionConfig
54
+ format: *id001
55
+ rep: *id002
56
+ state_key: null
57
+ type: *id003
58
+ - !!python/object:gr00t.data.types.ActionConfig
59
+ format: *id001
60
+ rep: *id002
61
+ state_key: null
62
+ type: *id003
63
+ delta_indices:
64
+ - 0
65
+ - 1
66
+ - 2
67
+ - 3
68
+ - 4
69
+ - 5
70
+ - 6
71
+ - 7
72
+ - 8
73
+ - 9
74
+ - 10
75
+ - 11
76
+ - 12
77
+ - 13
78
+ - 14
79
+ - 15
80
+ - 16
81
+ - 17
82
+ - 18
83
+ - 19
84
+ - 20
85
+ - 21
86
+ - 22
87
+ - 23
88
+ - 24
89
+ - 25
90
+ - 26
91
+ - 27
92
+ - 28
93
+ - 29
94
+ - 30
95
+ - 31
96
+ - 32
97
+ - 33
98
+ - 34
99
+ - 35
100
+ - 36
101
+ - 37
102
+ - 38
103
+ - 39
104
+ - 40
105
+ - 41
106
+ - 42
107
+ - 43
108
+ - 44
109
+ - 45
110
+ - 46
111
+ - 47
112
+ - 48
113
+ - 49
114
+ mean_std_embedding_keys: null
115
+ modality_keys:
116
+ - left_arm
117
+ - right_arm
118
+ - left_hand
119
+ - right_hand
120
+ - waist
121
+ - base_height_command
122
+ - navigate_command
123
+ sin_cos_embedding_keys: null
124
+ language: !!python/object:gr00t.data.types.ModalityConfig
125
+ action_configs: null
126
+ delta_indices:
127
+ - 0
128
+ mean_std_embedding_keys: null
129
+ modality_keys:
130
+ - annotation.human.task_description
131
+ sin_cos_embedding_keys: null
132
+ state: !!python/object:gr00t.data.types.ModalityConfig
133
+ action_configs: null
134
+ delta_indices:
135
+ - 0
136
+ mean_std_embedding_keys: null
137
+ modality_keys:
138
+ - left_arm
139
+ - right_arm
140
+ - left_hand
141
+ - right_hand
142
+ - waist
143
+ sin_cos_embedding_keys: null
144
+ video: !!python/object:gr00t.data.types.ModalityConfig
145
+ action_configs: null
146
+ delta_indices:
147
+ - 0
148
+ mean_std_embedding_keys: null
149
+ modality_keys:
150
+ - ego_view
151
+ sin_cos_embedding_keys: null
152
+ mode: single_turn
153
+ multiprocessing_context: fork
154
+ num_shards_per_epoch: 100000
155
+ override_pretraining_statistics: false
156
+ random_chop: 0.0
157
+ seed: 42
158
+ shard_size: 1024
159
+ shuffle: true
160
+ subsample_ratio: 1.0
161
+ video_backend: torchcodec
162
+ load_config_path: null
163
+ model: !!python/object:gr00t.configs.model.gr00t_n1d6.Gr00tN1d6Config
164
+ _attn_implementation_autoset: false
165
+ _attn_implementation_internal: null
166
+ _commit_hash: null
167
+ _name_or_path: ''
168
+ add_cross_attention: false
169
+ architectures: null
170
+ backbone_model_type: eagle
171
+ backbone_trainable_params_fp32: true
172
+ bad_words_ids: null
173
+ begin_suppress_tokens: null
174
+ bos_token_id: null
175
+ chunk_size_feed_forward: 0
176
+ color_jitter_params:
177
+ brightness: 0.3
178
+ contrast: 0.4
179
+ hue: 0.08
180
+ saturation: 0.5
181
+ cross_attention_hidden_size: null
182
+ decoder_start_token_id: null
183
+ diffusion_model_cfg:
184
+ attention_head_dim: 48
185
+ dropout: 0.2
186
+ final_dropout: true
187
+ interleave_self_attention: true
188
+ norm_type: ada_norm
189
+ num_attention_heads: 32
190
+ num_layers: 32
191
+ output_dim: 1024
192
+ positional_embeddings: null
193
+ diversity_penalty: 0.0
194
+ do_sample: false
195
+ eagle_collator: true
196
+ early_stopping: false
197
+ encoder_no_repeat_ngram_size: 0
198
+ eos_token_id: null
199
+ exponential_decay_length_penalty: null
200
+ finetuning_task: null
201
+ forced_bos_token_id: null
202
+ forced_eos_token_id: null
203
+ id2label:
204
+ 0: LABEL_0
205
+ 1: LABEL_1
206
+ is_decoder: false
207
+ is_encoder_decoder: false
208
+ label2id:
209
+ LABEL_0: 0
210
+ LABEL_1: 1
211
+ length_penalty: 1.0
212
+ load_bf16: false
213
+ max_length: 20
214
+ min_length: 0
215
+ model_name: nvidia/Eagle-Block2A-2B-v2
216
+ no_repeat_ngram_size: 0
217
+ num_beam_groups: 1
218
+ num_beams: 1
219
+ num_return_sequences: 1
220
+ output_attentions: false
221
+ output_hidden_states: false
222
+ output_scores: false
223
+ pad_token_id: null
224
+ prefix: null
225
+ problem_type: null
226
+ pruned_heads: {}
227
+ random_rotation_angle: null
228
+ remove_invalid_values: false
229
+ repetition_penalty: 1.0
230
+ reproject_vision: false
231
+ return_dict: true
232
+ return_dict_in_generate: false
233
+ sep_token_id: null
234
+ state_dropout_prob: 0.0
235
+ suppress_tokens: null
236
+ task_specific_params: null
237
+ temperature: 1.0
238
+ tf_legacy_loss: false
239
+ tie_encoder_decoder: false
240
+ tie_word_embeddings: true
241
+ tokenizer_class: null
242
+ top_k: 50
243
+ top_p: 1.0
244
+ torch_dtype: null
245
+ torchscript: false
246
+ transformers_version: null
247
+ tune_diffusion_model: true
248
+ tune_llm: false
249
+ tune_projector: true
250
+ tune_visual: true
251
+ typical_p: 1.0
252
+ use_bfloat16: false
253
+ use_relative_action: true
254
+ training: !!python/object:gr00t.configs.training.training_config.TrainingConfig
255
+ add_rl_callback: false
256
+ assert_loss_less_than: null
257
+ batch_size: null
258
+ bf16: true
259
+ dataloader_num_workers: 16
260
+ ddp_bucket_cap_mb: 100
261
+ deepspeed_stage: 2
262
+ enable_open_loop_eval: false
263
+ enable_profiling: false
264
+ eval_batch_size: 2
265
+ eval_bf16: true
266
+ eval_set_split_ratio: 0.1
267
+ eval_steps: 500
268
+ eval_strategy: 'no'
269
+ experiment_name: null
270
+ fp16: false
271
+ global_batch_size: 192
272
+ gradient_accumulation_steps: 1
273
+ gradient_checkpointing: false
274
+ learning_rate: 0.0001
275
+ logging_steps: 10
276
+ lr_scheduler_type: cosine
277
+ max_concurrent_uploads: 2
278
+ max_grad_norm: 1.0
279
+ max_retries: 3
280
+ max_steps: 20000
281
+ num_gpus: 8
282
+ open_loop_eval_plot_indices: null
283
+ open_loop_eval_steps_per_traj: 100
284
+ open_loop_eval_traj_ids:
285
+ - 0
286
+ optim: adamw_torch
287
+ output_dir: /models/isaaclab_arena/locomanipulation_tutorial
288
+ remove_unused_columns: false
289
+ save_best_eval_metric_greater_is_better: true
290
+ save_best_eval_metric_name: ''
291
+ save_steps: 5000
292
+ save_total_limit: 5
293
+ save_vl_model: false
294
+ start_from_checkpoint: nvidia/GR00T-N1.6-3B
295
+ tf32: true
296
+ transformers_access_token: null
297
+ transformers_cache_dir: null
298
+ transformers_local_files_only: false
299
+ transformers_trust_remote_code: true
300
+ upload_checkpoints: false
301
+ upload_every: 1000
302
+ upload_last_n_checkpoints: 5
303
+ use_ddp: false
304
+ use_wandb: false
305
+ wandb_project: finetune-gr00t-n1d6
306
+ warmup_ratio: 0.05
307
+ warmup_steps: 0
308
+ weight_decay: 1.0e-05
checkpoint-20000/experiment_cfg/dataset_statistics.json ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "new_embodiment": {
3
+ "state": {
4
+ "left_arm": {
5
+ "min": [
6
+ -1.2616037130355835,
7
+ -0.29025015234947205,
8
+ -0.22703997790813446,
9
+ -0.3353549540042877,
10
+ -0.0829518586397171,
11
+ -0.8195276260375977,
12
+ -0.2688920795917511
13
+ ],
14
+ "max": [
15
+ 0.15299034118652344,
16
+ 0.4194548726081848,
17
+ 0.304278701543808,
18
+ 1.4247486591339111,
19
+ 0.751840353012085,
20
+ 0.6736590266227722,
21
+ 0.569625973701477
22
+ ],
23
+ "mean": [
24
+ -0.6218094229698181,
25
+ -0.03578367084264755,
26
+ 0.05471671372652054,
27
+ 0.3273524045944214,
28
+ 0.16905353963375092,
29
+ 0.1931331604719162,
30
+ 0.0418560616672039
31
+ ],
32
+ "std": [
33
+ 0.2542016804218292,
34
+ 0.08585234731435776,
35
+ 0.05442973971366882,
36
+ 0.3563520908355713,
37
+ 0.10547080636024475,
38
+ 0.21155740320682526,
39
+ 0.0815652459859848
40
+ ],
41
+ "q01": [
42
+ -1.0867726147174834,
43
+ -0.23316791355609895,
44
+ -0.06077688504010439,
45
+ -0.2531130000948906,
46
+ -0.025190447550266983,
47
+ -0.41234332919120786,
48
+ -0.14684838354587554
49
+ ],
50
+ "q99": [
51
+ 0.02166599538177228,
52
+ 0.16592777222394936,
53
+ 0.19437864869832985,
54
+ 1.3526465594768522,
55
+ 0.47515065073966933,
56
+ 0.6158077389001846,
57
+ 0.267849366366863
58
+ ]
59
+ },
60
+ "right_arm": {
61
+ "min": [
62
+ -0.9889344573020935,
63
+ -0.7240632772445679,
64
+ -0.4150152802467346,
65
+ -0.2197991907596588,
66
+ -0.44296473264694214,
67
+ -0.9651272296905518,
68
+ -0.4595109820365906
69
+ ],
70
+ "max": [
71
+ 0.15951132774353027,
72
+ 0.21149154007434845,
73
+ 0.13221219182014465,
74
+ 1.4304473400115967,
75
+ 0.6581774950027466,
76
+ 0.33145904541015625,
77
+ 0.42284855246543884
78
+ ],
79
+ "mean": [
80
+ -0.5138179659843445,
81
+ -0.07899317145347595,
82
+ -0.1299561709165573,
83
+ 0.40922680497169495,
84
+ 0.027388907968997955,
85
+ -0.0835803970694542,
86
+ 0.024336807429790497
87
+ ],
88
+ "std": [
89
+ 0.1910795420408249,
90
+ 0.10697221755981445,
91
+ 0.0633271336555481,
92
+ 0.2594990134239197,
93
+ 0.14704135060310364,
94
+ 0.15591612458229065,
95
+ 0.06830708682537079
96
+ ],
97
+ "q01": [
98
+ -0.83366958796978,
99
+ -0.38898577094078063,
100
+ -0.27746869176626204,
101
+ -0.12615955173969268,
102
+ -0.2731088250875473,
103
+ -0.6371771156787872,
104
+ -0.16048517003655433
105
+ ],
106
+ "q99": [
107
+ 0.019438467640429113,
108
+ 0.13264653384685496,
109
+ 0.03749443646520371,
110
+ 1.3000927805900555,
111
+ 0.3483726784586904,
112
+ 0.12948824167251569,
113
+ 0.168773318082094
114
+ ]
115
+ },
116
+ "left_hand": {
117
+ "min": [
118
+ -0.008645662106573582,
119
+ -0.0016571161104366183,
120
+ -0.008173327893018723,
121
+ -0.0033370573073625565,
122
+ -0.049815986305475235,
123
+ -0.13737092912197113,
124
+ -8.590802735852776e-09
125
+ ],
126
+ "max": [
127
+ 8.85741064848844e-06,
128
+ 1.4383874713530531e-06,
129
+ 7.31344407540746e-05,
130
+ 4.420346158440225e-05,
131
+ 0.026730380952358246,
132
+ 0.06749135255813599,
133
+ 0.004176338668912649
134
+ ],
135
+ "mean": [
136
+ -0.00045161443995311856,
137
+ -9.045441402122378e-05,
138
+ -0.0008751734858378768,
139
+ -0.00010305152682121843,
140
+ -0.0026190115604549646,
141
+ -0.0007728625205345452,
142
+ 3.4298220271011814e-05
143
+ ],
144
+ "std": [
145
+ 0.0010219421237707138,
146
+ 0.00011942393030039966,
147
+ 0.0011946671875193715,
148
+ 0.00021070965158287436,
149
+ 0.004766007885336876,
150
+ 0.008314870297908783,
151
+ 0.00020773601136170328
152
+ ],
153
+ "q01": [
154
+ -0.004614621866494417,
155
+ -0.0005385997559642419,
156
+ -0.004787646210752427,
157
+ -0.0012936698796693236,
158
+ -0.01875622048974037,
159
+ -0.03178232274949551,
160
+ -2.9993839079089924e-10
161
+ ],
162
+ "q99": [
163
+ 1.4417540605826582e-09,
164
+ -5.172329953229189e-10,
165
+ -2.493637962786175e-10,
166
+ -6.717705641756689e-10,
167
+ 0.008347299136221403,
168
+ 0.012830186681821834,
169
+ 0.0014548563922289215
170
+ ]
171
+ },
172
+ "right_hand": {
173
+ "min": [
174
+ -1.5373115047623287e-07,
175
+ -2.7022052151437492e-08,
176
+ -2.0592709915945306e-05,
177
+ -7.066118541843025e-06,
178
+ -0.03601590916514397,
179
+ -0.5857902765274048,
180
+ -0.3214021623134613
181
+ ],
182
+ "max": [
183
+ 0.006290650460869074,
184
+ 0.001731343101710081,
185
+ 0.017454728484153748,
186
+ 0.012643150985240936,
187
+ 0.09934248775243759,
188
+ 0.0994623526930809,
189
+ 3.1769886277288606e-08
190
+ ],
191
+ "mean": [
192
+ 0.00025306272436864674,
193
+ 5.4000069212634116e-05,
194
+ 0.0003351480991113931,
195
+ 0.0008108046022243798,
196
+ 0.0006079890299588442,
197
+ -0.006738435477018356,
198
+ -0.00452095502987504
199
+ ],
200
+ "std": [
201
+ 0.0006930792587809265,
202
+ 0.00016116801998578012,
203
+ 0.0007848768145777285,
204
+ 0.0014818455092608929,
205
+ 0.009566166438162327,
206
+ 0.05241963639855385,
207
+ 0.030341269448399544
208
+ ],
209
+ "q01": [
210
+ -1.1203826366656955e-09,
211
+ 5.471793157463268e-10,
212
+ -7.516792688289087e-10,
213
+ 1.7157600895600922e-10,
214
+ -0.008333299728110432,
215
+ -0.3553843080997467,
216
+ -0.20837910920381547
217
+ ],
218
+ "q99": [
219
+ 0.0038171554915606976,
220
+ 0.0008218895673053339,
221
+ 0.003914117161184549,
222
+ 0.005107918474823237,
223
+ 0.061319448240101194,
224
+ 0.009818258183076798,
225
+ 3.1323699190011206e-10
226
+ ]
227
+ },
228
+ "waist": {
229
+ "min": [
230
+ -0.04632357507944107,
231
+ -0.11110502481460571,
232
+ -0.036814406514167786
233
+ ],
234
+ "max": [
235
+ 0.0633544921875,
236
+ 0.11162503063678741,
237
+ 0.1282370686531067
238
+ ],
239
+ "mean": [
240
+ 0.002279821317642927,
241
+ -0.0016866918886080384,
242
+ 0.05629865825176239
243
+ ],
244
+ "std": [
245
+ 0.019741930067539215,
246
+ 0.04374425858259201,
247
+ 0.023172633722424507
248
+ ],
249
+ "q01": [
250
+ -0.039197818748652934,
251
+ -0.09254500381648541,
252
+ -0.020507800113409757
253
+ ],
254
+ "q99": [
255
+ 0.054476964659988844,
256
+ 0.09499521441757679,
257
+ 0.10415777899324889
258
+ ]
259
+ }
260
+ },
261
+ "action": {
262
+ "left_arm": {
263
+ "min": [
264
+ -1.348067283630371,
265
+ -0.3527751564979553,
266
+ -0.3787360191345215,
267
+ -0.625663697719574,
268
+ -0.09716995060443878,
269
+ -0.9718959331512451,
270
+ -0.41488397121429443
271
+ ],
272
+ "max": [
273
+ 0.1336316466331482,
274
+ 0.4716266393661499,
275
+ 0.30831149220466614,
276
+ 1.4016180038452148,
277
+ 0.9397326111793518,
278
+ 0.6476842761039734,
279
+ 0.8313083648681641
280
+ ],
281
+ "mean": [
282
+ -0.6952570080757141,
283
+ -0.0709061548113823,
284
+ -0.04288463667035103,
285
+ 0.2694568634033203,
286
+ 0.1649714857339859,
287
+ 0.13536368310451508,
288
+ -0.02554020844399929
289
+ ],
290
+ "std": [
291
+ 0.26363858580589294,
292
+ 0.10477105528116226,
293
+ 0.07000378519296646,
294
+ 0.3648890554904938,
295
+ 0.11654239892959595,
296
+ 0.2099701166152954,
297
+ 0.08394794911146164
298
+ ],
299
+ "q01": [
300
+ -1.1805148243904113,
301
+ -0.308816134929657,
302
+ -0.17785422429442405,
303
+ -0.3138654500246048,
304
+ -0.05110809002071619,
305
+ -0.4920081451535225,
306
+ -0.1742709159851074
307
+ ],
308
+ "q99": [
309
+ -0.008620778424665838,
310
+ 0.20248875990509888,
311
+ 0.17697372585535032,
312
+ 1.284248530864715,
313
+ 0.522044214606285,
314
+ 0.5478375405073164,
315
+ 0.24634651243686412
316
+ ]
317
+ },
318
+ "right_arm": {
319
+ "min": [
320
+ -1.0777442455291748,
321
+ -0.7950155735015869,
322
+ -0.4215357005596161,
323
+ -0.33741918206214905,
324
+ -0.5877293348312378,
325
+ -1.0788743495941162,
326
+ -0.573306679725647
327
+ ],
328
+ "max": [
329
+ 0.14458219707012177,
330
+ 0.31825390458106995,
331
+ 0.3697803318500519,
332
+ 1.4193015098571777,
333
+ 0.6486993432044983,
334
+ 0.28742435574531555,
335
+ 0.49852707982063293
336
+ ],
337
+ "mean": [
338
+ -0.604250967502594,
339
+ -0.0556945763528347,
340
+ -0.03765946254134178,
341
+ 0.30660828948020935,
342
+ 0.01742653176188469,
343
+ -0.16916987299919128,
344
+ 0.09518744796514511
345
+ ],
346
+ "std": [
347
+ 0.20923613011837006,
348
+ 0.12663093209266663,
349
+ 0.08735905587673187,
350
+ 0.2593192756175995,
351
+ 0.15945474803447723,
352
+ 0.16604292392730713,
353
+ 0.07976584881544113
354
+ ],
355
+ "q01": [
356
+ -0.9175809919834137,
357
+ -0.5007677406072617,
358
+ -0.21304122656583785,
359
+ -0.21431435346603395,
360
+ -0.2938103020191193,
361
+ -0.7407654404640198,
362
+ -0.1693093843758106
363
+ ],
364
+ "q99": [
365
+ -0.011969150230289034,
366
+ 0.1981081753969192,
367
+ 0.14730184450745581,
368
+ 1.2670192122459407,
369
+ 0.3571772933006279,
370
+ 0.07727374359965306,
371
+ 0.24925321042537663
372
+ ]
373
+ },
374
+ "left_hand": {
375
+ "min": [
376
+ 0.0,
377
+ 0.0,
378
+ 0.0,
379
+ 0.0,
380
+ 0.0,
381
+ 0.0,
382
+ 0.0
383
+ ],
384
+ "max": [
385
+ 0.0,
386
+ 0.0,
387
+ 0.0,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.0
392
+ ],
393
+ "mean": [
394
+ 0.0,
395
+ 0.0,
396
+ 0.0,
397
+ 0.0,
398
+ 0.0,
399
+ 0.0,
400
+ 0.0
401
+ ],
402
+ "std": [
403
+ 0.0,
404
+ 0.0,
405
+ 0.0,
406
+ 0.0,
407
+ 0.0,
408
+ 0.0,
409
+ 0.0
410
+ ],
411
+ "q01": [
412
+ 0.0,
413
+ 0.0,
414
+ 0.0,
415
+ 0.0,
416
+ 0.0,
417
+ 0.0,
418
+ 0.0
419
+ ],
420
+ "q99": [
421
+ 0.0,
422
+ 0.0,
423
+ 0.0,
424
+ 0.0,
425
+ 0.0,
426
+ 0.0,
427
+ 0.0
428
+ ]
429
+ },
430
+ "right_hand": {
431
+ "min": [
432
+ -0.0,
433
+ -0.0,
434
+ -0.0,
435
+ -0.0,
436
+ -0.0,
437
+ -0.0,
438
+ -0.0
439
+ ],
440
+ "max": [
441
+ -0.0,
442
+ -0.0,
443
+ -0.0,
444
+ -0.0,
445
+ -0.0,
446
+ -0.0,
447
+ -0.0
448
+ ],
449
+ "mean": [
450
+ 0.0,
451
+ 0.0,
452
+ 0.0,
453
+ 0.0,
454
+ 0.0,
455
+ 0.0,
456
+ 0.0
457
+ ],
458
+ "std": [
459
+ 0.0,
460
+ 0.0,
461
+ 0.0,
462
+ 0.0,
463
+ 0.0,
464
+ 0.0,
465
+ 0.0
466
+ ],
467
+ "q01": [
468
+ 0.0,
469
+ 0.0,
470
+ 0.0,
471
+ 0.0,
472
+ 0.0,
473
+ 0.0,
474
+ 0.0
475
+ ],
476
+ "q99": [
477
+ -0.0,
478
+ -0.0,
479
+ -0.0,
480
+ -0.0,
481
+ -0.0,
482
+ -0.0,
483
+ -0.0
484
+ ]
485
+ },
486
+ "waist": {
487
+ "min": [
488
+ -0.03817012533545494,
489
+ -0.14767035841941833,
490
+ -0.09924878180027008
491
+ ],
492
+ "max": [
493
+ 0.05044477432966232,
494
+ 0.13773855566978455,
495
+ 0.10575182735919952
496
+ ],
497
+ "mean": [
498
+ 0.0021713885944336653,
499
+ -0.006043997593224049,
500
+ -0.0009960572933778167
501
+ ],
502
+ "std": [
503
+ 0.01315564289689064,
504
+ 0.04625461995601654,
505
+ 0.0275924950838089
506
+ ],
507
+ "q01": [
508
+ -0.02857382604852319,
509
+ -0.1123543307185173,
510
+ -0.09090777784585953
511
+ ],
512
+ "q99": [
513
+ 0.04313158672302961,
514
+ 0.1042894288897514,
515
+ 0.06339201703667638
516
+ ]
517
+ },
518
+ "base_height_command": {
519
+ "min": [
520
+ 0.6000000238418579
521
+ ],
522
+ "max": [
523
+ 0.75
524
+ ],
525
+ "mean": [
526
+ 0.7374278903007507
527
+ ],
528
+ "std": [
529
+ 0.039233911782502955
530
+ ],
531
+ "q01": [
532
+ 0.6000000238418579
533
+ ],
534
+ "q99": [
535
+ 0.75
536
+ ]
537
+ },
538
+ "navigate_command": {
539
+ "min": [
540
+ 0.0,
541
+ -0.12772086262702942,
542
+ -0.4000000059604645
543
+ ],
544
+ "max": [
545
+ 0.4000000059604645,
546
+ 0.15753206610679626,
547
+ 0.10000000149011612
548
+ ],
549
+ "mean": [
550
+ 0.10862857103347778,
551
+ 0.006709238979965448,
552
+ -0.08270397037267685
553
+ ],
554
+ "std": [
555
+ 0.17079046368598938,
556
+ 0.035745956003665924,
557
+ 0.1377689093351364
558
+ ],
559
+ "q01": [
560
+ 0.0,
561
+ -0.06209215875715017,
562
+ -0.4000000059604645
563
+ ],
564
+ "q99": [
565
+ 0.4000000059604645,
566
+ 0.10000000149011612,
567
+ 0.004937881324440136
568
+ ]
569
+ }
570
+ },
571
+ "relative_action": {}
572
+ }
573
+ }
checkpoint-20000/experiment_cfg/final_model_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "Gr00tN1d6",
3
+ "model_dtype": "bfloat16",
4
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
5
+ "backbone_model_type": "eagle",
6
+ "model_revision": null,
7
+ "tune_top_llm_layers": 4,
8
+ "backbone_embedding_dim": 2048,
9
+ "tune_llm": false,
10
+ "tune_visual": true,
11
+ "select_layer": 16,
12
+ "reproject_vision": false,
13
+ "use_flash_attention": true,
14
+ "load_bf16": true,
15
+ "collator_overwrite_image_inputs": false,
16
+ "eagle_collator": true,
17
+ "backbone_trainable_params_fp32": true,
18
+ "apply_sincos_state_encoding": true,
19
+ "use_relative_action": true,
20
+ "max_state_dim": 128,
21
+ "max_action_dim": 128,
22
+ "action_horizon": 50,
23
+ "hidden_size": 1024,
24
+ "input_embedding_dim": 1536,
25
+ "add_pos_embed": true,
26
+ "attn_dropout": 0.2,
27
+ "use_vlln": true,
28
+ "max_seq_len": 1024,
29
+ "use_alternate_vl_dit": true,
30
+ "attend_text_every_n_blocks": 2,
31
+ "diffusion_model_cfg": {
32
+ "attention_head_dim": 48,
33
+ "dropout": 0.2,
34
+ "final_dropout": true,
35
+ "interleave_self_attention": true,
36
+ "norm_type": "ada_norm",
37
+ "num_attention_heads": 32,
38
+ "num_layers": 32,
39
+ "output_dim": 1024,
40
+ "positional_embeddings": null
41
+ },
42
+ "num_inference_timesteps": 4,
43
+ "noise_beta_alpha": 1.5,
44
+ "noise_beta_beta": 1.0,
45
+ "noise_s": 0.999,
46
+ "num_timestep_buckets": 1000,
47
+ "tune_projector": true,
48
+ "tune_diffusion_model": true,
49
+ "tune_vlln": true,
50
+ "state_dropout_prob": 0.0,
51
+ "state_additive_noise_scale": 0.0,
52
+ "max_num_embodiments": 32
53
+ }
checkpoint-20000/experiment_cfg/final_processor_config.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-20000/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step20000
checkpoint-20000/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-20000/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
checkpoint-5000/config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_horizon": 50,
3
+ "add_pos_embed": true,
4
+ "apply_sincos_state_encoding": true,
5
+ "architectures": [
6
+ "Gr00tN1d6"
7
+ ],
8
+ "attn_dropout": 0.2,
9
+ "attn_implementation": null,
10
+ "backbone_embedding_dim": 2048,
11
+ "backbone_model_type": "eagle",
12
+ "backbone_trainable_params_fp32": true,
13
+ "collator_overwrite_image_inputs": false,
14
+ "color_jitter_params": {
15
+ "brightness": 0.1,
16
+ "contrast": 0.1,
17
+ "hue": 0.1,
18
+ "saturation": 0.1
19
+ },
20
+ "crop_fraction": 0.95,
21
+ "diffusion_model_cfg": {
22
+ "attention_head_dim": 48,
23
+ "dropout": 0.2,
24
+ "final_dropout": true,
25
+ "interleave_self_attention": true,
26
+ "norm_type": "ada_norm",
27
+ "num_attention_heads": 32,
28
+ "num_layers": 32,
29
+ "output_dim": 1024,
30
+ "positional_embeddings": null
31
+ },
32
+ "eagle_collator": true,
33
+ "formalize_language": true,
34
+ "gemma_collator": false,
35
+ "hidden_size": 1024,
36
+ "image_crop_size": null,
37
+ "image_target_size": null,
38
+ "input_embedding_dim": 1536,
39
+ "load_bf16": true,
40
+ "max_action_dim": 128,
41
+ "max_num_embodiments": 32,
42
+ "max_seq_len": 1024,
43
+ "max_state_dim": 128,
44
+ "model_dtype": "bfloat16",
45
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
46
+ "model_type": "Gr00tN1d6",
47
+ "noise_beta_alpha": 1.5,
48
+ "noise_beta_beta": 1.0,
49
+ "noise_s": 0.999,
50
+ "num_inference_timesteps": 4,
51
+ "num_timestep_buckets": 1000,
52
+ "random_rotation_angle": null,
53
+ "reproject_vision": false,
54
+ "select_layer": 16,
55
+ "shortest_image_edge": 256,
56
+ "state_dropout_prob": 0.0,
57
+ "torch_dtype": "bfloat16",
58
+ "transformers_version": "4.51.3",
59
+ "tune_diffusion_model": true,
60
+ "tune_llm": false,
61
+ "tune_projector": true,
62
+ "tune_top_llm_layers": 4,
63
+ "tune_visual": true,
64
+ "tune_vlln": true,
65
+ "use_albumentations_transforms": true,
66
+ "use_alternate_vl_dit": true,
67
+ "use_flash_attention": true,
68
+ "use_relative_action": true,
69
+ "use_vlln": true
70
+ }
checkpoint-5000/embodiment_id.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "robocasa_panda_omron": 13,
3
+ "gr1": 20,
4
+ "behavior_r1_pro": 24,
5
+ "unitree_g1": 8,
6
+ "oxe_google": 0,
7
+ "oxe_widowx": 1,
8
+ "libero_panda": 2,
9
+ "oxe_droid": 16,
10
+ "new_embodiment": 10
11
+ }
checkpoint-5000/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step5000
checkpoint-5000/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-5000/processor_config.json ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "Gr00tN1d6Processor",
3
+ "processor_kwargs": {
4
+ "modality_configs": {
5
+ "behavior_r1_pro": {
6
+ "video": {
7
+ "delta_indices": [
8
+ 0
9
+ ],
10
+ "modality_keys": [
11
+ "observation.images.rgb.head_256_256",
12
+ "observation.images.rgb.left_wrist_256_256",
13
+ "observation.images.rgb.right_wrist_256_256"
14
+ ],
15
+ "sin_cos_embedding_keys": null,
16
+ "mean_std_embedding_keys": null,
17
+ "action_configs": null
18
+ },
19
+ "state": {
20
+ "delta_indices": [
21
+ 0
22
+ ],
23
+ "modality_keys": [
24
+ "robot_pos",
25
+ "robot_ori_cos",
26
+ "robot_ori_sin",
27
+ "robot_2d_ori",
28
+ "robot_2d_ori_cos",
29
+ "robot_2d_ori_sin",
30
+ "robot_lin_vel",
31
+ "robot_ang_vel",
32
+ "arm_left_qpos",
33
+ "arm_left_qpos_sin",
34
+ "arm_left_qpos_cos",
35
+ "eef_left_pos",
36
+ "eef_left_quat",
37
+ "gripper_left_qpos",
38
+ "arm_right_qpos",
39
+ "arm_right_qpos_sin",
40
+ "arm_right_qpos_cos",
41
+ "eef_right_pos",
42
+ "eef_right_quat",
43
+ "gripper_right_qpos",
44
+ "trunk_qpos"
45
+ ],
46
+ "sin_cos_embedding_keys": null,
47
+ "mean_std_embedding_keys": null,
48
+ "action_configs": null
49
+ },
50
+ "action": {
51
+ "delta_indices": [
52
+ 0,
53
+ 1,
54
+ 2,
55
+ 3,
56
+ 4,
57
+ 5,
58
+ 6,
59
+ 7,
60
+ 8,
61
+ 9,
62
+ 10,
63
+ 11,
64
+ 12,
65
+ 13,
66
+ 14,
67
+ 15,
68
+ 16,
69
+ 17,
70
+ 18,
71
+ 19,
72
+ 20,
73
+ 21,
74
+ 22,
75
+ 23,
76
+ 24,
77
+ 25,
78
+ 26,
79
+ 27,
80
+ 28,
81
+ 29,
82
+ 30,
83
+ 31
84
+ ],
85
+ "modality_keys": [
86
+ "base",
87
+ "torso",
88
+ "left_arm",
89
+ "left_gripper",
90
+ "right_arm",
91
+ "right_gripper"
92
+ ],
93
+ "sin_cos_embedding_keys": null,
94
+ "mean_std_embedding_keys": null,
95
+ "action_configs": [
96
+ {
97
+ "rep": "ABSOLUTE",
98
+ "type": "NON_EEF",
99
+ "format": "DEFAULT",
100
+ "state_key": null
101
+ },
102
+ {
103
+ "rep": "RELATIVE",
104
+ "type": "NON_EEF",
105
+ "format": "DEFAULT",
106
+ "state_key": "trunk_qpos"
107
+ },
108
+ {
109
+ "rep": "RELATIVE",
110
+ "type": "NON_EEF",
111
+ "format": "DEFAULT",
112
+ "state_key": "arm_left_qpos"
113
+ },
114
+ {
115
+ "rep": "ABSOLUTE",
116
+ "type": "NON_EEF",
117
+ "format": "DEFAULT",
118
+ "state_key": null
119
+ },
120
+ {
121
+ "rep": "RELATIVE",
122
+ "type": "NON_EEF",
123
+ "format": "DEFAULT",
124
+ "state_key": "arm_right_qpos"
125
+ },
126
+ {
127
+ "rep": "ABSOLUTE",
128
+ "type": "NON_EEF",
129
+ "format": "DEFAULT",
130
+ "state_key": null
131
+ }
132
+ ]
133
+ },
134
+ "language": {
135
+ "delta_indices": [
136
+ 0
137
+ ],
138
+ "modality_keys": [
139
+ "annotation.human.coarse_action"
140
+ ],
141
+ "sin_cos_embedding_keys": null,
142
+ "mean_std_embedding_keys": null,
143
+ "action_configs": null
144
+ }
145
+ },
146
+ "gr1": {
147
+ "video": {
148
+ "delta_indices": [
149
+ 0
150
+ ],
151
+ "modality_keys": [
152
+ "ego_view_bg_crop_pad_res256_freq20"
153
+ ],
154
+ "sin_cos_embedding_keys": null,
155
+ "mean_std_embedding_keys": null,
156
+ "action_configs": null
157
+ },
158
+ "state": {
159
+ "delta_indices": [
160
+ 0
161
+ ],
162
+ "modality_keys": [
163
+ "left_arm",
164
+ "right_arm",
165
+ "left_hand",
166
+ "right_hand",
167
+ "waist"
168
+ ],
169
+ "sin_cos_embedding_keys": [
170
+ "left_arm",
171
+ "right_arm",
172
+ "left_hand",
173
+ "right_hand",
174
+ "waist"
175
+ ],
176
+ "mean_std_embedding_keys": null,
177
+ "action_configs": null
178
+ },
179
+ "action": {
180
+ "delta_indices": [
181
+ 0,
182
+ 1,
183
+ 2,
184
+ 3,
185
+ 4,
186
+ 5,
187
+ 6,
188
+ 7,
189
+ 8,
190
+ 9,
191
+ 10,
192
+ 11,
193
+ 12,
194
+ 13,
195
+ 14,
196
+ 15
197
+ ],
198
+ "modality_keys": [
199
+ "left_arm",
200
+ "right_arm",
201
+ "left_hand",
202
+ "right_hand",
203
+ "waist"
204
+ ],
205
+ "sin_cos_embedding_keys": null,
206
+ "mean_std_embedding_keys": null,
207
+ "action_configs": [
208
+ {
209
+ "rep": "RELATIVE",
210
+ "type": "NON_EEF",
211
+ "format": "DEFAULT",
212
+ "state_key": null
213
+ },
214
+ {
215
+ "rep": "RELATIVE",
216
+ "type": "NON_EEF",
217
+ "format": "DEFAULT",
218
+ "state_key": null
219
+ },
220
+ {
221
+ "rep": "RELATIVE",
222
+ "type": "NON_EEF",
223
+ "format": "DEFAULT",
224
+ "state_key": null
225
+ },
226
+ {
227
+ "rep": "RELATIVE",
228
+ "type": "NON_EEF",
229
+ "format": "DEFAULT",
230
+ "state_key": null
231
+ },
232
+ {
233
+ "rep": "ABSOLUTE",
234
+ "type": "NON_EEF",
235
+ "format": "DEFAULT",
236
+ "state_key": null
237
+ }
238
+ ]
239
+ },
240
+ "language": {
241
+ "delta_indices": [
242
+ 0
243
+ ],
244
+ "modality_keys": [
245
+ "task"
246
+ ],
247
+ "sin_cos_embedding_keys": null,
248
+ "mean_std_embedding_keys": null,
249
+ "action_configs": null
250
+ }
251
+ },
252
+ "robocasa_panda_omron": {
253
+ "video": {
254
+ "delta_indices": [
255
+ 0
256
+ ],
257
+ "modality_keys": [
258
+ "res256_image_side_0",
259
+ "res256_image_side_1",
260
+ "res256_image_wrist_0"
261
+ ],
262
+ "sin_cos_embedding_keys": null,
263
+ "mean_std_embedding_keys": null,
264
+ "action_configs": null
265
+ },
266
+ "state": {
267
+ "delta_indices": [
268
+ 0
269
+ ],
270
+ "modality_keys": [
271
+ "end_effector_position_relative",
272
+ "end_effector_rotation_relative",
273
+ "gripper_qpos",
274
+ "base_position",
275
+ "base_rotation"
276
+ ],
277
+ "sin_cos_embedding_keys": null,
278
+ "mean_std_embedding_keys": null,
279
+ "action_configs": null
280
+ },
281
+ "action": {
282
+ "delta_indices": [
283
+ 0,
284
+ 1,
285
+ 2,
286
+ 3,
287
+ 4,
288
+ 5,
289
+ 6,
290
+ 7,
291
+ 8,
292
+ 9,
293
+ 10,
294
+ 11,
295
+ 12,
296
+ 13,
297
+ 14,
298
+ 15
299
+ ],
300
+ "modality_keys": [
301
+ "end_effector_position",
302
+ "end_effector_rotation",
303
+ "gripper_close",
304
+ "base_motion",
305
+ "control_mode"
306
+ ],
307
+ "sin_cos_embedding_keys": null,
308
+ "mean_std_embedding_keys": null,
309
+ "action_configs": [
310
+ {
311
+ "rep": "ABSOLUTE",
312
+ "type": "NON_EEF",
313
+ "format": "DEFAULT",
314
+ "state_key": null
315
+ },
316
+ {
317
+ "rep": "ABSOLUTE",
318
+ "type": "NON_EEF",
319
+ "format": "DEFAULT",
320
+ "state_key": null
321
+ },
322
+ {
323
+ "rep": "ABSOLUTE",
324
+ "type": "NON_EEF",
325
+ "format": "DEFAULT",
326
+ "state_key": null
327
+ },
328
+ {
329
+ "rep": "ABSOLUTE",
330
+ "type": "NON_EEF",
331
+ "format": "DEFAULT",
332
+ "state_key": null
333
+ },
334
+ {
335
+ "rep": "ABSOLUTE",
336
+ "type": "NON_EEF",
337
+ "format": "DEFAULT",
338
+ "state_key": null
339
+ }
340
+ ]
341
+ },
342
+ "language": {
343
+ "delta_indices": [
344
+ 0
345
+ ],
346
+ "modality_keys": [
347
+ "annotation.human.action.task_description"
348
+ ],
349
+ "sin_cos_embedding_keys": null,
350
+ "mean_std_embedding_keys": null,
351
+ "action_configs": null
352
+ }
353
+ },
354
+ "new_embodiment": {
355
+ "video": {
356
+ "delta_indices": [
357
+ 0
358
+ ],
359
+ "modality_keys": [
360
+ "ego_view"
361
+ ],
362
+ "sin_cos_embedding_keys": null,
363
+ "mean_std_embedding_keys": null,
364
+ "action_configs": null
365
+ },
366
+ "state": {
367
+ "delta_indices": [
368
+ 0
369
+ ],
370
+ "modality_keys": [
371
+ "left_arm",
372
+ "right_arm",
373
+ "left_hand",
374
+ "right_hand",
375
+ "waist"
376
+ ],
377
+ "sin_cos_embedding_keys": null,
378
+ "mean_std_embedding_keys": null,
379
+ "action_configs": null
380
+ },
381
+ "action": {
382
+ "delta_indices": [
383
+ 0,
384
+ 1,
385
+ 2,
386
+ 3,
387
+ 4,
388
+ 5,
389
+ 6,
390
+ 7,
391
+ 8,
392
+ 9,
393
+ 10,
394
+ 11,
395
+ 12,
396
+ 13,
397
+ 14,
398
+ 15,
399
+ 16,
400
+ 17,
401
+ 18,
402
+ 19,
403
+ 20,
404
+ 21,
405
+ 22,
406
+ 23,
407
+ 24,
408
+ 25,
409
+ 26,
410
+ 27,
411
+ 28,
412
+ 29,
413
+ 30,
414
+ 31,
415
+ 32,
416
+ 33,
417
+ 34,
418
+ 35,
419
+ 36,
420
+ 37,
421
+ 38,
422
+ 39,
423
+ 40,
424
+ 41,
425
+ 42,
426
+ 43,
427
+ 44,
428
+ 45,
429
+ 46,
430
+ 47,
431
+ 48,
432
+ 49
433
+ ],
434
+ "modality_keys": [
435
+ "left_arm",
436
+ "right_arm",
437
+ "left_hand",
438
+ "right_hand",
439
+ "waist",
440
+ "base_height_command",
441
+ "navigate_command"
442
+ ],
443
+ "sin_cos_embedding_keys": null,
444
+ "mean_std_embedding_keys": null,
445
+ "action_configs": [
446
+ {
447
+ "rep": "ABSOLUTE",
448
+ "type": "NON_EEF",
449
+ "format": "DEFAULT",
450
+ "state_key": null
451
+ },
452
+ {
453
+ "rep": "ABSOLUTE",
454
+ "type": "NON_EEF",
455
+ "format": "DEFAULT",
456
+ "state_key": null
457
+ },
458
+ {
459
+ "rep": "ABSOLUTE",
460
+ "type": "NON_EEF",
461
+ "format": "DEFAULT",
462
+ "state_key": null
463
+ },
464
+ {
465
+ "rep": "ABSOLUTE",
466
+ "type": "NON_EEF",
467
+ "format": "DEFAULT",
468
+ "state_key": null
469
+ },
470
+ {
471
+ "rep": "ABSOLUTE",
472
+ "type": "NON_EEF",
473
+ "format": "DEFAULT",
474
+ "state_key": null
475
+ },
476
+ {
477
+ "rep": "ABSOLUTE",
478
+ "type": "NON_EEF",
479
+ "format": "DEFAULT",
480
+ "state_key": null
481
+ },
482
+ {
483
+ "rep": "ABSOLUTE",
484
+ "type": "NON_EEF",
485
+ "format": "DEFAULT",
486
+ "state_key": null
487
+ }
488
+ ]
489
+ },
490
+ "language": {
491
+ "delta_indices": [
492
+ 0
493
+ ],
494
+ "modality_keys": [
495
+ "annotation.human.task_description"
496
+ ],
497
+ "sin_cos_embedding_keys": null,
498
+ "mean_std_embedding_keys": null,
499
+ "action_configs": null
500
+ }
501
+ }
502
+ },
503
+ "image_crop_size": null,
504
+ "image_target_size": null,
505
+ "use_albumentations": true,
506
+ "random_rotation_angle": null,
507
+ "color_jitter_params": {
508
+ "brightness": 0.3,
509
+ "contrast": 0.4,
510
+ "saturation": 0.5,
511
+ "hue": 0.08
512
+ },
513
+ "shortest_image_edge": 256,
514
+ "crop_fraction": 0.95,
515
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
516
+ "model_type": "eagle",
517
+ "formalize_language": true,
518
+ "max_state_dim": 128,
519
+ "max_action_dim": 128,
520
+ "max_action_horizon": 50,
521
+ "use_percentiles": false,
522
+ "clip_outliers": true,
523
+ "apply_sincos_state_encoding": true,
524
+ "use_relative_action": true
525
+ }
526
+ }
checkpoint-5000/statistics.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-5000/trainer_state.json ADDED
@@ -0,0 +1,3034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.25,
6
+ "eval_steps": 500,
7
+ "global_step": 5000,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "grad_norm": 1.3594739437103271,
14
+ "learning_rate": 9e-07,
15
+ "loss": 1.1913,
16
+ "step": 10
17
+ },
18
+ {
19
+ "grad_norm": 1.0572824478149414,
20
+ "learning_rate": 1.9e-06,
21
+ "loss": 1.1841,
22
+ "step": 20
23
+ },
24
+ {
25
+ "grad_norm": 0.5717663764953613,
26
+ "learning_rate": 2.9e-06,
27
+ "loss": 1.1508,
28
+ "step": 30
29
+ },
30
+ {
31
+ "grad_norm": 0.3898443877696991,
32
+ "learning_rate": 3.9e-06,
33
+ "loss": 1.1205,
34
+ "step": 40
35
+ },
36
+ {
37
+ "grad_norm": 0.28664326667785645,
38
+ "learning_rate": 4.9000000000000005e-06,
39
+ "loss": 1.0888,
40
+ "step": 50
41
+ },
42
+ {
43
+ "grad_norm": 0.1729290783405304,
44
+ "learning_rate": 5.9e-06,
45
+ "loss": 1.0782,
46
+ "step": 60
47
+ },
48
+ {
49
+ "grad_norm": 0.17002208530902863,
50
+ "learning_rate": 6.900000000000001e-06,
51
+ "loss": 1.0691,
52
+ "step": 70
53
+ },
54
+ {
55
+ "grad_norm": 0.2152942717075348,
56
+ "learning_rate": 7.9e-06,
57
+ "loss": 1.0562,
58
+ "step": 80
59
+ },
60
+ {
61
+ "grad_norm": 0.19103780388832092,
62
+ "learning_rate": 8.9e-06,
63
+ "loss": 1.0479,
64
+ "step": 90
65
+ },
66
+ {
67
+ "grad_norm": 0.3243984878063202,
68
+ "learning_rate": 9.900000000000002e-06,
69
+ "loss": 1.0372,
70
+ "step": 100
71
+ },
72
+ {
73
+ "grad_norm": 0.1820673942565918,
74
+ "learning_rate": 1.09e-05,
75
+ "loss": 1.0272,
76
+ "step": 110
77
+ },
78
+ {
79
+ "grad_norm": 0.21819084882736206,
80
+ "learning_rate": 1.19e-05,
81
+ "loss": 1.0236,
82
+ "step": 120
83
+ },
84
+ {
85
+ "grad_norm": 0.20377595722675323,
86
+ "learning_rate": 1.29e-05,
87
+ "loss": 1.0237,
88
+ "step": 130
89
+ },
90
+ {
91
+ "grad_norm": 0.20572194457054138,
92
+ "learning_rate": 1.3900000000000002e-05,
93
+ "loss": 1.0228,
94
+ "step": 140
95
+ },
96
+ {
97
+ "grad_norm": 0.20157840847969055,
98
+ "learning_rate": 1.49e-05,
99
+ "loss": 1.0217,
100
+ "step": 150
101
+ },
102
+ {
103
+ "grad_norm": 0.23459017276763916,
104
+ "learning_rate": 1.59e-05,
105
+ "loss": 1.0192,
106
+ "step": 160
107
+ },
108
+ {
109
+ "grad_norm": 0.32469043135643005,
110
+ "learning_rate": 1.69e-05,
111
+ "loss": 1.0063,
112
+ "step": 170
113
+ },
114
+ {
115
+ "grad_norm": 0.36008527874946594,
116
+ "learning_rate": 1.79e-05,
117
+ "loss": 0.9873,
118
+ "step": 180
119
+ },
120
+ {
121
+ "grad_norm": 0.5633573532104492,
122
+ "learning_rate": 1.8900000000000002e-05,
123
+ "loss": 0.9672,
124
+ "step": 190
125
+ },
126
+ {
127
+ "grad_norm": 0.7019369006156921,
128
+ "learning_rate": 1.9900000000000003e-05,
129
+ "loss": 0.9315,
130
+ "step": 200
131
+ },
132
+ {
133
+ "grad_norm": 0.5538105964660645,
134
+ "learning_rate": 2.09e-05,
135
+ "loss": 0.8958,
136
+ "step": 210
137
+ },
138
+ {
139
+ "grad_norm": 0.5306029319763184,
140
+ "learning_rate": 2.19e-05,
141
+ "loss": 0.8707,
142
+ "step": 220
143
+ },
144
+ {
145
+ "grad_norm": 0.6606974005699158,
146
+ "learning_rate": 2.29e-05,
147
+ "loss": 0.8479,
148
+ "step": 230
149
+ },
150
+ {
151
+ "grad_norm": 0.8058410882949829,
152
+ "learning_rate": 2.39e-05,
153
+ "loss": 0.8169,
154
+ "step": 240
155
+ },
156
+ {
157
+ "grad_norm": 0.7277475595474243,
158
+ "learning_rate": 2.4900000000000002e-05,
159
+ "loss": 0.77,
160
+ "step": 250
161
+ },
162
+ {
163
+ "grad_norm": 0.6617355942726135,
164
+ "learning_rate": 2.5900000000000003e-05,
165
+ "loss": 0.7456,
166
+ "step": 260
167
+ },
168
+ {
169
+ "grad_norm": 0.8156651258468628,
170
+ "learning_rate": 2.6900000000000003e-05,
171
+ "loss": 0.6984,
172
+ "step": 270
173
+ },
174
+ {
175
+ "grad_norm": 0.7090954780578613,
176
+ "learning_rate": 2.7900000000000004e-05,
177
+ "loss": 0.6774,
178
+ "step": 280
179
+ },
180
+ {
181
+ "grad_norm": 0.8667084574699402,
182
+ "learning_rate": 2.8899999999999998e-05,
183
+ "loss": 0.6429,
184
+ "step": 290
185
+ },
186
+ {
187
+ "grad_norm": 0.946596622467041,
188
+ "learning_rate": 2.9900000000000002e-05,
189
+ "loss": 0.6052,
190
+ "step": 300
191
+ },
192
+ {
193
+ "grad_norm": 0.8120863437652588,
194
+ "learning_rate": 3.09e-05,
195
+ "loss": 0.5681,
196
+ "step": 310
197
+ },
198
+ {
199
+ "grad_norm": 0.9630921483039856,
200
+ "learning_rate": 3.19e-05,
201
+ "loss": 0.5267,
202
+ "step": 320
203
+ },
204
+ {
205
+ "grad_norm": 0.9185823798179626,
206
+ "learning_rate": 3.29e-05,
207
+ "loss": 0.497,
208
+ "step": 330
209
+ },
210
+ {
211
+ "grad_norm": 0.9909350872039795,
212
+ "learning_rate": 3.3900000000000004e-05,
213
+ "loss": 0.4704,
214
+ "step": 340
215
+ },
216
+ {
217
+ "grad_norm": 0.7408623695373535,
218
+ "learning_rate": 3.49e-05,
219
+ "loss": 0.4463,
220
+ "step": 350
221
+ },
222
+ {
223
+ "grad_norm": 0.8417967557907104,
224
+ "learning_rate": 3.59e-05,
225
+ "loss": 0.4515,
226
+ "step": 360
227
+ },
228
+ {
229
+ "grad_norm": 0.9200495481491089,
230
+ "learning_rate": 3.69e-05,
231
+ "loss": 0.417,
232
+ "step": 370
233
+ },
234
+ {
235
+ "grad_norm": 1.146302342414856,
236
+ "learning_rate": 3.79e-05,
237
+ "loss": 0.3937,
238
+ "step": 380
239
+ },
240
+ {
241
+ "grad_norm": 1.0057293176651,
242
+ "learning_rate": 3.8900000000000004e-05,
243
+ "loss": 0.3773,
244
+ "step": 390
245
+ },
246
+ {
247
+ "grad_norm": 1.112216591835022,
248
+ "learning_rate": 3.99e-05,
249
+ "loss": 0.348,
250
+ "step": 400
251
+ },
252
+ {
253
+ "grad_norm": 1.0176512002944946,
254
+ "learning_rate": 4.09e-05,
255
+ "loss": 0.3392,
256
+ "step": 410
257
+ },
258
+ {
259
+ "grad_norm": 1.0310163497924805,
260
+ "learning_rate": 4.19e-05,
261
+ "loss": 0.3065,
262
+ "step": 420
263
+ },
264
+ {
265
+ "grad_norm": 1.022374153137207,
266
+ "learning_rate": 4.29e-05,
267
+ "loss": 0.2808,
268
+ "step": 430
269
+ },
270
+ {
271
+ "grad_norm": 1.368080735206604,
272
+ "learning_rate": 4.39e-05,
273
+ "loss": 0.2624,
274
+ "step": 440
275
+ },
276
+ {
277
+ "grad_norm": 1.1092591285705566,
278
+ "learning_rate": 4.49e-05,
279
+ "loss": 0.2405,
280
+ "step": 450
281
+ },
282
+ {
283
+ "grad_norm": 0.9738430380821228,
284
+ "learning_rate": 4.5900000000000004e-05,
285
+ "loss": 0.2254,
286
+ "step": 460
287
+ },
288
+ {
289
+ "grad_norm": 1.033246636390686,
290
+ "learning_rate": 4.69e-05,
291
+ "loss": 0.2162,
292
+ "step": 470
293
+ },
294
+ {
295
+ "grad_norm": 0.9855560064315796,
296
+ "learning_rate": 4.79e-05,
297
+ "loss": 0.2088,
298
+ "step": 480
299
+ },
300
+ {
301
+ "grad_norm": 1.0313360691070557,
302
+ "learning_rate": 4.89e-05,
303
+ "loss": 0.2188,
304
+ "step": 490
305
+ },
306
+ {
307
+ "grad_norm": 1.100176215171814,
308
+ "learning_rate": 4.99e-05,
309
+ "loss": 0.2007,
310
+ "step": 500
311
+ },
312
+ {
313
+ "grad_norm": 1.0784265995025635,
314
+ "learning_rate": 5.0900000000000004e-05,
315
+ "loss": 0.2016,
316
+ "step": 510
317
+ },
318
+ {
319
+ "grad_norm": 1.0822303295135498,
320
+ "learning_rate": 5.19e-05,
321
+ "loss": 0.1961,
322
+ "step": 520
323
+ },
324
+ {
325
+ "grad_norm": 1.067589282989502,
326
+ "learning_rate": 5.2900000000000005e-05,
327
+ "loss": 0.1801,
328
+ "step": 530
329
+ },
330
+ {
331
+ "grad_norm": 1.1917147636413574,
332
+ "learning_rate": 5.390000000000001e-05,
333
+ "loss": 0.1705,
334
+ "step": 540
335
+ },
336
+ {
337
+ "grad_norm": 1.3141072988510132,
338
+ "learning_rate": 5.4900000000000006e-05,
339
+ "loss": 0.1851,
340
+ "step": 550
341
+ },
342
+ {
343
+ "grad_norm": 1.002855658531189,
344
+ "learning_rate": 5.590000000000001e-05,
345
+ "loss": 0.1663,
346
+ "step": 560
347
+ },
348
+ {
349
+ "grad_norm": 1.167011022567749,
350
+ "learning_rate": 5.69e-05,
351
+ "loss": 0.1741,
352
+ "step": 570
353
+ },
354
+ {
355
+ "grad_norm": 1.0936863422393799,
356
+ "learning_rate": 5.79e-05,
357
+ "loss": 0.1661,
358
+ "step": 580
359
+ },
360
+ {
361
+ "grad_norm": 0.9669778347015381,
362
+ "learning_rate": 5.89e-05,
363
+ "loss": 0.1648,
364
+ "step": 590
365
+ },
366
+ {
367
+ "grad_norm": 0.9405611753463745,
368
+ "learning_rate": 5.99e-05,
369
+ "loss": 0.1627,
370
+ "step": 600
371
+ },
372
+ {
373
+ "grad_norm": 1.0284767150878906,
374
+ "learning_rate": 6.09e-05,
375
+ "loss": 0.1496,
376
+ "step": 610
377
+ },
378
+ {
379
+ "grad_norm": 1.1097605228424072,
380
+ "learning_rate": 6.19e-05,
381
+ "loss": 0.1628,
382
+ "step": 620
383
+ },
384
+ {
385
+ "grad_norm": 0.9104214310646057,
386
+ "learning_rate": 6.29e-05,
387
+ "loss": 0.1302,
388
+ "step": 630
389
+ },
390
+ {
391
+ "grad_norm": 0.8578998446464539,
392
+ "learning_rate": 6.390000000000001e-05,
393
+ "loss": 0.1326,
394
+ "step": 640
395
+ },
396
+ {
397
+ "grad_norm": 1.1287304162979126,
398
+ "learning_rate": 6.49e-05,
399
+ "loss": 0.1127,
400
+ "step": 650
401
+ },
402
+ {
403
+ "grad_norm": 0.8655268549919128,
404
+ "learning_rate": 6.59e-05,
405
+ "loss": 0.1202,
406
+ "step": 660
407
+ },
408
+ {
409
+ "grad_norm": 0.9937160015106201,
410
+ "learning_rate": 6.690000000000001e-05,
411
+ "loss": 0.1198,
412
+ "step": 670
413
+ },
414
+ {
415
+ "grad_norm": 0.9691420197486877,
416
+ "learning_rate": 6.790000000000001e-05,
417
+ "loss": 0.1096,
418
+ "step": 680
419
+ },
420
+ {
421
+ "grad_norm": 1.0945252180099487,
422
+ "learning_rate": 6.89e-05,
423
+ "loss": 0.105,
424
+ "step": 690
425
+ },
426
+ {
427
+ "grad_norm": 1.0388752222061157,
428
+ "learning_rate": 6.99e-05,
429
+ "loss": 0.1027,
430
+ "step": 700
431
+ },
432
+ {
433
+ "grad_norm": 0.881949245929718,
434
+ "learning_rate": 7.09e-05,
435
+ "loss": 0.1044,
436
+ "step": 710
437
+ },
438
+ {
439
+ "grad_norm": 0.8678519129753113,
440
+ "learning_rate": 7.19e-05,
441
+ "loss": 0.0842,
442
+ "step": 720
443
+ },
444
+ {
445
+ "grad_norm": 1.2314260005950928,
446
+ "learning_rate": 7.29e-05,
447
+ "loss": 0.0841,
448
+ "step": 730
449
+ },
450
+ {
451
+ "grad_norm": 0.7337191700935364,
452
+ "learning_rate": 7.390000000000001e-05,
453
+ "loss": 0.0771,
454
+ "step": 740
455
+ },
456
+ {
457
+ "grad_norm": 1.194354772567749,
458
+ "learning_rate": 7.49e-05,
459
+ "loss": 0.0791,
460
+ "step": 750
461
+ },
462
+ {
463
+ "grad_norm": 1.0703870058059692,
464
+ "learning_rate": 7.59e-05,
465
+ "loss": 0.0697,
466
+ "step": 760
467
+ },
468
+ {
469
+ "grad_norm": 0.9820927977561951,
470
+ "learning_rate": 7.69e-05,
471
+ "loss": 0.0798,
472
+ "step": 770
473
+ },
474
+ {
475
+ "grad_norm": 1.099042534828186,
476
+ "learning_rate": 7.790000000000001e-05,
477
+ "loss": 0.0736,
478
+ "step": 780
479
+ },
480
+ {
481
+ "grad_norm": 0.9056155681610107,
482
+ "learning_rate": 7.890000000000001e-05,
483
+ "loss": 0.0756,
484
+ "step": 790
485
+ },
486
+ {
487
+ "grad_norm": 0.8292648792266846,
488
+ "learning_rate": 7.99e-05,
489
+ "loss": 0.0796,
490
+ "step": 800
491
+ },
492
+ {
493
+ "grad_norm": 0.9507290720939636,
494
+ "learning_rate": 8.090000000000001e-05,
495
+ "loss": 0.0829,
496
+ "step": 810
497
+ },
498
+ {
499
+ "grad_norm": 0.9466397762298584,
500
+ "learning_rate": 8.19e-05,
501
+ "loss": 0.0688,
502
+ "step": 820
503
+ },
504
+ {
505
+ "grad_norm": 0.7956731915473938,
506
+ "learning_rate": 8.29e-05,
507
+ "loss": 0.0747,
508
+ "step": 830
509
+ },
510
+ {
511
+ "grad_norm": 0.7995853424072266,
512
+ "learning_rate": 8.39e-05,
513
+ "loss": 0.0634,
514
+ "step": 840
515
+ },
516
+ {
517
+ "grad_norm": 0.7665478587150574,
518
+ "learning_rate": 8.49e-05,
519
+ "loss": 0.0661,
520
+ "step": 850
521
+ },
522
+ {
523
+ "grad_norm": 0.9283880591392517,
524
+ "learning_rate": 8.59e-05,
525
+ "loss": 0.0702,
526
+ "step": 860
527
+ },
528
+ {
529
+ "grad_norm": 1.126967191696167,
530
+ "learning_rate": 8.69e-05,
531
+ "loss": 0.0716,
532
+ "step": 870
533
+ },
534
+ {
535
+ "grad_norm": 0.8662194609642029,
536
+ "learning_rate": 8.790000000000001e-05,
537
+ "loss": 0.0667,
538
+ "step": 880
539
+ },
540
+ {
541
+ "grad_norm": 0.9572857022285461,
542
+ "learning_rate": 8.89e-05,
543
+ "loss": 0.0791,
544
+ "step": 890
545
+ },
546
+ {
547
+ "grad_norm": 0.9036967158317566,
548
+ "learning_rate": 8.99e-05,
549
+ "loss": 0.0745,
550
+ "step": 900
551
+ },
552
+ {
553
+ "grad_norm": 0.7550048828125,
554
+ "learning_rate": 9.090000000000001e-05,
555
+ "loss": 0.0746,
556
+ "step": 910
557
+ },
558
+ {
559
+ "grad_norm": 0.9990408420562744,
560
+ "learning_rate": 9.190000000000001e-05,
561
+ "loss": 0.0648,
562
+ "step": 920
563
+ },
564
+ {
565
+ "grad_norm": 0.8286410570144653,
566
+ "learning_rate": 9.290000000000001e-05,
567
+ "loss": 0.0697,
568
+ "step": 930
569
+ },
570
+ {
571
+ "grad_norm": 0.9783310890197754,
572
+ "learning_rate": 9.39e-05,
573
+ "loss": 0.0749,
574
+ "step": 940
575
+ },
576
+ {
577
+ "grad_norm": 0.9899768233299255,
578
+ "learning_rate": 9.49e-05,
579
+ "loss": 0.0722,
580
+ "step": 950
581
+ },
582
+ {
583
+ "grad_norm": 0.7450554370880127,
584
+ "learning_rate": 9.59e-05,
585
+ "loss": 0.0599,
586
+ "step": 960
587
+ },
588
+ {
589
+ "grad_norm": 0.7791635394096375,
590
+ "learning_rate": 9.69e-05,
591
+ "loss": 0.0654,
592
+ "step": 970
593
+ },
594
+ {
595
+ "grad_norm": 0.7614015340805054,
596
+ "learning_rate": 9.790000000000001e-05,
597
+ "loss": 0.0558,
598
+ "step": 980
599
+ },
600
+ {
601
+ "grad_norm": 0.9096309542655945,
602
+ "learning_rate": 9.89e-05,
603
+ "loss": 0.0581,
604
+ "step": 990
605
+ },
606
+ {
607
+ "grad_norm": 0.668950080871582,
608
+ "learning_rate": 9.99e-05,
609
+ "loss": 0.0652,
610
+ "step": 1000
611
+ },
612
+ {
613
+ "grad_norm": 0.8658283948898315,
614
+ "learning_rate": 9.999994463727085e-05,
615
+ "loss": 0.0529,
616
+ "step": 1010
617
+ },
618
+ {
619
+ "grad_norm": 0.7495288848876953,
620
+ "learning_rate": 9.999975326009292e-05,
621
+ "loss": 0.059,
622
+ "step": 1020
623
+ },
624
+ {
625
+ "grad_norm": 0.9980189204216003,
626
+ "learning_rate": 9.999942518549879e-05,
627
+ "loss": 0.0638,
628
+ "step": 1030
629
+ },
630
+ {
631
+ "grad_norm": 0.7826606035232544,
632
+ "learning_rate": 9.999896041438544e-05,
633
+ "loss": 0.0546,
634
+ "step": 1040
635
+ },
636
+ {
637
+ "grad_norm": 0.6360778212547302,
638
+ "learning_rate": 9.999835894802353e-05,
639
+ "loss": 0.054,
640
+ "step": 1050
641
+ },
642
+ {
643
+ "grad_norm": 0.7757160067558289,
644
+ "learning_rate": 9.999762078805743e-05,
645
+ "loss": 0.0591,
646
+ "step": 1060
647
+ },
648
+ {
649
+ "grad_norm": 0.7390689849853516,
650
+ "learning_rate": 9.999674593650526e-05,
651
+ "loss": 0.0595,
652
+ "step": 1070
653
+ },
654
+ {
655
+ "grad_norm": 0.6460424065589905,
656
+ "learning_rate": 9.99957343957588e-05,
657
+ "loss": 0.0658,
658
+ "step": 1080
659
+ },
660
+ {
661
+ "grad_norm": 0.8082983493804932,
662
+ "learning_rate": 9.99945861685836e-05,
663
+ "loss": 0.0596,
664
+ "step": 1090
665
+ },
666
+ {
667
+ "grad_norm": 0.7415626645088196,
668
+ "learning_rate": 9.999330125811884e-05,
669
+ "loss": 0.0483,
670
+ "step": 1100
671
+ },
672
+ {
673
+ "grad_norm": 0.8829818367958069,
674
+ "learning_rate": 9.999187966787744e-05,
675
+ "loss": 0.0619,
676
+ "step": 1110
677
+ },
678
+ {
679
+ "grad_norm": 0.8239393830299377,
680
+ "learning_rate": 9.999032140174595e-05,
681
+ "loss": 0.0528,
682
+ "step": 1120
683
+ },
684
+ {
685
+ "grad_norm": 0.8529507517814636,
686
+ "learning_rate": 9.998862646398464e-05,
687
+ "loss": 0.0654,
688
+ "step": 1130
689
+ },
690
+ {
691
+ "grad_norm": 0.7502208948135376,
692
+ "learning_rate": 9.998679485922739e-05,
693
+ "loss": 0.0526,
694
+ "step": 1140
695
+ },
696
+ {
697
+ "grad_norm": 0.6970030069351196,
698
+ "learning_rate": 9.998482659248174e-05,
699
+ "loss": 0.0547,
700
+ "step": 1150
701
+ },
702
+ {
703
+ "grad_norm": 0.9376399517059326,
704
+ "learning_rate": 9.998272166912883e-05,
705
+ "loss": 0.0557,
706
+ "step": 1160
707
+ },
708
+ {
709
+ "grad_norm": 0.7249330282211304,
710
+ "learning_rate": 9.998048009492347e-05,
711
+ "loss": 0.0504,
712
+ "step": 1170
713
+ },
714
+ {
715
+ "grad_norm": 0.8968970775604248,
716
+ "learning_rate": 9.997810187599403e-05,
717
+ "loss": 0.0526,
718
+ "step": 1180
719
+ },
720
+ {
721
+ "grad_norm": 0.7676458358764648,
722
+ "learning_rate": 9.997558701884249e-05,
723
+ "loss": 0.0506,
724
+ "step": 1190
725
+ },
726
+ {
727
+ "grad_norm": 0.6501711010932922,
728
+ "learning_rate": 9.997293553034433e-05,
729
+ "loss": 0.061,
730
+ "step": 1200
731
+ },
732
+ {
733
+ "grad_norm": 0.677116870880127,
734
+ "learning_rate": 9.997014741774866e-05,
735
+ "loss": 0.0462,
736
+ "step": 1210
737
+ },
738
+ {
739
+ "grad_norm": 0.8147766590118408,
740
+ "learning_rate": 9.996722268867803e-05,
741
+ "loss": 0.0486,
742
+ "step": 1220
743
+ },
744
+ {
745
+ "grad_norm": 0.706069827079773,
746
+ "learning_rate": 9.996416135112858e-05,
747
+ "loss": 0.0511,
748
+ "step": 1230
749
+ },
750
+ {
751
+ "grad_norm": 0.6159539818763733,
752
+ "learning_rate": 9.996096341346988e-05,
753
+ "loss": 0.0492,
754
+ "step": 1240
755
+ },
756
+ {
757
+ "grad_norm": 0.6369336843490601,
758
+ "learning_rate": 9.995762888444495e-05,
759
+ "loss": 0.0479,
760
+ "step": 1250
761
+ },
762
+ {
763
+ "grad_norm": 0.7543830275535583,
764
+ "learning_rate": 9.995415777317027e-05,
765
+ "loss": 0.0493,
766
+ "step": 1260
767
+ },
768
+ {
769
+ "grad_norm": 0.7505154609680176,
770
+ "learning_rate": 9.995055008913574e-05,
771
+ "loss": 0.053,
772
+ "step": 1270
773
+ },
774
+ {
775
+ "grad_norm": 0.5397493243217468,
776
+ "learning_rate": 9.994680584220463e-05,
777
+ "loss": 0.0432,
778
+ "step": 1280
779
+ },
780
+ {
781
+ "grad_norm": 0.6707198619842529,
782
+ "learning_rate": 9.994292504261355e-05,
783
+ "loss": 0.0472,
784
+ "step": 1290
785
+ },
786
+ {
787
+ "grad_norm": 0.8792182803153992,
788
+ "learning_rate": 9.993890770097247e-05,
789
+ "loss": 0.0453,
790
+ "step": 1300
791
+ },
792
+ {
793
+ "grad_norm": 0.7324561476707458,
794
+ "learning_rate": 9.993475382826467e-05,
795
+ "loss": 0.0479,
796
+ "step": 1310
797
+ },
798
+ {
799
+ "grad_norm": 0.8385289907455444,
800
+ "learning_rate": 9.993046343584664e-05,
801
+ "loss": 0.0549,
802
+ "step": 1320
803
+ },
804
+ {
805
+ "grad_norm": 0.5908923745155334,
806
+ "learning_rate": 9.992603653544816e-05,
807
+ "loss": 0.0483,
808
+ "step": 1330
809
+ },
810
+ {
811
+ "grad_norm": 0.63700932264328,
812
+ "learning_rate": 9.992147313917222e-05,
813
+ "loss": 0.0485,
814
+ "step": 1340
815
+ },
816
+ {
817
+ "grad_norm": 0.7525864839553833,
818
+ "learning_rate": 9.991677325949497e-05,
819
+ "loss": 0.0469,
820
+ "step": 1350
821
+ },
822
+ {
823
+ "grad_norm": 0.5628486275672913,
824
+ "learning_rate": 9.991193690926568e-05,
825
+ "loss": 0.0459,
826
+ "step": 1360
827
+ },
828
+ {
829
+ "grad_norm": 0.795554518699646,
830
+ "learning_rate": 9.990696410170678e-05,
831
+ "loss": 0.0467,
832
+ "step": 1370
833
+ },
834
+ {
835
+ "grad_norm": 0.7957155704498291,
836
+ "learning_rate": 9.990185485041371e-05,
837
+ "loss": 0.0481,
838
+ "step": 1380
839
+ },
840
+ {
841
+ "grad_norm": 0.5773254632949829,
842
+ "learning_rate": 9.989660916935498e-05,
843
+ "loss": 0.0471,
844
+ "step": 1390
845
+ },
846
+ {
847
+ "grad_norm": 0.6150880455970764,
848
+ "learning_rate": 9.989122707287208e-05,
849
+ "loss": 0.0426,
850
+ "step": 1400
851
+ },
852
+ {
853
+ "grad_norm": 0.7106145620346069,
854
+ "learning_rate": 9.988570857567945e-05,
855
+ "loss": 0.0537,
856
+ "step": 1410
857
+ },
858
+ {
859
+ "grad_norm": 0.9491516947746277,
860
+ "learning_rate": 9.988005369286446e-05,
861
+ "loss": 0.0525,
862
+ "step": 1420
863
+ },
864
+ {
865
+ "grad_norm": 0.6860232353210449,
866
+ "learning_rate": 9.987426243988734e-05,
867
+ "loss": 0.0429,
868
+ "step": 1430
869
+ },
870
+ {
871
+ "grad_norm": 0.7841853499412537,
872
+ "learning_rate": 9.986833483258114e-05,
873
+ "loss": 0.0524,
874
+ "step": 1440
875
+ },
876
+ {
877
+ "grad_norm": 0.6175568103790283,
878
+ "learning_rate": 9.986227088715173e-05,
879
+ "loss": 0.0385,
880
+ "step": 1450
881
+ },
882
+ {
883
+ "grad_norm": 0.5932314991950989,
884
+ "learning_rate": 9.98560706201777e-05,
885
+ "loss": 0.0408,
886
+ "step": 1460
887
+ },
888
+ {
889
+ "grad_norm": 0.7410153150558472,
890
+ "learning_rate": 9.984973404861036e-05,
891
+ "loss": 0.043,
892
+ "step": 1470
893
+ },
894
+ {
895
+ "grad_norm": 0.8330276608467102,
896
+ "learning_rate": 9.984326118977361e-05,
897
+ "loss": 0.051,
898
+ "step": 1480
899
+ },
900
+ {
901
+ "grad_norm": 0.7202706933021545,
902
+ "learning_rate": 9.983665206136406e-05,
903
+ "loss": 0.0493,
904
+ "step": 1490
905
+ },
906
+ {
907
+ "grad_norm": 0.574433445930481,
908
+ "learning_rate": 9.982990668145075e-05,
909
+ "loss": 0.0466,
910
+ "step": 1500
911
+ },
912
+ {
913
+ "grad_norm": 0.7351802587509155,
914
+ "learning_rate": 9.982302506847534e-05,
915
+ "loss": 0.057,
916
+ "step": 1510
917
+ },
918
+ {
919
+ "grad_norm": 0.819564163684845,
920
+ "learning_rate": 9.981600724125189e-05,
921
+ "loss": 0.0555,
922
+ "step": 1520
923
+ },
924
+ {
925
+ "grad_norm": 0.6065496206283569,
926
+ "learning_rate": 9.980885321896685e-05,
927
+ "loss": 0.0509,
928
+ "step": 1530
929
+ },
930
+ {
931
+ "grad_norm": 0.6572223901748657,
932
+ "learning_rate": 9.980156302117905e-05,
933
+ "loss": 0.044,
934
+ "step": 1540
935
+ },
936
+ {
937
+ "grad_norm": 0.6978927254676819,
938
+ "learning_rate": 9.979413666781963e-05,
939
+ "loss": 0.0465,
940
+ "step": 1550
941
+ },
942
+ {
943
+ "grad_norm": 0.5508580803871155,
944
+ "learning_rate": 9.978657417919193e-05,
945
+ "loss": 0.0452,
946
+ "step": 1560
947
+ },
948
+ {
949
+ "grad_norm": 0.5769541263580322,
950
+ "learning_rate": 9.977887557597153e-05,
951
+ "loss": 0.0475,
952
+ "step": 1570
953
+ },
954
+ {
955
+ "grad_norm": 0.5610742568969727,
956
+ "learning_rate": 9.97710408792061e-05,
957
+ "loss": 0.0469,
958
+ "step": 1580
959
+ },
960
+ {
961
+ "grad_norm": 0.5692776441574097,
962
+ "learning_rate": 9.976307011031542e-05,
963
+ "loss": 0.0449,
964
+ "step": 1590
965
+ },
966
+ {
967
+ "grad_norm": 0.5226185321807861,
968
+ "learning_rate": 9.975496329109126e-05,
969
+ "loss": 0.0476,
970
+ "step": 1600
971
+ },
972
+ {
973
+ "grad_norm": 0.7111744284629822,
974
+ "learning_rate": 9.974672044369732e-05,
975
+ "loss": 0.047,
976
+ "step": 1610
977
+ },
978
+ {
979
+ "grad_norm": 0.514858067035675,
980
+ "learning_rate": 9.97383415906693e-05,
981
+ "loss": 0.043,
982
+ "step": 1620
983
+ },
984
+ {
985
+ "grad_norm": 0.5856963396072388,
986
+ "learning_rate": 9.97298267549146e-05,
987
+ "loss": 0.0471,
988
+ "step": 1630
989
+ },
990
+ {
991
+ "grad_norm": 0.6191436052322388,
992
+ "learning_rate": 9.972117595971249e-05,
993
+ "loss": 0.0422,
994
+ "step": 1640
995
+ },
996
+ {
997
+ "grad_norm": 0.5670982599258423,
998
+ "learning_rate": 9.971238922871391e-05,
999
+ "loss": 0.0419,
1000
+ "step": 1650
1001
+ },
1002
+ {
1003
+ "grad_norm": 0.7190003991127014,
1004
+ "learning_rate": 9.970346658594142e-05,
1005
+ "loss": 0.0453,
1006
+ "step": 1660
1007
+ },
1008
+ {
1009
+ "grad_norm": 0.6552428007125854,
1010
+ "learning_rate": 9.969440805578923e-05,
1011
+ "loss": 0.046,
1012
+ "step": 1670
1013
+ },
1014
+ {
1015
+ "grad_norm": 0.578118622303009,
1016
+ "learning_rate": 9.968521366302298e-05,
1017
+ "loss": 0.0392,
1018
+ "step": 1680
1019
+ },
1020
+ {
1021
+ "grad_norm": 0.7054030895233154,
1022
+ "learning_rate": 9.967588343277981e-05,
1023
+ "loss": 0.0455,
1024
+ "step": 1690
1025
+ },
1026
+ {
1027
+ "grad_norm": 0.6531293392181396,
1028
+ "learning_rate": 9.966641739056818e-05,
1029
+ "loss": 0.0421,
1030
+ "step": 1700
1031
+ },
1032
+ {
1033
+ "grad_norm": 0.6111751198768616,
1034
+ "learning_rate": 9.965681556226793e-05,
1035
+ "loss": 0.0517,
1036
+ "step": 1710
1037
+ },
1038
+ {
1039
+ "grad_norm": 0.4928556978702545,
1040
+ "learning_rate": 9.964707797413006e-05,
1041
+ "loss": 0.044,
1042
+ "step": 1720
1043
+ },
1044
+ {
1045
+ "grad_norm": 0.6597058773040771,
1046
+ "learning_rate": 9.963720465277679e-05,
1047
+ "loss": 0.047,
1048
+ "step": 1730
1049
+ },
1050
+ {
1051
+ "grad_norm": 0.6202155351638794,
1052
+ "learning_rate": 9.96271956252014e-05,
1053
+ "loss": 0.0384,
1054
+ "step": 1740
1055
+ },
1056
+ {
1057
+ "grad_norm": 0.5262959599494934,
1058
+ "learning_rate": 9.961705091876816e-05,
1059
+ "loss": 0.0425,
1060
+ "step": 1750
1061
+ },
1062
+ {
1063
+ "grad_norm": 0.6935763955116272,
1064
+ "learning_rate": 9.960677056121235e-05,
1065
+ "loss": 0.0409,
1066
+ "step": 1760
1067
+ },
1068
+ {
1069
+ "grad_norm": 0.6149827837944031,
1070
+ "learning_rate": 9.959635458064005e-05,
1071
+ "loss": 0.0383,
1072
+ "step": 1770
1073
+ },
1074
+ {
1075
+ "grad_norm": 0.5901826024055481,
1076
+ "learning_rate": 9.958580300552815e-05,
1077
+ "loss": 0.0426,
1078
+ "step": 1780
1079
+ },
1080
+ {
1081
+ "grad_norm": 0.5597098469734192,
1082
+ "learning_rate": 9.957511586472426e-05,
1083
+ "loss": 0.0352,
1084
+ "step": 1790
1085
+ },
1086
+ {
1087
+ "grad_norm": 0.5581690073013306,
1088
+ "learning_rate": 9.956429318744662e-05,
1089
+ "loss": 0.0366,
1090
+ "step": 1800
1091
+ },
1092
+ {
1093
+ "grad_norm": 0.5969916582107544,
1094
+ "learning_rate": 9.955333500328404e-05,
1095
+ "loss": 0.0355,
1096
+ "step": 1810
1097
+ },
1098
+ {
1099
+ "grad_norm": 0.5474916696548462,
1100
+ "learning_rate": 9.95422413421957e-05,
1101
+ "loss": 0.0376,
1102
+ "step": 1820
1103
+ },
1104
+ {
1105
+ "grad_norm": 0.5651562809944153,
1106
+ "learning_rate": 9.953101223451133e-05,
1107
+ "loss": 0.0359,
1108
+ "step": 1830
1109
+ },
1110
+ {
1111
+ "grad_norm": 0.6243921518325806,
1112
+ "learning_rate": 9.951964771093085e-05,
1113
+ "loss": 0.0373,
1114
+ "step": 1840
1115
+ },
1116
+ {
1117
+ "grad_norm": 0.4624647796154022,
1118
+ "learning_rate": 9.950814780252442e-05,
1119
+ "loss": 0.0347,
1120
+ "step": 1850
1121
+ },
1122
+ {
1123
+ "grad_norm": 0.5893751382827759,
1124
+ "learning_rate": 9.949651254073236e-05,
1125
+ "loss": 0.0408,
1126
+ "step": 1860
1127
+ },
1128
+ {
1129
+ "grad_norm": 0.526287317276001,
1130
+ "learning_rate": 9.948474195736504e-05,
1131
+ "loss": 0.0388,
1132
+ "step": 1870
1133
+ },
1134
+ {
1135
+ "grad_norm": 0.6111840605735779,
1136
+ "learning_rate": 9.947283608460277e-05,
1137
+ "loss": 0.0346,
1138
+ "step": 1880
1139
+ },
1140
+ {
1141
+ "grad_norm": 0.46461328864097595,
1142
+ "learning_rate": 9.946079495499577e-05,
1143
+ "loss": 0.0411,
1144
+ "step": 1890
1145
+ },
1146
+ {
1147
+ "grad_norm": 0.610548734664917,
1148
+ "learning_rate": 9.944861860146401e-05,
1149
+ "loss": 0.0407,
1150
+ "step": 1900
1151
+ },
1152
+ {
1153
+ "grad_norm": 0.5339504480361938,
1154
+ "learning_rate": 9.943630705729719e-05,
1155
+ "loss": 0.0398,
1156
+ "step": 1910
1157
+ },
1158
+ {
1159
+ "grad_norm": 0.46559029817581177,
1160
+ "learning_rate": 9.942386035615459e-05,
1161
+ "loss": 0.039,
1162
+ "step": 1920
1163
+ },
1164
+ {
1165
+ "grad_norm": 0.7745798826217651,
1166
+ "learning_rate": 9.941127853206503e-05,
1167
+ "loss": 0.04,
1168
+ "step": 1930
1169
+ },
1170
+ {
1171
+ "grad_norm": 0.5811882019042969,
1172
+ "learning_rate": 9.939856161942673e-05,
1173
+ "loss": 0.0425,
1174
+ "step": 1940
1175
+ },
1176
+ {
1177
+ "grad_norm": 0.4856541156768799,
1178
+ "learning_rate": 9.938570965300724e-05,
1179
+ "loss": 0.0363,
1180
+ "step": 1950
1181
+ },
1182
+ {
1183
+ "grad_norm": 0.5952467918395996,
1184
+ "learning_rate": 9.937272266794335e-05,
1185
+ "loss": 0.0439,
1186
+ "step": 1960
1187
+ },
1188
+ {
1189
+ "grad_norm": 0.5669976472854614,
1190
+ "learning_rate": 9.935960069974096e-05,
1191
+ "loss": 0.05,
1192
+ "step": 1970
1193
+ },
1194
+ {
1195
+ "grad_norm": 0.5959198474884033,
1196
+ "learning_rate": 9.934634378427506e-05,
1197
+ "loss": 0.0382,
1198
+ "step": 1980
1199
+ },
1200
+ {
1201
+ "grad_norm": 0.520875096321106,
1202
+ "learning_rate": 9.933295195778954e-05,
1203
+ "loss": 0.0386,
1204
+ "step": 1990
1205
+ },
1206
+ {
1207
+ "grad_norm": 0.4351758360862732,
1208
+ "learning_rate": 9.931942525689715e-05,
1209
+ "loss": 0.0488,
1210
+ "step": 2000
1211
+ },
1212
+ {
1213
+ "grad_norm": 0.6345981359481812,
1214
+ "learning_rate": 9.930576371857936e-05,
1215
+ "loss": 0.0391,
1216
+ "step": 2010
1217
+ },
1218
+ {
1219
+ "grad_norm": 0.6230748295783997,
1220
+ "learning_rate": 9.929196738018629e-05,
1221
+ "loss": 0.0388,
1222
+ "step": 2020
1223
+ },
1224
+ {
1225
+ "grad_norm": 0.5425089001655579,
1226
+ "learning_rate": 9.927803627943662e-05,
1227
+ "loss": 0.0395,
1228
+ "step": 2030
1229
+ },
1230
+ {
1231
+ "grad_norm": 0.49332770705223083,
1232
+ "learning_rate": 9.926397045441744e-05,
1233
+ "loss": 0.039,
1234
+ "step": 2040
1235
+ },
1236
+ {
1237
+ "grad_norm": 0.6731558442115784,
1238
+ "learning_rate": 9.924976994358417e-05,
1239
+ "loss": 0.0427,
1240
+ "step": 2050
1241
+ },
1242
+ {
1243
+ "grad_norm": 0.5310463309288025,
1244
+ "learning_rate": 9.923543478576048e-05,
1245
+ "loss": 0.0474,
1246
+ "step": 2060
1247
+ },
1248
+ {
1249
+ "grad_norm": 0.548930823802948,
1250
+ "learning_rate": 9.922096502013813e-05,
1251
+ "loss": 0.0423,
1252
+ "step": 2070
1253
+ },
1254
+ {
1255
+ "grad_norm": 0.5744786262512207,
1256
+ "learning_rate": 9.92063606862769e-05,
1257
+ "loss": 0.0372,
1258
+ "step": 2080
1259
+ },
1260
+ {
1261
+ "grad_norm": 0.6390929222106934,
1262
+ "learning_rate": 9.919162182410453e-05,
1263
+ "loss": 0.0368,
1264
+ "step": 2090
1265
+ },
1266
+ {
1267
+ "grad_norm": 0.5252511501312256,
1268
+ "learning_rate": 9.917674847391645e-05,
1269
+ "loss": 0.038,
1270
+ "step": 2100
1271
+ },
1272
+ {
1273
+ "grad_norm": 0.5656434297561646,
1274
+ "learning_rate": 9.916174067637584e-05,
1275
+ "loss": 0.0333,
1276
+ "step": 2110
1277
+ },
1278
+ {
1279
+ "grad_norm": 0.5288258790969849,
1280
+ "learning_rate": 9.914659847251348e-05,
1281
+ "loss": 0.0406,
1282
+ "step": 2120
1283
+ },
1284
+ {
1285
+ "grad_norm": 0.5040147304534912,
1286
+ "learning_rate": 9.913132190372753e-05,
1287
+ "loss": 0.0369,
1288
+ "step": 2130
1289
+ },
1290
+ {
1291
+ "grad_norm": 0.5128138661384583,
1292
+ "learning_rate": 9.911591101178359e-05,
1293
+ "loss": 0.0368,
1294
+ "step": 2140
1295
+ },
1296
+ {
1297
+ "grad_norm": 0.4942684769630432,
1298
+ "learning_rate": 9.910036583881443e-05,
1299
+ "loss": 0.0334,
1300
+ "step": 2150
1301
+ },
1302
+ {
1303
+ "grad_norm": 0.5318565368652344,
1304
+ "learning_rate": 9.908468642731995e-05,
1305
+ "loss": 0.0325,
1306
+ "step": 2160
1307
+ },
1308
+ {
1309
+ "grad_norm": 0.5772367715835571,
1310
+ "learning_rate": 9.906887282016707e-05,
1311
+ "loss": 0.0344,
1312
+ "step": 2170
1313
+ },
1314
+ {
1315
+ "grad_norm": 0.5957911014556885,
1316
+ "learning_rate": 9.90529250605896e-05,
1317
+ "loss": 0.0368,
1318
+ "step": 2180
1319
+ },
1320
+ {
1321
+ "grad_norm": 0.6259480714797974,
1322
+ "learning_rate": 9.903684319218809e-05,
1323
+ "loss": 0.0375,
1324
+ "step": 2190
1325
+ },
1326
+ {
1327
+ "grad_norm": 0.691277801990509,
1328
+ "learning_rate": 9.902062725892976e-05,
1329
+ "loss": 0.0402,
1330
+ "step": 2200
1331
+ },
1332
+ {
1333
+ "grad_norm": 0.624859094619751,
1334
+ "learning_rate": 9.900427730514834e-05,
1335
+ "loss": 0.0316,
1336
+ "step": 2210
1337
+ },
1338
+ {
1339
+ "grad_norm": 0.46915674209594727,
1340
+ "learning_rate": 9.8987793375544e-05,
1341
+ "loss": 0.0352,
1342
+ "step": 2220
1343
+ },
1344
+ {
1345
+ "grad_norm": 0.5559591054916382,
1346
+ "learning_rate": 9.897117551518318e-05,
1347
+ "loss": 0.0353,
1348
+ "step": 2230
1349
+ },
1350
+ {
1351
+ "grad_norm": 0.47577548027038574,
1352
+ "learning_rate": 9.895442376949844e-05,
1353
+ "loss": 0.0395,
1354
+ "step": 2240
1355
+ },
1356
+ {
1357
+ "grad_norm": 0.7231595516204834,
1358
+ "learning_rate": 9.893753818428845e-05,
1359
+ "loss": 0.0442,
1360
+ "step": 2250
1361
+ },
1362
+ {
1363
+ "grad_norm": 0.4607575535774231,
1364
+ "learning_rate": 9.892051880571773e-05,
1365
+ "loss": 0.037,
1366
+ "step": 2260
1367
+ },
1368
+ {
1369
+ "grad_norm": 0.4901242256164551,
1370
+ "learning_rate": 9.890336568031663e-05,
1371
+ "loss": 0.0342,
1372
+ "step": 2270
1373
+ },
1374
+ {
1375
+ "grad_norm": 0.46413323283195496,
1376
+ "learning_rate": 9.888607885498113e-05,
1377
+ "loss": 0.0386,
1378
+ "step": 2280
1379
+ },
1380
+ {
1381
+ "grad_norm": 0.5028432607650757,
1382
+ "learning_rate": 9.886865837697275e-05,
1383
+ "loss": 0.0384,
1384
+ "step": 2290
1385
+ },
1386
+ {
1387
+ "grad_norm": 0.6079827547073364,
1388
+ "learning_rate": 9.88511042939184e-05,
1389
+ "loss": 0.0416,
1390
+ "step": 2300
1391
+ },
1392
+ {
1393
+ "grad_norm": 0.6189248561859131,
1394
+ "learning_rate": 9.883341665381028e-05,
1395
+ "loss": 0.0372,
1396
+ "step": 2310
1397
+ },
1398
+ {
1399
+ "grad_norm": 0.569456160068512,
1400
+ "learning_rate": 9.881559550500575e-05,
1401
+ "loss": 0.0317,
1402
+ "step": 2320
1403
+ },
1404
+ {
1405
+ "grad_norm": 0.5782006978988647,
1406
+ "learning_rate": 9.879764089622712e-05,
1407
+ "loss": 0.0363,
1408
+ "step": 2330
1409
+ },
1410
+ {
1411
+ "grad_norm": 0.6612024307250977,
1412
+ "learning_rate": 9.87795528765616e-05,
1413
+ "loss": 0.0386,
1414
+ "step": 2340
1415
+ },
1416
+ {
1417
+ "grad_norm": 0.45619797706604004,
1418
+ "learning_rate": 9.876133149546118e-05,
1419
+ "loss": 0.0385,
1420
+ "step": 2350
1421
+ },
1422
+ {
1423
+ "grad_norm": 0.4743977189064026,
1424
+ "learning_rate": 9.874297680274238e-05,
1425
+ "loss": 0.0384,
1426
+ "step": 2360
1427
+ },
1428
+ {
1429
+ "grad_norm": 0.5303918719291687,
1430
+ "learning_rate": 9.872448884858624e-05,
1431
+ "loss": 0.0364,
1432
+ "step": 2370
1433
+ },
1434
+ {
1435
+ "grad_norm": 0.5923212766647339,
1436
+ "learning_rate": 9.870586768353815e-05,
1437
+ "loss": 0.0366,
1438
+ "step": 2380
1439
+ },
1440
+ {
1441
+ "grad_norm": 0.5156052112579346,
1442
+ "learning_rate": 9.868711335850764e-05,
1443
+ "loss": 0.0412,
1444
+ "step": 2390
1445
+ },
1446
+ {
1447
+ "grad_norm": 0.4702778458595276,
1448
+ "learning_rate": 9.866822592476833e-05,
1449
+ "loss": 0.0353,
1450
+ "step": 2400
1451
+ },
1452
+ {
1453
+ "grad_norm": 0.4955006241798401,
1454
+ "learning_rate": 9.86492054339577e-05,
1455
+ "loss": 0.0356,
1456
+ "step": 2410
1457
+ },
1458
+ {
1459
+ "grad_norm": 0.4722374677658081,
1460
+ "learning_rate": 9.863005193807711e-05,
1461
+ "loss": 0.0328,
1462
+ "step": 2420
1463
+ },
1464
+ {
1465
+ "grad_norm": 0.5261074900627136,
1466
+ "learning_rate": 9.861076548949143e-05,
1467
+ "loss": 0.0314,
1468
+ "step": 2430
1469
+ },
1470
+ {
1471
+ "grad_norm": 0.43109720945358276,
1472
+ "learning_rate": 9.859134614092912e-05,
1473
+ "loss": 0.0306,
1474
+ "step": 2440
1475
+ },
1476
+ {
1477
+ "grad_norm": 0.5150691270828247,
1478
+ "learning_rate": 9.857179394548191e-05,
1479
+ "loss": 0.0331,
1480
+ "step": 2450
1481
+ },
1482
+ {
1483
+ "grad_norm": 0.413881778717041,
1484
+ "learning_rate": 9.855210895660477e-05,
1485
+ "loss": 0.0313,
1486
+ "step": 2460
1487
+ },
1488
+ {
1489
+ "grad_norm": 0.5778813362121582,
1490
+ "learning_rate": 9.853229122811568e-05,
1491
+ "loss": 0.0327,
1492
+ "step": 2470
1493
+ },
1494
+ {
1495
+ "grad_norm": 0.5499809980392456,
1496
+ "learning_rate": 9.851234081419559e-05,
1497
+ "loss": 0.0371,
1498
+ "step": 2480
1499
+ },
1500
+ {
1501
+ "grad_norm": 0.533755898475647,
1502
+ "learning_rate": 9.849225776938814e-05,
1503
+ "loss": 0.0347,
1504
+ "step": 2490
1505
+ },
1506
+ {
1507
+ "grad_norm": 0.5036794543266296,
1508
+ "learning_rate": 9.847204214859964e-05,
1509
+ "loss": 0.0365,
1510
+ "step": 2500
1511
+ },
1512
+ {
1513
+ "grad_norm": 0.4547636806964874,
1514
+ "learning_rate": 9.845169400709879e-05,
1515
+ "loss": 0.0284,
1516
+ "step": 2510
1517
+ },
1518
+ {
1519
+ "grad_norm": 0.4148177206516266,
1520
+ "learning_rate": 9.843121340051664e-05,
1521
+ "loss": 0.0338,
1522
+ "step": 2520
1523
+ },
1524
+ {
1525
+ "grad_norm": 0.4307814836502075,
1526
+ "learning_rate": 9.841060038484641e-05,
1527
+ "loss": 0.0401,
1528
+ "step": 2530
1529
+ },
1530
+ {
1531
+ "grad_norm": 0.5055217146873474,
1532
+ "learning_rate": 9.838985501644328e-05,
1533
+ "loss": 0.0413,
1534
+ "step": 2540
1535
+ },
1536
+ {
1537
+ "grad_norm": 0.5252987742424011,
1538
+ "learning_rate": 9.83689773520243e-05,
1539
+ "loss": 0.0334,
1540
+ "step": 2550
1541
+ },
1542
+ {
1543
+ "grad_norm": 0.5325053334236145,
1544
+ "learning_rate": 9.834796744866819e-05,
1545
+ "loss": 0.0339,
1546
+ "step": 2560
1547
+ },
1548
+ {
1549
+ "grad_norm": 0.5485632419586182,
1550
+ "learning_rate": 9.832682536381525e-05,
1551
+ "loss": 0.0354,
1552
+ "step": 2570
1553
+ },
1554
+ {
1555
+ "grad_norm": 0.5406777262687683,
1556
+ "learning_rate": 9.830555115526711e-05,
1557
+ "loss": 0.0368,
1558
+ "step": 2580
1559
+ },
1560
+ {
1561
+ "grad_norm": 0.37698280811309814,
1562
+ "learning_rate": 9.828414488118667e-05,
1563
+ "loss": 0.0336,
1564
+ "step": 2590
1565
+ },
1566
+ {
1567
+ "grad_norm": 0.5253736972808838,
1568
+ "learning_rate": 9.826260660009785e-05,
1569
+ "loss": 0.0337,
1570
+ "step": 2600
1571
+ },
1572
+ {
1573
+ "grad_norm": 0.482319176197052,
1574
+ "learning_rate": 9.824093637088547e-05,
1575
+ "loss": 0.0299,
1576
+ "step": 2610
1577
+ },
1578
+ {
1579
+ "grad_norm": 0.43845584988594055,
1580
+ "learning_rate": 9.821913425279514e-05,
1581
+ "loss": 0.032,
1582
+ "step": 2620
1583
+ },
1584
+ {
1585
+ "grad_norm": 0.4526597559452057,
1586
+ "learning_rate": 9.8197200305433e-05,
1587
+ "loss": 0.034,
1588
+ "step": 2630
1589
+ },
1590
+ {
1591
+ "grad_norm": 0.45589521527290344,
1592
+ "learning_rate": 9.817513458876564e-05,
1593
+ "loss": 0.0464,
1594
+ "step": 2640
1595
+ },
1596
+ {
1597
+ "grad_norm": 0.5381149649620056,
1598
+ "learning_rate": 9.815293716311987e-05,
1599
+ "loss": 0.0334,
1600
+ "step": 2650
1601
+ },
1602
+ {
1603
+ "grad_norm": 0.5279123187065125,
1604
+ "learning_rate": 9.813060808918262e-05,
1605
+ "loss": 0.0318,
1606
+ "step": 2660
1607
+ },
1608
+ {
1609
+ "grad_norm": 0.3532435894012451,
1610
+ "learning_rate": 9.810814742800069e-05,
1611
+ "loss": 0.0285,
1612
+ "step": 2670
1613
+ },
1614
+ {
1615
+ "grad_norm": 0.3765302896499634,
1616
+ "learning_rate": 9.808555524098074e-05,
1617
+ "loss": 0.0289,
1618
+ "step": 2680
1619
+ },
1620
+ {
1621
+ "grad_norm": 0.46037837862968445,
1622
+ "learning_rate": 9.806283158988887e-05,
1623
+ "loss": 0.0291,
1624
+ "step": 2690
1625
+ },
1626
+ {
1627
+ "grad_norm": 0.483735591173172,
1628
+ "learning_rate": 9.803997653685072e-05,
1629
+ "loss": 0.0392,
1630
+ "step": 2700
1631
+ },
1632
+ {
1633
+ "grad_norm": 0.45865148305892944,
1634
+ "learning_rate": 9.801699014435112e-05,
1635
+ "loss": 0.0393,
1636
+ "step": 2710
1637
+ },
1638
+ {
1639
+ "grad_norm": 0.4620376229286194,
1640
+ "learning_rate": 9.799387247523398e-05,
1641
+ "loss": 0.0352,
1642
+ "step": 2720
1643
+ },
1644
+ {
1645
+ "grad_norm": 0.41832435131073,
1646
+ "learning_rate": 9.797062359270215e-05,
1647
+ "loss": 0.0319,
1648
+ "step": 2730
1649
+ },
1650
+ {
1651
+ "grad_norm": 0.4439375400543213,
1652
+ "learning_rate": 9.794724356031715e-05,
1653
+ "loss": 0.0307,
1654
+ "step": 2740
1655
+ },
1656
+ {
1657
+ "grad_norm": 0.5037664771080017,
1658
+ "learning_rate": 9.792373244199913e-05,
1659
+ "loss": 0.0306,
1660
+ "step": 2750
1661
+ },
1662
+ {
1663
+ "grad_norm": 0.378164678812027,
1664
+ "learning_rate": 9.790009030202658e-05,
1665
+ "loss": 0.0313,
1666
+ "step": 2760
1667
+ },
1668
+ {
1669
+ "grad_norm": 0.5053073763847351,
1670
+ "learning_rate": 9.78763172050362e-05,
1671
+ "loss": 0.0295,
1672
+ "step": 2770
1673
+ },
1674
+ {
1675
+ "grad_norm": 0.4680381119251251,
1676
+ "learning_rate": 9.785241321602274e-05,
1677
+ "loss": 0.0277,
1678
+ "step": 2780
1679
+ },
1680
+ {
1681
+ "grad_norm": 0.4624013304710388,
1682
+ "learning_rate": 9.782837840033879e-05,
1683
+ "loss": 0.0288,
1684
+ "step": 2790
1685
+ },
1686
+ {
1687
+ "grad_norm": 0.5074241757392883,
1688
+ "learning_rate": 9.780421282369461e-05,
1689
+ "loss": 0.0292,
1690
+ "step": 2800
1691
+ },
1692
+ {
1693
+ "grad_norm": 0.4835506081581116,
1694
+ "learning_rate": 9.777991655215797e-05,
1695
+ "loss": 0.0294,
1696
+ "step": 2810
1697
+ },
1698
+ {
1699
+ "grad_norm": 0.5738292336463928,
1700
+ "learning_rate": 9.775548965215394e-05,
1701
+ "loss": 0.0295,
1702
+ "step": 2820
1703
+ },
1704
+ {
1705
+ "grad_norm": 0.5334445238113403,
1706
+ "learning_rate": 9.773093219046474e-05,
1707
+ "loss": 0.0293,
1708
+ "step": 2830
1709
+ },
1710
+ {
1711
+ "grad_norm": 0.4011390507221222,
1712
+ "learning_rate": 9.770624423422954e-05,
1713
+ "loss": 0.0291,
1714
+ "step": 2840
1715
+ },
1716
+ {
1717
+ "grad_norm": 0.41171419620513916,
1718
+ "learning_rate": 9.768142585094426e-05,
1719
+ "loss": 0.0302,
1720
+ "step": 2850
1721
+ },
1722
+ {
1723
+ "grad_norm": 0.46391263604164124,
1724
+ "learning_rate": 9.765647710846142e-05,
1725
+ "loss": 0.0405,
1726
+ "step": 2860
1727
+ },
1728
+ {
1729
+ "grad_norm": 0.5071845650672913,
1730
+ "learning_rate": 9.763139807498991e-05,
1731
+ "loss": 0.0285,
1732
+ "step": 2870
1733
+ },
1734
+ {
1735
+ "grad_norm": 0.4814237058162689,
1736
+ "learning_rate": 9.760618881909487e-05,
1737
+ "loss": 0.0317,
1738
+ "step": 2880
1739
+ },
1740
+ {
1741
+ "grad_norm": 0.5396919846534729,
1742
+ "learning_rate": 9.758084940969744e-05,
1743
+ "loss": 0.0316,
1744
+ "step": 2890
1745
+ },
1746
+ {
1747
+ "grad_norm": 0.5363779664039612,
1748
+ "learning_rate": 9.755537991607459e-05,
1749
+ "loss": 0.027,
1750
+ "step": 2900
1751
+ },
1752
+ {
1753
+ "grad_norm": 0.505138099193573,
1754
+ "learning_rate": 9.752978040785895e-05,
1755
+ "loss": 0.0354,
1756
+ "step": 2910
1757
+ },
1758
+ {
1759
+ "grad_norm": 0.5476271510124207,
1760
+ "learning_rate": 9.750405095503859e-05,
1761
+ "loss": 0.0299,
1762
+ "step": 2920
1763
+ },
1764
+ {
1765
+ "grad_norm": 0.5189036130905151,
1766
+ "learning_rate": 9.747819162795686e-05,
1767
+ "loss": 0.0331,
1768
+ "step": 2930
1769
+ },
1770
+ {
1771
+ "grad_norm": 0.45717042684555054,
1772
+ "learning_rate": 9.745220249731217e-05,
1773
+ "loss": 0.026,
1774
+ "step": 2940
1775
+ },
1776
+ {
1777
+ "grad_norm": 0.4337165355682373,
1778
+ "learning_rate": 9.742608363415781e-05,
1779
+ "loss": 0.0272,
1780
+ "step": 2950
1781
+ },
1782
+ {
1783
+ "grad_norm": 0.4811023771762848,
1784
+ "learning_rate": 9.739983510990176e-05,
1785
+ "loss": 0.0288,
1786
+ "step": 2960
1787
+ },
1788
+ {
1789
+ "grad_norm": 0.3455168902873993,
1790
+ "learning_rate": 9.737345699630647e-05,
1791
+ "loss": 0.0298,
1792
+ "step": 2970
1793
+ },
1794
+ {
1795
+ "grad_norm": 0.5057815313339233,
1796
+ "learning_rate": 9.734694936548869e-05,
1797
+ "loss": 0.0332,
1798
+ "step": 2980
1799
+ },
1800
+ {
1801
+ "grad_norm": 0.38619765639305115,
1802
+ "learning_rate": 9.732031228991932e-05,
1803
+ "loss": 0.0256,
1804
+ "step": 2990
1805
+ },
1806
+ {
1807
+ "grad_norm": 0.3297816514968872,
1808
+ "learning_rate": 9.729354584242302e-05,
1809
+ "loss": 0.0355,
1810
+ "step": 3000
1811
+ },
1812
+ {
1813
+ "grad_norm": 0.5174765586853027,
1814
+ "learning_rate": 9.726665009617832e-05,
1815
+ "loss": 0.0309,
1816
+ "step": 3010
1817
+ },
1818
+ {
1819
+ "grad_norm": 0.43245866894721985,
1820
+ "learning_rate": 9.723962512471714e-05,
1821
+ "loss": 0.033,
1822
+ "step": 3020
1823
+ },
1824
+ {
1825
+ "grad_norm": 0.516598105430603,
1826
+ "learning_rate": 9.72124710019247e-05,
1827
+ "loss": 0.03,
1828
+ "step": 3030
1829
+ },
1830
+ {
1831
+ "grad_norm": 0.48712822794914246,
1832
+ "learning_rate": 9.718518780203934e-05,
1833
+ "loss": 0.0322,
1834
+ "step": 3040
1835
+ },
1836
+ {
1837
+ "grad_norm": 0.3674415946006775,
1838
+ "learning_rate": 9.715777559965228e-05,
1839
+ "loss": 0.0319,
1840
+ "step": 3050
1841
+ },
1842
+ {
1843
+ "grad_norm": 0.4218079149723053,
1844
+ "learning_rate": 9.713023446970746e-05,
1845
+ "loss": 0.0255,
1846
+ "step": 3060
1847
+ },
1848
+ {
1849
+ "grad_norm": 0.4967867136001587,
1850
+ "learning_rate": 9.710256448750126e-05,
1851
+ "loss": 0.0311,
1852
+ "step": 3070
1853
+ },
1854
+ {
1855
+ "grad_norm": 0.497653067111969,
1856
+ "learning_rate": 9.707476572868235e-05,
1857
+ "loss": 0.0341,
1858
+ "step": 3080
1859
+ },
1860
+ {
1861
+ "grad_norm": 0.4222137928009033,
1862
+ "learning_rate": 9.704683826925149e-05,
1863
+ "loss": 0.0273,
1864
+ "step": 3090
1865
+ },
1866
+ {
1867
+ "grad_norm": 0.37705838680267334,
1868
+ "learning_rate": 9.701878218556129e-05,
1869
+ "loss": 0.036,
1870
+ "step": 3100
1871
+ },
1872
+ {
1873
+ "grad_norm": 0.5626199841499329,
1874
+ "learning_rate": 9.699059755431598e-05,
1875
+ "loss": 0.0331,
1876
+ "step": 3110
1877
+ },
1878
+ {
1879
+ "grad_norm": 0.46293774247169495,
1880
+ "learning_rate": 9.696228445257132e-05,
1881
+ "loss": 0.0277,
1882
+ "step": 3120
1883
+ },
1884
+ {
1885
+ "grad_norm": 0.42764750123023987,
1886
+ "learning_rate": 9.693384295773419e-05,
1887
+ "loss": 0.0327,
1888
+ "step": 3130
1889
+ },
1890
+ {
1891
+ "grad_norm": 0.4717363715171814,
1892
+ "learning_rate": 9.690527314756259e-05,
1893
+ "loss": 0.0339,
1894
+ "step": 3140
1895
+ },
1896
+ {
1897
+ "grad_norm": 0.458967387676239,
1898
+ "learning_rate": 9.687657510016527e-05,
1899
+ "loss": 0.0261,
1900
+ "step": 3150
1901
+ },
1902
+ {
1903
+ "grad_norm": 0.45871081948280334,
1904
+ "learning_rate": 9.684774889400161e-05,
1905
+ "loss": 0.0309,
1906
+ "step": 3160
1907
+ },
1908
+ {
1909
+ "grad_norm": 0.5132860541343689,
1910
+ "learning_rate": 9.681879460788135e-05,
1911
+ "loss": 0.0264,
1912
+ "step": 3170
1913
+ },
1914
+ {
1915
+ "grad_norm": 0.4729975461959839,
1916
+ "learning_rate": 9.67897123209644e-05,
1917
+ "loss": 0.0315,
1918
+ "step": 3180
1919
+ },
1920
+ {
1921
+ "grad_norm": 0.4921012818813324,
1922
+ "learning_rate": 9.676050211276062e-05,
1923
+ "loss": 0.035,
1924
+ "step": 3190
1925
+ },
1926
+ {
1927
+ "grad_norm": 0.4574073255062103,
1928
+ "learning_rate": 9.673116406312962e-05,
1929
+ "loss": 0.0284,
1930
+ "step": 3200
1931
+ },
1932
+ {
1933
+ "grad_norm": 0.48541590571403503,
1934
+ "learning_rate": 9.67016982522805e-05,
1935
+ "loss": 0.028,
1936
+ "step": 3210
1937
+ },
1938
+ {
1939
+ "grad_norm": 0.4924331307411194,
1940
+ "learning_rate": 9.667210476077164e-05,
1941
+ "loss": 0.028,
1942
+ "step": 3220
1943
+ },
1944
+ {
1945
+ "grad_norm": 0.5730510950088501,
1946
+ "learning_rate": 9.664238366951055e-05,
1947
+ "loss": 0.0288,
1948
+ "step": 3230
1949
+ },
1950
+ {
1951
+ "grad_norm": 0.5551027059555054,
1952
+ "learning_rate": 9.661253505975355e-05,
1953
+ "loss": 0.0269,
1954
+ "step": 3240
1955
+ },
1956
+ {
1957
+ "grad_norm": 0.4366356134414673,
1958
+ "learning_rate": 9.658255901310557e-05,
1959
+ "loss": 0.0301,
1960
+ "step": 3250
1961
+ },
1962
+ {
1963
+ "grad_norm": 0.5327138304710388,
1964
+ "learning_rate": 9.655245561152e-05,
1965
+ "loss": 0.0278,
1966
+ "step": 3260
1967
+ },
1968
+ {
1969
+ "grad_norm": 0.4516207277774811,
1970
+ "learning_rate": 9.65222249372984e-05,
1971
+ "loss": 0.0266,
1972
+ "step": 3270
1973
+ },
1974
+ {
1975
+ "grad_norm": 0.4709407687187195,
1976
+ "learning_rate": 9.649186707309026e-05,
1977
+ "loss": 0.0325,
1978
+ "step": 3280
1979
+ },
1980
+ {
1981
+ "grad_norm": 0.36673372983932495,
1982
+ "learning_rate": 9.646138210189283e-05,
1983
+ "loss": 0.0285,
1984
+ "step": 3290
1985
+ },
1986
+ {
1987
+ "grad_norm": 0.5308244824409485,
1988
+ "learning_rate": 9.643077010705087e-05,
1989
+ "loss": 0.0281,
1990
+ "step": 3300
1991
+ },
1992
+ {
1993
+ "grad_norm": 0.45568153262138367,
1994
+ "learning_rate": 9.640003117225637e-05,
1995
+ "loss": 0.0286,
1996
+ "step": 3310
1997
+ },
1998
+ {
1999
+ "grad_norm": 0.4082559049129486,
2000
+ "learning_rate": 9.636916538154846e-05,
2001
+ "loss": 0.0241,
2002
+ "step": 3320
2003
+ },
2004
+ {
2005
+ "grad_norm": 0.48012563586235046,
2006
+ "learning_rate": 9.633817281931296e-05,
2007
+ "loss": 0.0297,
2008
+ "step": 3330
2009
+ },
2010
+ {
2011
+ "grad_norm": 0.4177444875240326,
2012
+ "learning_rate": 9.630705357028242e-05,
2013
+ "loss": 0.032,
2014
+ "step": 3340
2015
+ },
2016
+ {
2017
+ "grad_norm": 0.48793429136276245,
2018
+ "learning_rate": 9.627580771953563e-05,
2019
+ "loss": 0.0285,
2020
+ "step": 3350
2021
+ },
2022
+ {
2023
+ "grad_norm": 0.4371464252471924,
2024
+ "learning_rate": 9.624443535249759e-05,
2025
+ "loss": 0.0275,
2026
+ "step": 3360
2027
+ },
2028
+ {
2029
+ "grad_norm": 0.4983312487602234,
2030
+ "learning_rate": 9.621293655493913e-05,
2031
+ "loss": 0.0254,
2032
+ "step": 3370
2033
+ },
2034
+ {
2035
+ "grad_norm": 0.5624396204948425,
2036
+ "learning_rate": 9.618131141297675e-05,
2037
+ "loss": 0.027,
2038
+ "step": 3380
2039
+ },
2040
+ {
2041
+ "grad_norm": 0.43570947647094727,
2042
+ "learning_rate": 9.614956001307242e-05,
2043
+ "loss": 0.0301,
2044
+ "step": 3390
2045
+ },
2046
+ {
2047
+ "grad_norm": 0.4448493719100952,
2048
+ "learning_rate": 9.611768244203321e-05,
2049
+ "loss": 0.0351,
2050
+ "step": 3400
2051
+ },
2052
+ {
2053
+ "grad_norm": 0.4213621914386749,
2054
+ "learning_rate": 9.60856787870112e-05,
2055
+ "loss": 0.0292,
2056
+ "step": 3410
2057
+ },
2058
+ {
2059
+ "grad_norm": 0.4154338836669922,
2060
+ "learning_rate": 9.605354913550318e-05,
2061
+ "loss": 0.0262,
2062
+ "step": 3420
2063
+ },
2064
+ {
2065
+ "grad_norm": 0.45102718472480774,
2066
+ "learning_rate": 9.602129357535037e-05,
2067
+ "loss": 0.0313,
2068
+ "step": 3430
2069
+ },
2070
+ {
2071
+ "grad_norm": 0.38145503401756287,
2072
+ "learning_rate": 9.598891219473825e-05,
2073
+ "loss": 0.027,
2074
+ "step": 3440
2075
+ },
2076
+ {
2077
+ "grad_norm": 0.41790488362312317,
2078
+ "learning_rate": 9.595640508219625e-05,
2079
+ "loss": 0.0291,
2080
+ "step": 3450
2081
+ },
2082
+ {
2083
+ "grad_norm": 0.4644753336906433,
2084
+ "learning_rate": 9.592377232659761e-05,
2085
+ "loss": 0.0249,
2086
+ "step": 3460
2087
+ },
2088
+ {
2089
+ "grad_norm": 0.4731713533401489,
2090
+ "learning_rate": 9.589101401715904e-05,
2091
+ "loss": 0.0263,
2092
+ "step": 3470
2093
+ },
2094
+ {
2095
+ "grad_norm": 0.42398542165756226,
2096
+ "learning_rate": 9.585813024344045e-05,
2097
+ "loss": 0.026,
2098
+ "step": 3480
2099
+ },
2100
+ {
2101
+ "grad_norm": 0.5419644117355347,
2102
+ "learning_rate": 9.58251210953449e-05,
2103
+ "loss": 0.0296,
2104
+ "step": 3490
2105
+ },
2106
+ {
2107
+ "grad_norm": 0.463670939207077,
2108
+ "learning_rate": 9.579198666311809e-05,
2109
+ "loss": 0.0238,
2110
+ "step": 3500
2111
+ },
2112
+ {
2113
+ "grad_norm": 0.39643239974975586,
2114
+ "learning_rate": 9.575872703734832e-05,
2115
+ "loss": 0.0292,
2116
+ "step": 3510
2117
+ },
2118
+ {
2119
+ "grad_norm": 0.3542700409889221,
2120
+ "learning_rate": 9.572534230896611e-05,
2121
+ "loss": 0.0231,
2122
+ "step": 3520
2123
+ },
2124
+ {
2125
+ "grad_norm": 0.43060752749443054,
2126
+ "learning_rate": 9.569183256924403e-05,
2127
+ "loss": 0.025,
2128
+ "step": 3530
2129
+ },
2130
+ {
2131
+ "grad_norm": 0.40233463048934937,
2132
+ "learning_rate": 9.565819790979646e-05,
2133
+ "loss": 0.0422,
2134
+ "step": 3540
2135
+ },
2136
+ {
2137
+ "grad_norm": 0.4497774839401245,
2138
+ "learning_rate": 9.562443842257925e-05,
2139
+ "loss": 0.029,
2140
+ "step": 3550
2141
+ },
2142
+ {
2143
+ "grad_norm": 0.5018470287322998,
2144
+ "learning_rate": 9.559055419988956e-05,
2145
+ "loss": 0.0283,
2146
+ "step": 3560
2147
+ },
2148
+ {
2149
+ "grad_norm": 0.47868454456329346,
2150
+ "learning_rate": 9.555654533436557e-05,
2151
+ "loss": 0.0349,
2152
+ "step": 3570
2153
+ },
2154
+ {
2155
+ "grad_norm": 0.4413691759109497,
2156
+ "learning_rate": 9.552241191898621e-05,
2157
+ "loss": 0.0238,
2158
+ "step": 3580
2159
+ },
2160
+ {
2161
+ "grad_norm": 0.40998080372810364,
2162
+ "learning_rate": 9.548815404707092e-05,
2163
+ "loss": 0.03,
2164
+ "step": 3590
2165
+ },
2166
+ {
2167
+ "grad_norm": 0.43824273347854614,
2168
+ "learning_rate": 9.545377181227942e-05,
2169
+ "loss": 0.0284,
2170
+ "step": 3600
2171
+ },
2172
+ {
2173
+ "grad_norm": 0.4570449888706207,
2174
+ "learning_rate": 9.541926530861145e-05,
2175
+ "loss": 0.0266,
2176
+ "step": 3610
2177
+ },
2178
+ {
2179
+ "grad_norm": 0.44766074419021606,
2180
+ "learning_rate": 9.538463463040645e-05,
2181
+ "loss": 0.0278,
2182
+ "step": 3620
2183
+ },
2184
+ {
2185
+ "grad_norm": 0.481611967086792,
2186
+ "learning_rate": 9.534987987234337e-05,
2187
+ "loss": 0.0277,
2188
+ "step": 3630
2189
+ },
2190
+ {
2191
+ "grad_norm": 0.4858357608318329,
2192
+ "learning_rate": 9.53150011294404e-05,
2193
+ "loss": 0.0265,
2194
+ "step": 3640
2195
+ },
2196
+ {
2197
+ "grad_norm": 0.40574368834495544,
2198
+ "learning_rate": 9.527999849705471e-05,
2199
+ "loss": 0.0297,
2200
+ "step": 3650
2201
+ },
2202
+ {
2203
+ "grad_norm": 0.4581122100353241,
2204
+ "learning_rate": 9.524487207088213e-05,
2205
+ "loss": 0.0224,
2206
+ "step": 3660
2207
+ },
2208
+ {
2209
+ "grad_norm": 0.4100882411003113,
2210
+ "learning_rate": 9.520962194695698e-05,
2211
+ "loss": 0.0239,
2212
+ "step": 3670
2213
+ },
2214
+ {
2215
+ "grad_norm": 0.40333643555641174,
2216
+ "learning_rate": 9.517424822165175e-05,
2217
+ "loss": 0.0238,
2218
+ "step": 3680
2219
+ },
2220
+ {
2221
+ "grad_norm": 0.5596145987510681,
2222
+ "learning_rate": 9.513875099167685e-05,
2223
+ "loss": 0.0245,
2224
+ "step": 3690
2225
+ },
2226
+ {
2227
+ "grad_norm": 0.5230712890625,
2228
+ "learning_rate": 9.510313035408035e-05,
2229
+ "loss": 0.0262,
2230
+ "step": 3700
2231
+ },
2232
+ {
2233
+ "grad_norm": 0.39155617356300354,
2234
+ "learning_rate": 9.506738640624775e-05,
2235
+ "loss": 0.0264,
2236
+ "step": 3710
2237
+ },
2238
+ {
2239
+ "grad_norm": 0.4129464328289032,
2240
+ "learning_rate": 9.50315192459016e-05,
2241
+ "loss": 0.0208,
2242
+ "step": 3720
2243
+ },
2244
+ {
2245
+ "grad_norm": 0.5159543752670288,
2246
+ "learning_rate": 9.499552897110136e-05,
2247
+ "loss": 0.0239,
2248
+ "step": 3730
2249
+ },
2250
+ {
2251
+ "grad_norm": 0.5178094506263733,
2252
+ "learning_rate": 9.495941568024304e-05,
2253
+ "loss": 0.0253,
2254
+ "step": 3740
2255
+ },
2256
+ {
2257
+ "grad_norm": 0.43580612540245056,
2258
+ "learning_rate": 9.492317947205904e-05,
2259
+ "loss": 0.0268,
2260
+ "step": 3750
2261
+ },
2262
+ {
2263
+ "grad_norm": 0.4596274495124817,
2264
+ "learning_rate": 9.488682044561775e-05,
2265
+ "loss": 0.0256,
2266
+ "step": 3760
2267
+ },
2268
+ {
2269
+ "grad_norm": 0.41573286056518555,
2270
+ "learning_rate": 9.485033870032335e-05,
2271
+ "loss": 0.0243,
2272
+ "step": 3770
2273
+ },
2274
+ {
2275
+ "grad_norm": 0.47876912355422974,
2276
+ "learning_rate": 9.481373433591556e-05,
2277
+ "loss": 0.0215,
2278
+ "step": 3780
2279
+ },
2280
+ {
2281
+ "grad_norm": 0.4741547703742981,
2282
+ "learning_rate": 9.47770074524693e-05,
2283
+ "loss": 0.027,
2284
+ "step": 3790
2285
+ },
2286
+ {
2287
+ "grad_norm": 0.4306631088256836,
2288
+ "learning_rate": 9.474015815039446e-05,
2289
+ "loss": 0.0277,
2290
+ "step": 3800
2291
+ },
2292
+ {
2293
+ "grad_norm": 0.46127429604530334,
2294
+ "learning_rate": 9.470318653043565e-05,
2295
+ "loss": 0.0273,
2296
+ "step": 3810
2297
+ },
2298
+ {
2299
+ "grad_norm": 0.5021414160728455,
2300
+ "learning_rate": 9.466609269367185e-05,
2301
+ "loss": 0.0263,
2302
+ "step": 3820
2303
+ },
2304
+ {
2305
+ "grad_norm": 0.5333779454231262,
2306
+ "learning_rate": 9.46288767415162e-05,
2307
+ "loss": 0.0234,
2308
+ "step": 3830
2309
+ },
2310
+ {
2311
+ "grad_norm": 0.4366990625858307,
2312
+ "learning_rate": 9.459153877571567e-05,
2313
+ "loss": 0.0225,
2314
+ "step": 3840
2315
+ },
2316
+ {
2317
+ "grad_norm": 0.4819251298904419,
2318
+ "learning_rate": 9.455407889835087e-05,
2319
+ "loss": 0.0238,
2320
+ "step": 3850
2321
+ },
2322
+ {
2323
+ "grad_norm": 0.3999616503715515,
2324
+ "learning_rate": 9.451649721183564e-05,
2325
+ "loss": 0.0234,
2326
+ "step": 3860
2327
+ },
2328
+ {
2329
+ "grad_norm": 0.37807697057724,
2330
+ "learning_rate": 9.447879381891692e-05,
2331
+ "loss": 0.0258,
2332
+ "step": 3870
2333
+ },
2334
+ {
2335
+ "grad_norm": 0.5266739130020142,
2336
+ "learning_rate": 9.444096882267428e-05,
2337
+ "loss": 0.0329,
2338
+ "step": 3880
2339
+ },
2340
+ {
2341
+ "grad_norm": 0.3961910903453827,
2342
+ "learning_rate": 9.440302232651988e-05,
2343
+ "loss": 0.0226,
2344
+ "step": 3890
2345
+ },
2346
+ {
2347
+ "grad_norm": 0.3786242604255676,
2348
+ "learning_rate": 9.436495443419795e-05,
2349
+ "loss": 0.024,
2350
+ "step": 3900
2351
+ },
2352
+ {
2353
+ "grad_norm": 0.4175941050052643,
2354
+ "learning_rate": 9.432676524978466e-05,
2355
+ "loss": 0.0219,
2356
+ "step": 3910
2357
+ },
2358
+ {
2359
+ "grad_norm": 0.44096827507019043,
2360
+ "learning_rate": 9.42884548776878e-05,
2361
+ "loss": 0.0253,
2362
+ "step": 3920
2363
+ },
2364
+ {
2365
+ "grad_norm": 0.41201087832450867,
2366
+ "learning_rate": 9.425002342264646e-05,
2367
+ "loss": 0.0223,
2368
+ "step": 3930
2369
+ },
2370
+ {
2371
+ "grad_norm": 0.5009353160858154,
2372
+ "learning_rate": 9.421147098973077e-05,
2373
+ "loss": 0.0266,
2374
+ "step": 3940
2375
+ },
2376
+ {
2377
+ "grad_norm": 0.5505723357200623,
2378
+ "learning_rate": 9.41727976843416e-05,
2379
+ "loss": 0.0258,
2380
+ "step": 3950
2381
+ },
2382
+ {
2383
+ "grad_norm": 0.45981982350349426,
2384
+ "learning_rate": 9.413400361221029e-05,
2385
+ "loss": 0.0279,
2386
+ "step": 3960
2387
+ },
2388
+ {
2389
+ "grad_norm": 0.4804719388484955,
2390
+ "learning_rate": 9.409508887939835e-05,
2391
+ "loss": 0.022,
2392
+ "step": 3970
2393
+ },
2394
+ {
2395
+ "grad_norm": 0.4238436222076416,
2396
+ "learning_rate": 9.40560535922972e-05,
2397
+ "loss": 0.0212,
2398
+ "step": 3980
2399
+ },
2400
+ {
2401
+ "grad_norm": 0.403974324464798,
2402
+ "learning_rate": 9.40168978576278e-05,
2403
+ "loss": 0.0189,
2404
+ "step": 3990
2405
+ },
2406
+ {
2407
+ "grad_norm": 0.48837044835090637,
2408
+ "learning_rate": 9.397762178244043e-05,
2409
+ "loss": 0.0244,
2410
+ "step": 4000
2411
+ },
2412
+ {
2413
+ "grad_norm": 0.48128196597099304,
2414
+ "learning_rate": 9.393822547411439e-05,
2415
+ "loss": 0.0217,
2416
+ "step": 4010
2417
+ },
2418
+ {
2419
+ "grad_norm": 0.3272818624973297,
2420
+ "learning_rate": 9.389870904035769e-05,
2421
+ "loss": 0.0242,
2422
+ "step": 4020
2423
+ },
2424
+ {
2425
+ "grad_norm": 0.36953118443489075,
2426
+ "learning_rate": 9.385907258920672e-05,
2427
+ "loss": 0.0246,
2428
+ "step": 4030
2429
+ },
2430
+ {
2431
+ "grad_norm": 0.41161492466926575,
2432
+ "learning_rate": 9.381931622902607e-05,
2433
+ "loss": 0.021,
2434
+ "step": 4040
2435
+ },
2436
+ {
2437
+ "grad_norm": 0.4544064998626709,
2438
+ "learning_rate": 9.377944006850807e-05,
2439
+ "loss": 0.0193,
2440
+ "step": 4050
2441
+ },
2442
+ {
2443
+ "grad_norm": 0.47396498918533325,
2444
+ "learning_rate": 9.373944421667265e-05,
2445
+ "loss": 0.0213,
2446
+ "step": 4060
2447
+ },
2448
+ {
2449
+ "grad_norm": 0.4621795117855072,
2450
+ "learning_rate": 9.369932878286691e-05,
2451
+ "loss": 0.0266,
2452
+ "step": 4070
2453
+ },
2454
+ {
2455
+ "grad_norm": 0.5184421539306641,
2456
+ "learning_rate": 9.365909387676494e-05,
2457
+ "loss": 0.0196,
2458
+ "step": 4080
2459
+ },
2460
+ {
2461
+ "grad_norm": 0.4004800319671631,
2462
+ "learning_rate": 9.361873960836744e-05,
2463
+ "loss": 0.0263,
2464
+ "step": 4090
2465
+ },
2466
+ {
2467
+ "grad_norm": 0.3737598657608032,
2468
+ "learning_rate": 9.357826608800142e-05,
2469
+ "loss": 0.0196,
2470
+ "step": 4100
2471
+ },
2472
+ {
2473
+ "grad_norm": 0.4000731110572815,
2474
+ "learning_rate": 9.353767342631994e-05,
2475
+ "loss": 0.0203,
2476
+ "step": 4110
2477
+ },
2478
+ {
2479
+ "grad_norm": 0.3826330006122589,
2480
+ "learning_rate": 9.34969617343018e-05,
2481
+ "loss": 0.0219,
2482
+ "step": 4120
2483
+ },
2484
+ {
2485
+ "grad_norm": 0.5988262891769409,
2486
+ "learning_rate": 9.345613112325122e-05,
2487
+ "loss": 0.0204,
2488
+ "step": 4130
2489
+ },
2490
+ {
2491
+ "grad_norm": 0.4280189275741577,
2492
+ "learning_rate": 9.34151817047975e-05,
2493
+ "loss": 0.0224,
2494
+ "step": 4140
2495
+ },
2496
+ {
2497
+ "grad_norm": 0.3716961145401001,
2498
+ "learning_rate": 9.33741135908948e-05,
2499
+ "loss": 0.0262,
2500
+ "step": 4150
2501
+ },
2502
+ {
2503
+ "grad_norm": 0.4295980930328369,
2504
+ "learning_rate": 9.33329268938218e-05,
2505
+ "loss": 0.0207,
2506
+ "step": 4160
2507
+ },
2508
+ {
2509
+ "grad_norm": 0.425942063331604,
2510
+ "learning_rate": 9.329162172618132e-05,
2511
+ "loss": 0.0238,
2512
+ "step": 4170
2513
+ },
2514
+ {
2515
+ "grad_norm": 0.416522741317749,
2516
+ "learning_rate": 9.325019820090013e-05,
2517
+ "loss": 0.0226,
2518
+ "step": 4180
2519
+ },
2520
+ {
2521
+ "grad_norm": 0.5610533952713013,
2522
+ "learning_rate": 9.320865643122855e-05,
2523
+ "loss": 0.0208,
2524
+ "step": 4190
2525
+ },
2526
+ {
2527
+ "grad_norm": 0.379802942276001,
2528
+ "learning_rate": 9.316699653074023e-05,
2529
+ "loss": 0.022,
2530
+ "step": 4200
2531
+ },
2532
+ {
2533
+ "grad_norm": 0.4576219618320465,
2534
+ "learning_rate": 9.312521861333172e-05,
2535
+ "loss": 0.0166,
2536
+ "step": 4210
2537
+ },
2538
+ {
2539
+ "grad_norm": 0.45310190320014954,
2540
+ "learning_rate": 9.308332279322224e-05,
2541
+ "loss": 0.0242,
2542
+ "step": 4220
2543
+ },
2544
+ {
2545
+ "grad_norm": 0.4080248177051544,
2546
+ "learning_rate": 9.304130918495338e-05,
2547
+ "loss": 0.0224,
2548
+ "step": 4230
2549
+ },
2550
+ {
2551
+ "grad_norm": 0.33399489521980286,
2552
+ "learning_rate": 9.299917790338874e-05,
2553
+ "loss": 0.0187,
2554
+ "step": 4240
2555
+ },
2556
+ {
2557
+ "grad_norm": 0.356057733297348,
2558
+ "learning_rate": 9.295692906371363e-05,
2559
+ "loss": 0.0173,
2560
+ "step": 4250
2561
+ },
2562
+ {
2563
+ "grad_norm": 0.42619287967681885,
2564
+ "learning_rate": 9.291456278143476e-05,
2565
+ "loss": 0.0264,
2566
+ "step": 4260
2567
+ },
2568
+ {
2569
+ "grad_norm": 0.3479536175727844,
2570
+ "learning_rate": 9.287207917237994e-05,
2571
+ "loss": 0.0213,
2572
+ "step": 4270
2573
+ },
2574
+ {
2575
+ "grad_norm": 0.3362795114517212,
2576
+ "learning_rate": 9.282947835269773e-05,
2577
+ "loss": 0.0206,
2578
+ "step": 4280
2579
+ },
2580
+ {
2581
+ "grad_norm": 0.43236204981803894,
2582
+ "learning_rate": 9.278676043885715e-05,
2583
+ "loss": 0.0191,
2584
+ "step": 4290
2585
+ },
2586
+ {
2587
+ "grad_norm": 0.32585880160331726,
2588
+ "learning_rate": 9.274392554764733e-05,
2589
+ "loss": 0.0194,
2590
+ "step": 4300
2591
+ },
2592
+ {
2593
+ "grad_norm": 0.4723697900772095,
2594
+ "learning_rate": 9.270097379617723e-05,
2595
+ "loss": 0.016,
2596
+ "step": 4310
2597
+ },
2598
+ {
2599
+ "grad_norm": 0.42713454365730286,
2600
+ "learning_rate": 9.26579053018753e-05,
2601
+ "loss": 0.0154,
2602
+ "step": 4320
2603
+ },
2604
+ {
2605
+ "grad_norm": 0.33830246329307556,
2606
+ "learning_rate": 9.261472018248918e-05,
2607
+ "loss": 0.0146,
2608
+ "step": 4330
2609
+ },
2610
+ {
2611
+ "grad_norm": 0.4066753387451172,
2612
+ "learning_rate": 9.25714185560853e-05,
2613
+ "loss": 0.0259,
2614
+ "step": 4340
2615
+ },
2616
+ {
2617
+ "grad_norm": 0.448772668838501,
2618
+ "learning_rate": 9.252800054104868e-05,
2619
+ "loss": 0.0187,
2620
+ "step": 4350
2621
+ },
2622
+ {
2623
+ "grad_norm": 0.4219300448894501,
2624
+ "learning_rate": 9.248446625608252e-05,
2625
+ "loss": 0.0208,
2626
+ "step": 4360
2627
+ },
2628
+ {
2629
+ "grad_norm": 0.39920371770858765,
2630
+ "learning_rate": 9.244081582020789e-05,
2631
+ "loss": 0.0175,
2632
+ "step": 4370
2633
+ },
2634
+ {
2635
+ "grad_norm": 0.42131638526916504,
2636
+ "learning_rate": 9.239704935276339e-05,
2637
+ "loss": 0.0182,
2638
+ "step": 4380
2639
+ },
2640
+ {
2641
+ "grad_norm": 0.45648935437202454,
2642
+ "learning_rate": 9.235316697340489e-05,
2643
+ "loss": 0.0158,
2644
+ "step": 4390
2645
+ },
2646
+ {
2647
+ "grad_norm": 0.42188429832458496,
2648
+ "learning_rate": 9.230916880210512e-05,
2649
+ "loss": 0.0183,
2650
+ "step": 4400
2651
+ },
2652
+ {
2653
+ "grad_norm": 0.36581969261169434,
2654
+ "learning_rate": 9.226505495915342e-05,
2655
+ "loss": 0.0147,
2656
+ "step": 4410
2657
+ },
2658
+ {
2659
+ "grad_norm": 0.42502549290657043,
2660
+ "learning_rate": 9.222082556515536e-05,
2661
+ "loss": 0.0198,
2662
+ "step": 4420
2663
+ },
2664
+ {
2665
+ "grad_norm": 0.35229989886283875,
2666
+ "learning_rate": 9.217648074103242e-05,
2667
+ "loss": 0.0153,
2668
+ "step": 4430
2669
+ },
2670
+ {
2671
+ "grad_norm": 0.4085313379764557,
2672
+ "learning_rate": 9.213202060802161e-05,
2673
+ "loss": 0.0192,
2674
+ "step": 4440
2675
+ },
2676
+ {
2677
+ "grad_norm": 0.4650028645992279,
2678
+ "learning_rate": 9.208744528767528e-05,
2679
+ "loss": 0.0173,
2680
+ "step": 4450
2681
+ },
2682
+ {
2683
+ "grad_norm": 0.4048616886138916,
2684
+ "learning_rate": 9.204275490186064e-05,
2685
+ "loss": 0.0204,
2686
+ "step": 4460
2687
+ },
2688
+ {
2689
+ "grad_norm": 0.4178619980812073,
2690
+ "learning_rate": 9.199794957275949e-05,
2691
+ "loss": 0.0204,
2692
+ "step": 4470
2693
+ },
2694
+ {
2695
+ "grad_norm": 0.46256691217422485,
2696
+ "learning_rate": 9.19530294228679e-05,
2697
+ "loss": 0.0177,
2698
+ "step": 4480
2699
+ },
2700
+ {
2701
+ "grad_norm": 0.35352519154548645,
2702
+ "learning_rate": 9.190799457499583e-05,
2703
+ "loss": 0.028,
2704
+ "step": 4490
2705
+ },
2706
+ {
2707
+ "grad_norm": 0.4470050632953644,
2708
+ "learning_rate": 9.186284515226686e-05,
2709
+ "loss": 0.0194,
2710
+ "step": 4500
2711
+ },
2712
+ {
2713
+ "grad_norm": 0.3508913815021515,
2714
+ "learning_rate": 9.181758127811777e-05,
2715
+ "loss": 0.0241,
2716
+ "step": 4510
2717
+ },
2718
+ {
2719
+ "grad_norm": 0.411702424287796,
2720
+ "learning_rate": 9.177220307629825e-05,
2721
+ "loss": 0.0204,
2722
+ "step": 4520
2723
+ },
2724
+ {
2725
+ "grad_norm": 0.4468960762023926,
2726
+ "learning_rate": 9.172671067087059e-05,
2727
+ "loss": 0.0194,
2728
+ "step": 4530
2729
+ },
2730
+ {
2731
+ "grad_norm": 0.4807928204536438,
2732
+ "learning_rate": 9.16811041862093e-05,
2733
+ "loss": 0.0256,
2734
+ "step": 4540
2735
+ },
2736
+ {
2737
+ "grad_norm": 0.39205247163772583,
2738
+ "learning_rate": 9.163538374700076e-05,
2739
+ "loss": 0.0185,
2740
+ "step": 4550
2741
+ },
2742
+ {
2743
+ "grad_norm": 0.44329723715782166,
2744
+ "learning_rate": 9.158954947824287e-05,
2745
+ "loss": 0.0178,
2746
+ "step": 4560
2747
+ },
2748
+ {
2749
+ "grad_norm": 0.47283023595809937,
2750
+ "learning_rate": 9.154360150524482e-05,
2751
+ "loss": 0.0174,
2752
+ "step": 4570
2753
+ },
2754
+ {
2755
+ "grad_norm": 0.38849857449531555,
2756
+ "learning_rate": 9.14975399536266e-05,
2757
+ "loss": 0.0143,
2758
+ "step": 4580
2759
+ },
2760
+ {
2761
+ "grad_norm": 0.3656264543533325,
2762
+ "learning_rate": 9.14513649493187e-05,
2763
+ "loss": 0.0212,
2764
+ "step": 4590
2765
+ },
2766
+ {
2767
+ "grad_norm": 0.4674840271472931,
2768
+ "learning_rate": 9.140507661856187e-05,
2769
+ "loss": 0.0153,
2770
+ "step": 4600
2771
+ },
2772
+ {
2773
+ "grad_norm": 0.4313472509384155,
2774
+ "learning_rate": 9.135867508790661e-05,
2775
+ "loss": 0.0214,
2776
+ "step": 4610
2777
+ },
2778
+ {
2779
+ "grad_norm": 0.3471619486808777,
2780
+ "learning_rate": 9.131216048421291e-05,
2781
+ "loss": 0.0165,
2782
+ "step": 4620
2783
+ },
2784
+ {
2785
+ "grad_norm": 0.4542539715766907,
2786
+ "learning_rate": 9.126553293464998e-05,
2787
+ "loss": 0.0189,
2788
+ "step": 4630
2789
+ },
2790
+ {
2791
+ "grad_norm": 0.47608688473701477,
2792
+ "learning_rate": 9.121879256669572e-05,
2793
+ "loss": 0.017,
2794
+ "step": 4640
2795
+ },
2796
+ {
2797
+ "grad_norm": 0.3959465026855469,
2798
+ "learning_rate": 9.117193950813652e-05,
2799
+ "loss": 0.0164,
2800
+ "step": 4650
2801
+ },
2802
+ {
2803
+ "grad_norm": 0.408431738615036,
2804
+ "learning_rate": 9.112497388706685e-05,
2805
+ "loss": 0.0255,
2806
+ "step": 4660
2807
+ },
2808
+ {
2809
+ "grad_norm": 0.4116475582122803,
2810
+ "learning_rate": 9.10778958318889e-05,
2811
+ "loss": 0.0174,
2812
+ "step": 4670
2813
+ },
2814
+ {
2815
+ "grad_norm": 0.3917919993400574,
2816
+ "learning_rate": 9.103070547131232e-05,
2817
+ "loss": 0.0199,
2818
+ "step": 4680
2819
+ },
2820
+ {
2821
+ "grad_norm": 0.3482106029987335,
2822
+ "learning_rate": 9.098340293435375e-05,
2823
+ "loss": 0.0179,
2824
+ "step": 4690
2825
+ },
2826
+ {
2827
+ "grad_norm": 0.34646838903427124,
2828
+ "learning_rate": 9.093598835033649e-05,
2829
+ "loss": 0.0174,
2830
+ "step": 4700
2831
+ },
2832
+ {
2833
+ "grad_norm": 0.39419376850128174,
2834
+ "learning_rate": 9.088846184889021e-05,
2835
+ "loss": 0.0191,
2836
+ "step": 4710
2837
+ },
2838
+ {
2839
+ "grad_norm": 0.4543268084526062,
2840
+ "learning_rate": 9.084082355995057e-05,
2841
+ "loss": 0.0213,
2842
+ "step": 4720
2843
+ },
2844
+ {
2845
+ "grad_norm": 0.4212946891784668,
2846
+ "learning_rate": 9.079307361375882e-05,
2847
+ "loss": 0.0181,
2848
+ "step": 4730
2849
+ },
2850
+ {
2851
+ "grad_norm": 0.3014923334121704,
2852
+ "learning_rate": 9.074521214086149e-05,
2853
+ "loss": 0.019,
2854
+ "step": 4740
2855
+ },
2856
+ {
2857
+ "grad_norm": 0.36527299880981445,
2858
+ "learning_rate": 9.069723927211001e-05,
2859
+ "loss": 0.0179,
2860
+ "step": 4750
2861
+ },
2862
+ {
2863
+ "grad_norm": 0.3752840757369995,
2864
+ "learning_rate": 9.064915513866037e-05,
2865
+ "loss": 0.0183,
2866
+ "step": 4760
2867
+ },
2868
+ {
2869
+ "grad_norm": 0.42201003432273865,
2870
+ "learning_rate": 9.060095987197279e-05,
2871
+ "loss": 0.0162,
2872
+ "step": 4770
2873
+ },
2874
+ {
2875
+ "grad_norm": 0.3307137191295624,
2876
+ "learning_rate": 9.055265360381126e-05,
2877
+ "loss": 0.0206,
2878
+ "step": 4780
2879
+ },
2880
+ {
2881
+ "grad_norm": 0.33322593569755554,
2882
+ "learning_rate": 9.050423646624326e-05,
2883
+ "loss": 0.016,
2884
+ "step": 4790
2885
+ },
2886
+ {
2887
+ "grad_norm": 0.35324618220329285,
2888
+ "learning_rate": 9.045570859163943e-05,
2889
+ "loss": 0.0194,
2890
+ "step": 4800
2891
+ },
2892
+ {
2893
+ "grad_norm": 0.427572637796402,
2894
+ "learning_rate": 9.04070701126731e-05,
2895
+ "loss": 0.015,
2896
+ "step": 4810
2897
+ },
2898
+ {
2899
+ "grad_norm": 0.3561609983444214,
2900
+ "learning_rate": 9.035832116232001e-05,
2901
+ "loss": 0.0145,
2902
+ "step": 4820
2903
+ },
2904
+ {
2905
+ "grad_norm": 0.37716561555862427,
2906
+ "learning_rate": 9.030946187385796e-05,
2907
+ "loss": 0.016,
2908
+ "step": 4830
2909
+ },
2910
+ {
2911
+ "grad_norm": 0.39859738945961,
2912
+ "learning_rate": 9.026049238086635e-05,
2913
+ "loss": 0.0178,
2914
+ "step": 4840
2915
+ },
2916
+ {
2917
+ "grad_norm": 0.4500395655632019,
2918
+ "learning_rate": 9.021141281722591e-05,
2919
+ "loss": 0.0202,
2920
+ "step": 4850
2921
+ },
2922
+ {
2923
+ "grad_norm": 0.34830138087272644,
2924
+ "learning_rate": 9.01622233171183e-05,
2925
+ "loss": 0.0169,
2926
+ "step": 4860
2927
+ },
2928
+ {
2929
+ "grad_norm": 0.3729107677936554,
2930
+ "learning_rate": 9.011292401502574e-05,
2931
+ "loss": 0.0212,
2932
+ "step": 4870
2933
+ },
2934
+ {
2935
+ "grad_norm": 0.3912448585033417,
2936
+ "learning_rate": 9.006351504573063e-05,
2937
+ "loss": 0.0146,
2938
+ "step": 4880
2939
+ },
2940
+ {
2941
+ "grad_norm": 0.4137353003025055,
2942
+ "learning_rate": 9.001399654431519e-05,
2943
+ "loss": 0.0171,
2944
+ "step": 4890
2945
+ },
2946
+ {
2947
+ "grad_norm": 0.4444160759449005,
2948
+ "learning_rate": 8.996436864616116e-05,
2949
+ "loss": 0.0162,
2950
+ "step": 4900
2951
+ },
2952
+ {
2953
+ "grad_norm": 0.3148241639137268,
2954
+ "learning_rate": 8.991463148694925e-05,
2955
+ "loss": 0.0191,
2956
+ "step": 4910
2957
+ },
2958
+ {
2959
+ "grad_norm": 0.4391416907310486,
2960
+ "learning_rate": 8.986478520265902e-05,
2961
+ "loss": 0.0187,
2962
+ "step": 4920
2963
+ },
2964
+ {
2965
+ "grad_norm": 0.4296688139438629,
2966
+ "learning_rate": 8.981482992956827e-05,
2967
+ "loss": 0.0143,
2968
+ "step": 4930
2969
+ },
2970
+ {
2971
+ "grad_norm": 0.29728299379348755,
2972
+ "learning_rate": 8.976476580425282e-05,
2973
+ "loss": 0.0148,
2974
+ "step": 4940
2975
+ },
2976
+ {
2977
+ "grad_norm": 0.4356195032596588,
2978
+ "learning_rate": 8.971459296358606e-05,
2979
+ "loss": 0.0287,
2980
+ "step": 4950
2981
+ },
2982
+ {
2983
+ "grad_norm": 0.4179481565952301,
2984
+ "learning_rate": 8.966431154473864e-05,
2985
+ "loss": 0.0157,
2986
+ "step": 4960
2987
+ },
2988
+ {
2989
+ "grad_norm": 0.3610477149486542,
2990
+ "learning_rate": 8.961392168517803e-05,
2991
+ "loss": 0.0159,
2992
+ "step": 4970
2993
+ },
2994
+ {
2995
+ "grad_norm": 0.34345686435699463,
2996
+ "learning_rate": 8.956342352266821e-05,
2997
+ "loss": 0.016,
2998
+ "step": 4980
2999
+ },
3000
+ {
3001
+ "grad_norm": 0.3698787987232208,
3002
+ "learning_rate": 8.95128171952692e-05,
3003
+ "loss": 0.0214,
3004
+ "step": 4990
3005
+ },
3006
+ {
3007
+ "grad_norm": 0.327648788690567,
3008
+ "learning_rate": 8.946210284133676e-05,
3009
+ "loss": 0.0173,
3010
+ "step": 5000
3011
+ }
3012
+ ],
3013
+ "logging_steps": 10,
3014
+ "max_steps": 20000,
3015
+ "num_input_tokens_seen": 0,
3016
+ "num_train_epochs": 9223372036854775807,
3017
+ "save_steps": 5000,
3018
+ "stateful_callbacks": {
3019
+ "TrainerControl": {
3020
+ "args": {
3021
+ "should_epoch_stop": false,
3022
+ "should_evaluate": false,
3023
+ "should_log": false,
3024
+ "should_save": true,
3025
+ "should_training_stop": false
3026
+ },
3027
+ "attributes": {}
3028
+ }
3029
+ },
3030
+ "total_flos": 0.0,
3031
+ "train_batch_size": 24,
3032
+ "trial_name": null,
3033
+ "trial_params": null
3034
+ }
checkpoint-5000/wandb_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"project": "finetune-gr00t-n1d6", "run_id": "locomanipulation_tutorial"}
checkpoint-5000/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_horizon": 50,
3
+ "add_pos_embed": true,
4
+ "apply_sincos_state_encoding": true,
5
+ "architectures": [
6
+ "Gr00tN1d6"
7
+ ],
8
+ "attn_dropout": 0.2,
9
+ "attn_implementation": null,
10
+ "backbone_embedding_dim": 2048,
11
+ "backbone_model_type": "eagle",
12
+ "backbone_trainable_params_fp32": true,
13
+ "collator_overwrite_image_inputs": false,
14
+ "color_jitter_params": {
15
+ "brightness": 0.1,
16
+ "contrast": 0.1,
17
+ "hue": 0.1,
18
+ "saturation": 0.1
19
+ },
20
+ "crop_fraction": 0.95,
21
+ "diffusion_model_cfg": {
22
+ "attention_head_dim": 48,
23
+ "dropout": 0.2,
24
+ "final_dropout": true,
25
+ "interleave_self_attention": true,
26
+ "norm_type": "ada_norm",
27
+ "num_attention_heads": 32,
28
+ "num_layers": 32,
29
+ "output_dim": 1024,
30
+ "positional_embeddings": null
31
+ },
32
+ "eagle_collator": true,
33
+ "formalize_language": true,
34
+ "gemma_collator": false,
35
+ "hidden_size": 1024,
36
+ "image_crop_size": null,
37
+ "image_target_size": null,
38
+ "input_embedding_dim": 1536,
39
+ "load_bf16": true,
40
+ "max_action_dim": 128,
41
+ "max_num_embodiments": 32,
42
+ "max_seq_len": 1024,
43
+ "max_state_dim": 128,
44
+ "model_dtype": "bfloat16",
45
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
46
+ "model_type": "Gr00tN1d6",
47
+ "noise_beta_alpha": 1.5,
48
+ "noise_beta_beta": 1.0,
49
+ "noise_s": 0.999,
50
+ "num_inference_timesteps": 4,
51
+ "num_timestep_buckets": 1000,
52
+ "random_rotation_angle": null,
53
+ "reproject_vision": false,
54
+ "select_layer": 16,
55
+ "shortest_image_edge": 256,
56
+ "state_dropout_prob": 0.0,
57
+ "torch_dtype": "bfloat16",
58
+ "transformers_version": "4.51.3",
59
+ "tune_diffusion_model": true,
60
+ "tune_llm": false,
61
+ "tune_projector": true,
62
+ "tune_top_llm_layers": 4,
63
+ "tune_visual": true,
64
+ "tune_vlln": true,
65
+ "use_albumentations_transforms": true,
66
+ "use_alternate_vl_dit": true,
67
+ "use_flash_attention": true,
68
+ "use_relative_action": true,
69
+ "use_vlln": true
70
+ }
experiment_cfg/conf.yaml ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ load_config_path: null
2
+ model:
3
+ model_type: Gr00tN1d6
4
+ model_dtype: bfloat16
5
+ model_name: nvidia/Eagle-Block2A-2B-v2
6
+ backbone_model_type: eagle
7
+ model_revision: null
8
+ tune_top_llm_layers: 4
9
+ backbone_embedding_dim: 2048
10
+ tune_llm: false
11
+ tune_visual: true
12
+ select_layer: 16
13
+ reproject_vision: false
14
+ use_flash_attention: true
15
+ load_bf16: false
16
+ collator_overwrite_image_inputs: false
17
+ eagle_collator: true
18
+ backbone_trainable_params_fp32: true
19
+ image_crop_size: null
20
+ image_target_size: null
21
+ shortest_image_edge: 256
22
+ crop_fraction: 0.95
23
+ random_rotation_angle: null
24
+ color_jitter_params:
25
+ brightness: 0.3
26
+ contrast: 0.4
27
+ saturation: 0.5
28
+ hue: 0.08
29
+ use_albumentations_transforms: true
30
+ formalize_language: true
31
+ apply_sincos_state_encoding: false
32
+ use_relative_action: true
33
+ max_state_dim: 29
34
+ max_action_dim: 29
35
+ action_horizon: 16
36
+ hidden_size: 1024
37
+ input_embedding_dim: 1536
38
+ add_pos_embed: true
39
+ attn_dropout: 0.2
40
+ use_vlln: true
41
+ max_seq_len: 1024
42
+ use_alternate_vl_dit: true
43
+ attend_text_every_n_blocks: 2
44
+ diffusion_model_cfg:
45
+ positional_embeddings: null
46
+ num_layers: 32
47
+ num_attention_heads: 32
48
+ attention_head_dim: 48
49
+ norm_type: ada_norm
50
+ dropout: 0.2
51
+ final_dropout: true
52
+ output_dim: 1024
53
+ interleave_self_attention: true
54
+ num_inference_timesteps: 4
55
+ noise_beta_alpha: 1.5
56
+ noise_beta_beta: 1.0
57
+ noise_s: 0.999
58
+ num_timestep_buckets: 1000
59
+ tune_projector: true
60
+ tune_diffusion_model: true
61
+ tune_vlln: true
62
+ state_dropout_prob: 0.0
63
+ state_additive_noise_scale: 0.0
64
+ max_num_embodiments: 32
65
+ data:
66
+ datasets:
67
+ - dataset_paths:
68
+ - /datasets/isaaclab_arena/locomanipulation_tutorial/arena_g1_loco_manipulation_dataset_generated/lerobot
69
+ embodiment_tag: new_embodiment
70
+ mix_ratio: 1.0
71
+ dataset_type: physical_embodiment
72
+ val_dataset_path: null
73
+ modality_configs:
74
+ new_embodiment:
75
+ video:
76
+ delta_indices:
77
+ - 0
78
+ modality_keys:
79
+ - ego_view
80
+ sin_cos_embedding_keys: null
81
+ mean_std_embedding_keys: null
82
+ action_configs: null
83
+ state:
84
+ delta_indices:
85
+ - 0
86
+ modality_keys:
87
+ - left_arm
88
+ - right_arm
89
+ - left_hand
90
+ - right_hand
91
+ - waist
92
+ sin_cos_embedding_keys: null
93
+ mean_std_embedding_keys: null
94
+ action_configs: null
95
+ action:
96
+ delta_indices:
97
+ - 0
98
+ - 1
99
+ - 2
100
+ - 3
101
+ - 4
102
+ - 5
103
+ - 6
104
+ - 7
105
+ - 8
106
+ - 9
107
+ - 10
108
+ - 11
109
+ - 12
110
+ - 13
111
+ - 14
112
+ - 15
113
+ - 16
114
+ - 17
115
+ - 18
116
+ - 19
117
+ - 20
118
+ - 21
119
+ - 22
120
+ - 23
121
+ - 24
122
+ - 25
123
+ - 26
124
+ - 27
125
+ - 28
126
+ - 29
127
+ - 30
128
+ - 31
129
+ - 32
130
+ - 33
131
+ - 34
132
+ - 35
133
+ - 36
134
+ - 37
135
+ - 38
136
+ - 39
137
+ - 40
138
+ - 41
139
+ - 42
140
+ - 43
141
+ - 44
142
+ - 45
143
+ - 46
144
+ - 47
145
+ - 48
146
+ - 49
147
+ modality_keys:
148
+ - left_arm
149
+ - right_arm
150
+ - left_hand
151
+ - right_hand
152
+ - waist
153
+ - base_height_command
154
+ - navigate_command
155
+ sin_cos_embedding_keys: null
156
+ mean_std_embedding_keys: null
157
+ action_configs:
158
+ - rep: ABSOLUTE
159
+ type: NON_EEF
160
+ format: DEFAULT
161
+ state_key: null
162
+ - rep: ABSOLUTE
163
+ type: NON_EEF
164
+ format: DEFAULT
165
+ state_key: null
166
+ - rep: ABSOLUTE
167
+ type: NON_EEF
168
+ format: DEFAULT
169
+ state_key: null
170
+ - rep: ABSOLUTE
171
+ type: NON_EEF
172
+ format: DEFAULT
173
+ state_key: null
174
+ - rep: ABSOLUTE
175
+ type: NON_EEF
176
+ format: DEFAULT
177
+ state_key: null
178
+ - rep: ABSOLUTE
179
+ type: NON_EEF
180
+ format: DEFAULT
181
+ state_key: null
182
+ - rep: ABSOLUTE
183
+ type: NON_EEF
184
+ format: DEFAULT
185
+ state_key: null
186
+ language:
187
+ delta_indices:
188
+ - 0
189
+ modality_keys:
190
+ - annotation.human.task_description
191
+ sin_cos_embedding_keys: null
192
+ mean_std_embedding_keys: null
193
+ action_configs: null
194
+ download_cache: false
195
+ shard_size: 1024
196
+ episode_sampling_rate: 0.1
197
+ num_shards_per_epoch: 100000
198
+ override_pretraining_statistics: false
199
+ mode: single_turn
200
+ random_chop: 0.0
201
+ mock_dataset_mode: false
202
+ shuffle: true
203
+ seed: 42
204
+ multiprocessing_context: fork
205
+ allow_padding: false
206
+ subsample_ratio: 1.0
207
+ image_crop_size:
208
+ - 244
209
+ - 244
210
+ image_target_size:
211
+ - 224
212
+ - 224
213
+ video_backend: torchcodec
214
+ training:
215
+ output_dir: /models/isaaclab_arena/locomanipulation_tutorial
216
+ experiment_name: null
217
+ max_steps: 20000
218
+ global_batch_size: 192
219
+ batch_size: null
220
+ gradient_accumulation_steps: 1
221
+ learning_rate: 0.0001
222
+ lr_scheduler_type: cosine
223
+ weight_decay: 1.0e-05
224
+ warmup_ratio: 0.05
225
+ warmup_steps: 0
226
+ max_grad_norm: 1.0
227
+ optim: adamw_torch
228
+ start_from_checkpoint: nvidia/GR00T-N1.6-3B
229
+ tf32: true
230
+ fp16: false
231
+ bf16: true
232
+ eval_bf16: true
233
+ logging_steps: 10
234
+ save_steps: 5000
235
+ save_total_limit: 5
236
+ save_vl_model: false
237
+ upload_checkpoints: false
238
+ upload_every: 1000
239
+ upload_last_n_checkpoints: 5
240
+ max_concurrent_uploads: 2
241
+ eval_strategy: 'no'
242
+ eval_steps: 500
243
+ eval_set_split_ratio: 0.1
244
+ eval_batch_size: 2
245
+ save_best_eval_metric_name: ''
246
+ save_best_eval_metric_greater_is_better: true
247
+ deepspeed_stage: 2
248
+ gradient_checkpointing: false
249
+ transformers_trust_remote_code: true
250
+ transformers_local_files_only: false
251
+ transformers_cache_dir: null
252
+ transformers_access_token: null
253
+ use_ddp: false
254
+ ddp_bucket_cap_mb: 100
255
+ num_gpus: 8
256
+ dataloader_num_workers: 16
257
+ remove_unused_columns: false
258
+ use_wandb: false
259
+ wandb_project: finetune-gr00t-n1d6
260
+ enable_profiling: false
261
+ max_retries: 3
262
+ assert_loss_less_than: null
263
+ add_rl_callback: false
264
+ enable_open_loop_eval: false
265
+ open_loop_eval_traj_ids:
266
+ - 0
267
+ open_loop_eval_steps_per_traj: 100
268
+ open_loop_eval_plot_indices: null
269
+ max_steps: 20000
270
+ save_steps: 5000
experiment_cfg/config.yaml ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !!python/object:gr00t.configs.base_config.Config
2
+ data: !!python/object:gr00t.configs.data.data_config.DataConfig
3
+ allow_padding: false
4
+ datasets:
5
+ - !!python/object:gr00t.configs.data.data_config.SingleDatasetConfig
6
+ dataset_paths:
7
+ - /datasets/isaaclab_arena/locomanipulation_tutorial/arena_g1_loco_manipulation_dataset_generated/lerobot
8
+ dataset_type: physical_embodiment
9
+ embodiment_tag: new_embodiment
10
+ mix_ratio: 1.0
11
+ val_dataset_path: null
12
+ download_cache: false
13
+ episode_sampling_rate: 0.1
14
+ image_crop_size:
15
+ - 244
16
+ - 244
17
+ image_target_size:
18
+ - 224
19
+ - 224
20
+ mock_dataset_mode: false
21
+ modality_configs:
22
+ new_embodiment:
23
+ action: !!python/object:gr00t.data.types.ModalityConfig
24
+ action_configs:
25
+ - !!python/object:gr00t.data.types.ActionConfig
26
+ format: &id001 !!python/object/apply:gr00t.data.types.ActionFormat
27
+ - default
28
+ rep: &id002 !!python/object/apply:gr00t.data.types.ActionRepresentation
29
+ - absolute
30
+ state_key: null
31
+ type: &id003 !!python/object/apply:gr00t.data.types.ActionType
32
+ - non_eef
33
+ - !!python/object:gr00t.data.types.ActionConfig
34
+ format: *id001
35
+ rep: *id002
36
+ state_key: null
37
+ type: *id003
38
+ - !!python/object:gr00t.data.types.ActionConfig
39
+ format: *id001
40
+ rep: *id002
41
+ state_key: null
42
+ type: *id003
43
+ - !!python/object:gr00t.data.types.ActionConfig
44
+ format: *id001
45
+ rep: *id002
46
+ state_key: null
47
+ type: *id003
48
+ - !!python/object:gr00t.data.types.ActionConfig
49
+ format: *id001
50
+ rep: *id002
51
+ state_key: null
52
+ type: *id003
53
+ - !!python/object:gr00t.data.types.ActionConfig
54
+ format: *id001
55
+ rep: *id002
56
+ state_key: null
57
+ type: *id003
58
+ - !!python/object:gr00t.data.types.ActionConfig
59
+ format: *id001
60
+ rep: *id002
61
+ state_key: null
62
+ type: *id003
63
+ delta_indices:
64
+ - 0
65
+ - 1
66
+ - 2
67
+ - 3
68
+ - 4
69
+ - 5
70
+ - 6
71
+ - 7
72
+ - 8
73
+ - 9
74
+ - 10
75
+ - 11
76
+ - 12
77
+ - 13
78
+ - 14
79
+ - 15
80
+ - 16
81
+ - 17
82
+ - 18
83
+ - 19
84
+ - 20
85
+ - 21
86
+ - 22
87
+ - 23
88
+ - 24
89
+ - 25
90
+ - 26
91
+ - 27
92
+ - 28
93
+ - 29
94
+ - 30
95
+ - 31
96
+ - 32
97
+ - 33
98
+ - 34
99
+ - 35
100
+ - 36
101
+ - 37
102
+ - 38
103
+ - 39
104
+ - 40
105
+ - 41
106
+ - 42
107
+ - 43
108
+ - 44
109
+ - 45
110
+ - 46
111
+ - 47
112
+ - 48
113
+ - 49
114
+ mean_std_embedding_keys: null
115
+ modality_keys:
116
+ - left_arm
117
+ - right_arm
118
+ - left_hand
119
+ - right_hand
120
+ - waist
121
+ - base_height_command
122
+ - navigate_command
123
+ sin_cos_embedding_keys: null
124
+ language: !!python/object:gr00t.data.types.ModalityConfig
125
+ action_configs: null
126
+ delta_indices:
127
+ - 0
128
+ mean_std_embedding_keys: null
129
+ modality_keys:
130
+ - annotation.human.task_description
131
+ sin_cos_embedding_keys: null
132
+ state: !!python/object:gr00t.data.types.ModalityConfig
133
+ action_configs: null
134
+ delta_indices:
135
+ - 0
136
+ mean_std_embedding_keys: null
137
+ modality_keys:
138
+ - left_arm
139
+ - right_arm
140
+ - left_hand
141
+ - right_hand
142
+ - waist
143
+ sin_cos_embedding_keys: null
144
+ video: !!python/object:gr00t.data.types.ModalityConfig
145
+ action_configs: null
146
+ delta_indices:
147
+ - 0
148
+ mean_std_embedding_keys: null
149
+ modality_keys:
150
+ - ego_view
151
+ sin_cos_embedding_keys: null
152
+ mode: single_turn
153
+ multiprocessing_context: fork
154
+ num_shards_per_epoch: 100000
155
+ override_pretraining_statistics: false
156
+ random_chop: 0.0
157
+ seed: 42
158
+ shard_size: 1024
159
+ shuffle: true
160
+ subsample_ratio: 1.0
161
+ video_backend: torchcodec
162
+ load_config_path: null
163
+ model: !!python/object:gr00t.configs.model.gr00t_n1d6.Gr00tN1d6Config
164
+ _attn_implementation_autoset: false
165
+ _attn_implementation_internal: null
166
+ _commit_hash: null
167
+ _name_or_path: ''
168
+ add_cross_attention: false
169
+ architectures: null
170
+ backbone_model_type: eagle
171
+ backbone_trainable_params_fp32: true
172
+ bad_words_ids: null
173
+ begin_suppress_tokens: null
174
+ bos_token_id: null
175
+ chunk_size_feed_forward: 0
176
+ color_jitter_params:
177
+ brightness: 0.3
178
+ contrast: 0.4
179
+ hue: 0.08
180
+ saturation: 0.5
181
+ cross_attention_hidden_size: null
182
+ decoder_start_token_id: null
183
+ diffusion_model_cfg:
184
+ attention_head_dim: 48
185
+ dropout: 0.2
186
+ final_dropout: true
187
+ interleave_self_attention: true
188
+ norm_type: ada_norm
189
+ num_attention_heads: 32
190
+ num_layers: 32
191
+ output_dim: 1024
192
+ positional_embeddings: null
193
+ diversity_penalty: 0.0
194
+ do_sample: false
195
+ eagle_collator: true
196
+ early_stopping: false
197
+ encoder_no_repeat_ngram_size: 0
198
+ eos_token_id: null
199
+ exponential_decay_length_penalty: null
200
+ finetuning_task: null
201
+ forced_bos_token_id: null
202
+ forced_eos_token_id: null
203
+ id2label:
204
+ 0: LABEL_0
205
+ 1: LABEL_1
206
+ is_decoder: false
207
+ is_encoder_decoder: false
208
+ label2id:
209
+ LABEL_0: 0
210
+ LABEL_1: 1
211
+ length_penalty: 1.0
212
+ load_bf16: false
213
+ max_length: 20
214
+ min_length: 0
215
+ model_name: nvidia/Eagle-Block2A-2B-v2
216
+ no_repeat_ngram_size: 0
217
+ num_beam_groups: 1
218
+ num_beams: 1
219
+ num_return_sequences: 1
220
+ output_attentions: false
221
+ output_hidden_states: false
222
+ output_scores: false
223
+ pad_token_id: null
224
+ prefix: null
225
+ problem_type: null
226
+ pruned_heads: {}
227
+ random_rotation_angle: null
228
+ remove_invalid_values: false
229
+ repetition_penalty: 1.0
230
+ reproject_vision: false
231
+ return_dict: true
232
+ return_dict_in_generate: false
233
+ sep_token_id: null
234
+ state_dropout_prob: 0.0
235
+ suppress_tokens: null
236
+ task_specific_params: null
237
+ temperature: 1.0
238
+ tf_legacy_loss: false
239
+ tie_encoder_decoder: false
240
+ tie_word_embeddings: true
241
+ tokenizer_class: null
242
+ top_k: 50
243
+ top_p: 1.0
244
+ torch_dtype: null
245
+ torchscript: false
246
+ transformers_version: null
247
+ tune_diffusion_model: true
248
+ tune_llm: false
249
+ tune_projector: true
250
+ tune_visual: true
251
+ typical_p: 1.0
252
+ use_bfloat16: false
253
+ use_relative_action: true
254
+ training: !!python/object:gr00t.configs.training.training_config.TrainingConfig
255
+ add_rl_callback: false
256
+ assert_loss_less_than: null
257
+ batch_size: null
258
+ bf16: true
259
+ dataloader_num_workers: 16
260
+ ddp_bucket_cap_mb: 100
261
+ deepspeed_stage: 2
262
+ enable_open_loop_eval: false
263
+ enable_profiling: false
264
+ eval_batch_size: 2
265
+ eval_bf16: true
266
+ eval_set_split_ratio: 0.1
267
+ eval_steps: 500
268
+ eval_strategy: 'no'
269
+ experiment_name: null
270
+ fp16: false
271
+ global_batch_size: 192
272
+ gradient_accumulation_steps: 1
273
+ gradient_checkpointing: false
274
+ learning_rate: 0.0001
275
+ logging_steps: 10
276
+ lr_scheduler_type: cosine
277
+ max_concurrent_uploads: 2
278
+ max_grad_norm: 1.0
279
+ max_retries: 3
280
+ max_steps: 20000
281
+ num_gpus: 8
282
+ open_loop_eval_plot_indices: null
283
+ open_loop_eval_steps_per_traj: 100
284
+ open_loop_eval_traj_ids:
285
+ - 0
286
+ optim: adamw_torch
287
+ output_dir: /models/isaaclab_arena/locomanipulation_tutorial
288
+ remove_unused_columns: false
289
+ save_best_eval_metric_greater_is_better: true
290
+ save_best_eval_metric_name: ''
291
+ save_steps: 5000
292
+ save_total_limit: 5
293
+ save_vl_model: false
294
+ start_from_checkpoint: nvidia/GR00T-N1.6-3B
295
+ tf32: true
296
+ transformers_access_token: null
297
+ transformers_cache_dir: null
298
+ transformers_local_files_only: false
299
+ transformers_trust_remote_code: true
300
+ upload_checkpoints: false
301
+ upload_every: 1000
302
+ upload_last_n_checkpoints: 5
303
+ use_ddp: false
304
+ use_wandb: false
305
+ wandb_project: finetune-gr00t-n1d6
306
+ warmup_ratio: 0.05
307
+ warmup_steps: 0
308
+ weight_decay: 1.0e-05
experiment_cfg/dataset_statistics.json ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "new_embodiment": {
3
+ "state": {
4
+ "left_arm": {
5
+ "min": [
6
+ -1.2616037130355835,
7
+ -0.29025015234947205,
8
+ -0.22703997790813446,
9
+ -0.3353549540042877,
10
+ -0.0829518586397171,
11
+ -0.8195276260375977,
12
+ -0.2688920795917511
13
+ ],
14
+ "max": [
15
+ 0.15299034118652344,
16
+ 0.4194548726081848,
17
+ 0.304278701543808,
18
+ 1.4247486591339111,
19
+ 0.751840353012085,
20
+ 0.6736590266227722,
21
+ 0.569625973701477
22
+ ],
23
+ "mean": [
24
+ -0.6218094229698181,
25
+ -0.03578367084264755,
26
+ 0.05471671372652054,
27
+ 0.3273524045944214,
28
+ 0.16905353963375092,
29
+ 0.1931331604719162,
30
+ 0.0418560616672039
31
+ ],
32
+ "std": [
33
+ 0.2542016804218292,
34
+ 0.08585234731435776,
35
+ 0.05442973971366882,
36
+ 0.3563520908355713,
37
+ 0.10547080636024475,
38
+ 0.21155740320682526,
39
+ 0.0815652459859848
40
+ ],
41
+ "q01": [
42
+ -1.0867726147174834,
43
+ -0.23316791355609895,
44
+ -0.06077688504010439,
45
+ -0.2531130000948906,
46
+ -0.025190447550266983,
47
+ -0.41234332919120786,
48
+ -0.14684838354587554
49
+ ],
50
+ "q99": [
51
+ 0.02166599538177228,
52
+ 0.16592777222394936,
53
+ 0.19437864869832985,
54
+ 1.3526465594768522,
55
+ 0.47515065073966933,
56
+ 0.6158077389001846,
57
+ 0.267849366366863
58
+ ]
59
+ },
60
+ "right_arm": {
61
+ "min": [
62
+ -0.9889344573020935,
63
+ -0.7240632772445679,
64
+ -0.4150152802467346,
65
+ -0.2197991907596588,
66
+ -0.44296473264694214,
67
+ -0.9651272296905518,
68
+ -0.4595109820365906
69
+ ],
70
+ "max": [
71
+ 0.15951132774353027,
72
+ 0.21149154007434845,
73
+ 0.13221219182014465,
74
+ 1.4304473400115967,
75
+ 0.6581774950027466,
76
+ 0.33145904541015625,
77
+ 0.42284855246543884
78
+ ],
79
+ "mean": [
80
+ -0.5138179659843445,
81
+ -0.07899317145347595,
82
+ -0.1299561709165573,
83
+ 0.40922680497169495,
84
+ 0.027388907968997955,
85
+ -0.0835803970694542,
86
+ 0.024336807429790497
87
+ ],
88
+ "std": [
89
+ 0.1910795420408249,
90
+ 0.10697221755981445,
91
+ 0.0633271336555481,
92
+ 0.2594990134239197,
93
+ 0.14704135060310364,
94
+ 0.15591612458229065,
95
+ 0.06830708682537079
96
+ ],
97
+ "q01": [
98
+ -0.83366958796978,
99
+ -0.38898577094078063,
100
+ -0.27746869176626204,
101
+ -0.12615955173969268,
102
+ -0.2731088250875473,
103
+ -0.6371771156787872,
104
+ -0.16048517003655433
105
+ ],
106
+ "q99": [
107
+ 0.019438467640429113,
108
+ 0.13264653384685496,
109
+ 0.03749443646520371,
110
+ 1.3000927805900555,
111
+ 0.3483726784586904,
112
+ 0.12948824167251569,
113
+ 0.168773318082094
114
+ ]
115
+ },
116
+ "left_hand": {
117
+ "min": [
118
+ -0.008645662106573582,
119
+ -0.0016571161104366183,
120
+ -0.008173327893018723,
121
+ -0.0033370573073625565,
122
+ -0.049815986305475235,
123
+ -0.13737092912197113,
124
+ -8.590802735852776e-09
125
+ ],
126
+ "max": [
127
+ 8.85741064848844e-06,
128
+ 1.4383874713530531e-06,
129
+ 7.31344407540746e-05,
130
+ 4.420346158440225e-05,
131
+ 0.026730380952358246,
132
+ 0.06749135255813599,
133
+ 0.004176338668912649
134
+ ],
135
+ "mean": [
136
+ -0.00045161443995311856,
137
+ -9.045441402122378e-05,
138
+ -0.0008751734858378768,
139
+ -0.00010305152682121843,
140
+ -0.0026190115604549646,
141
+ -0.0007728625205345452,
142
+ 3.4298220271011814e-05
143
+ ],
144
+ "std": [
145
+ 0.0010219421237707138,
146
+ 0.00011942393030039966,
147
+ 0.0011946671875193715,
148
+ 0.00021070965158287436,
149
+ 0.004766007885336876,
150
+ 0.008314870297908783,
151
+ 0.00020773601136170328
152
+ ],
153
+ "q01": [
154
+ -0.004614621866494417,
155
+ -0.0005385997559642419,
156
+ -0.004787646210752427,
157
+ -0.0012936698796693236,
158
+ -0.01875622048974037,
159
+ -0.03178232274949551,
160
+ -2.9993839079089924e-10
161
+ ],
162
+ "q99": [
163
+ 1.4417540605826582e-09,
164
+ -5.172329953229189e-10,
165
+ -2.493637962786175e-10,
166
+ -6.717705641756689e-10,
167
+ 0.008347299136221403,
168
+ 0.012830186681821834,
169
+ 0.0014548563922289215
170
+ ]
171
+ },
172
+ "right_hand": {
173
+ "min": [
174
+ -1.5373115047623287e-07,
175
+ -2.7022052151437492e-08,
176
+ -2.0592709915945306e-05,
177
+ -7.066118541843025e-06,
178
+ -0.03601590916514397,
179
+ -0.5857902765274048,
180
+ -0.3214021623134613
181
+ ],
182
+ "max": [
183
+ 0.006290650460869074,
184
+ 0.001731343101710081,
185
+ 0.017454728484153748,
186
+ 0.012643150985240936,
187
+ 0.09934248775243759,
188
+ 0.0994623526930809,
189
+ 3.1769886277288606e-08
190
+ ],
191
+ "mean": [
192
+ 0.00025306272436864674,
193
+ 5.4000069212634116e-05,
194
+ 0.0003351480991113931,
195
+ 0.0008108046022243798,
196
+ 0.0006079890299588442,
197
+ -0.006738435477018356,
198
+ -0.00452095502987504
199
+ ],
200
+ "std": [
201
+ 0.0006930792587809265,
202
+ 0.00016116801998578012,
203
+ 0.0007848768145777285,
204
+ 0.0014818455092608929,
205
+ 0.009566166438162327,
206
+ 0.05241963639855385,
207
+ 0.030341269448399544
208
+ ],
209
+ "q01": [
210
+ -1.1203826366656955e-09,
211
+ 5.471793157463268e-10,
212
+ -7.516792688289087e-10,
213
+ 1.7157600895600922e-10,
214
+ -0.008333299728110432,
215
+ -0.3553843080997467,
216
+ -0.20837910920381547
217
+ ],
218
+ "q99": [
219
+ 0.0038171554915606976,
220
+ 0.0008218895673053339,
221
+ 0.003914117161184549,
222
+ 0.005107918474823237,
223
+ 0.061319448240101194,
224
+ 0.009818258183076798,
225
+ 3.1323699190011206e-10
226
+ ]
227
+ },
228
+ "waist": {
229
+ "min": [
230
+ -0.04632357507944107,
231
+ -0.11110502481460571,
232
+ -0.036814406514167786
233
+ ],
234
+ "max": [
235
+ 0.0633544921875,
236
+ 0.11162503063678741,
237
+ 0.1282370686531067
238
+ ],
239
+ "mean": [
240
+ 0.002279821317642927,
241
+ -0.0016866918886080384,
242
+ 0.05629865825176239
243
+ ],
244
+ "std": [
245
+ 0.019741930067539215,
246
+ 0.04374425858259201,
247
+ 0.023172633722424507
248
+ ],
249
+ "q01": [
250
+ -0.039197818748652934,
251
+ -0.09254500381648541,
252
+ -0.020507800113409757
253
+ ],
254
+ "q99": [
255
+ 0.054476964659988844,
256
+ 0.09499521441757679,
257
+ 0.10415777899324889
258
+ ]
259
+ }
260
+ },
261
+ "action": {
262
+ "left_arm": {
263
+ "min": [
264
+ -1.348067283630371,
265
+ -0.3527751564979553,
266
+ -0.3787360191345215,
267
+ -0.625663697719574,
268
+ -0.09716995060443878,
269
+ -0.9718959331512451,
270
+ -0.41488397121429443
271
+ ],
272
+ "max": [
273
+ 0.1336316466331482,
274
+ 0.4716266393661499,
275
+ 0.30831149220466614,
276
+ 1.4016180038452148,
277
+ 0.9397326111793518,
278
+ 0.6476842761039734,
279
+ 0.8313083648681641
280
+ ],
281
+ "mean": [
282
+ -0.6952570080757141,
283
+ -0.0709061548113823,
284
+ -0.04288463667035103,
285
+ 0.2694568634033203,
286
+ 0.1649714857339859,
287
+ 0.13536368310451508,
288
+ -0.02554020844399929
289
+ ],
290
+ "std": [
291
+ 0.26363858580589294,
292
+ 0.10477105528116226,
293
+ 0.07000378519296646,
294
+ 0.3648890554904938,
295
+ 0.11654239892959595,
296
+ 0.2099701166152954,
297
+ 0.08394794911146164
298
+ ],
299
+ "q01": [
300
+ -1.1805148243904113,
301
+ -0.308816134929657,
302
+ -0.17785422429442405,
303
+ -0.3138654500246048,
304
+ -0.05110809002071619,
305
+ -0.4920081451535225,
306
+ -0.1742709159851074
307
+ ],
308
+ "q99": [
309
+ -0.008620778424665838,
310
+ 0.20248875990509888,
311
+ 0.17697372585535032,
312
+ 1.284248530864715,
313
+ 0.522044214606285,
314
+ 0.5478375405073164,
315
+ 0.24634651243686412
316
+ ]
317
+ },
318
+ "right_arm": {
319
+ "min": [
320
+ -1.0777442455291748,
321
+ -0.7950155735015869,
322
+ -0.4215357005596161,
323
+ -0.33741918206214905,
324
+ -0.5877293348312378,
325
+ -1.0788743495941162,
326
+ -0.573306679725647
327
+ ],
328
+ "max": [
329
+ 0.14458219707012177,
330
+ 0.31825390458106995,
331
+ 0.3697803318500519,
332
+ 1.4193015098571777,
333
+ 0.6486993432044983,
334
+ 0.28742435574531555,
335
+ 0.49852707982063293
336
+ ],
337
+ "mean": [
338
+ -0.604250967502594,
339
+ -0.0556945763528347,
340
+ -0.03765946254134178,
341
+ 0.30660828948020935,
342
+ 0.01742653176188469,
343
+ -0.16916987299919128,
344
+ 0.09518744796514511
345
+ ],
346
+ "std": [
347
+ 0.20923613011837006,
348
+ 0.12663093209266663,
349
+ 0.08735905587673187,
350
+ 0.2593192756175995,
351
+ 0.15945474803447723,
352
+ 0.16604292392730713,
353
+ 0.07976584881544113
354
+ ],
355
+ "q01": [
356
+ -0.9175809919834137,
357
+ -0.5007677406072617,
358
+ -0.21304122656583785,
359
+ -0.21431435346603395,
360
+ -0.2938103020191193,
361
+ -0.7407654404640198,
362
+ -0.1693093843758106
363
+ ],
364
+ "q99": [
365
+ -0.011969150230289034,
366
+ 0.1981081753969192,
367
+ 0.14730184450745581,
368
+ 1.2670192122459407,
369
+ 0.3571772933006279,
370
+ 0.07727374359965306,
371
+ 0.24925321042537663
372
+ ]
373
+ },
374
+ "left_hand": {
375
+ "min": [
376
+ 0.0,
377
+ 0.0,
378
+ 0.0,
379
+ 0.0,
380
+ 0.0,
381
+ 0.0,
382
+ 0.0
383
+ ],
384
+ "max": [
385
+ 0.0,
386
+ 0.0,
387
+ 0.0,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.0
392
+ ],
393
+ "mean": [
394
+ 0.0,
395
+ 0.0,
396
+ 0.0,
397
+ 0.0,
398
+ 0.0,
399
+ 0.0,
400
+ 0.0
401
+ ],
402
+ "std": [
403
+ 0.0,
404
+ 0.0,
405
+ 0.0,
406
+ 0.0,
407
+ 0.0,
408
+ 0.0,
409
+ 0.0
410
+ ],
411
+ "q01": [
412
+ 0.0,
413
+ 0.0,
414
+ 0.0,
415
+ 0.0,
416
+ 0.0,
417
+ 0.0,
418
+ 0.0
419
+ ],
420
+ "q99": [
421
+ 0.0,
422
+ 0.0,
423
+ 0.0,
424
+ 0.0,
425
+ 0.0,
426
+ 0.0,
427
+ 0.0
428
+ ]
429
+ },
430
+ "right_hand": {
431
+ "min": [
432
+ -0.0,
433
+ -0.0,
434
+ -0.0,
435
+ -0.0,
436
+ -0.0,
437
+ -0.0,
438
+ -0.0
439
+ ],
440
+ "max": [
441
+ -0.0,
442
+ -0.0,
443
+ -0.0,
444
+ -0.0,
445
+ -0.0,
446
+ -0.0,
447
+ -0.0
448
+ ],
449
+ "mean": [
450
+ 0.0,
451
+ 0.0,
452
+ 0.0,
453
+ 0.0,
454
+ 0.0,
455
+ 0.0,
456
+ 0.0
457
+ ],
458
+ "std": [
459
+ 0.0,
460
+ 0.0,
461
+ 0.0,
462
+ 0.0,
463
+ 0.0,
464
+ 0.0,
465
+ 0.0
466
+ ],
467
+ "q01": [
468
+ 0.0,
469
+ 0.0,
470
+ 0.0,
471
+ 0.0,
472
+ 0.0,
473
+ 0.0,
474
+ 0.0
475
+ ],
476
+ "q99": [
477
+ -0.0,
478
+ -0.0,
479
+ -0.0,
480
+ -0.0,
481
+ -0.0,
482
+ -0.0,
483
+ -0.0
484
+ ]
485
+ },
486
+ "waist": {
487
+ "min": [
488
+ -0.03817012533545494,
489
+ -0.14767035841941833,
490
+ -0.09924878180027008
491
+ ],
492
+ "max": [
493
+ 0.05044477432966232,
494
+ 0.13773855566978455,
495
+ 0.10575182735919952
496
+ ],
497
+ "mean": [
498
+ 0.0021713885944336653,
499
+ -0.006043997593224049,
500
+ -0.0009960572933778167
501
+ ],
502
+ "std": [
503
+ 0.01315564289689064,
504
+ 0.04625461995601654,
505
+ 0.0275924950838089
506
+ ],
507
+ "q01": [
508
+ -0.02857382604852319,
509
+ -0.1123543307185173,
510
+ -0.09090777784585953
511
+ ],
512
+ "q99": [
513
+ 0.04313158672302961,
514
+ 0.1042894288897514,
515
+ 0.06339201703667638
516
+ ]
517
+ },
518
+ "base_height_command": {
519
+ "min": [
520
+ 0.6000000238418579
521
+ ],
522
+ "max": [
523
+ 0.75
524
+ ],
525
+ "mean": [
526
+ 0.7374278903007507
527
+ ],
528
+ "std": [
529
+ 0.039233911782502955
530
+ ],
531
+ "q01": [
532
+ 0.6000000238418579
533
+ ],
534
+ "q99": [
535
+ 0.75
536
+ ]
537
+ },
538
+ "navigate_command": {
539
+ "min": [
540
+ 0.0,
541
+ -0.12772086262702942,
542
+ -0.4000000059604645
543
+ ],
544
+ "max": [
545
+ 0.4000000059604645,
546
+ 0.15753206610679626,
547
+ 0.10000000149011612
548
+ ],
549
+ "mean": [
550
+ 0.10862857103347778,
551
+ 0.006709238979965448,
552
+ -0.08270397037267685
553
+ ],
554
+ "std": [
555
+ 0.17079046368598938,
556
+ 0.035745956003665924,
557
+ 0.1377689093351364
558
+ ],
559
+ "q01": [
560
+ 0.0,
561
+ -0.06209215875715017,
562
+ -0.4000000059604645
563
+ ],
564
+ "q99": [
565
+ 0.4000000059604645,
566
+ 0.10000000149011612,
567
+ 0.004937881324440136
568
+ ]
569
+ }
570
+ },
571
+ "relative_action": {}
572
+ }
573
+ }
experiment_cfg/final_model_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "Gr00tN1d6",
3
+ "model_dtype": "bfloat16",
4
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
5
+ "backbone_model_type": "eagle",
6
+ "model_revision": null,
7
+ "tune_top_llm_layers": 4,
8
+ "backbone_embedding_dim": 2048,
9
+ "tune_llm": false,
10
+ "tune_visual": true,
11
+ "select_layer": 16,
12
+ "reproject_vision": false,
13
+ "use_flash_attention": true,
14
+ "load_bf16": true,
15
+ "collator_overwrite_image_inputs": false,
16
+ "eagle_collator": true,
17
+ "backbone_trainable_params_fp32": true,
18
+ "apply_sincos_state_encoding": true,
19
+ "use_relative_action": true,
20
+ "max_state_dim": 128,
21
+ "max_action_dim": 128,
22
+ "action_horizon": 50,
23
+ "hidden_size": 1024,
24
+ "input_embedding_dim": 1536,
25
+ "add_pos_embed": true,
26
+ "attn_dropout": 0.2,
27
+ "use_vlln": true,
28
+ "max_seq_len": 1024,
29
+ "use_alternate_vl_dit": true,
30
+ "attend_text_every_n_blocks": 2,
31
+ "diffusion_model_cfg": {
32
+ "attention_head_dim": 48,
33
+ "dropout": 0.2,
34
+ "final_dropout": true,
35
+ "interleave_self_attention": true,
36
+ "norm_type": "ada_norm",
37
+ "num_attention_heads": 32,
38
+ "num_layers": 32,
39
+ "output_dim": 1024,
40
+ "positional_embeddings": null
41
+ },
42
+ "num_inference_timesteps": 4,
43
+ "noise_beta_alpha": 1.5,
44
+ "noise_beta_beta": 1.0,
45
+ "noise_s": 0.999,
46
+ "num_timestep_buckets": 1000,
47
+ "tune_projector": true,
48
+ "tune_diffusion_model": true,
49
+ "tune_vlln": true,
50
+ "state_dropout_prob": 0.0,
51
+ "state_additive_noise_scale": 0.0,
52
+ "max_num_embodiments": 32
53
+ }
experiment_cfg/final_processor_config.json ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
processor/embodiment_id.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "robocasa_panda_omron": 13,
3
+ "gr1": 20,
4
+ "behavior_r1_pro": 24,
5
+ "unitree_g1": 8,
6
+ "oxe_google": 0,
7
+ "oxe_widowx": 1,
8
+ "libero_panda": 2,
9
+ "oxe_droid": 16,
10
+ "new_embodiment": 10
11
+ }
processor/processor_config.json ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "Gr00tN1d6Processor",
3
+ "processor_kwargs": {
4
+ "modality_configs": {
5
+ "behavior_r1_pro": {
6
+ "video": {
7
+ "delta_indices": [
8
+ 0
9
+ ],
10
+ "modality_keys": [
11
+ "observation.images.rgb.head_256_256",
12
+ "observation.images.rgb.left_wrist_256_256",
13
+ "observation.images.rgb.right_wrist_256_256"
14
+ ],
15
+ "sin_cos_embedding_keys": null,
16
+ "mean_std_embedding_keys": null,
17
+ "action_configs": null
18
+ },
19
+ "state": {
20
+ "delta_indices": [
21
+ 0
22
+ ],
23
+ "modality_keys": [
24
+ "robot_pos",
25
+ "robot_ori_cos",
26
+ "robot_ori_sin",
27
+ "robot_2d_ori",
28
+ "robot_2d_ori_cos",
29
+ "robot_2d_ori_sin",
30
+ "robot_lin_vel",
31
+ "robot_ang_vel",
32
+ "arm_left_qpos",
33
+ "arm_left_qpos_sin",
34
+ "arm_left_qpos_cos",
35
+ "eef_left_pos",
36
+ "eef_left_quat",
37
+ "gripper_left_qpos",
38
+ "arm_right_qpos",
39
+ "arm_right_qpos_sin",
40
+ "arm_right_qpos_cos",
41
+ "eef_right_pos",
42
+ "eef_right_quat",
43
+ "gripper_right_qpos",
44
+ "trunk_qpos"
45
+ ],
46
+ "sin_cos_embedding_keys": null,
47
+ "mean_std_embedding_keys": null,
48
+ "action_configs": null
49
+ },
50
+ "action": {
51
+ "delta_indices": [
52
+ 0,
53
+ 1,
54
+ 2,
55
+ 3,
56
+ 4,
57
+ 5,
58
+ 6,
59
+ 7,
60
+ 8,
61
+ 9,
62
+ 10,
63
+ 11,
64
+ 12,
65
+ 13,
66
+ 14,
67
+ 15,
68
+ 16,
69
+ 17,
70
+ 18,
71
+ 19,
72
+ 20,
73
+ 21,
74
+ 22,
75
+ 23,
76
+ 24,
77
+ 25,
78
+ 26,
79
+ 27,
80
+ 28,
81
+ 29,
82
+ 30,
83
+ 31
84
+ ],
85
+ "modality_keys": [
86
+ "base",
87
+ "torso",
88
+ "left_arm",
89
+ "left_gripper",
90
+ "right_arm",
91
+ "right_gripper"
92
+ ],
93
+ "sin_cos_embedding_keys": null,
94
+ "mean_std_embedding_keys": null,
95
+ "action_configs": [
96
+ {
97
+ "rep": "ABSOLUTE",
98
+ "type": "NON_EEF",
99
+ "format": "DEFAULT",
100
+ "state_key": null
101
+ },
102
+ {
103
+ "rep": "RELATIVE",
104
+ "type": "NON_EEF",
105
+ "format": "DEFAULT",
106
+ "state_key": "trunk_qpos"
107
+ },
108
+ {
109
+ "rep": "RELATIVE",
110
+ "type": "NON_EEF",
111
+ "format": "DEFAULT",
112
+ "state_key": "arm_left_qpos"
113
+ },
114
+ {
115
+ "rep": "ABSOLUTE",
116
+ "type": "NON_EEF",
117
+ "format": "DEFAULT",
118
+ "state_key": null
119
+ },
120
+ {
121
+ "rep": "RELATIVE",
122
+ "type": "NON_EEF",
123
+ "format": "DEFAULT",
124
+ "state_key": "arm_right_qpos"
125
+ },
126
+ {
127
+ "rep": "ABSOLUTE",
128
+ "type": "NON_EEF",
129
+ "format": "DEFAULT",
130
+ "state_key": null
131
+ }
132
+ ]
133
+ },
134
+ "language": {
135
+ "delta_indices": [
136
+ 0
137
+ ],
138
+ "modality_keys": [
139
+ "annotation.human.coarse_action"
140
+ ],
141
+ "sin_cos_embedding_keys": null,
142
+ "mean_std_embedding_keys": null,
143
+ "action_configs": null
144
+ }
145
+ },
146
+ "gr1": {
147
+ "video": {
148
+ "delta_indices": [
149
+ 0
150
+ ],
151
+ "modality_keys": [
152
+ "ego_view_bg_crop_pad_res256_freq20"
153
+ ],
154
+ "sin_cos_embedding_keys": null,
155
+ "mean_std_embedding_keys": null,
156
+ "action_configs": null
157
+ },
158
+ "state": {
159
+ "delta_indices": [
160
+ 0
161
+ ],
162
+ "modality_keys": [
163
+ "left_arm",
164
+ "right_arm",
165
+ "left_hand",
166
+ "right_hand",
167
+ "waist"
168
+ ],
169
+ "sin_cos_embedding_keys": [
170
+ "left_arm",
171
+ "right_arm",
172
+ "left_hand",
173
+ "right_hand",
174
+ "waist"
175
+ ],
176
+ "mean_std_embedding_keys": null,
177
+ "action_configs": null
178
+ },
179
+ "action": {
180
+ "delta_indices": [
181
+ 0,
182
+ 1,
183
+ 2,
184
+ 3,
185
+ 4,
186
+ 5,
187
+ 6,
188
+ 7,
189
+ 8,
190
+ 9,
191
+ 10,
192
+ 11,
193
+ 12,
194
+ 13,
195
+ 14,
196
+ 15
197
+ ],
198
+ "modality_keys": [
199
+ "left_arm",
200
+ "right_arm",
201
+ "left_hand",
202
+ "right_hand",
203
+ "waist"
204
+ ],
205
+ "sin_cos_embedding_keys": null,
206
+ "mean_std_embedding_keys": null,
207
+ "action_configs": [
208
+ {
209
+ "rep": "RELATIVE",
210
+ "type": "NON_EEF",
211
+ "format": "DEFAULT",
212
+ "state_key": null
213
+ },
214
+ {
215
+ "rep": "RELATIVE",
216
+ "type": "NON_EEF",
217
+ "format": "DEFAULT",
218
+ "state_key": null
219
+ },
220
+ {
221
+ "rep": "RELATIVE",
222
+ "type": "NON_EEF",
223
+ "format": "DEFAULT",
224
+ "state_key": null
225
+ },
226
+ {
227
+ "rep": "RELATIVE",
228
+ "type": "NON_EEF",
229
+ "format": "DEFAULT",
230
+ "state_key": null
231
+ },
232
+ {
233
+ "rep": "ABSOLUTE",
234
+ "type": "NON_EEF",
235
+ "format": "DEFAULT",
236
+ "state_key": null
237
+ }
238
+ ]
239
+ },
240
+ "language": {
241
+ "delta_indices": [
242
+ 0
243
+ ],
244
+ "modality_keys": [
245
+ "task"
246
+ ],
247
+ "sin_cos_embedding_keys": null,
248
+ "mean_std_embedding_keys": null,
249
+ "action_configs": null
250
+ }
251
+ },
252
+ "robocasa_panda_omron": {
253
+ "video": {
254
+ "delta_indices": [
255
+ 0
256
+ ],
257
+ "modality_keys": [
258
+ "res256_image_side_0",
259
+ "res256_image_side_1",
260
+ "res256_image_wrist_0"
261
+ ],
262
+ "sin_cos_embedding_keys": null,
263
+ "mean_std_embedding_keys": null,
264
+ "action_configs": null
265
+ },
266
+ "state": {
267
+ "delta_indices": [
268
+ 0
269
+ ],
270
+ "modality_keys": [
271
+ "end_effector_position_relative",
272
+ "end_effector_rotation_relative",
273
+ "gripper_qpos",
274
+ "base_position",
275
+ "base_rotation"
276
+ ],
277
+ "sin_cos_embedding_keys": null,
278
+ "mean_std_embedding_keys": null,
279
+ "action_configs": null
280
+ },
281
+ "action": {
282
+ "delta_indices": [
283
+ 0,
284
+ 1,
285
+ 2,
286
+ 3,
287
+ 4,
288
+ 5,
289
+ 6,
290
+ 7,
291
+ 8,
292
+ 9,
293
+ 10,
294
+ 11,
295
+ 12,
296
+ 13,
297
+ 14,
298
+ 15
299
+ ],
300
+ "modality_keys": [
301
+ "end_effector_position",
302
+ "end_effector_rotation",
303
+ "gripper_close",
304
+ "base_motion",
305
+ "control_mode"
306
+ ],
307
+ "sin_cos_embedding_keys": null,
308
+ "mean_std_embedding_keys": null,
309
+ "action_configs": [
310
+ {
311
+ "rep": "ABSOLUTE",
312
+ "type": "NON_EEF",
313
+ "format": "DEFAULT",
314
+ "state_key": null
315
+ },
316
+ {
317
+ "rep": "ABSOLUTE",
318
+ "type": "NON_EEF",
319
+ "format": "DEFAULT",
320
+ "state_key": null
321
+ },
322
+ {
323
+ "rep": "ABSOLUTE",
324
+ "type": "NON_EEF",
325
+ "format": "DEFAULT",
326
+ "state_key": null
327
+ },
328
+ {
329
+ "rep": "ABSOLUTE",
330
+ "type": "NON_EEF",
331
+ "format": "DEFAULT",
332
+ "state_key": null
333
+ },
334
+ {
335
+ "rep": "ABSOLUTE",
336
+ "type": "NON_EEF",
337
+ "format": "DEFAULT",
338
+ "state_key": null
339
+ }
340
+ ]
341
+ },
342
+ "language": {
343
+ "delta_indices": [
344
+ 0
345
+ ],
346
+ "modality_keys": [
347
+ "annotation.human.action.task_description"
348
+ ],
349
+ "sin_cos_embedding_keys": null,
350
+ "mean_std_embedding_keys": null,
351
+ "action_configs": null
352
+ }
353
+ },
354
+ "new_embodiment": {
355
+ "video": {
356
+ "delta_indices": [
357
+ 0
358
+ ],
359
+ "modality_keys": [
360
+ "ego_view"
361
+ ],
362
+ "sin_cos_embedding_keys": null,
363
+ "mean_std_embedding_keys": null,
364
+ "action_configs": null
365
+ },
366
+ "state": {
367
+ "delta_indices": [
368
+ 0
369
+ ],
370
+ "modality_keys": [
371
+ "left_arm",
372
+ "right_arm",
373
+ "left_hand",
374
+ "right_hand",
375
+ "waist"
376
+ ],
377
+ "sin_cos_embedding_keys": null,
378
+ "mean_std_embedding_keys": null,
379
+ "action_configs": null
380
+ },
381
+ "action": {
382
+ "delta_indices": [
383
+ 0,
384
+ 1,
385
+ 2,
386
+ 3,
387
+ 4,
388
+ 5,
389
+ 6,
390
+ 7,
391
+ 8,
392
+ 9,
393
+ 10,
394
+ 11,
395
+ 12,
396
+ 13,
397
+ 14,
398
+ 15,
399
+ 16,
400
+ 17,
401
+ 18,
402
+ 19,
403
+ 20,
404
+ 21,
405
+ 22,
406
+ 23,
407
+ 24,
408
+ 25,
409
+ 26,
410
+ 27,
411
+ 28,
412
+ 29,
413
+ 30,
414
+ 31,
415
+ 32,
416
+ 33,
417
+ 34,
418
+ 35,
419
+ 36,
420
+ 37,
421
+ 38,
422
+ 39,
423
+ 40,
424
+ 41,
425
+ 42,
426
+ 43,
427
+ 44,
428
+ 45,
429
+ 46,
430
+ 47,
431
+ 48,
432
+ 49
433
+ ],
434
+ "modality_keys": [
435
+ "left_arm",
436
+ "right_arm",
437
+ "left_hand",
438
+ "right_hand",
439
+ "waist",
440
+ "base_height_command",
441
+ "navigate_command"
442
+ ],
443
+ "sin_cos_embedding_keys": null,
444
+ "mean_std_embedding_keys": null,
445
+ "action_configs": [
446
+ {
447
+ "rep": "ABSOLUTE",
448
+ "type": "NON_EEF",
449
+ "format": "DEFAULT",
450
+ "state_key": null
451
+ },
452
+ {
453
+ "rep": "ABSOLUTE",
454
+ "type": "NON_EEF",
455
+ "format": "DEFAULT",
456
+ "state_key": null
457
+ },
458
+ {
459
+ "rep": "ABSOLUTE",
460
+ "type": "NON_EEF",
461
+ "format": "DEFAULT",
462
+ "state_key": null
463
+ },
464
+ {
465
+ "rep": "ABSOLUTE",
466
+ "type": "NON_EEF",
467
+ "format": "DEFAULT",
468
+ "state_key": null
469
+ },
470
+ {
471
+ "rep": "ABSOLUTE",
472
+ "type": "NON_EEF",
473
+ "format": "DEFAULT",
474
+ "state_key": null
475
+ },
476
+ {
477
+ "rep": "ABSOLUTE",
478
+ "type": "NON_EEF",
479
+ "format": "DEFAULT",
480
+ "state_key": null
481
+ },
482
+ {
483
+ "rep": "ABSOLUTE",
484
+ "type": "NON_EEF",
485
+ "format": "DEFAULT",
486
+ "state_key": null
487
+ }
488
+ ]
489
+ },
490
+ "language": {
491
+ "delta_indices": [
492
+ 0
493
+ ],
494
+ "modality_keys": [
495
+ "annotation.human.task_description"
496
+ ],
497
+ "sin_cos_embedding_keys": null,
498
+ "mean_std_embedding_keys": null,
499
+ "action_configs": null
500
+ }
501
+ }
502
+ },
503
+ "image_crop_size": null,
504
+ "image_target_size": null,
505
+ "use_albumentations": true,
506
+ "random_rotation_angle": null,
507
+ "color_jitter_params": {
508
+ "brightness": 0.3,
509
+ "contrast": 0.4,
510
+ "saturation": 0.5,
511
+ "hue": 0.08
512
+ },
513
+ "shortest_image_edge": 256,
514
+ "crop_fraction": 0.95,
515
+ "model_name": "nvidia/Eagle-Block2A-2B-v2",
516
+ "model_type": "eagle",
517
+ "formalize_language": true,
518
+ "max_state_dim": 128,
519
+ "max_action_dim": 128,
520
+ "max_action_horizon": 50,
521
+ "use_percentiles": false,
522
+ "clip_outliers": true,
523
+ "apply_sincos_state_encoding": true,
524
+ "use_relative_action": true
525
+ }
526
+ }
processor/statistics.json ADDED
The diff for this file is too large to render. See raw diff
 
wandb_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"project": "finetune-gr00t-n1d6", "run_id": "locomanipulation_tutorial"}