ginwind commited on
Commit
9cea7bd
·
verified ·
1 Parent(s): b712cc3
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Real-world/checkpoints/VLA-JEPA-Real-World.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1ab32b71d2c5f2f755bdb42d9815dece05446d288cbf38bd946c01981949779
3
+ size 6163571823
Real-world/config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "run_id": "fr3_realworld",
3
+ "run_root_dir": "checkpoints",
4
+ "seed": 42,
5
+ "trackers": [
6
+ "json"
7
+ ],
8
+ "is_debug": false,
9
+ "framework": {
10
+ "name": "VLA_JEPA",
11
+ "qwenvl": {
12
+ "base_vlm": "/home/dataset-local/models/Qwen3-VL-2B-Instruct",
13
+ "attn_implementation": "flash_attention_2",
14
+ "vl_hidden_dim": 2048
15
+ },
16
+ "action_model": {
17
+ "action_model_type": "DiT-B",
18
+ "action_hidden_dim": 1024,
19
+ "hidden_size": 1024,
20
+ "add_pos_embed": true,
21
+ "max_seq_len": 1024,
22
+ "action_dim": 7,
23
+ "state_dim": 8,
24
+ "future_action_window_size": 6,
25
+ "action_horizon": 7,
26
+ "past_action_window_size": 0,
27
+ "repeated_diffusion_steps": 8,
28
+ "noise_beta_alpha": 1.5,
29
+ "noise_beta_beta": 1.0,
30
+ "noise_s": 0.999,
31
+ "num_timestep_buckets": 1000,
32
+ "num_inference_timesteps": 4,
33
+ "num_target_vision_tokens": 32,
34
+ "diffusion_model_cfg": {
35
+ "cross_attention_dim": 2048,
36
+ "dropout": 0.2,
37
+ "final_dropout": true,
38
+ "interleave_self_attention": true,
39
+ "norm_type": "ada_norm",
40
+ "num_layers": 16,
41
+ "output_dim": 1024,
42
+ "positional_embeddings": null
43
+ }
44
+ },
45
+ "vj2_model": {
46
+ "base_encoder": "/home/dataset-local/models/vjepa2-vitl-fpc64-256",
47
+ "depth": 12,
48
+ "num_heads": 8,
49
+ "special_action_token": "<|action_{}|>",
50
+ "num_action_tokens_per_timestep": 8,
51
+ "embodied_action_token": "<|embodied_action|>",
52
+ "num_embodied_action_tokens_per_instruction": 32,
53
+ "num_frames": 8
54
+ },
55
+ "reduce_in_full_precision": true
56
+ },
57
+ "datasets": {
58
+ "vla_data": {
59
+ "dataset_py": "lerobot_datasets",
60
+ "data_root_dir": "/home/dataset-local/datasets/LeRobot/lerobot_simple_pp_starvla",
61
+ "data_mix": "fr3_realworld",
62
+ "action_type": "delta_qpos",
63
+ "CoT_prompt": "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}.",
64
+ "resolution_size": 224,
65
+ "per_device_batch_size": 8,
66
+ "video_resolution_size": 256,
67
+ "load_all_data_for_training": true,
68
+ "with_state": true
69
+ }
70
+ },
71
+ "trainer": {
72
+ "epochs": 100,
73
+ "max_train_steps": 20000,
74
+ "num_warmup_steps": 5000,
75
+ "save_interval": 5000,
76
+ "eval_interval": 100,
77
+ "learning_rate": {
78
+ "base": 3e-05,
79
+ "qwen_vl_interface": 1e-05,
80
+ "action_model": 0.0001
81
+ },
82
+ "lr_scheduler_type": "cosine_with_min_lr",
83
+ "scheduler_specific_kwargs": {
84
+ "min_lr": 1e-06
85
+ },
86
+ "freeze_modules": "",
87
+ "loss_scale": {
88
+ "vla": 1.0,
89
+ "vlm": 0.1
90
+ },
91
+ "max_grad_norm": 1.0,
92
+ "warmup_ratio": 0.1,
93
+ "weight_decay": 0.0,
94
+ "logging_frequency": 10,
95
+ "gradient_clipping": 1.0,
96
+ "gradient_accumulation_steps": 1,
97
+ "pretrained_checkpoint": "/home/dataset-local/VLA_JEPA/checkpoints/pretrain/VLA-JEPA-pretrain.pt",
98
+ "optimizer": {
99
+ "name": "AdamW",
100
+ "betas": [
101
+ 0.9,
102
+ 0.95
103
+ ],
104
+ "eps": 1e-08,
105
+ "weight_decay": 1e-08
106
+ },
107
+ "is_resume": false,
108
+ "resume_epoch": null,
109
+ "resume_step": null,
110
+ "enable_gradient_checkpointing": true,
111
+ "enable_mixed_precision_training": true
112
+ },
113
+ "output_dir": "checkpoints/fr3_realworld"
114
+ }
Real-world/config.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: fr3_realworld
2
+ run_root_dir: checkpoints
3
+ seed: 42
4
+ trackers:
5
+ - json
6
+ is_debug: false
7
+ framework:
8
+ name: VLA_JEPA
9
+ qwenvl:
10
+ base_vlm: /home/dataset-local/models/Qwen3-VL-2B-Instruct
11
+ attn_implementation: flash_attention_2
12
+ vl_hidden_dim: 2048
13
+ action_model:
14
+ action_model_type: DiT-B
15
+ action_hidden_dim: 1024
16
+ hidden_size: 1024
17
+ add_pos_embed: true
18
+ max_seq_len: 1024
19
+ action_dim: 7
20
+ state_dim: 8
21
+ future_action_window_size: 6
22
+ action_horizon: 7
23
+ past_action_window_size: 0
24
+ repeated_diffusion_steps: 8
25
+ noise_beta_alpha: 1.5
26
+ noise_beta_beta: 1.0
27
+ noise_s: 0.999
28
+ num_timestep_buckets: 1000
29
+ num_inference_timesteps: 4
30
+ num_target_vision_tokens: 32
31
+ diffusion_model_cfg:
32
+ cross_attention_dim: 2048
33
+ dropout: 0.2
34
+ final_dropout: true
35
+ interleave_self_attention: true
36
+ norm_type: ada_norm
37
+ num_layers: 16
38
+ output_dim: 1024
39
+ positional_embeddings: null
40
+ vj2_model:
41
+ base_encoder: /home/dataset-local/models/vjepa2-vitl-fpc64-256
42
+ depth: 12
43
+ num_heads: 8
44
+ special_action_token: <|action_{}|>
45
+ num_action_tokens_per_timestep: 8
46
+ embodied_action_token: <|embodied_action|>
47
+ num_embodied_action_tokens_per_instruction: 32
48
+ num_frames: 8
49
+ reduce_in_full_precision: true
50
+ datasets:
51
+ vla_data:
52
+ dataset_py: lerobot_datasets
53
+ data_root_dir: /home/dataset-local/datasets/LeRobot/lerobot_simple_pp_starvla
54
+ data_mix: fr3_realworld
55
+ action_type: delta_qpos
56
+ CoT_prompt: Your task is {instruction}. Infer the temporal dynamics from frames
57
+ {actions} and produce the corresponding policy actions {e_actions}.
58
+ resolution_size: 224
59
+ per_device_batch_size: 8
60
+ video_resolution_size: 256
61
+ load_all_data_for_training: true
62
+ with_state: true
63
+ trainer:
64
+ epochs: 100
65
+ max_train_steps: 20000
66
+ num_warmup_steps: 5000
67
+ save_interval: 5000
68
+ eval_interval: 100
69
+ learning_rate:
70
+ base: 3.0e-05
71
+ qwen_vl_interface: 1.0e-05
72
+ action_model: 0.0001
73
+ lr_scheduler_type: cosine_with_min_lr
74
+ scheduler_specific_kwargs:
75
+ min_lr: 1.0e-06
76
+ freeze_modules: ''
77
+ loss_scale:
78
+ vla: 1.0
79
+ vlm: 0.1
80
+ max_grad_norm: 1.0
81
+ warmup_ratio: 0.1
82
+ weight_decay: 0.0
83
+ logging_frequency: 10
84
+ gradient_clipping: 1.0
85
+ gradient_accumulation_steps: 1
86
+ pretrained_checkpoint: /home/dataset-local/VLA_JEPA/checkpoints/pretrain/VLA-JEPA-pretrain.pt
87
+ optimizer:
88
+ name: AdamW
89
+ betas:
90
+ - 0.9
91
+ - 0.95
92
+ eps: 1.0e-08
93
+ weight_decay: 1.0e-08
94
+ is_resume: false
95
+ resume_epoch: null
96
+ resume_step: null
97
+ enable_gradient_checkpointing: true
98
+ enable_mixed_precision_training: true
99
+ output_dir: checkpoints/fr3_realworld
Real-world/dataset_statistics.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "new_embodiment": {
3
+ "action": {
4
+ "mean": [
5
+ -1.6495379895786755e-05,
6
+ -1.1559173799469136e-05,
7
+ 6.800065875722794e-06,
8
+ -2.9312453989405185e-05,
9
+ 2.715121809160337e-05,
10
+ -4.9356358431396075e-06,
11
+ 0.8368043303489685
12
+ ],
13
+ "std": [
14
+ 0.0033433015923947096,
15
+ 0.0033241035416722298,
16
+ 0.006203544791787863,
17
+ 0.0064756181091070175,
18
+ 0.006977501790970564,
19
+ 0.008858172222971916,
20
+ 0.308319091796875
21
+ ],
22
+ "max": [
23
+ 0.029447495937347412,
24
+ 0.04054729640483856,
25
+ 0.05029946565628052,
26
+ 0.04862421378493309,
27
+ 0.08689296990633011,
28
+ 0.06699639558792114,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.022726356983184814,
33
+ -0.027304204180836678,
34
+ -0.04183477163314819,
35
+ -0.06220978870987892,
36
+ -0.07182798534631729,
37
+ -0.09515094757080078,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.009029481410980224,
42
+ -0.010530177503824234,
43
+ -0.014956550300121307,
44
+ -0.02058939129114151,
45
+ -0.020688764695078136,
46
+ -0.03448918495327234,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.010799104869365696,
51
+ 0.01133852298371494,
52
+ 0.018483443856239404,
53
+ 0.01734598506242038,
54
+ 0.019764822572469743,
55
+ 0.022766803707927472,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "state": {
69
+ "mean": [
70
+ 0.564656674861908,
71
+ -0.002990487962961197,
72
+ 0.3436369001865387,
73
+ 0.8995662331581116,
74
+ -0.32835933566093445,
75
+ 0.03922935202717781,
76
+ -0.01520092599093914,
77
+ 0.8368043303489685
78
+ ],
79
+ "std": [
80
+ 0.087661974132061,
81
+ 0.09329120814800262,
82
+ 0.16795401275157928,
83
+ 0.18832752108573914,
84
+ 0.1940889209508896,
85
+ 0.07308010756969452,
86
+ 0.051968540996313095,
87
+ 0.308319091796875
88
+ ],
89
+ "max": [
90
+ 0.8323888182640076,
91
+ 0.2805258333683014,
92
+ 0.7719749808311462,
93
+ 0.9999507069587708,
94
+ 0.8582825660705566,
95
+ 0.3295547068119049,
96
+ 0.17514149844646454,
97
+ 1.0
98
+ ],
99
+ "min": [
100
+ 0.2107870727777481,
101
+ -0.3139313757419586,
102
+ 0.06993856281042099,
103
+ -0.7052522301673889,
104
+ -0.7020109295845032,
105
+ -0.24225156009197235,
106
+ -0.22992828488349915,
107
+ 0.0
108
+ ],
109
+ "q01": [
110
+ 0.29528580099344254,
111
+ -0.27157875895500183,
112
+ 0.08648627504706383,
113
+ -0.5524575877189636,
114
+ -0.6196129459142685,
115
+ -0.12603880420327188,
116
+ -0.1619066223502159,
117
+ 0.0
118
+ ],
119
+ "q99": [
120
+ 0.7763650172948837,
121
+ 0.22109754353761682,
122
+ 0.5853771680593494,
123
+ 0.9986419814825058,
124
+ 0.7496437698602678,
125
+ 0.24052399903535857,
126
+ 0.11271571815013894,
127
+ 1.0
128
+ ]
129
+ },
130
+ "num_transitions": 17919,
131
+ "num_trajectories": 100
132
+ }
133
+ }
Real-world/summary.jsonl ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"steps": 5000}
2
+ {"steps": 10000}
3
+ {"steps": 15000}
4
+ {"steps": 20000}
5
+ {"steps": 25000}
6
+ {"steps": 30000}
SimplerEnv/checkpoints/VLA-JEPA-SimplerEnv.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f797a93590847960cbd14feb56c29e5fbc39220c162ed29451c385caa9dab6e
3
+ size 6163573444
SimplerEnv/config.json ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "run_id": "SimplerEnv",
3
+ "run_root_dir": "checkpoints",
4
+ "seed": 42,
5
+ "trackers": [
6
+ "json"
7
+ ],
8
+ "is_debug": false,
9
+ "framework": {
10
+ "name": "VLA_JEPA",
11
+ "qwenvl": {
12
+ "base_vlm": "/home/dataset-local/models/Qwen3-VL-2B-Instruct",
13
+ "attn_implementation": "flash_attention_2",
14
+ "vl_hidden_dim": 2048
15
+ },
16
+ "action_model": {
17
+ "action_model_type": "DiT-B",
18
+ "action_hidden_dim": 1024,
19
+ "hidden_size": 1024,
20
+ "add_pos_embed": true,
21
+ "max_seq_len": 1024,
22
+ "action_dim": 7,
23
+ "state_dim": 8,
24
+ "future_action_window_size": 6,
25
+ "action_horizon": 7,
26
+ "past_action_window_size": 0,
27
+ "repeated_diffusion_steps": 8,
28
+ "noise_beta_alpha": 1.5,
29
+ "noise_beta_beta": 1.0,
30
+ "noise_s": 0.999,
31
+ "num_timestep_buckets": 1000,
32
+ "num_inference_timesteps": 4,
33
+ "num_target_vision_tokens": 32,
34
+ "diffusion_model_cfg": {
35
+ "cross_attention_dim": 2048,
36
+ "dropout": 0.2,
37
+ "final_dropout": true,
38
+ "interleave_self_attention": true,
39
+ "norm_type": "ada_norm",
40
+ "num_layers": 16,
41
+ "output_dim": 1024,
42
+ "positional_embeddings": null
43
+ }
44
+ },
45
+ "vj2_model": {
46
+ "base_encoder": "/home/dataset-local/models/vjepa2-vitl-fpc64-256",
47
+ "depth": 12,
48
+ "num_heads": 8,
49
+ "special_action_token": "<|action_{}|>",
50
+ "num_action_tokens_per_timestep": 8,
51
+ "embodied_action_token": "<|embodied_action|>",
52
+ "num_embodied_action_tokens_per_instruction": 32,
53
+ "num_frames": 8
54
+ },
55
+ "reduce_in_full_precision": true
56
+ },
57
+ "datasets": {
58
+ "vla_data": {
59
+ "dataset_py": "lerobot_datasets",
60
+ "data_root_dir": "/home/dataset-local/datasets/LeRobot/OXE_LEROBOT_DATASET",
61
+ "data_mix": "bridge_rt_1",
62
+ "action_type": "delta_ee",
63
+ "CoT_prompt": "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}.",
64
+ "resolution_size": 224,
65
+ "video_resolution_size": 256,
66
+ "per_device_batch_size": 32,
67
+ "load_all_data_for_training": true,
68
+ "action_horizon": 7,
69
+ "with_state": false
70
+ }
71
+ },
72
+ "trainer": {
73
+ "epochs": 100,
74
+ "max_train_steps": 30000,
75
+ "num_warmup_steps": 5000,
76
+ "save_interval": 10000,
77
+ "eval_interval": 100,
78
+ "learning_rate": {
79
+ "base": 3e-05,
80
+ "qwen_vl_interface": 1e-05,
81
+ "action_model": 0.0001,
82
+ "vj_predictor": 0.0005
83
+ },
84
+ "lr_scheduler_type": "cosine_with_min_lr",
85
+ "scheduler_specific_kwargs": {
86
+ "min_lr": 1e-05
87
+ },
88
+ "freeze_modules": "",
89
+ "loss_scale": {
90
+ "vla": 1.0,
91
+ "vlm": 0.1
92
+ },
93
+ "max_grad_norm": 1.0,
94
+ "warmup_ratio": 0.1,
95
+ "weight_decay": 0.0,
96
+ "logging_frequency": 10,
97
+ "gradient_clipping": 1.0,
98
+ "gradient_accumulation_steps": 1,
99
+ "optimizer": {
100
+ "name": "AdamW",
101
+ "betas": [
102
+ 0.9,
103
+ 0.95
104
+ ],
105
+ "eps": 1e-08,
106
+ "weight_decay": 1e-08
107
+ },
108
+ "is_resume": false,
109
+ "resume_epoch": null,
110
+ "resume_step": null,
111
+ "enable_gradient_checkpointing": true,
112
+ "enable_mixed_precision_training": true
113
+ },
114
+ "output_dir": "checkpoints/SimplerEnv"
115
+ }
SimplerEnv/config.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: SimplerEnv
2
+ run_root_dir: checkpoints
3
+ seed: 42
4
+ trackers:
5
+ - json
6
+ is_debug: false
7
+ framework:
8
+ name: VLA_JEPA
9
+ qwenvl:
10
+ base_vlm: /home/dataset-local/models/Qwen3-VL-2B-Instruct
11
+ attn_implementation: flash_attention_2
12
+ vl_hidden_dim: 2048
13
+ action_model:
14
+ action_model_type: DiT-B
15
+ action_hidden_dim: 1024
16
+ hidden_size: 1024
17
+ add_pos_embed: true
18
+ max_seq_len: 1024
19
+ action_dim: 7
20
+ state_dim: 8
21
+ future_action_window_size: 6
22
+ action_horizon: 7
23
+ past_action_window_size: 0
24
+ repeated_diffusion_steps: 8
25
+ noise_beta_alpha: 1.5
26
+ noise_beta_beta: 1.0
27
+ noise_s: 0.999
28
+ num_timestep_buckets: 1000
29
+ num_inference_timesteps: 4
30
+ num_target_vision_tokens: 32
31
+ diffusion_model_cfg:
32
+ cross_attention_dim: 2048
33
+ dropout: 0.2
34
+ final_dropout: true
35
+ interleave_self_attention: true
36
+ norm_type: ada_norm
37
+ num_layers: 16
38
+ output_dim: 1024
39
+ positional_embeddings: null
40
+ vj2_model:
41
+ base_encoder: /home/dataset-local/models/vjepa2-vitl-fpc64-256
42
+ depth: 12
43
+ num_heads: 8
44
+ special_action_token: <|action_{}|>
45
+ num_action_tokens_per_timestep: 8
46
+ embodied_action_token: <|embodied_action|>
47
+ num_embodied_action_tokens_per_instruction: 32
48
+ num_frames: 8
49
+ reduce_in_full_precision: true
50
+ datasets:
51
+ vla_data:
52
+ dataset_py: lerobot_datasets
53
+ data_root_dir: /home/dataset-local/datasets/LeRobot/OXE_LEROBOT_DATASET
54
+ data_mix: bridge_rt_1
55
+ action_type: delta_ee
56
+ CoT_prompt: Your task is {instruction}. Infer the temporal dynamics from frames
57
+ {actions} and produce the corresponding policy actions {e_actions}.
58
+ resolution_size: 224
59
+ video_resolution_size: 256
60
+ per_device_batch_size: 32
61
+ load_all_data_for_training: true
62
+ action_horizon: 7
63
+ with_state: false
64
+ trainer:
65
+ epochs: 100
66
+ max_train_steps: 30000
67
+ num_warmup_steps: 5000
68
+ save_interval: 10000
69
+ eval_interval: 100
70
+ learning_rate:
71
+ base: 3.0e-05
72
+ qwen_vl_interface: 1.0e-05
73
+ action_model: 0.0001
74
+ vj_predictor: 0.0005
75
+ lr_scheduler_type: cosine_with_min_lr
76
+ scheduler_specific_kwargs:
77
+ min_lr: 1.0e-05
78
+ freeze_modules: ''
79
+ loss_scale:
80
+ vla: 1.0
81
+ vlm: 0.1
82
+ max_grad_norm: 1.0
83
+ warmup_ratio: 0.1
84
+ weight_decay: 0.0
85
+ logging_frequency: 10
86
+ gradient_clipping: 1.0
87
+ gradient_accumulation_steps: 1
88
+ optimizer:
89
+ name: AdamW
90
+ betas:
91
+ - 0.9
92
+ - 0.95
93
+ eps: 1.0e-08
94
+ weight_decay: 1.0e-08
95
+ is_resume: false
96
+ resume_epoch: null
97
+ resume_step: null
98
+ enable_gradient_checkpointing: true
99
+ enable_mixed_precision_training: true
100
+ output_dir: checkpoints/SimplerEnv
SimplerEnv/dataset_statistics.json ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "oxe_bridge": {
3
+ "action": {
4
+ "mean": [
5
+ 0.00011365935642970726,
6
+ 6.556110747624189e-05,
7
+ -6.320965621853247e-05,
8
+ -7.205353176686913e-05,
9
+ -0.00019515302847139537,
10
+ 0.0001203166029881686,
11
+ 0.28829458355903625
12
+ ],
13
+ "std": [
14
+ 0.006909770731857718,
15
+ 0.009684093232852218,
16
+ 0.00896290498528129,
17
+ 0.020121052930683073,
18
+ 0.021582655517295487,
19
+ 0.054723342223346974,
20
+ 0.4543627821514982
21
+ ],
22
+ "max": [
23
+ 0.41691166162490845,
24
+ 0.25864794850349426,
25
+ 0.21218234300613403,
26
+ 3.122201919555664,
27
+ 1.8618112802505493,
28
+ 6.272472858428955,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.4007510244846344,
33
+ -0.13874775171279907,
34
+ -0.22553899884223938,
35
+ -3.2010786533355713,
36
+ -1.8618112802505493,
37
+ -6.279075622558594,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.02875255048274994,
42
+ -0.041702136397361755,
43
+ -0.02609672024846077,
44
+ -0.08052875101566315,
45
+ -0.09249906986951828,
46
+ -0.20738555490970612,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.028306663036346436,
51
+ 0.040898531675338745,
52
+ 0.0401805154979229,
53
+ 0.08173403143882751,
54
+ 0.07760760188102722,
55
+ 0.2038465440273285,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "state": {
69
+ "mean": [
70
+ 0.15471743047237396,
71
+ 0.015362550504505634,
72
+ 0.03222028166055679,
73
+ 0.0032453201711177826,
74
+ -0.038600146770477295,
75
+ 0.05382946878671646,
76
+ 0.0,
77
+ 0.35406652092933655
78
+ ],
79
+ "std": [
80
+ 0.1605359274864927,
81
+ 0.06677912092232112,
82
+ 0.048657150951527554,
83
+ 0.09275336958940797,
84
+ 0.12566984746914736,
85
+ 0.41222738578801793,
86
+ 0.0,
87
+ 0.4330223535305803
88
+ ],
89
+ "max": [
90
+ 0.5862360596656799,
91
+ 0.4034728705883026,
92
+ 0.3568263053894043,
93
+ 1.3517684936523438,
94
+ 1.570796251296997,
95
+ 3.141204357147217,
96
+ 0.0,
97
+ 1.1121242046356201
98
+ ],
99
+ "min": [
100
+ -0.04167502000927925,
101
+ -0.3563207685947418,
102
+ -0.15537554025650024,
103
+ -3.141592502593994,
104
+ -1.4992541074752808,
105
+ -3.14153790473938,
106
+ 0.0,
107
+ 0.04637829214334488
108
+ ],
109
+ "q01": [
110
+ 0.17102622985839844,
111
+ -0.1698118895292282,
112
+ -0.05563216283917427,
113
+ -0.36493173241615295,
114
+ -0.541871190071106,
115
+ -1.3542754650115967,
116
+ 0.0,
117
+ 0.052190229296684265
118
+ ],
119
+ "q99": [
120
+ 0.45322078466415405,
121
+ 0.2354845106601715,
122
+ 0.19489620625972748,
123
+ 0.3780156075954437,
124
+ 0.27568644285202026,
125
+ 1.8500566482543945,
126
+ 0.0,
127
+ 1.0105689764022827
128
+ ]
129
+ },
130
+ "num_transitions": 1863900,
131
+ "num_trajectories": 53192
132
+ },
133
+ "oxe_rt1": {
134
+ "action": {
135
+ "mean": [
136
+ 0.003493750700727105,
137
+ 0.003132961690425873,
138
+ -0.0063125672750175,
139
+ 0.02166595682501793,
140
+ -0.0028780836146324873,
141
+ 0.0004565489653032273,
142
+ 0.26771023869514465
143
+ ],
144
+ "std": [
145
+ 0.049065014465362655,
146
+ 0.04229853739828572,
147
+ 0.05237628880142378,
148
+ 0.1124860236500875,
149
+ 0.09312952783816872,
150
+ 0.10319098309601091,
151
+ 0.4418448662622395
152
+ ],
153
+ "max": [
154
+ 2.9984593391418457,
155
+ 22.09052848815918,
156
+ 2.7507524490356445,
157
+ 1.570636510848999,
158
+ 1.5321086645126343,
159
+ 1.5691522359848022,
160
+ 1.0
161
+ ],
162
+ "min": [
163
+ -2.0204520225524902,
164
+ -5.497899532318115,
165
+ -2.031663417816162,
166
+ -1.569917917251587,
167
+ -1.569892168045044,
168
+ -1.570419430732727,
169
+ 0.0
170
+ ],
171
+ "q01": [
172
+ -0.224535271525383,
173
+ -0.1482001394033432,
174
+ -0.23158970475196838,
175
+ -0.35179948806762695,
176
+ -0.4193011224269867,
177
+ -0.43643462657928467,
178
+ 0.0
179
+ ],
180
+ "q99": [
181
+ 0.17824687063694,
182
+ 0.1493837833404541,
183
+ 0.21842354536056519,
184
+ 0.5892665982246399,
185
+ 0.352726548910141,
186
+ 0.4479667842388153,
187
+ 1.0
188
+ ],
189
+ "mask": [
190
+ true,
191
+ true,
192
+ true,
193
+ true,
194
+ true,
195
+ true,
196
+ false
197
+ ]
198
+ },
199
+ "state": {
200
+ "mean": [
201
+ 0.2799473702907562,
202
+ -0.04167069122195244,
203
+ 0.38854750990867615,
204
+ -0.12402277439832687,
205
+ 0.24756911396980286,
206
+ 0.046330634504556656,
207
+ 0.10487449914216995,
208
+ 0.21306729316711426
209
+ ],
210
+ "std": [
211
+ 0.29342642876909925,
212
+ 0.09174024655686211,
213
+ 0.42569508885539115,
214
+ 0.38314586427420927,
215
+ 0.44433568806919804,
216
+ 0.1263927443679382,
217
+ 0.22122596673781084,
218
+ 0.38616252611341306
219
+ ],
220
+ "max": [
221
+ 1.0534898042678833,
222
+ 0.48018959164619446,
223
+ 1.6896663904190063,
224
+ 0.9999993443489075,
225
+ 0.9999874830245972,
226
+ 0.9554369449615479,
227
+ 0.9914546012878418,
228
+ 1.0
229
+ ],
230
+ "min": [
231
+ -0.4436439275741577,
232
+ -0.9970501065254211,
233
+ -0.006579156965017319,
234
+ -0.8643477559089661,
235
+ -0.7079970240592957,
236
+ -0.7688722014427185,
237
+ -0.4999994933605194,
238
+ 0.0
239
+ ],
240
+ "q01": [
241
+ 0.3248138129711151,
242
+ -0.2833428978919983,
243
+ 0.14107070863246918,
244
+ -0.6864742040634155,
245
+ -0.6808923482894897,
246
+ -0.3604559600353241,
247
+ -0.45438095927238464,
248
+ 0.0
249
+ ],
250
+ "q99": [
251
+ 0.8750156164169312,
252
+ 0.21247053146362305,
253
+ 1.0727112293243408,
254
+ 0.9377871155738831,
255
+ 0.9563050866127014,
256
+ 0.4599004089832306,
257
+ 0.7216041088104248,
258
+ 1.0
259
+ ]
260
+ },
261
+ "num_transitions": 3449894,
262
+ "num_trajectories": 87212
263
+ }
264
+ }
SimplerEnv/summary.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"steps": 10000}
2
+ {"steps": 20000}
3
+ {"steps": 30000}