lsnu commited on
Commit
ccf25b1
·
verified ·
1 Parent(s): da5e1bd

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +43 -0
  2. artifacts/twin_split_expert_bringup_20260310/README.md +59 -0
  3. artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single/config.json +15 -0
  4. artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single/init_parallel_metadata.json +654 -0
  5. artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single/model.safetensors +3 -0
  6. artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single/config.json +15 -0
  7. artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single/init_parallel_metadata.json +633 -0
  8. artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single/model.safetensors +3 -0
  9. artifacts/twin_split_expert_bringup_20260310/repro/commands_bringup.sh +82 -0
  10. artifacts/twin_split_expert_bringup_20260310/run_logs/split_communicating_real_smoke3.log +104 -0
  11. artifacts/twin_split_expert_bringup_20260310/run_logs/split_communicating_real_train20.log +173 -0
  12. artifacts/twin_split_expert_bringup_20260310/run_logs/split_independent_real_smoke3_r2.log +104 -0
  13. artifacts/twin_split_expert_bringup_20260310/run_logs/split_independent_real_train20.log +173 -0
  14. artifacts/twin_split_expert_bringup_20260310/sanity_checks/split_communicating_invariants.txt +6 -0
  15. artifacts/twin_split_expert_bringup_20260310/sanity_checks/split_independent_invariants.txt +8 -0
  16. openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json +152 -0
  17. openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json +152 -0
  18. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/assets/lsnu/twin_dual_push_128_train/norm_stats.json +152 -0
  19. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/metadata.pt +3 -0
  20. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/model.safetensors +3 -0
  21. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/optimizer.pt +3 -0
  22. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/assets/lsnu/twin_dual_push_128_train/norm_stats.json +152 -0
  23. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/metadata.pt +3 -0
  24. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/model.safetensors +3 -0
  25. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/optimizer.pt +3 -0
  26. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/assets/lsnu/twin_dual_push_128_train/norm_stats.json +152 -0
  27. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/metadata.pt +3 -0
  28. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/model.safetensors +3 -0
  29. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/optimizer.pt +3 -0
  30. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/assets/lsnu/twin_dual_push_128_train/norm_stats.json +152 -0
  31. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/metadata.pt +3 -0
  32. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/model.safetensors +3 -0
  33. openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/optimizer.pt +3 -0
  34. openpi/run_logs/split_communicating_real_smoke3.log +104 -0
  35. openpi/run_logs/split_communicating_real_train20.log +173 -0
  36. openpi/run_logs/split_independent_real_smoke3_r2.log +104 -0
  37. openpi/run_logs/split_independent_real_train20.log +173 -0
  38. openpi/scripts/check_parallel_warmstart_equivalence.py +7 -0
  39. openpi/scripts/check_split_expert_invariants.py +154 -0
  40. openpi/scripts/eval_twin_val_loss_pytorch.py +1 -1
  41. openpi/scripts/init_parallel_pi05_from_single_pytorch.py +151 -37
  42. openpi/scripts/run_twin_dual_push_128_packed_5k.sh +2 -2
  43. openpi/scripts/run_twin_handover_packed_10k.sh +2 -2
  44. openpi/scripts/train_pytorch.py +34 -2
  45. openpi/src/openpi/models/pi0_config.py +29 -2
  46. openpi/src/openpi/models/utils/fsq_tokenizer.py +17 -4
  47. openpi/src/openpi/models_pytorch/gemma_pytorch.py +236 -170
  48. openpi/src/openpi/models_pytorch/pi0_pytorch.py +238 -150
  49. openpi/src/openpi/training/config.py +233 -0
  50. openpi/src/openpi/training/data_loader.py +111 -1
README.md CHANGED
@@ -13,6 +13,13 @@ Three runs are included:
13
  2. a longer `10K` follow-up on the same packed setup
14
  3. a `5K` dual-push `128` screening study on the same packed path
15
 
 
 
 
 
 
 
 
16
  ## Experiment setup
17
 
18
  - Handover train/val: `lsnu/twin_handover_256_train`, `lsnu/twin_handover_256_val`
@@ -60,6 +67,34 @@ The packed parallel warm-start uses the slice/fuse mapping implemented in `openp
60
 
61
  So this repo should be read as a matched warm-start study, not as a bitwise-identical step-0 control.
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ## Repo layout
64
 
65
  - `openpi/`
@@ -72,6 +107,8 @@ So this repo should be read as a matched warm-start study, not as a bitwise-iden
72
  - `10K` follow-up bundle with metrics, logs, repro manifests, and environment snapshot
73
  - `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/`
74
  - dual-push `128` screening bundle with metrics, logs, repro manifests, and environment snapshot
 
 
75
  - `artifacts/pi05_base_params/`
76
  - staged base parameter snapshot used during JAX-to-PyTorch conversion
77
 
@@ -85,6 +122,11 @@ So this repo should be read as a matched warm-start study, not as a bitwise-iden
85
  - dual-push `5K` teacher-forced table: `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/metrics/teacher_forced_eval_table.csv`
86
  - dual-push `5K` sample eval table: `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/metrics/sample_eval_table.csv`
87
  - dual-push `5K` environment snapshot: `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/environment/`
 
 
 
 
 
88
  - `10K` repro commands: `artifacts/twin_handover_packed_parallelization_10k_20260309/repro/commands_reproduce.sh`
89
  - `10K` changed-file manifest: `artifacts/twin_handover_packed_parallelization_10k_20260309/repro/changed_files.txt`
90
  - `10K` environment snapshot: `artifacts/twin_handover_packed_parallelization_10k_20260309/environment/`
@@ -104,6 +146,7 @@ Initial `2K` + `10K` study logic lives primarily in:
104
  - `openpi/scripts/init_parallel_pi05_from_single_pytorch.py`
105
  - `openpi/scripts/inspect_twin_packed_batch.py`
106
  - `openpi/scripts/check_parallel_warmstart_equivalence.py`
 
107
  - `openpi/scripts/run_twin_handover_packed_followup.sh`
108
  - `openpi/scripts/run_twin_handover_packed_10k.sh`
109
  - `openpi/scripts/run_twin_dual_push_128_packed_5k.sh`
 
13
  2. a longer `10K` follow-up on the same packed setup
14
  3. a `5K` dual-push `128` screening study on the same packed path
15
 
16
+ This update also adds a split-action-expert bring-up bundle for the packed TWIN path, covering:
17
+
18
+ - exact single-to-split warm-start checkpoints for `split_independent` and `split_communicating`
19
+ - invariant checks for the new split architecture
20
+ - detached real-data smoke and `20`-step training runs on `lsnu/twin_dual_push_128_train`
21
+ - the code changes that introduce the new split-expert action path
22
+
23
  ## Experiment setup
24
 
25
  - Handover train/val: `lsnu/twin_handover_256_train`, `lsnu/twin_handover_256_val`
 
67
 
68
  So this repo should be read as a matched warm-start study, not as a bitwise-identical step-0 control.
69
 
70
+ ## Split-Expert Bring-Up (`2026-03-10`)
71
+
72
+ The current repo now contains a true split-action-expert implementation in addition to the earlier packed head-only factorization. The new config flag is `action_expert_mode` with:
73
+
74
+ - `shared`
75
+ - `head_only_parallel`
76
+ - `split_independent`
77
+ - `split_communicating`
78
+
79
+ Key bring-up results:
80
+
81
+ - the split warm-start copies the original single `gemma_expert` into exact left/right expert branches for both split modes
82
+ - `split_independent` passes the branch-local invariants:
83
+ - identical left/right inputs produce identical suffix outputs
84
+ - perturbing right-arm inputs leaves left-arm outputs unchanged, and vice versa
85
+ - both split modes pass detached real-data training on packed TWIN dual-push:
86
+ - `3`-step real-data smoke run with checkpoint save
87
+ - `20`-step real-data training run with checkpoint save
88
+ - the communicating model emits nonzero cross-arm attention diagnostics and remains finite through the real-data `20`-step run
89
+
90
+ New bring-up artifact bundle:
91
+
92
+ - `artifacts/twin_split_expert_bringup_20260310/`
93
+ - split warm-start checkpoints
94
+ - invariant-check outputs
95
+ - reproducibility commands
96
+ - summary README for the split-expert bring-up
97
+
98
  ## Repo layout
99
 
100
  - `openpi/`
 
107
  - `10K` follow-up bundle with metrics, logs, repro manifests, and environment snapshot
108
  - `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/`
109
  - dual-push `128` screening bundle with metrics, logs, repro manifests, and environment snapshot
110
+ - `artifacts/twin_split_expert_bringup_20260310/`
111
+ - split-expert warm-start checkpoints, sanity checks, and bring-up repro commands
112
  - `artifacts/pi05_base_params/`
113
  - staged base parameter snapshot used during JAX-to-PyTorch conversion
114
 
 
122
  - dual-push `5K` teacher-forced table: `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/metrics/teacher_forced_eval_table.csv`
123
  - dual-push `5K` sample eval table: `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/metrics/sample_eval_table.csv`
124
  - dual-push `5K` environment snapshot: `artifacts/twin_dual_push_128_packed_parallelization_5k_20260310/environment/`
125
+ - split-expert bring-up summary: `artifacts/twin_split_expert_bringup_20260310/README.md`
126
+ - split-expert repro commands: `artifacts/twin_split_expert_bringup_20260310/repro/commands_bringup.sh`
127
+ - split-expert invariant check outputs: `artifacts/twin_split_expert_bringup_20260310/sanity_checks/`
128
+ - split-expert real-data logs: `openpi/run_logs/split_independent_real_smoke3_r2.log`, `openpi/run_logs/split_communicating_real_smoke3.log`, `openpi/run_logs/split_independent_real_train20.log`, `openpi/run_logs/split_communicating_real_train20.log`
129
+ - split-expert real-data checkpoints: `openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/`, `openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/`
130
  - `10K` repro commands: `artifacts/twin_handover_packed_parallelization_10k_20260309/repro/commands_reproduce.sh`
131
  - `10K` changed-file manifest: `artifacts/twin_handover_packed_parallelization_10k_20260309/repro/changed_files.txt`
132
  - `10K` environment snapshot: `artifacts/twin_handover_packed_parallelization_10k_20260309/environment/`
 
146
  - `openpi/scripts/init_parallel_pi05_from_single_pytorch.py`
147
  - `openpi/scripts/inspect_twin_packed_batch.py`
148
  - `openpi/scripts/check_parallel_warmstart_equivalence.py`
149
+ - `openpi/scripts/check_split_expert_invariants.py`
150
  - `openpi/scripts/run_twin_handover_packed_followup.sh`
151
  - `openpi/scripts/run_twin_handover_packed_10k.sh`
152
  - `openpi/scripts/run_twin_dual_push_128_packed_5k.sh`
artifacts/twin_split_expert_bringup_20260310/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Split-Expert Bring-Up (`2026-03-10`)
2
+
3
+ This bundle captures the initial PyTorch bring-up for the new packed TWIN split-action-expert path on `pi0.5`.
4
+
5
+ Included here:
6
+
7
+ - exact split warm-start checkpoints created from the original single-head PyTorch base checkpoint
8
+ - invariant-check outputs for `split_independent` and `split_communicating`
9
+ - detached real-data smoke and `20`-step training logs on `lsnu/twin_dual_push_128_train`
10
+ - reproducibility commands used for the bring-up
11
+
12
+ ## Warm-start summary
13
+
14
+ Both split modes inherit the same base expert weights and per-arm input/output projections from the single-head checkpoint.
15
+
16
+ - `split_independent`
17
+ - `left_expert_max_abs_diff = 0.0`
18
+ - `right_expert_max_abs_diff = 0.0`
19
+ - `left_input_projection_max_abs_diff = 0.0`
20
+ - `right_input_projection_max_abs_diff = 0.0`
21
+ - `left_output_projection_max_abs_diff = 0.0`
22
+ - `right_output_projection_max_abs_diff = 0.0`
23
+ - `split_communicating`
24
+ - same exact inherited diffs as above
25
+ - added cross-arm communication parameters are zero-initialized at warm start
26
+
27
+ ## Real-data bring-up summary
28
+
29
+ Dataset used for real-data smoke and short training:
30
+
31
+ - `lsnu/twin_dual_push_128_train`
32
+
33
+ Successful detached runs:
34
+
35
+ - `split_independent_real_smoke3_r2`
36
+ - `3` train steps on real packed TWIN data
37
+ - checkpoint saved at step `3`
38
+ - `split_communicating_real_smoke3`
39
+ - `3` train steps on real packed TWIN data
40
+ - checkpoint saved at step `3`
41
+ - `split_independent_real_train20`
42
+ - `20` train steps on real packed TWIN data
43
+ - final logged train loss at step `20`: `0.6038`
44
+ - checkpoint saved at step `20`
45
+ - `split_communicating_real_train20`
46
+ - `20` train steps on real packed TWIN data
47
+ - final logged train loss at step `20`: `0.5943`
48
+ - checkpoint saved at step `20`
49
+
50
+ ## Layout
51
+
52
+ - `bootstrap_checkpoints/`
53
+ - exact split warm-start checkpoints
54
+ - `sanity_checks/`
55
+ - invariant-check outputs
56
+ - `run_logs/`
57
+ - detached real-data run logs
58
+ - `repro/commands_bringup.sh`
59
+ - reproduction commands used during the bring-up
artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_dim": 32,
3
+ "action_expert_mode": "split_communicating",
4
+ "action_expert_variant": "gemma_300m",
5
+ "action_horizon": 16,
6
+ "arm_action_dims": [
7
+ 16,
8
+ 16
9
+ ],
10
+ "discrete_state_input": true,
11
+ "dtype": "bfloat16",
12
+ "max_token_len": 200,
13
+ "paligemma_variant": "gemma_2b",
14
+ "pi05": true
15
+ }
artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single/init_parallel_metadata.json ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_expert_mode": "split_communicating",
3
+ "config_name": "pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k",
4
+ "cross_arm_comm_init": [
5
+ 0.0,
6
+ 0.0,
7
+ 0.0,
8
+ 0.0,
9
+ 0.0,
10
+ 0.0,
11
+ 0.0,
12
+ 0.0,
13
+ 0.0,
14
+ 0.0,
15
+ 0.0,
16
+ 0.0,
17
+ 0.0,
18
+ 0.0,
19
+ 0.0,
20
+ 0.0,
21
+ 0.0,
22
+ 0.0
23
+ ],
24
+ "left_expert_max_abs_diff": 0.0,
25
+ "left_input_projection_max_abs_diff": 0.0,
26
+ "left_output_projection_max_abs_diff": 0.0,
27
+ "load_state_missing_keys": [
28
+ "paligemma_with_expert.cross_arm_comm",
29
+ "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight",
30
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.q_proj.weight",
31
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.k_proj.weight",
32
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.v_proj.weight",
33
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.o_proj.weight",
34
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.mlp.gate_proj.weight",
35
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.mlp.up_proj.weight",
36
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.mlp.down_proj.weight",
37
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.input_layernorm.dense.weight",
38
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.input_layernorm.dense.bias",
39
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.post_attention_layernorm.dense.weight",
40
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.post_attention_layernorm.dense.bias",
41
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.q_proj.weight",
42
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.k_proj.weight",
43
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.v_proj.weight",
44
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.o_proj.weight",
45
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.mlp.gate_proj.weight",
46
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.mlp.up_proj.weight",
47
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.mlp.down_proj.weight",
48
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.input_layernorm.dense.weight",
49
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.input_layernorm.dense.bias",
50
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.post_attention_layernorm.dense.weight",
51
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.post_attention_layernorm.dense.bias",
52
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.q_proj.weight",
53
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.k_proj.weight",
54
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.v_proj.weight",
55
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.o_proj.weight",
56
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.mlp.gate_proj.weight",
57
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.mlp.up_proj.weight",
58
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.mlp.down_proj.weight",
59
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.input_layernorm.dense.weight",
60
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.input_layernorm.dense.bias",
61
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.post_attention_layernorm.dense.weight",
62
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.post_attention_layernorm.dense.bias",
63
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.q_proj.weight",
64
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.k_proj.weight",
65
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.v_proj.weight",
66
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.o_proj.weight",
67
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.mlp.gate_proj.weight",
68
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.mlp.up_proj.weight",
69
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.mlp.down_proj.weight",
70
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.input_layernorm.dense.weight",
71
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.input_layernorm.dense.bias",
72
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.post_attention_layernorm.dense.weight",
73
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.post_attention_layernorm.dense.bias",
74
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.q_proj.weight",
75
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.k_proj.weight",
76
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.v_proj.weight",
77
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.o_proj.weight",
78
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.mlp.gate_proj.weight",
79
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.mlp.up_proj.weight",
80
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.mlp.down_proj.weight",
81
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.input_layernorm.dense.weight",
82
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.input_layernorm.dense.bias",
83
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.post_attention_layernorm.dense.weight",
84
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.post_attention_layernorm.dense.bias",
85
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.q_proj.weight",
86
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.k_proj.weight",
87
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.v_proj.weight",
88
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.o_proj.weight",
89
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.mlp.gate_proj.weight",
90
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.mlp.up_proj.weight",
91
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.mlp.down_proj.weight",
92
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.input_layernorm.dense.weight",
93
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.input_layernorm.dense.bias",
94
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.post_attention_layernorm.dense.weight",
95
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.post_attention_layernorm.dense.bias",
96
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.q_proj.weight",
97
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.k_proj.weight",
98
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.v_proj.weight",
99
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.o_proj.weight",
100
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.mlp.gate_proj.weight",
101
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.mlp.up_proj.weight",
102
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.mlp.down_proj.weight",
103
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.input_layernorm.dense.weight",
104
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.input_layernorm.dense.bias",
105
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.post_attention_layernorm.dense.weight",
106
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.post_attention_layernorm.dense.bias",
107
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.q_proj.weight",
108
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.k_proj.weight",
109
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.v_proj.weight",
110
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.o_proj.weight",
111
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.mlp.gate_proj.weight",
112
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.mlp.up_proj.weight",
113
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.mlp.down_proj.weight",
114
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.input_layernorm.dense.weight",
115
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.input_layernorm.dense.bias",
116
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.post_attention_layernorm.dense.weight",
117
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.post_attention_layernorm.dense.bias",
118
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.q_proj.weight",
119
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.k_proj.weight",
120
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.v_proj.weight",
121
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.o_proj.weight",
122
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.mlp.gate_proj.weight",
123
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.mlp.up_proj.weight",
124
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.mlp.down_proj.weight",
125
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.input_layernorm.dense.weight",
126
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.input_layernorm.dense.bias",
127
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.post_attention_layernorm.dense.weight",
128
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.post_attention_layernorm.dense.bias",
129
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.q_proj.weight",
130
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.k_proj.weight",
131
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.v_proj.weight",
132
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.o_proj.weight",
133
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.mlp.gate_proj.weight",
134
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.mlp.up_proj.weight",
135
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.mlp.down_proj.weight",
136
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.input_layernorm.dense.weight",
137
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.input_layernorm.dense.bias",
138
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.post_attention_layernorm.dense.weight",
139
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.post_attention_layernorm.dense.bias",
140
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.q_proj.weight",
141
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.k_proj.weight",
142
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.v_proj.weight",
143
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.o_proj.weight",
144
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.mlp.gate_proj.weight",
145
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.mlp.up_proj.weight",
146
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.mlp.down_proj.weight",
147
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.input_layernorm.dense.weight",
148
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.input_layernorm.dense.bias",
149
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.post_attention_layernorm.dense.weight",
150
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.post_attention_layernorm.dense.bias",
151
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.q_proj.weight",
152
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.k_proj.weight",
153
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.v_proj.weight",
154
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.o_proj.weight",
155
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.mlp.gate_proj.weight",
156
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.mlp.up_proj.weight",
157
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.mlp.down_proj.weight",
158
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.input_layernorm.dense.weight",
159
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.input_layernorm.dense.bias",
160
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.post_attention_layernorm.dense.weight",
161
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.post_attention_layernorm.dense.bias",
162
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.q_proj.weight",
163
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.k_proj.weight",
164
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.v_proj.weight",
165
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.o_proj.weight",
166
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.mlp.gate_proj.weight",
167
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.mlp.up_proj.weight",
168
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.mlp.down_proj.weight",
169
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.input_layernorm.dense.weight",
170
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.input_layernorm.dense.bias",
171
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.post_attention_layernorm.dense.weight",
172
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.post_attention_layernorm.dense.bias",
173
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.q_proj.weight",
174
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.k_proj.weight",
175
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.v_proj.weight",
176
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.o_proj.weight",
177
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.mlp.gate_proj.weight",
178
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.mlp.up_proj.weight",
179
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.mlp.down_proj.weight",
180
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.input_layernorm.dense.weight",
181
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.input_layernorm.dense.bias",
182
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.post_attention_layernorm.dense.weight",
183
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.post_attention_layernorm.dense.bias",
184
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.q_proj.weight",
185
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.k_proj.weight",
186
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.v_proj.weight",
187
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.o_proj.weight",
188
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.mlp.gate_proj.weight",
189
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.mlp.up_proj.weight",
190
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.mlp.down_proj.weight",
191
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.input_layernorm.dense.weight",
192
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.input_layernorm.dense.bias",
193
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.post_attention_layernorm.dense.weight",
194
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.post_attention_layernorm.dense.bias",
195
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.q_proj.weight",
196
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.k_proj.weight",
197
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.v_proj.weight",
198
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.o_proj.weight",
199
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.mlp.gate_proj.weight",
200
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.mlp.up_proj.weight",
201
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.mlp.down_proj.weight",
202
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.input_layernorm.dense.weight",
203
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.input_layernorm.dense.bias",
204
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.post_attention_layernorm.dense.weight",
205
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.post_attention_layernorm.dense.bias",
206
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.q_proj.weight",
207
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.k_proj.weight",
208
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.v_proj.weight",
209
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.o_proj.weight",
210
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.mlp.gate_proj.weight",
211
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.mlp.up_proj.weight",
212
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.mlp.down_proj.weight",
213
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.input_layernorm.dense.weight",
214
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.input_layernorm.dense.bias",
215
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.post_attention_layernorm.dense.weight",
216
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.post_attention_layernorm.dense.bias",
217
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.q_proj.weight",
218
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.k_proj.weight",
219
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.v_proj.weight",
220
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.o_proj.weight",
221
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.mlp.gate_proj.weight",
222
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.mlp.up_proj.weight",
223
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.mlp.down_proj.weight",
224
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.input_layernorm.dense.weight",
225
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.input_layernorm.dense.bias",
226
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.post_attention_layernorm.dense.weight",
227
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.post_attention_layernorm.dense.bias",
228
+ "paligemma_with_expert.left_gemma_expert.model.norm.dense.weight",
229
+ "paligemma_with_expert.left_gemma_expert.model.norm.dense.bias",
230
+ "paligemma_with_expert.left_gemma_expert.lm_head.weight",
231
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.q_proj.weight",
232
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.k_proj.weight",
233
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.v_proj.weight",
234
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.o_proj.weight",
235
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.mlp.gate_proj.weight",
236
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.mlp.up_proj.weight",
237
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.mlp.down_proj.weight",
238
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.input_layernorm.dense.weight",
239
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.input_layernorm.dense.bias",
240
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.post_attention_layernorm.dense.weight",
241
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.post_attention_layernorm.dense.bias",
242
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.q_proj.weight",
243
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.k_proj.weight",
244
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.v_proj.weight",
245
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.o_proj.weight",
246
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.mlp.gate_proj.weight",
247
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.mlp.up_proj.weight",
248
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.mlp.down_proj.weight",
249
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.input_layernorm.dense.weight",
250
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.input_layernorm.dense.bias",
251
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.post_attention_layernorm.dense.weight",
252
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.post_attention_layernorm.dense.bias",
253
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.q_proj.weight",
254
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.k_proj.weight",
255
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.v_proj.weight",
256
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.o_proj.weight",
257
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.mlp.gate_proj.weight",
258
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.mlp.up_proj.weight",
259
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.mlp.down_proj.weight",
260
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.input_layernorm.dense.weight",
261
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.input_layernorm.dense.bias",
262
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.post_attention_layernorm.dense.weight",
263
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.post_attention_layernorm.dense.bias",
264
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.q_proj.weight",
265
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.k_proj.weight",
266
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.v_proj.weight",
267
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.o_proj.weight",
268
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.mlp.gate_proj.weight",
269
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.mlp.up_proj.weight",
270
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.mlp.down_proj.weight",
271
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.input_layernorm.dense.weight",
272
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.input_layernorm.dense.bias",
273
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.post_attention_layernorm.dense.weight",
274
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.post_attention_layernorm.dense.bias",
275
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.q_proj.weight",
276
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.k_proj.weight",
277
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.v_proj.weight",
278
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.o_proj.weight",
279
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.mlp.gate_proj.weight",
280
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.mlp.up_proj.weight",
281
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.mlp.down_proj.weight",
282
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.input_layernorm.dense.weight",
283
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.input_layernorm.dense.bias",
284
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.post_attention_layernorm.dense.weight",
285
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.post_attention_layernorm.dense.bias",
286
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.q_proj.weight",
287
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.k_proj.weight",
288
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.v_proj.weight",
289
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.o_proj.weight",
290
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.mlp.gate_proj.weight",
291
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.mlp.up_proj.weight",
292
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.mlp.down_proj.weight",
293
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.input_layernorm.dense.weight",
294
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.input_layernorm.dense.bias",
295
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.post_attention_layernorm.dense.weight",
296
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.post_attention_layernorm.dense.bias",
297
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.q_proj.weight",
298
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.k_proj.weight",
299
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.v_proj.weight",
300
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.o_proj.weight",
301
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.mlp.gate_proj.weight",
302
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.mlp.up_proj.weight",
303
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.mlp.down_proj.weight",
304
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.input_layernorm.dense.weight",
305
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.input_layernorm.dense.bias",
306
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.post_attention_layernorm.dense.weight",
307
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.post_attention_layernorm.dense.bias",
308
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.q_proj.weight",
309
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.k_proj.weight",
310
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.v_proj.weight",
311
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.o_proj.weight",
312
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.mlp.gate_proj.weight",
313
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.mlp.up_proj.weight",
314
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.mlp.down_proj.weight",
315
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.input_layernorm.dense.weight",
316
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.input_layernorm.dense.bias",
317
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.post_attention_layernorm.dense.weight",
318
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.post_attention_layernorm.dense.bias",
319
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.q_proj.weight",
320
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.k_proj.weight",
321
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.v_proj.weight",
322
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.o_proj.weight",
323
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.mlp.gate_proj.weight",
324
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.mlp.up_proj.weight",
325
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.mlp.down_proj.weight",
326
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.input_layernorm.dense.weight",
327
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.input_layernorm.dense.bias",
328
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.post_attention_layernorm.dense.weight",
329
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.post_attention_layernorm.dense.bias",
330
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.q_proj.weight",
331
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.k_proj.weight",
332
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.v_proj.weight",
333
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.o_proj.weight",
334
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.mlp.gate_proj.weight",
335
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.mlp.up_proj.weight",
336
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.mlp.down_proj.weight",
337
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.input_layernorm.dense.weight",
338
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.input_layernorm.dense.bias",
339
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.post_attention_layernorm.dense.weight",
340
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.post_attention_layernorm.dense.bias",
341
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.q_proj.weight",
342
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.k_proj.weight",
343
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.v_proj.weight",
344
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.o_proj.weight",
345
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.mlp.gate_proj.weight",
346
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.mlp.up_proj.weight",
347
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.mlp.down_proj.weight",
348
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.input_layernorm.dense.weight",
349
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.input_layernorm.dense.bias",
350
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.post_attention_layernorm.dense.weight",
351
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.post_attention_layernorm.dense.bias",
352
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.q_proj.weight",
353
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.k_proj.weight",
354
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.v_proj.weight",
355
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.o_proj.weight",
356
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.mlp.gate_proj.weight",
357
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.mlp.up_proj.weight",
358
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.mlp.down_proj.weight",
359
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.input_layernorm.dense.weight",
360
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.input_layernorm.dense.bias",
361
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.post_attention_layernorm.dense.weight",
362
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.post_attention_layernorm.dense.bias",
363
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.q_proj.weight",
364
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.k_proj.weight",
365
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.v_proj.weight",
366
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.o_proj.weight",
367
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.mlp.gate_proj.weight",
368
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.mlp.up_proj.weight",
369
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.mlp.down_proj.weight",
370
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.input_layernorm.dense.weight",
371
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.input_layernorm.dense.bias",
372
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.post_attention_layernorm.dense.weight",
373
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.post_attention_layernorm.dense.bias",
374
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.q_proj.weight",
375
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.k_proj.weight",
376
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.v_proj.weight",
377
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.o_proj.weight",
378
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.mlp.gate_proj.weight",
379
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.mlp.up_proj.weight",
380
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.mlp.down_proj.weight",
381
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.input_layernorm.dense.weight",
382
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.input_layernorm.dense.bias",
383
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.post_attention_layernorm.dense.weight",
384
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.post_attention_layernorm.dense.bias",
385
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.q_proj.weight",
386
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.k_proj.weight",
387
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.v_proj.weight",
388
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.o_proj.weight",
389
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.mlp.gate_proj.weight",
390
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.mlp.up_proj.weight",
391
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.mlp.down_proj.weight",
392
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.input_layernorm.dense.weight",
393
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.input_layernorm.dense.bias",
394
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.post_attention_layernorm.dense.weight",
395
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.post_attention_layernorm.dense.bias",
396
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.q_proj.weight",
397
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.k_proj.weight",
398
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.v_proj.weight",
399
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.o_proj.weight",
400
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.mlp.gate_proj.weight",
401
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.mlp.up_proj.weight",
402
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.mlp.down_proj.weight",
403
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.input_layernorm.dense.weight",
404
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.input_layernorm.dense.bias",
405
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.post_attention_layernorm.dense.weight",
406
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.post_attention_layernorm.dense.bias",
407
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.q_proj.weight",
408
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.k_proj.weight",
409
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.v_proj.weight",
410
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.o_proj.weight",
411
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.mlp.gate_proj.weight",
412
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.mlp.up_proj.weight",
413
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.mlp.down_proj.weight",
414
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.input_layernorm.dense.weight",
415
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.input_layernorm.dense.bias",
416
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.post_attention_layernorm.dense.weight",
417
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.post_attention_layernorm.dense.bias",
418
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.q_proj.weight",
419
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.k_proj.weight",
420
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.v_proj.weight",
421
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.o_proj.weight",
422
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.mlp.gate_proj.weight",
423
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.mlp.up_proj.weight",
424
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.mlp.down_proj.weight",
425
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.input_layernorm.dense.weight",
426
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.input_layernorm.dense.bias",
427
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.post_attention_layernorm.dense.weight",
428
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.post_attention_layernorm.dense.bias",
429
+ "paligemma_with_expert.right_gemma_expert.model.norm.dense.weight",
430
+ "paligemma_with_expert.right_gemma_expert.model.norm.dense.bias",
431
+ "paligemma_with_expert.right_gemma_expert.lm_head.weight",
432
+ "action_in_proj_arms.0.weight",
433
+ "action_in_proj_arms.0.bias",
434
+ "action_in_proj_arms.1.weight",
435
+ "action_in_proj_arms.1.bias",
436
+ "action_out_proj_arms.0.weight",
437
+ "action_out_proj_arms.0.bias",
438
+ "action_out_proj_arms.1.weight",
439
+ "action_out_proj_arms.1.bias"
440
+ ],
441
+ "load_state_unexpected_keys": [
442
+ "action_in_proj.bias",
443
+ "action_in_proj.weight",
444
+ "action_out_proj.bias",
445
+ "action_out_proj.weight",
446
+ "paligemma_with_expert.gemma_expert.lm_head.weight",
447
+ "paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.bias",
448
+ "paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.weight",
449
+ "paligemma_with_expert.gemma_expert.model.layers.0.mlp.down_proj.weight",
450
+ "paligemma_with_expert.gemma_expert.model.layers.0.mlp.gate_proj.weight",
451
+ "paligemma_with_expert.gemma_expert.model.layers.0.mlp.up_proj.weight",
452
+ "paligemma_with_expert.gemma_expert.model.layers.0.post_attention_layernorm.dense.bias",
453
+ "paligemma_with_expert.gemma_expert.model.layers.0.post_attention_layernorm.dense.weight",
454
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.k_proj.weight",
455
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.o_proj.weight",
456
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.q_proj.weight",
457
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.v_proj.weight",
458
+ "paligemma_with_expert.gemma_expert.model.layers.1.input_layernorm.dense.bias",
459
+ "paligemma_with_expert.gemma_expert.model.layers.1.input_layernorm.dense.weight",
460
+ "paligemma_with_expert.gemma_expert.model.layers.1.mlp.down_proj.weight",
461
+ "paligemma_with_expert.gemma_expert.model.layers.1.mlp.gate_proj.weight",
462
+ "paligemma_with_expert.gemma_expert.model.layers.1.mlp.up_proj.weight",
463
+ "paligemma_with_expert.gemma_expert.model.layers.1.post_attention_layernorm.dense.bias",
464
+ "paligemma_with_expert.gemma_expert.model.layers.1.post_attention_layernorm.dense.weight",
465
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.k_proj.weight",
466
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.o_proj.weight",
467
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.q_proj.weight",
468
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.v_proj.weight",
469
+ "paligemma_with_expert.gemma_expert.model.layers.10.input_layernorm.dense.bias",
470
+ "paligemma_with_expert.gemma_expert.model.layers.10.input_layernorm.dense.weight",
471
+ "paligemma_with_expert.gemma_expert.model.layers.10.mlp.down_proj.weight",
472
+ "paligemma_with_expert.gemma_expert.model.layers.10.mlp.gate_proj.weight",
473
+ "paligemma_with_expert.gemma_expert.model.layers.10.mlp.up_proj.weight",
474
+ "paligemma_with_expert.gemma_expert.model.layers.10.post_attention_layernorm.dense.bias",
475
+ "paligemma_with_expert.gemma_expert.model.layers.10.post_attention_layernorm.dense.weight",
476
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.k_proj.weight",
477
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.o_proj.weight",
478
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.q_proj.weight",
479
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.v_proj.weight",
480
+ "paligemma_with_expert.gemma_expert.model.layers.11.input_layernorm.dense.bias",
481
+ "paligemma_with_expert.gemma_expert.model.layers.11.input_layernorm.dense.weight",
482
+ "paligemma_with_expert.gemma_expert.model.layers.11.mlp.down_proj.weight",
483
+ "paligemma_with_expert.gemma_expert.model.layers.11.mlp.gate_proj.weight",
484
+ "paligemma_with_expert.gemma_expert.model.layers.11.mlp.up_proj.weight",
485
+ "paligemma_with_expert.gemma_expert.model.layers.11.post_attention_layernorm.dense.bias",
486
+ "paligemma_with_expert.gemma_expert.model.layers.11.post_attention_layernorm.dense.weight",
487
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.k_proj.weight",
488
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.o_proj.weight",
489
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.q_proj.weight",
490
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.v_proj.weight",
491
+ "paligemma_with_expert.gemma_expert.model.layers.12.input_layernorm.dense.bias",
492
+ "paligemma_with_expert.gemma_expert.model.layers.12.input_layernorm.dense.weight",
493
+ "paligemma_with_expert.gemma_expert.model.layers.12.mlp.down_proj.weight",
494
+ "paligemma_with_expert.gemma_expert.model.layers.12.mlp.gate_proj.weight",
495
+ "paligemma_with_expert.gemma_expert.model.layers.12.mlp.up_proj.weight",
496
+ "paligemma_with_expert.gemma_expert.model.layers.12.post_attention_layernorm.dense.bias",
497
+ "paligemma_with_expert.gemma_expert.model.layers.12.post_attention_layernorm.dense.weight",
498
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.k_proj.weight",
499
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.o_proj.weight",
500
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.q_proj.weight",
501
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.v_proj.weight",
502
+ "paligemma_with_expert.gemma_expert.model.layers.13.input_layernorm.dense.bias",
503
+ "paligemma_with_expert.gemma_expert.model.layers.13.input_layernorm.dense.weight",
504
+ "paligemma_with_expert.gemma_expert.model.layers.13.mlp.down_proj.weight",
505
+ "paligemma_with_expert.gemma_expert.model.layers.13.mlp.gate_proj.weight",
506
+ "paligemma_with_expert.gemma_expert.model.layers.13.mlp.up_proj.weight",
507
+ "paligemma_with_expert.gemma_expert.model.layers.13.post_attention_layernorm.dense.bias",
508
+ "paligemma_with_expert.gemma_expert.model.layers.13.post_attention_layernorm.dense.weight",
509
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.k_proj.weight",
510
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.o_proj.weight",
511
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.q_proj.weight",
512
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.v_proj.weight",
513
+ "paligemma_with_expert.gemma_expert.model.layers.14.input_layernorm.dense.bias",
514
+ "paligemma_with_expert.gemma_expert.model.layers.14.input_layernorm.dense.weight",
515
+ "paligemma_with_expert.gemma_expert.model.layers.14.mlp.down_proj.weight",
516
+ "paligemma_with_expert.gemma_expert.model.layers.14.mlp.gate_proj.weight",
517
+ "paligemma_with_expert.gemma_expert.model.layers.14.mlp.up_proj.weight",
518
+ "paligemma_with_expert.gemma_expert.model.layers.14.post_attention_layernorm.dense.bias",
519
+ "paligemma_with_expert.gemma_expert.model.layers.14.post_attention_layernorm.dense.weight",
520
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.k_proj.weight",
521
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.o_proj.weight",
522
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.q_proj.weight",
523
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.v_proj.weight",
524
+ "paligemma_with_expert.gemma_expert.model.layers.15.input_layernorm.dense.bias",
525
+ "paligemma_with_expert.gemma_expert.model.layers.15.input_layernorm.dense.weight",
526
+ "paligemma_with_expert.gemma_expert.model.layers.15.mlp.down_proj.weight",
527
+ "paligemma_with_expert.gemma_expert.model.layers.15.mlp.gate_proj.weight",
528
+ "paligemma_with_expert.gemma_expert.model.layers.15.mlp.up_proj.weight",
529
+ "paligemma_with_expert.gemma_expert.model.layers.15.post_attention_layernorm.dense.bias",
530
+ "paligemma_with_expert.gemma_expert.model.layers.15.post_attention_layernorm.dense.weight",
531
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.k_proj.weight",
532
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.o_proj.weight",
533
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.q_proj.weight",
534
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.v_proj.weight",
535
+ "paligemma_with_expert.gemma_expert.model.layers.16.input_layernorm.dense.bias",
536
+ "paligemma_with_expert.gemma_expert.model.layers.16.input_layernorm.dense.weight",
537
+ "paligemma_with_expert.gemma_expert.model.layers.16.mlp.down_proj.weight",
538
+ "paligemma_with_expert.gemma_expert.model.layers.16.mlp.gate_proj.weight",
539
+ "paligemma_with_expert.gemma_expert.model.layers.16.mlp.up_proj.weight",
540
+ "paligemma_with_expert.gemma_expert.model.layers.16.post_attention_layernorm.dense.bias",
541
+ "paligemma_with_expert.gemma_expert.model.layers.16.post_attention_layernorm.dense.weight",
542
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.k_proj.weight",
543
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.o_proj.weight",
544
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.q_proj.weight",
545
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.v_proj.weight",
546
+ "paligemma_with_expert.gemma_expert.model.layers.17.input_layernorm.dense.bias",
547
+ "paligemma_with_expert.gemma_expert.model.layers.17.input_layernorm.dense.weight",
548
+ "paligemma_with_expert.gemma_expert.model.layers.17.mlp.down_proj.weight",
549
+ "paligemma_with_expert.gemma_expert.model.layers.17.mlp.gate_proj.weight",
550
+ "paligemma_with_expert.gemma_expert.model.layers.17.mlp.up_proj.weight",
551
+ "paligemma_with_expert.gemma_expert.model.layers.17.post_attention_layernorm.dense.bias",
552
+ "paligemma_with_expert.gemma_expert.model.layers.17.post_attention_layernorm.dense.weight",
553
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.k_proj.weight",
554
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.o_proj.weight",
555
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.q_proj.weight",
556
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.v_proj.weight",
557
+ "paligemma_with_expert.gemma_expert.model.layers.2.input_layernorm.dense.bias",
558
+ "paligemma_with_expert.gemma_expert.model.layers.2.input_layernorm.dense.weight",
559
+ "paligemma_with_expert.gemma_expert.model.layers.2.mlp.down_proj.weight",
560
+ "paligemma_with_expert.gemma_expert.model.layers.2.mlp.gate_proj.weight",
561
+ "paligemma_with_expert.gemma_expert.model.layers.2.mlp.up_proj.weight",
562
+ "paligemma_with_expert.gemma_expert.model.layers.2.post_attention_layernorm.dense.bias",
563
+ "paligemma_with_expert.gemma_expert.model.layers.2.post_attention_layernorm.dense.weight",
564
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.k_proj.weight",
565
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.o_proj.weight",
566
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.q_proj.weight",
567
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.v_proj.weight",
568
+ "paligemma_with_expert.gemma_expert.model.layers.3.input_layernorm.dense.bias",
569
+ "paligemma_with_expert.gemma_expert.model.layers.3.input_layernorm.dense.weight",
570
+ "paligemma_with_expert.gemma_expert.model.layers.3.mlp.down_proj.weight",
571
+ "paligemma_with_expert.gemma_expert.model.layers.3.mlp.gate_proj.weight",
572
+ "paligemma_with_expert.gemma_expert.model.layers.3.mlp.up_proj.weight",
573
+ "paligemma_with_expert.gemma_expert.model.layers.3.post_attention_layernorm.dense.bias",
574
+ "paligemma_with_expert.gemma_expert.model.layers.3.post_attention_layernorm.dense.weight",
575
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.k_proj.weight",
576
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.o_proj.weight",
577
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.q_proj.weight",
578
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.v_proj.weight",
579
+ "paligemma_with_expert.gemma_expert.model.layers.4.input_layernorm.dense.bias",
580
+ "paligemma_with_expert.gemma_expert.model.layers.4.input_layernorm.dense.weight",
581
+ "paligemma_with_expert.gemma_expert.model.layers.4.mlp.down_proj.weight",
582
+ "paligemma_with_expert.gemma_expert.model.layers.4.mlp.gate_proj.weight",
583
+ "paligemma_with_expert.gemma_expert.model.layers.4.mlp.up_proj.weight",
584
+ "paligemma_with_expert.gemma_expert.model.layers.4.post_attention_layernorm.dense.bias",
585
+ "paligemma_with_expert.gemma_expert.model.layers.4.post_attention_layernorm.dense.weight",
586
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.k_proj.weight",
587
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.o_proj.weight",
588
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.q_proj.weight",
589
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.v_proj.weight",
590
+ "paligemma_with_expert.gemma_expert.model.layers.5.input_layernorm.dense.bias",
591
+ "paligemma_with_expert.gemma_expert.model.layers.5.input_layernorm.dense.weight",
592
+ "paligemma_with_expert.gemma_expert.model.layers.5.mlp.down_proj.weight",
593
+ "paligemma_with_expert.gemma_expert.model.layers.5.mlp.gate_proj.weight",
594
+ "paligemma_with_expert.gemma_expert.model.layers.5.mlp.up_proj.weight",
595
+ "paligemma_with_expert.gemma_expert.model.layers.5.post_attention_layernorm.dense.bias",
596
+ "paligemma_with_expert.gemma_expert.model.layers.5.post_attention_layernorm.dense.weight",
597
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.k_proj.weight",
598
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.o_proj.weight",
599
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.q_proj.weight",
600
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.v_proj.weight",
601
+ "paligemma_with_expert.gemma_expert.model.layers.6.input_layernorm.dense.bias",
602
+ "paligemma_with_expert.gemma_expert.model.layers.6.input_layernorm.dense.weight",
603
+ "paligemma_with_expert.gemma_expert.model.layers.6.mlp.down_proj.weight",
604
+ "paligemma_with_expert.gemma_expert.model.layers.6.mlp.gate_proj.weight",
605
+ "paligemma_with_expert.gemma_expert.model.layers.6.mlp.up_proj.weight",
606
+ "paligemma_with_expert.gemma_expert.model.layers.6.post_attention_layernorm.dense.bias",
607
+ "paligemma_with_expert.gemma_expert.model.layers.6.post_attention_layernorm.dense.weight",
608
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.k_proj.weight",
609
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.o_proj.weight",
610
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.q_proj.weight",
611
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.v_proj.weight",
612
+ "paligemma_with_expert.gemma_expert.model.layers.7.input_layernorm.dense.bias",
613
+ "paligemma_with_expert.gemma_expert.model.layers.7.input_layernorm.dense.weight",
614
+ "paligemma_with_expert.gemma_expert.model.layers.7.mlp.down_proj.weight",
615
+ "paligemma_with_expert.gemma_expert.model.layers.7.mlp.gate_proj.weight",
616
+ "paligemma_with_expert.gemma_expert.model.layers.7.mlp.up_proj.weight",
617
+ "paligemma_with_expert.gemma_expert.model.layers.7.post_attention_layernorm.dense.bias",
618
+ "paligemma_with_expert.gemma_expert.model.layers.7.post_attention_layernorm.dense.weight",
619
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.k_proj.weight",
620
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.o_proj.weight",
621
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.q_proj.weight",
622
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.v_proj.weight",
623
+ "paligemma_with_expert.gemma_expert.model.layers.8.input_layernorm.dense.bias",
624
+ "paligemma_with_expert.gemma_expert.model.layers.8.input_layernorm.dense.weight",
625
+ "paligemma_with_expert.gemma_expert.model.layers.8.mlp.down_proj.weight",
626
+ "paligemma_with_expert.gemma_expert.model.layers.8.mlp.gate_proj.weight",
627
+ "paligemma_with_expert.gemma_expert.model.layers.8.mlp.up_proj.weight",
628
+ "paligemma_with_expert.gemma_expert.model.layers.8.post_attention_layernorm.dense.bias",
629
+ "paligemma_with_expert.gemma_expert.model.layers.8.post_attention_layernorm.dense.weight",
630
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.k_proj.weight",
631
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.o_proj.weight",
632
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.q_proj.weight",
633
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.v_proj.weight",
634
+ "paligemma_with_expert.gemma_expert.model.layers.9.input_layernorm.dense.bias",
635
+ "paligemma_with_expert.gemma_expert.model.layers.9.input_layernorm.dense.weight",
636
+ "paligemma_with_expert.gemma_expert.model.layers.9.mlp.down_proj.weight",
637
+ "paligemma_with_expert.gemma_expert.model.layers.9.mlp.gate_proj.weight",
638
+ "paligemma_with_expert.gemma_expert.model.layers.9.mlp.up_proj.weight",
639
+ "paligemma_with_expert.gemma_expert.model.layers.9.post_attention_layernorm.dense.bias",
640
+ "paligemma_with_expert.gemma_expert.model.layers.9.post_attention_layernorm.dense.weight",
641
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.k_proj.weight",
642
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.o_proj.weight",
643
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.q_proj.weight",
644
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.v_proj.weight",
645
+ "paligemma_with_expert.gemma_expert.model.norm.dense.bias",
646
+ "paligemma_with_expert.gemma_expert.model.norm.dense.weight"
647
+ ],
648
+ "output_path": "/workspace/checkpoints/pi05_base_split_communicating_packed_from_single",
649
+ "right_expert_max_abs_diff": 0.0,
650
+ "right_input_projection_max_abs_diff": 0.0,
651
+ "right_output_projection_max_abs_diff": 0.0,
652
+ "single_ckpt": "/workspace/checkpoints/pi05_base_single_pytorch",
653
+ "warm_start_exact": true
654
+ }
artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de1302c45e9d80cbfe5124e0b288ed4da27c18599f1c73fc84714d6c6f45d998
3
+ size 9088652708
artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_dim": 32,
3
+ "action_expert_mode": "split_independent",
4
+ "action_expert_variant": "gemma_300m",
5
+ "action_horizon": 16,
6
+ "arm_action_dims": [
7
+ 16,
8
+ 16
9
+ ],
10
+ "discrete_state_input": true,
11
+ "dtype": "bfloat16",
12
+ "max_token_len": 200,
13
+ "paligemma_variant": "gemma_2b",
14
+ "pi05": true
15
+ }
artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single/init_parallel_metadata.json ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_expert_mode": "split_independent",
3
+ "config_name": "pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k",
4
+ "left_expert_max_abs_diff": 0.0,
5
+ "left_input_projection_max_abs_diff": 0.0,
6
+ "left_output_projection_max_abs_diff": 0.0,
7
+ "load_state_missing_keys": [
8
+ "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight",
9
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.q_proj.weight",
10
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.k_proj.weight",
11
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.v_proj.weight",
12
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.self_attn.o_proj.weight",
13
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.mlp.gate_proj.weight",
14
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.mlp.up_proj.weight",
15
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.mlp.down_proj.weight",
16
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.input_layernorm.dense.weight",
17
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.input_layernorm.dense.bias",
18
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.post_attention_layernorm.dense.weight",
19
+ "paligemma_with_expert.left_gemma_expert.model.layers.0.post_attention_layernorm.dense.bias",
20
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.q_proj.weight",
21
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.k_proj.weight",
22
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.v_proj.weight",
23
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.self_attn.o_proj.weight",
24
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.mlp.gate_proj.weight",
25
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.mlp.up_proj.weight",
26
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.mlp.down_proj.weight",
27
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.input_layernorm.dense.weight",
28
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.input_layernorm.dense.bias",
29
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.post_attention_layernorm.dense.weight",
30
+ "paligemma_with_expert.left_gemma_expert.model.layers.1.post_attention_layernorm.dense.bias",
31
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.q_proj.weight",
32
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.k_proj.weight",
33
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.v_proj.weight",
34
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.self_attn.o_proj.weight",
35
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.mlp.gate_proj.weight",
36
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.mlp.up_proj.weight",
37
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.mlp.down_proj.weight",
38
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.input_layernorm.dense.weight",
39
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.input_layernorm.dense.bias",
40
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.post_attention_layernorm.dense.weight",
41
+ "paligemma_with_expert.left_gemma_expert.model.layers.2.post_attention_layernorm.dense.bias",
42
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.q_proj.weight",
43
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.k_proj.weight",
44
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.v_proj.weight",
45
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.self_attn.o_proj.weight",
46
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.mlp.gate_proj.weight",
47
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.mlp.up_proj.weight",
48
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.mlp.down_proj.weight",
49
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.input_layernorm.dense.weight",
50
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.input_layernorm.dense.bias",
51
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.post_attention_layernorm.dense.weight",
52
+ "paligemma_with_expert.left_gemma_expert.model.layers.3.post_attention_layernorm.dense.bias",
53
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.q_proj.weight",
54
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.k_proj.weight",
55
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.v_proj.weight",
56
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.self_attn.o_proj.weight",
57
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.mlp.gate_proj.weight",
58
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.mlp.up_proj.weight",
59
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.mlp.down_proj.weight",
60
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.input_layernorm.dense.weight",
61
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.input_layernorm.dense.bias",
62
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.post_attention_layernorm.dense.weight",
63
+ "paligemma_with_expert.left_gemma_expert.model.layers.4.post_attention_layernorm.dense.bias",
64
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.q_proj.weight",
65
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.k_proj.weight",
66
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.v_proj.weight",
67
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.self_attn.o_proj.weight",
68
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.mlp.gate_proj.weight",
69
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.mlp.up_proj.weight",
70
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.mlp.down_proj.weight",
71
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.input_layernorm.dense.weight",
72
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.input_layernorm.dense.bias",
73
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.post_attention_layernorm.dense.weight",
74
+ "paligemma_with_expert.left_gemma_expert.model.layers.5.post_attention_layernorm.dense.bias",
75
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.q_proj.weight",
76
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.k_proj.weight",
77
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.v_proj.weight",
78
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.self_attn.o_proj.weight",
79
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.mlp.gate_proj.weight",
80
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.mlp.up_proj.weight",
81
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.mlp.down_proj.weight",
82
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.input_layernorm.dense.weight",
83
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.input_layernorm.dense.bias",
84
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.post_attention_layernorm.dense.weight",
85
+ "paligemma_with_expert.left_gemma_expert.model.layers.6.post_attention_layernorm.dense.bias",
86
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.q_proj.weight",
87
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.k_proj.weight",
88
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.v_proj.weight",
89
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.self_attn.o_proj.weight",
90
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.mlp.gate_proj.weight",
91
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.mlp.up_proj.weight",
92
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.mlp.down_proj.weight",
93
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.input_layernorm.dense.weight",
94
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.input_layernorm.dense.bias",
95
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.post_attention_layernorm.dense.weight",
96
+ "paligemma_with_expert.left_gemma_expert.model.layers.7.post_attention_layernorm.dense.bias",
97
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.q_proj.weight",
98
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.k_proj.weight",
99
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.v_proj.weight",
100
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.self_attn.o_proj.weight",
101
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.mlp.gate_proj.weight",
102
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.mlp.up_proj.weight",
103
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.mlp.down_proj.weight",
104
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.input_layernorm.dense.weight",
105
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.input_layernorm.dense.bias",
106
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.post_attention_layernorm.dense.weight",
107
+ "paligemma_with_expert.left_gemma_expert.model.layers.8.post_attention_layernorm.dense.bias",
108
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.q_proj.weight",
109
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.k_proj.weight",
110
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.v_proj.weight",
111
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.self_attn.o_proj.weight",
112
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.mlp.gate_proj.weight",
113
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.mlp.up_proj.weight",
114
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.mlp.down_proj.weight",
115
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.input_layernorm.dense.weight",
116
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.input_layernorm.dense.bias",
117
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.post_attention_layernorm.dense.weight",
118
+ "paligemma_with_expert.left_gemma_expert.model.layers.9.post_attention_layernorm.dense.bias",
119
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.q_proj.weight",
120
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.k_proj.weight",
121
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.v_proj.weight",
122
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.self_attn.o_proj.weight",
123
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.mlp.gate_proj.weight",
124
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.mlp.up_proj.weight",
125
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.mlp.down_proj.weight",
126
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.input_layernorm.dense.weight",
127
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.input_layernorm.dense.bias",
128
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.post_attention_layernorm.dense.weight",
129
+ "paligemma_with_expert.left_gemma_expert.model.layers.10.post_attention_layernorm.dense.bias",
130
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.q_proj.weight",
131
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.k_proj.weight",
132
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.v_proj.weight",
133
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.self_attn.o_proj.weight",
134
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.mlp.gate_proj.weight",
135
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.mlp.up_proj.weight",
136
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.mlp.down_proj.weight",
137
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.input_layernorm.dense.weight",
138
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.input_layernorm.dense.bias",
139
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.post_attention_layernorm.dense.weight",
140
+ "paligemma_with_expert.left_gemma_expert.model.layers.11.post_attention_layernorm.dense.bias",
141
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.q_proj.weight",
142
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.k_proj.weight",
143
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.v_proj.weight",
144
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.self_attn.o_proj.weight",
145
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.mlp.gate_proj.weight",
146
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.mlp.up_proj.weight",
147
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.mlp.down_proj.weight",
148
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.input_layernorm.dense.weight",
149
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.input_layernorm.dense.bias",
150
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.post_attention_layernorm.dense.weight",
151
+ "paligemma_with_expert.left_gemma_expert.model.layers.12.post_attention_layernorm.dense.bias",
152
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.q_proj.weight",
153
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.k_proj.weight",
154
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.v_proj.weight",
155
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.self_attn.o_proj.weight",
156
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.mlp.gate_proj.weight",
157
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.mlp.up_proj.weight",
158
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.mlp.down_proj.weight",
159
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.input_layernorm.dense.weight",
160
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.input_layernorm.dense.bias",
161
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.post_attention_layernorm.dense.weight",
162
+ "paligemma_with_expert.left_gemma_expert.model.layers.13.post_attention_layernorm.dense.bias",
163
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.q_proj.weight",
164
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.k_proj.weight",
165
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.v_proj.weight",
166
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.self_attn.o_proj.weight",
167
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.mlp.gate_proj.weight",
168
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.mlp.up_proj.weight",
169
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.mlp.down_proj.weight",
170
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.input_layernorm.dense.weight",
171
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.input_layernorm.dense.bias",
172
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.post_attention_layernorm.dense.weight",
173
+ "paligemma_with_expert.left_gemma_expert.model.layers.14.post_attention_layernorm.dense.bias",
174
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.q_proj.weight",
175
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.k_proj.weight",
176
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.v_proj.weight",
177
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.self_attn.o_proj.weight",
178
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.mlp.gate_proj.weight",
179
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.mlp.up_proj.weight",
180
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.mlp.down_proj.weight",
181
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.input_layernorm.dense.weight",
182
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.input_layernorm.dense.bias",
183
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.post_attention_layernorm.dense.weight",
184
+ "paligemma_with_expert.left_gemma_expert.model.layers.15.post_attention_layernorm.dense.bias",
185
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.q_proj.weight",
186
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.k_proj.weight",
187
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.v_proj.weight",
188
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.self_attn.o_proj.weight",
189
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.mlp.gate_proj.weight",
190
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.mlp.up_proj.weight",
191
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.mlp.down_proj.weight",
192
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.input_layernorm.dense.weight",
193
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.input_layernorm.dense.bias",
194
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.post_attention_layernorm.dense.weight",
195
+ "paligemma_with_expert.left_gemma_expert.model.layers.16.post_attention_layernorm.dense.bias",
196
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.q_proj.weight",
197
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.k_proj.weight",
198
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.v_proj.weight",
199
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.self_attn.o_proj.weight",
200
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.mlp.gate_proj.weight",
201
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.mlp.up_proj.weight",
202
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.mlp.down_proj.weight",
203
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.input_layernorm.dense.weight",
204
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.input_layernorm.dense.bias",
205
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.post_attention_layernorm.dense.weight",
206
+ "paligemma_with_expert.left_gemma_expert.model.layers.17.post_attention_layernorm.dense.bias",
207
+ "paligemma_with_expert.left_gemma_expert.model.norm.dense.weight",
208
+ "paligemma_with_expert.left_gemma_expert.model.norm.dense.bias",
209
+ "paligemma_with_expert.left_gemma_expert.lm_head.weight",
210
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.q_proj.weight",
211
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.k_proj.weight",
212
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.v_proj.weight",
213
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.self_attn.o_proj.weight",
214
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.mlp.gate_proj.weight",
215
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.mlp.up_proj.weight",
216
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.mlp.down_proj.weight",
217
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.input_layernorm.dense.weight",
218
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.input_layernorm.dense.bias",
219
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.post_attention_layernorm.dense.weight",
220
+ "paligemma_with_expert.right_gemma_expert.model.layers.0.post_attention_layernorm.dense.bias",
221
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.q_proj.weight",
222
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.k_proj.weight",
223
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.v_proj.weight",
224
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.self_attn.o_proj.weight",
225
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.mlp.gate_proj.weight",
226
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.mlp.up_proj.weight",
227
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.mlp.down_proj.weight",
228
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.input_layernorm.dense.weight",
229
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.input_layernorm.dense.bias",
230
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.post_attention_layernorm.dense.weight",
231
+ "paligemma_with_expert.right_gemma_expert.model.layers.1.post_attention_layernorm.dense.bias",
232
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.q_proj.weight",
233
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.k_proj.weight",
234
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.v_proj.weight",
235
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.self_attn.o_proj.weight",
236
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.mlp.gate_proj.weight",
237
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.mlp.up_proj.weight",
238
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.mlp.down_proj.weight",
239
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.input_layernorm.dense.weight",
240
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.input_layernorm.dense.bias",
241
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.post_attention_layernorm.dense.weight",
242
+ "paligemma_with_expert.right_gemma_expert.model.layers.2.post_attention_layernorm.dense.bias",
243
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.q_proj.weight",
244
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.k_proj.weight",
245
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.v_proj.weight",
246
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.self_attn.o_proj.weight",
247
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.mlp.gate_proj.weight",
248
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.mlp.up_proj.weight",
249
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.mlp.down_proj.weight",
250
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.input_layernorm.dense.weight",
251
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.input_layernorm.dense.bias",
252
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.post_attention_layernorm.dense.weight",
253
+ "paligemma_with_expert.right_gemma_expert.model.layers.3.post_attention_layernorm.dense.bias",
254
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.q_proj.weight",
255
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.k_proj.weight",
256
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.v_proj.weight",
257
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.self_attn.o_proj.weight",
258
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.mlp.gate_proj.weight",
259
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.mlp.up_proj.weight",
260
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.mlp.down_proj.weight",
261
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.input_layernorm.dense.weight",
262
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.input_layernorm.dense.bias",
263
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.post_attention_layernorm.dense.weight",
264
+ "paligemma_with_expert.right_gemma_expert.model.layers.4.post_attention_layernorm.dense.bias",
265
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.q_proj.weight",
266
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.k_proj.weight",
267
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.v_proj.weight",
268
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.self_attn.o_proj.weight",
269
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.mlp.gate_proj.weight",
270
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.mlp.up_proj.weight",
271
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.mlp.down_proj.weight",
272
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.input_layernorm.dense.weight",
273
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.input_layernorm.dense.bias",
274
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.post_attention_layernorm.dense.weight",
275
+ "paligemma_with_expert.right_gemma_expert.model.layers.5.post_attention_layernorm.dense.bias",
276
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.q_proj.weight",
277
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.k_proj.weight",
278
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.v_proj.weight",
279
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.self_attn.o_proj.weight",
280
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.mlp.gate_proj.weight",
281
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.mlp.up_proj.weight",
282
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.mlp.down_proj.weight",
283
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.input_layernorm.dense.weight",
284
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.input_layernorm.dense.bias",
285
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.post_attention_layernorm.dense.weight",
286
+ "paligemma_with_expert.right_gemma_expert.model.layers.6.post_attention_layernorm.dense.bias",
287
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.q_proj.weight",
288
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.k_proj.weight",
289
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.v_proj.weight",
290
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.self_attn.o_proj.weight",
291
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.mlp.gate_proj.weight",
292
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.mlp.up_proj.weight",
293
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.mlp.down_proj.weight",
294
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.input_layernorm.dense.weight",
295
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.input_layernorm.dense.bias",
296
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.post_attention_layernorm.dense.weight",
297
+ "paligemma_with_expert.right_gemma_expert.model.layers.7.post_attention_layernorm.dense.bias",
298
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.q_proj.weight",
299
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.k_proj.weight",
300
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.v_proj.weight",
301
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.self_attn.o_proj.weight",
302
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.mlp.gate_proj.weight",
303
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.mlp.up_proj.weight",
304
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.mlp.down_proj.weight",
305
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.input_layernorm.dense.weight",
306
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.input_layernorm.dense.bias",
307
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.post_attention_layernorm.dense.weight",
308
+ "paligemma_with_expert.right_gemma_expert.model.layers.8.post_attention_layernorm.dense.bias",
309
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.q_proj.weight",
310
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.k_proj.weight",
311
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.v_proj.weight",
312
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.self_attn.o_proj.weight",
313
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.mlp.gate_proj.weight",
314
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.mlp.up_proj.weight",
315
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.mlp.down_proj.weight",
316
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.input_layernorm.dense.weight",
317
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.input_layernorm.dense.bias",
318
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.post_attention_layernorm.dense.weight",
319
+ "paligemma_with_expert.right_gemma_expert.model.layers.9.post_attention_layernorm.dense.bias",
320
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.q_proj.weight",
321
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.k_proj.weight",
322
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.v_proj.weight",
323
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.self_attn.o_proj.weight",
324
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.mlp.gate_proj.weight",
325
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.mlp.up_proj.weight",
326
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.mlp.down_proj.weight",
327
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.input_layernorm.dense.weight",
328
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.input_layernorm.dense.bias",
329
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.post_attention_layernorm.dense.weight",
330
+ "paligemma_with_expert.right_gemma_expert.model.layers.10.post_attention_layernorm.dense.bias",
331
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.q_proj.weight",
332
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.k_proj.weight",
333
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.v_proj.weight",
334
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.self_attn.o_proj.weight",
335
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.mlp.gate_proj.weight",
336
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.mlp.up_proj.weight",
337
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.mlp.down_proj.weight",
338
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.input_layernorm.dense.weight",
339
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.input_layernorm.dense.bias",
340
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.post_attention_layernorm.dense.weight",
341
+ "paligemma_with_expert.right_gemma_expert.model.layers.11.post_attention_layernorm.dense.bias",
342
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.q_proj.weight",
343
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.k_proj.weight",
344
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.v_proj.weight",
345
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.self_attn.o_proj.weight",
346
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.mlp.gate_proj.weight",
347
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.mlp.up_proj.weight",
348
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.mlp.down_proj.weight",
349
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.input_layernorm.dense.weight",
350
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.input_layernorm.dense.bias",
351
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.post_attention_layernorm.dense.weight",
352
+ "paligemma_with_expert.right_gemma_expert.model.layers.12.post_attention_layernorm.dense.bias",
353
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.q_proj.weight",
354
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.k_proj.weight",
355
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.v_proj.weight",
356
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.self_attn.o_proj.weight",
357
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.mlp.gate_proj.weight",
358
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.mlp.up_proj.weight",
359
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.mlp.down_proj.weight",
360
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.input_layernorm.dense.weight",
361
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.input_layernorm.dense.bias",
362
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.post_attention_layernorm.dense.weight",
363
+ "paligemma_with_expert.right_gemma_expert.model.layers.13.post_attention_layernorm.dense.bias",
364
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.q_proj.weight",
365
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.k_proj.weight",
366
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.v_proj.weight",
367
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.self_attn.o_proj.weight",
368
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.mlp.gate_proj.weight",
369
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.mlp.up_proj.weight",
370
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.mlp.down_proj.weight",
371
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.input_layernorm.dense.weight",
372
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.input_layernorm.dense.bias",
373
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.post_attention_layernorm.dense.weight",
374
+ "paligemma_with_expert.right_gemma_expert.model.layers.14.post_attention_layernorm.dense.bias",
375
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.q_proj.weight",
376
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.k_proj.weight",
377
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.v_proj.weight",
378
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.self_attn.o_proj.weight",
379
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.mlp.gate_proj.weight",
380
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.mlp.up_proj.weight",
381
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.mlp.down_proj.weight",
382
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.input_layernorm.dense.weight",
383
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.input_layernorm.dense.bias",
384
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.post_attention_layernorm.dense.weight",
385
+ "paligemma_with_expert.right_gemma_expert.model.layers.15.post_attention_layernorm.dense.bias",
386
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.q_proj.weight",
387
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.k_proj.weight",
388
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.v_proj.weight",
389
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.self_attn.o_proj.weight",
390
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.mlp.gate_proj.weight",
391
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.mlp.up_proj.weight",
392
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.mlp.down_proj.weight",
393
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.input_layernorm.dense.weight",
394
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.input_layernorm.dense.bias",
395
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.post_attention_layernorm.dense.weight",
396
+ "paligemma_with_expert.right_gemma_expert.model.layers.16.post_attention_layernorm.dense.bias",
397
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.q_proj.weight",
398
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.k_proj.weight",
399
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.v_proj.weight",
400
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.self_attn.o_proj.weight",
401
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.mlp.gate_proj.weight",
402
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.mlp.up_proj.weight",
403
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.mlp.down_proj.weight",
404
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.input_layernorm.dense.weight",
405
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.input_layernorm.dense.bias",
406
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.post_attention_layernorm.dense.weight",
407
+ "paligemma_with_expert.right_gemma_expert.model.layers.17.post_attention_layernorm.dense.bias",
408
+ "paligemma_with_expert.right_gemma_expert.model.norm.dense.weight",
409
+ "paligemma_with_expert.right_gemma_expert.model.norm.dense.bias",
410
+ "paligemma_with_expert.right_gemma_expert.lm_head.weight",
411
+ "action_in_proj_arms.0.weight",
412
+ "action_in_proj_arms.0.bias",
413
+ "action_in_proj_arms.1.weight",
414
+ "action_in_proj_arms.1.bias",
415
+ "action_out_proj_arms.0.weight",
416
+ "action_out_proj_arms.0.bias",
417
+ "action_out_proj_arms.1.weight",
418
+ "action_out_proj_arms.1.bias"
419
+ ],
420
+ "load_state_unexpected_keys": [
421
+ "action_in_proj.bias",
422
+ "action_in_proj.weight",
423
+ "action_out_proj.bias",
424
+ "action_out_proj.weight",
425
+ "paligemma_with_expert.gemma_expert.lm_head.weight",
426
+ "paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.bias",
427
+ "paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.weight",
428
+ "paligemma_with_expert.gemma_expert.model.layers.0.mlp.down_proj.weight",
429
+ "paligemma_with_expert.gemma_expert.model.layers.0.mlp.gate_proj.weight",
430
+ "paligemma_with_expert.gemma_expert.model.layers.0.mlp.up_proj.weight",
431
+ "paligemma_with_expert.gemma_expert.model.layers.0.post_attention_layernorm.dense.bias",
432
+ "paligemma_with_expert.gemma_expert.model.layers.0.post_attention_layernorm.dense.weight",
433
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.k_proj.weight",
434
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.o_proj.weight",
435
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.q_proj.weight",
436
+ "paligemma_with_expert.gemma_expert.model.layers.0.self_attn.v_proj.weight",
437
+ "paligemma_with_expert.gemma_expert.model.layers.1.input_layernorm.dense.bias",
438
+ "paligemma_with_expert.gemma_expert.model.layers.1.input_layernorm.dense.weight",
439
+ "paligemma_with_expert.gemma_expert.model.layers.1.mlp.down_proj.weight",
440
+ "paligemma_with_expert.gemma_expert.model.layers.1.mlp.gate_proj.weight",
441
+ "paligemma_with_expert.gemma_expert.model.layers.1.mlp.up_proj.weight",
442
+ "paligemma_with_expert.gemma_expert.model.layers.1.post_attention_layernorm.dense.bias",
443
+ "paligemma_with_expert.gemma_expert.model.layers.1.post_attention_layernorm.dense.weight",
444
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.k_proj.weight",
445
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.o_proj.weight",
446
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.q_proj.weight",
447
+ "paligemma_with_expert.gemma_expert.model.layers.1.self_attn.v_proj.weight",
448
+ "paligemma_with_expert.gemma_expert.model.layers.10.input_layernorm.dense.bias",
449
+ "paligemma_with_expert.gemma_expert.model.layers.10.input_layernorm.dense.weight",
450
+ "paligemma_with_expert.gemma_expert.model.layers.10.mlp.down_proj.weight",
451
+ "paligemma_with_expert.gemma_expert.model.layers.10.mlp.gate_proj.weight",
452
+ "paligemma_with_expert.gemma_expert.model.layers.10.mlp.up_proj.weight",
453
+ "paligemma_with_expert.gemma_expert.model.layers.10.post_attention_layernorm.dense.bias",
454
+ "paligemma_with_expert.gemma_expert.model.layers.10.post_attention_layernorm.dense.weight",
455
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.k_proj.weight",
456
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.o_proj.weight",
457
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.q_proj.weight",
458
+ "paligemma_with_expert.gemma_expert.model.layers.10.self_attn.v_proj.weight",
459
+ "paligemma_with_expert.gemma_expert.model.layers.11.input_layernorm.dense.bias",
460
+ "paligemma_with_expert.gemma_expert.model.layers.11.input_layernorm.dense.weight",
461
+ "paligemma_with_expert.gemma_expert.model.layers.11.mlp.down_proj.weight",
462
+ "paligemma_with_expert.gemma_expert.model.layers.11.mlp.gate_proj.weight",
463
+ "paligemma_with_expert.gemma_expert.model.layers.11.mlp.up_proj.weight",
464
+ "paligemma_with_expert.gemma_expert.model.layers.11.post_attention_layernorm.dense.bias",
465
+ "paligemma_with_expert.gemma_expert.model.layers.11.post_attention_layernorm.dense.weight",
466
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.k_proj.weight",
467
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.o_proj.weight",
468
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.q_proj.weight",
469
+ "paligemma_with_expert.gemma_expert.model.layers.11.self_attn.v_proj.weight",
470
+ "paligemma_with_expert.gemma_expert.model.layers.12.input_layernorm.dense.bias",
471
+ "paligemma_with_expert.gemma_expert.model.layers.12.input_layernorm.dense.weight",
472
+ "paligemma_with_expert.gemma_expert.model.layers.12.mlp.down_proj.weight",
473
+ "paligemma_with_expert.gemma_expert.model.layers.12.mlp.gate_proj.weight",
474
+ "paligemma_with_expert.gemma_expert.model.layers.12.mlp.up_proj.weight",
475
+ "paligemma_with_expert.gemma_expert.model.layers.12.post_attention_layernorm.dense.bias",
476
+ "paligemma_with_expert.gemma_expert.model.layers.12.post_attention_layernorm.dense.weight",
477
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.k_proj.weight",
478
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.o_proj.weight",
479
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.q_proj.weight",
480
+ "paligemma_with_expert.gemma_expert.model.layers.12.self_attn.v_proj.weight",
481
+ "paligemma_with_expert.gemma_expert.model.layers.13.input_layernorm.dense.bias",
482
+ "paligemma_with_expert.gemma_expert.model.layers.13.input_layernorm.dense.weight",
483
+ "paligemma_with_expert.gemma_expert.model.layers.13.mlp.down_proj.weight",
484
+ "paligemma_with_expert.gemma_expert.model.layers.13.mlp.gate_proj.weight",
485
+ "paligemma_with_expert.gemma_expert.model.layers.13.mlp.up_proj.weight",
486
+ "paligemma_with_expert.gemma_expert.model.layers.13.post_attention_layernorm.dense.bias",
487
+ "paligemma_with_expert.gemma_expert.model.layers.13.post_attention_layernorm.dense.weight",
488
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.k_proj.weight",
489
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.o_proj.weight",
490
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.q_proj.weight",
491
+ "paligemma_with_expert.gemma_expert.model.layers.13.self_attn.v_proj.weight",
492
+ "paligemma_with_expert.gemma_expert.model.layers.14.input_layernorm.dense.bias",
493
+ "paligemma_with_expert.gemma_expert.model.layers.14.input_layernorm.dense.weight",
494
+ "paligemma_with_expert.gemma_expert.model.layers.14.mlp.down_proj.weight",
495
+ "paligemma_with_expert.gemma_expert.model.layers.14.mlp.gate_proj.weight",
496
+ "paligemma_with_expert.gemma_expert.model.layers.14.mlp.up_proj.weight",
497
+ "paligemma_with_expert.gemma_expert.model.layers.14.post_attention_layernorm.dense.bias",
498
+ "paligemma_with_expert.gemma_expert.model.layers.14.post_attention_layernorm.dense.weight",
499
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.k_proj.weight",
500
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.o_proj.weight",
501
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.q_proj.weight",
502
+ "paligemma_with_expert.gemma_expert.model.layers.14.self_attn.v_proj.weight",
503
+ "paligemma_with_expert.gemma_expert.model.layers.15.input_layernorm.dense.bias",
504
+ "paligemma_with_expert.gemma_expert.model.layers.15.input_layernorm.dense.weight",
505
+ "paligemma_with_expert.gemma_expert.model.layers.15.mlp.down_proj.weight",
506
+ "paligemma_with_expert.gemma_expert.model.layers.15.mlp.gate_proj.weight",
507
+ "paligemma_with_expert.gemma_expert.model.layers.15.mlp.up_proj.weight",
508
+ "paligemma_with_expert.gemma_expert.model.layers.15.post_attention_layernorm.dense.bias",
509
+ "paligemma_with_expert.gemma_expert.model.layers.15.post_attention_layernorm.dense.weight",
510
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.k_proj.weight",
511
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.o_proj.weight",
512
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.q_proj.weight",
513
+ "paligemma_with_expert.gemma_expert.model.layers.15.self_attn.v_proj.weight",
514
+ "paligemma_with_expert.gemma_expert.model.layers.16.input_layernorm.dense.bias",
515
+ "paligemma_with_expert.gemma_expert.model.layers.16.input_layernorm.dense.weight",
516
+ "paligemma_with_expert.gemma_expert.model.layers.16.mlp.down_proj.weight",
517
+ "paligemma_with_expert.gemma_expert.model.layers.16.mlp.gate_proj.weight",
518
+ "paligemma_with_expert.gemma_expert.model.layers.16.mlp.up_proj.weight",
519
+ "paligemma_with_expert.gemma_expert.model.layers.16.post_attention_layernorm.dense.bias",
520
+ "paligemma_with_expert.gemma_expert.model.layers.16.post_attention_layernorm.dense.weight",
521
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.k_proj.weight",
522
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.o_proj.weight",
523
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.q_proj.weight",
524
+ "paligemma_with_expert.gemma_expert.model.layers.16.self_attn.v_proj.weight",
525
+ "paligemma_with_expert.gemma_expert.model.layers.17.input_layernorm.dense.bias",
526
+ "paligemma_with_expert.gemma_expert.model.layers.17.input_layernorm.dense.weight",
527
+ "paligemma_with_expert.gemma_expert.model.layers.17.mlp.down_proj.weight",
528
+ "paligemma_with_expert.gemma_expert.model.layers.17.mlp.gate_proj.weight",
529
+ "paligemma_with_expert.gemma_expert.model.layers.17.mlp.up_proj.weight",
530
+ "paligemma_with_expert.gemma_expert.model.layers.17.post_attention_layernorm.dense.bias",
531
+ "paligemma_with_expert.gemma_expert.model.layers.17.post_attention_layernorm.dense.weight",
532
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.k_proj.weight",
533
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.o_proj.weight",
534
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.q_proj.weight",
535
+ "paligemma_with_expert.gemma_expert.model.layers.17.self_attn.v_proj.weight",
536
+ "paligemma_with_expert.gemma_expert.model.layers.2.input_layernorm.dense.bias",
537
+ "paligemma_with_expert.gemma_expert.model.layers.2.input_layernorm.dense.weight",
538
+ "paligemma_with_expert.gemma_expert.model.layers.2.mlp.down_proj.weight",
539
+ "paligemma_with_expert.gemma_expert.model.layers.2.mlp.gate_proj.weight",
540
+ "paligemma_with_expert.gemma_expert.model.layers.2.mlp.up_proj.weight",
541
+ "paligemma_with_expert.gemma_expert.model.layers.2.post_attention_layernorm.dense.bias",
542
+ "paligemma_with_expert.gemma_expert.model.layers.2.post_attention_layernorm.dense.weight",
543
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.k_proj.weight",
544
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.o_proj.weight",
545
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.q_proj.weight",
546
+ "paligemma_with_expert.gemma_expert.model.layers.2.self_attn.v_proj.weight",
547
+ "paligemma_with_expert.gemma_expert.model.layers.3.input_layernorm.dense.bias",
548
+ "paligemma_with_expert.gemma_expert.model.layers.3.input_layernorm.dense.weight",
549
+ "paligemma_with_expert.gemma_expert.model.layers.3.mlp.down_proj.weight",
550
+ "paligemma_with_expert.gemma_expert.model.layers.3.mlp.gate_proj.weight",
551
+ "paligemma_with_expert.gemma_expert.model.layers.3.mlp.up_proj.weight",
552
+ "paligemma_with_expert.gemma_expert.model.layers.3.post_attention_layernorm.dense.bias",
553
+ "paligemma_with_expert.gemma_expert.model.layers.3.post_attention_layernorm.dense.weight",
554
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.k_proj.weight",
555
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.o_proj.weight",
556
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.q_proj.weight",
557
+ "paligemma_with_expert.gemma_expert.model.layers.3.self_attn.v_proj.weight",
558
+ "paligemma_with_expert.gemma_expert.model.layers.4.input_layernorm.dense.bias",
559
+ "paligemma_with_expert.gemma_expert.model.layers.4.input_layernorm.dense.weight",
560
+ "paligemma_with_expert.gemma_expert.model.layers.4.mlp.down_proj.weight",
561
+ "paligemma_with_expert.gemma_expert.model.layers.4.mlp.gate_proj.weight",
562
+ "paligemma_with_expert.gemma_expert.model.layers.4.mlp.up_proj.weight",
563
+ "paligemma_with_expert.gemma_expert.model.layers.4.post_attention_layernorm.dense.bias",
564
+ "paligemma_with_expert.gemma_expert.model.layers.4.post_attention_layernorm.dense.weight",
565
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.k_proj.weight",
566
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.o_proj.weight",
567
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.q_proj.weight",
568
+ "paligemma_with_expert.gemma_expert.model.layers.4.self_attn.v_proj.weight",
569
+ "paligemma_with_expert.gemma_expert.model.layers.5.input_layernorm.dense.bias",
570
+ "paligemma_with_expert.gemma_expert.model.layers.5.input_layernorm.dense.weight",
571
+ "paligemma_with_expert.gemma_expert.model.layers.5.mlp.down_proj.weight",
572
+ "paligemma_with_expert.gemma_expert.model.layers.5.mlp.gate_proj.weight",
573
+ "paligemma_with_expert.gemma_expert.model.layers.5.mlp.up_proj.weight",
574
+ "paligemma_with_expert.gemma_expert.model.layers.5.post_attention_layernorm.dense.bias",
575
+ "paligemma_with_expert.gemma_expert.model.layers.5.post_attention_layernorm.dense.weight",
576
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.k_proj.weight",
577
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.o_proj.weight",
578
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.q_proj.weight",
579
+ "paligemma_with_expert.gemma_expert.model.layers.5.self_attn.v_proj.weight",
580
+ "paligemma_with_expert.gemma_expert.model.layers.6.input_layernorm.dense.bias",
581
+ "paligemma_with_expert.gemma_expert.model.layers.6.input_layernorm.dense.weight",
582
+ "paligemma_with_expert.gemma_expert.model.layers.6.mlp.down_proj.weight",
583
+ "paligemma_with_expert.gemma_expert.model.layers.6.mlp.gate_proj.weight",
584
+ "paligemma_with_expert.gemma_expert.model.layers.6.mlp.up_proj.weight",
585
+ "paligemma_with_expert.gemma_expert.model.layers.6.post_attention_layernorm.dense.bias",
586
+ "paligemma_with_expert.gemma_expert.model.layers.6.post_attention_layernorm.dense.weight",
587
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.k_proj.weight",
588
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.o_proj.weight",
589
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.q_proj.weight",
590
+ "paligemma_with_expert.gemma_expert.model.layers.6.self_attn.v_proj.weight",
591
+ "paligemma_with_expert.gemma_expert.model.layers.7.input_layernorm.dense.bias",
592
+ "paligemma_with_expert.gemma_expert.model.layers.7.input_layernorm.dense.weight",
593
+ "paligemma_with_expert.gemma_expert.model.layers.7.mlp.down_proj.weight",
594
+ "paligemma_with_expert.gemma_expert.model.layers.7.mlp.gate_proj.weight",
595
+ "paligemma_with_expert.gemma_expert.model.layers.7.mlp.up_proj.weight",
596
+ "paligemma_with_expert.gemma_expert.model.layers.7.post_attention_layernorm.dense.bias",
597
+ "paligemma_with_expert.gemma_expert.model.layers.7.post_attention_layernorm.dense.weight",
598
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.k_proj.weight",
599
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.o_proj.weight",
600
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.q_proj.weight",
601
+ "paligemma_with_expert.gemma_expert.model.layers.7.self_attn.v_proj.weight",
602
+ "paligemma_with_expert.gemma_expert.model.layers.8.input_layernorm.dense.bias",
603
+ "paligemma_with_expert.gemma_expert.model.layers.8.input_layernorm.dense.weight",
604
+ "paligemma_with_expert.gemma_expert.model.layers.8.mlp.down_proj.weight",
605
+ "paligemma_with_expert.gemma_expert.model.layers.8.mlp.gate_proj.weight",
606
+ "paligemma_with_expert.gemma_expert.model.layers.8.mlp.up_proj.weight",
607
+ "paligemma_with_expert.gemma_expert.model.layers.8.post_attention_layernorm.dense.bias",
608
+ "paligemma_with_expert.gemma_expert.model.layers.8.post_attention_layernorm.dense.weight",
609
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.k_proj.weight",
610
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.o_proj.weight",
611
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.q_proj.weight",
612
+ "paligemma_with_expert.gemma_expert.model.layers.8.self_attn.v_proj.weight",
613
+ "paligemma_with_expert.gemma_expert.model.layers.9.input_layernorm.dense.bias",
614
+ "paligemma_with_expert.gemma_expert.model.layers.9.input_layernorm.dense.weight",
615
+ "paligemma_with_expert.gemma_expert.model.layers.9.mlp.down_proj.weight",
616
+ "paligemma_with_expert.gemma_expert.model.layers.9.mlp.gate_proj.weight",
617
+ "paligemma_with_expert.gemma_expert.model.layers.9.mlp.up_proj.weight",
618
+ "paligemma_with_expert.gemma_expert.model.layers.9.post_attention_layernorm.dense.bias",
619
+ "paligemma_with_expert.gemma_expert.model.layers.9.post_attention_layernorm.dense.weight",
620
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.k_proj.weight",
621
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.o_proj.weight",
622
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.q_proj.weight",
623
+ "paligemma_with_expert.gemma_expert.model.layers.9.self_attn.v_proj.weight",
624
+ "paligemma_with_expert.gemma_expert.model.norm.dense.bias",
625
+ "paligemma_with_expert.gemma_expert.model.norm.dense.weight"
626
+ ],
627
+ "output_path": "/workspace/checkpoints/pi05_base_split_independent_packed_from_single",
628
+ "right_expert_max_abs_diff": 0.0,
629
+ "right_input_projection_max_abs_diff": 0.0,
630
+ "right_output_projection_max_abs_diff": 0.0,
631
+ "single_ckpt": "/workspace/checkpoints/pi05_base_single_pytorch",
632
+ "warm_start_exact": true
633
+ }
artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5164534199e5396320dfc44ac251f50e117d92cb82d90aa3f3f2fe8e82c620dc
3
+ size 9088652560
artifacts/twin_split_expert_bringup_20260310/repro/commands_bringup.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ export HF_TOKEN="${HF_TOKEN:-}"
5
+ export HF_HOME=/workspace/.hf
6
+ export HF_HUB_CACHE=/workspace/.hf/hub
7
+ export HF_DATASETS_CACHE=/workspace/.hf/datasets
8
+ export HUGGINGFACE_HUB_CACHE=/workspace/.hf/hub
9
+ export XDG_CACHE_HOME=/workspace/.cache
10
+ export OPENPI_LEROBOT_HOME=/workspace/lerobot
11
+ export OPENPI_TORCH_COMPILE_SAMPLE_ACTIONS=0
12
+ export TOKENIZERS_PARALLELISM=false
13
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
14
+ export PYTHONPATH=/workspace/pi05tests/openpi/src
15
+
16
+ cd /workspace/pi05tests/openpi
17
+
18
+ # Create exact split warm-start checkpoints from the single-head PyTorch base checkpoint.
19
+ ./.venv/bin/python -u scripts/init_parallel_pi05_from_single_pytorch.py \
20
+ --single_ckpt /workspace/checkpoints/pi05_base_single_pytorch \
21
+ --config_name pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k \
22
+ --output_path /workspace/pi05tests/artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single
23
+
24
+ ./.venv/bin/python -u scripts/init_parallel_pi05_from_single_pytorch.py \
25
+ --single_ckpt /workspace/checkpoints/pi05_base_single_pytorch \
26
+ --config_name pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k \
27
+ --output_path /workspace/pi05tests/artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single
28
+
29
+ # Check split invariants.
30
+ ./.venv/bin/python -u scripts/check_split_expert_invariants.py \
31
+ --config_name pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k \
32
+ --checkpoint_dir /workspace/pi05tests/artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_independent_packed_from_single
33
+
34
+ ./.venv/bin/python -u scripts/check_split_expert_invariants.py \
35
+ --config_name pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k \
36
+ --checkpoint_dir /workspace/pi05tests/artifacts/twin_split_expert_bringup_20260310/bootstrap_checkpoints/pi05_base_split_communicating_packed_from_single
37
+
38
+ # Detached real-data smoke runs.
39
+ CUDA_VISIBLE_DEVICES=0 setsid -f ./.venv/bin/python -u scripts/train_pytorch.py \
40
+ pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k \
41
+ --exp_name split_independent_real_smoke3_r2 \
42
+ --num_train_steps 3 \
43
+ --save_interval 3 \
44
+ --log_interval 1 \
45
+ --batch_size 1 \
46
+ --num_workers 0 \
47
+ --pytorch_training_precision float32 \
48
+ > run_logs/split_independent_real_smoke3_r2.log 2>&1
49
+
50
+ CUDA_VISIBLE_DEVICES=1 setsid -f ./.venv/bin/python -u scripts/train_pytorch.py \
51
+ pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k \
52
+ --exp_name split_communicating_real_smoke3 \
53
+ --num_train_steps 3 \
54
+ --save_interval 3 \
55
+ --log_interval 1 \
56
+ --batch_size 1 \
57
+ --num_workers 0 \
58
+ --pytorch_training_precision float32 \
59
+ > run_logs/split_communicating_real_smoke3.log 2>&1
60
+
61
+ # Detached short real-data training runs.
62
+ CUDA_VISIBLE_DEVICES=0 setsid -f ./.venv/bin/python -u scripts/train_pytorch.py \
63
+ pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k \
64
+ --exp_name split_independent_real_train20 \
65
+ --num_train_steps 20 \
66
+ --save_interval 20 \
67
+ --log_interval 1 \
68
+ --batch_size 1 \
69
+ --num_workers 0 \
70
+ --pytorch_training_precision float32 \
71
+ > run_logs/split_independent_real_train20.log 2>&1
72
+
73
+ CUDA_VISIBLE_DEVICES=1 setsid -f ./.venv/bin/python -u scripts/train_pytorch.py \
74
+ pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k \
75
+ --exp_name split_communicating_real_train20 \
76
+ --num_train_steps 20 \
77
+ --save_interval 20 \
78
+ --log_interval 1 \
79
+ --batch_size 1 \
80
+ --num_workers 0 \
81
+ --pytorch_training_precision float32 \
82
+ > run_logs/split_communicating_real_train20.log 2>&1
artifacts/twin_split_expert_bringup_20260310/run_logs/split_communicating_real_smoke3.log ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 19:55:02.788 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3 (22110:train_pytorch.py:533)
2
+ 19:55:02.789 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (22110:train_pytorch.py:552)
3
+ 19:55:02.865 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train (22110:config.py:234)
4
+ 19:55:02.867 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7ec79fca8910>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (22110:data_loader.py:284)
29
+ 19:55:09.225 [I] JAX version 0.5.3 available. (22110:config.py:125)
30
+ 19:55:34.099 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (22110:data_loader.py:148)
31
+ 19:55:34.205 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (22110:video_utils.py:36)
32
+ 19:56:38.376 [I] local_batch_size: 1 (22110:data_loader.py:365)
33
+ 19:58:25.969 [I] Enabled gradient checkpointing for PI0Pytorch model (22110:pi0_pytorch.py:138)
34
+ 19:58:25.971 [I] Enabled gradient checkpointing for memory optimization (22110:train_pytorch.py:624)
35
+ 19:58:25.972 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (22110:train_pytorch.py:493)
36
+ 19:58:25.972 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22110:train_pytorch.py:653)
37
+ 19:58:29.565 [I] Weight loading missing key count: 0 (22110:train_pytorch.py:657)
38
+ 19:58:29.566 [I] Weight loading missing keys: set() (22110:train_pytorch.py:658)
39
+ 19:58:29.566 [I] Weight loading unexpected key count: 0 (22110:train_pytorch.py:659)
40
+ 19:58:29.566 [I] Weight loading unexpected keys: [] (22110:train_pytorch.py:660)
41
+ 19:58:29.567 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22110:train_pytorch.py:661)
42
+ 19:58:29.571 [I] Running on: 963c158043aa | world_size=1 (22110:train_pytorch.py:701)
43
+ 19:58:29.571 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=3 (22110:train_pytorch.py:702)
44
+ 19:58:29.572 [I] Memory optimizations: gradient_checkpointing=True (22110:train_pytorch.py:705)
45
+ 19:58:29.572 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (22110:train_pytorch.py:706)
46
+ 19:58:29.573 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (22110:train_pytorch.py:707)
47
+ 19:58:29.573 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (22110:train_pytorch.py:710)
48
+ 19:58:29.573 [I] EMA is not supported for PyTorch training (22110:train_pytorch.py:713)
49
+ 19:58:29.574 [I] Training precision: float32 (22110:train_pytorch.py:714)
50
+ 19:58:29.590 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k (22110:train_pytorch.py:308)
51
+ 19:58:29.590 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (22110:train_pytorch.py:309)
52
+ 19:58:29.591 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (22110:train_pytorch.py:310)
53
+ 19:58:29.592 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (22110:train_pytorch.py:311)
54
+ 19:58:29.592 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22110:train_pytorch.py:312)
55
+ 19:58:29.592 [I] Model type: split_communicating (22110:train_pytorch.py:313)
56
+ 19:58:29.593 [I] Packed transforms active: True (22110:train_pytorch.py:314)
57
+ 19:58:29.593 [I] World size: 1 (22110:train_pytorch.py:315)
58
+ 19:58:29.594 [I] Batch size: local=1, global=1 (22110:train_pytorch.py:316)
59
+ 19:58:29.594 [I] num_workers: 0 (22110:train_pytorch.py:317)
60
+ 19:58:29.595 [I] Precision: float32 (22110:train_pytorch.py:318)
61
+ 19:58:29.595 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (22110:train_pytorch.py:319)
62
+ 19:58:29.595 [I] Save/log intervals: save_interval=3, log_interval=1 (22110:train_pytorch.py:326)
63
+ 19:58:29.596 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (22110:train_pytorch.py:327)
64
+ 19:58:29.596 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (22110:train_pytorch.py:328)
65
+ 19:58:29.597 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (22110:train_pytorch.py:329)
66
+ 19:58:29.597 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (22110:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 19:58:31.354 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22110:train_pytorch.py:831)
70
+ 19:58:31.355 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22110:train_pytorch.py:835)
71
+ 19:58:31.356 [I] debug_step=1 prompt_token_lengths=[75] (22110:train_pytorch.py:838)
72
+ 19:58:31.356 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (22110:train_pytorch.py:839)
73
+ 19:58:31.357 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (22110:train_pytorch.py:842)
74
+ 19:58:31.358 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22110:train_pytorch.py:845)
75
+ 19:58:31.372 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22110:train_pytorch.py:849)
76
+ 19:58:31.372 [I] debug_step=1 lr=9.96e-08 grad_norm=60.0472 data_time=0.3311s step_time=1.3966s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (22110:train_pytorch.py:854)
77
+ 19:58:31.373 [I] debug_step=1 grad_shared_backbone=36.9945 grad_left_action_in=2.3769 grad_right_action_in=1.7630 grad_left_expert=31.1244 grad_right_expert=27.8917 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 (22110:train_pytorch.py:862)
78
+ 19:58:31.374 [I] step=1 loss=3.8411 smoothed_loss=3.8411 lr=9.96e-08 grad_norm=60.0472 step_time=1.3966s data_time=0.3311s it/s=0.555 eta_to_3=3.6s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 grad_left_action_in=2.3769 grad_left_expert=31.1244 grad_right_action_in=1.7630 grad_right_expert=27.8917 grad_shared_backbone=36.9945 (22110:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 19:58:32.164 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22110:train_pytorch.py:831)
82
+ 19:58:32.165 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22110:train_pytorch.py:835)
83
+ 19:58:32.166 [I] debug_step=2 prompt_token_lengths=[76] (22110:train_pytorch.py:838)
84
+ 19:58:32.166 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (22110:train_pytorch.py:839)
85
+ 19:58:32.167 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (22110:train_pytorch.py:842)
86
+ 19:58:32.168 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22110:train_pytorch.py:845)
87
+ 19:58:32.168 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22110:train_pytorch.py:849)
88
+ 19:58:32.169 [I] debug_step=2 lr=1.99e-07 grad_norm=10.7300 data_time=0.1812s step_time=0.6234s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22110:train_pytorch.py:854)
89
+ 19:58:32.169 [I] debug_step=2 grad_shared_backbone=9.2018 grad_left_action_in=0.1651 grad_right_action_in=0.1485 grad_left_expert=2.5032 grad_right_expert=2.3988 grad_action_out=4.0772 grad_cross_arm_comm=0.0166 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 (22110:train_pytorch.py:862)
90
+ 19:58:32.170 [I] step=2 loss=1.1389 smoothed_loss=3.5709 lr=1.99e-07 grad_norm=10.7300 step_time=0.6234s data_time=0.1812s it/s=1.257 eta_to_3=0.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.0772 grad_cross_arm_comm=0.0166 grad_left_action_in=0.1651 grad_left_expert=2.5032 grad_right_action_in=0.1485 grad_right_expert=2.3988 grad_shared_backbone=9.2018 (22110:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 19:58:32.708 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22110:train_pytorch.py:831)
94
+ 19:58:32.709 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22110:train_pytorch.py:835)
95
+ 19:58:32.709 [I] debug_step=3 prompt_token_lengths=[75] (22110:train_pytorch.py:838)
96
+ 19:58:32.710 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (22110:train_pytorch.py:839)
97
+ 19:58:32.711 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (22110:train_pytorch.py:842)
98
+ 19:58:32.711 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22110:train_pytorch.py:845)
99
+ 19:58:32.712 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22110:train_pytorch.py:849)
100
+ 19:58:32.712 [I] debug_step=3 lr=2.99e-07 grad_norm=343.7256 data_time=0.1312s step_time=0.4126s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22110:train_pytorch.py:854)
101
+ 19:58:32.713 [I] debug_step=3 grad_shared_backbone=215.2880 grad_left_action_in=4.7981 grad_right_action_in=9.5346 grad_left_expert=72.6437 grad_right_expert=227.6029 grad_action_out=23.7709 grad_cross_arm_comm=3.3555 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 (22110:train_pytorch.py:862)
102
+ 19:58:32.713 [I] step=3 loss=5.0518 smoothed_loss=3.7190 lr=2.99e-07 grad_norm=343.7256 step_time=0.4126s data_time=0.1312s it/s=1.843 eta_to_3=0.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=23.7709 grad_cross_arm_comm=3.3555 grad_left_action_in=4.7981 grad_left_expert=72.6437 grad_right_action_in=9.5346 grad_right_expert=227.6029 grad_shared_backbone=215.2880 (22110:train_pytorch.py:882)
103
+ 20:01:38.475 [I] Saved checkpoint at step 3 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3 (22110:train_pytorch.py:378)
104
+
artifacts/twin_split_expert_bringup_20260310/run_logs/split_communicating_real_train20.log ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20:03:03.480 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20 (22938:train_pytorch.py:533)
2
+ 20:03:03.486 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (22938:train_pytorch.py:552)
3
+ 20:03:03.634 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train (22938:config.py:234)
4
+ 20:03:03.637 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7303f4ce5b90>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (22938:data_loader.py:393)
29
+ 20:03:15.223 [I] JAX version 0.5.3 available. (22938:config.py:125)
30
+ 20:04:19.283 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (22938:data_loader.py:148)
31
+ 20:04:19.378 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (22938:video_utils.py:36)
32
+ 20:09:10.375 [I] local_batch_size: 1 (22938:data_loader.py:474)
33
+ 20:11:59.735 [I] Enabled gradient checkpointing for PI0Pytorch model (22938:pi0_pytorch.py:138)
34
+ 20:11:59.737 [I] Enabled gradient checkpointing for memory optimization (22938:train_pytorch.py:624)
35
+ 20:11:59.738 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (22938:train_pytorch.py:493)
36
+ 20:11:59.738 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22938:train_pytorch.py:653)
37
+ 20:12:04.492 [I] Weight loading missing key count: 0 (22938:train_pytorch.py:657)
38
+ 20:12:04.492 [I] Weight loading missing keys: set() (22938:train_pytorch.py:658)
39
+ 20:12:04.492 [I] Weight loading unexpected key count: 0 (22938:train_pytorch.py:659)
40
+ 20:12:04.493 [I] Weight loading unexpected keys: [] (22938:train_pytorch.py:660)
41
+ 20:12:04.493 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22938:train_pytorch.py:661)
42
+ 20:12:04.497 [I] Running on: 963c158043aa | world_size=1 (22938:train_pytorch.py:701)
43
+ 20:12:04.498 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=20 (22938:train_pytorch.py:702)
44
+ 20:12:04.498 [I] Memory optimizations: gradient_checkpointing=True (22938:train_pytorch.py:705)
45
+ 20:12:04.499 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (22938:train_pytorch.py:706)
46
+ 20:12:04.499 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (22938:train_pytorch.py:707)
47
+ 20:12:04.499 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (22938:train_pytorch.py:710)
48
+ 20:12:04.500 [I] EMA is not supported for PyTorch training (22938:train_pytorch.py:713)
49
+ 20:12:04.500 [I] Training precision: float32 (22938:train_pytorch.py:714)
50
+ 20:12:04.509 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k (22938:train_pytorch.py:308)
51
+ 20:12:04.509 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (22938:train_pytorch.py:309)
52
+ 20:12:04.510 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (22938:train_pytorch.py:310)
53
+ 20:12:04.510 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (22938:train_pytorch.py:311)
54
+ 20:12:04.511 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22938:train_pytorch.py:312)
55
+ 20:12:04.511 [I] Model type: split_communicating (22938:train_pytorch.py:313)
56
+ 20:12:04.511 [I] Packed transforms active: True (22938:train_pytorch.py:314)
57
+ 20:12:04.512 [I] World size: 1 (22938:train_pytorch.py:315)
58
+ 20:12:04.512 [I] Batch size: local=1, global=1 (22938:train_pytorch.py:316)
59
+ 20:12:04.512 [I] num_workers: 0 (22938:train_pytorch.py:317)
60
+ 20:12:04.513 [I] Precision: float32 (22938:train_pytorch.py:318)
61
+ 20:12:04.513 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (22938:train_pytorch.py:319)
62
+ 20:12:04.513 [I] Save/log intervals: save_interval=20, log_interval=1 (22938:train_pytorch.py:326)
63
+ 20:12:04.514 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (22938:train_pytorch.py:327)
64
+ 20:12:04.514 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (22938:train_pytorch.py:328)
65
+ 20:12:04.515 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (22938:train_pytorch.py:329)
66
+ 20:12:04.515 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (22938:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 20:12:06.079 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
70
+ 20:12:06.080 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
71
+ 20:12:06.080 [I] debug_step=1 prompt_token_lengths=[75] (22938:train_pytorch.py:838)
72
+ 20:12:06.081 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (22938:train_pytorch.py:839)
73
+ 20:12:06.081 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (22938:train_pytorch.py:842)
74
+ 20:12:06.082 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
75
+ 20:12:06.097 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
76
+ 20:12:06.097 [I] debug_step=1 lr=9.96e-08 grad_norm=60.0473 data_time=0.2034s step_time=1.3216s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (22938:train_pytorch.py:854)
77
+ 20:12:06.098 [I] debug_step=1 grad_shared_backbone=36.9946 grad_left_action_in=2.3769 grad_right_action_in=1.7630 grad_left_expert=31.1244 grad_right_expert=27.8917 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 (22938:train_pytorch.py:862)
78
+ 20:12:06.099 [I] step=1 loss=3.8411 smoothed_loss=3.8411 lr=9.96e-08 grad_norm=60.0473 step_time=1.3216s data_time=0.2034s it/s=0.625 eta_to_20=30.4s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 grad_left_action_in=2.3769 grad_left_expert=31.1244 grad_right_action_in=1.7630 grad_right_expert=27.8917 grad_shared_backbone=36.9946 (22938:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 20:12:07.067 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
82
+ 20:12:07.067 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
83
+ 20:12:07.068 [I] debug_step=2 prompt_token_lengths=[76] (22938:train_pytorch.py:838)
84
+ 20:12:07.069 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (22938:train_pytorch.py:839)
85
+ 20:12:07.069 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (22938:train_pytorch.py:842)
86
+ 20:12:07.070 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
87
+ 20:12:07.070 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
88
+ 20:12:07.071 [I] debug_step=2 lr=1.99e-07 grad_norm=10.7247 data_time=0.2263s step_time=0.7585s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
89
+ 20:12:07.071 [I] debug_step=2 grad_shared_backbone=9.1973 grad_left_action_in=0.1651 grad_right_action_in=0.1484 grad_left_expert=2.5023 grad_right_expert=2.3935 grad_action_out=4.0770 grad_cross_arm_comm=0.0166 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 (22938:train_pytorch.py:862)
90
+ 20:12:07.072 [I] step=2 loss=1.1389 smoothed_loss=3.5709 lr=1.99e-07 grad_norm=10.7247 step_time=0.7585s data_time=0.2263s it/s=1.028 eta_to_20=17.5s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.0770 grad_cross_arm_comm=0.0166 grad_left_action_in=0.1651 grad_left_expert=2.5023 grad_right_action_in=0.1484 grad_right_expert=2.3935 grad_shared_backbone=9.1973 (22938:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 20:12:07.689 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
94
+ 20:12:07.690 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
95
+ 20:12:07.690 [I] debug_step=3 prompt_token_lengths=[75] (22938:train_pytorch.py:838)
96
+ 20:12:07.691 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (22938:train_pytorch.py:839)
97
+ 20:12:07.692 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (22938:train_pytorch.py:842)
98
+ 20:12:07.692 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
99
+ 20:12:07.693 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
100
+ 20:12:07.693 [I] debug_step=3 lr=2.99e-07 grad_norm=343.6402 data_time=0.1557s step_time=0.4654s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
101
+ 20:12:07.694 [I] debug_step=3 grad_shared_backbone=215.2410 grad_left_action_in=4.7969 grad_right_action_in=9.5325 grad_left_expert=72.6238 grad_right_expert=227.5470 grad_action_out=23.7695 grad_cross_arm_comm=3.3548 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 (22938:train_pytorch.py:862)
102
+ 20:12:07.694 [I] step=3 loss=5.0512 smoothed_loss=3.7189 lr=2.99e-07 grad_norm=343.6402 step_time=0.4654s data_time=0.1557s it/s=1.609 eta_to_20=10.6s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=23.7695 grad_cross_arm_comm=3.3548 grad_left_action_in=4.7969 grad_left_expert=72.6238 grad_right_action_in=9.5325 grad_right_expert=227.5470 grad_shared_backbone=215.2410 (22938:train_pytorch.py:882)
103
+
104
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
105
+ 20:12:08.256 [I] debug_step=4 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
106
+ 20:12:08.257 [I] debug_step=4 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
107
+ 20:12:08.257 [I] debug_step=4 prompt_token_lengths=[78] (22938:train_pytorch.py:838)
108
+ 20:12:08.258 [I] debug_step=4 state_stats min=-0.7017 max=1.0004 mean=0.0553 std=0.3507 (22938:train_pytorch.py:839)
109
+ 20:12:08.258 [I] debug_step=4 action_stats min=-1.0014 max=1.0004 mean=-0.0683 std=0.4561 (22938:train_pytorch.py:842)
110
+ 20:12:08.259 [I] debug_step=4 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
111
+ 20:12:08.259 [I] debug_step=4 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
112
+ 20:12:08.260 [I] debug_step=4 lr=3.98e-07 grad_norm=8.7944 data_time=0.1312s step_time=0.4359s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
113
+ 20:12:08.260 [I] debug_step=4 grad_shared_backbone=7.5903 grad_left_action_in=0.1438 grad_right_action_in=0.1015 grad_left_expert=2.4058 grad_right_expert=1.2982 grad_action_out=3.3839 grad_cross_arm_comm=0.0147 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_2=0.0133 cross_arm_attention_mass_layer_3=0.0026 cross_arm_attention_mass_layer_4=0.0148 cross_arm_attention_mass_layer_5=0.0199 cross_arm_attention_mass_layer_6=0.0062 cross_arm_attention_mass_layer_7=0.0154 cross_arm_attention_mass_layer_8=0.0102 cross_arm_attention_mass_layer_9=0.0086 cross_arm_attention_mass_layer_10=0.0065 cross_arm_attention_mass_layer_11=0.0099 cross_arm_attention_mass_layer_12=0.0010 cross_arm_attention_mass_layer_13=0.0040 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0227 cross_arm_attention_mass_layer_16=0.0351 cross_arm_attention_mass_layer_17=0.0406 (22938:train_pytorch.py:862)
114
+ 20:12:08.261 [I] step=4 loss=1.1860 smoothed_loss=3.4656 lr=3.98e-07 grad_norm=8.7944 step_time=0.4359s data_time=0.1312s it/s=1.768 eta_to_20=9.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0065 cross_arm_attention_mass_layer_11=0.0099 cross_arm_attention_mass_layer_12=0.0010 cross_arm_attention_mass_layer_13=0.0040 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0227 cross_arm_attention_mass_layer_16=0.0351 cross_arm_attention_mass_layer_17=0.0406 cross_arm_attention_mass_layer_2=0.0133 cross_arm_attention_mass_layer_3=0.0026 cross_arm_attention_mass_layer_4=0.0148 cross_arm_attention_mass_layer_5=0.0199 cross_arm_attention_mass_layer_6=0.0062 cross_arm_attention_mass_layer_7=0.0154 cross_arm_attention_mass_layer_8=0.0102 cross_arm_attention_mass_layer_9=0.0086 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=3.3839 grad_cross_arm_comm=0.0147 grad_left_action_in=0.1438 grad_left_expert=2.4058 grad_right_action_in=0.1015 grad_right_expert=1.2982 grad_shared_backbone=7.5903 (22938:train_pytorch.py:882)
115
+
116
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
117
+ 20:12:08.933 [I] debug_step=5 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
118
+ 20:12:08.934 [I] debug_step=5 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
119
+ 20:12:08.934 [I] debug_step=5 prompt_token_lengths=[73] (22938:train_pytorch.py:838)
120
+ 20:12:08.935 [I] debug_step=5 state_stats min=-0.9599 max=1.0004 mean=0.0170 std=0.5364 (22938:train_pytorch.py:839)
121
+ 20:12:08.935 [I] debug_step=5 action_stats min=-1.0392 max=1.0004 mean=-0.0159 std=0.4488 (22938:train_pytorch.py:842)
122
+ 20:12:08.935 [I] debug_step=5 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
123
+ 20:12:08.936 [I] debug_step=5 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
124
+ 20:12:08.936 [I] debug_step=5 lr=4.98e-07 grad_norm=20.1429 data_time=0.2048s step_time=0.4721s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
125
+ 20:12:08.937 [I] debug_step=5 grad_shared_backbone=16.7899 grad_left_action_in=0.2534 grad_right_action_in=0.3335 grad_left_expert=7.9047 grad_right_expert=3.6853 grad_action_out=6.0934 grad_cross_arm_comm=0.0735 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0020 cross_arm_attention_mass_layer_2=0.0178 cross_arm_attention_mass_layer_3=0.0039 cross_arm_attention_mass_layer_4=0.0203 cross_arm_attention_mass_layer_5=0.0294 cross_arm_attention_mass_layer_6=0.0106 cross_arm_attention_mass_layer_7=0.0286 cross_arm_attention_mass_layer_8=0.0175 cross_arm_attention_mass_layer_9=0.0157 cross_arm_attention_mass_layer_10=0.0148 cross_arm_attention_mass_layer_11=0.0181 cross_arm_attention_mass_layer_12=0.0023 cross_arm_attention_mass_layer_13=0.0128 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0232 cross_arm_attention_mass_layer_16=0.0437 cross_arm_attention_mass_layer_17=0.0451 (22938:train_pytorch.py:862)
126
+ 20:12:08.937 [I] step=5 loss=1.8898 smoothed_loss=3.3081 lr=4.98e-07 grad_norm=20.1429 step_time=0.4721s data_time=0.2048s it/s=1.481 eta_to_20=10.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0020 cross_arm_attention_mass_layer_10=0.0148 cross_arm_attention_mass_layer_11=0.0181 cross_arm_attention_mass_layer_12=0.0023 cross_arm_attention_mass_layer_13=0.0128 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0232 cross_arm_attention_mass_layer_16=0.0437 cross_arm_attention_mass_layer_17=0.0451 cross_arm_attention_mass_layer_2=0.0178 cross_arm_attention_mass_layer_3=0.0039 cross_arm_attention_mass_layer_4=0.0203 cross_arm_attention_mass_layer_5=0.0294 cross_arm_attention_mass_layer_6=0.0106 cross_arm_attention_mass_layer_7=0.0286 cross_arm_attention_mass_layer_8=0.0175 cross_arm_attention_mass_layer_9=0.0157 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=6.0934 grad_cross_arm_comm=0.0735 grad_left_action_in=0.2534 grad_left_expert=7.9047 grad_right_action_in=0.3335 grad_right_expert=3.6853 grad_shared_backbone=16.7899 (22938:train_pytorch.py:882)
127
+
128
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
129
+ 20:12:09.727 [I] step=6 loss=2.2855 smoothed_loss=3.2058 lr=5.98e-07 grad_norm=22.2605 step_time=0.5043s data_time=0.2901s it/s=1.267 eta_to_20=11.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0030 cross_arm_attention_mass_layer_10=0.0179 cross_arm_attention_mass_layer_11=0.0219 cross_arm_attention_mass_layer_12=0.0017 cross_arm_attention_mass_layer_13=0.0164 cross_arm_attention_mass_layer_14=0.0065 cross_arm_attention_mass_layer_15=0.0300 cross_arm_attention_mass_layer_16=0.0448 cross_arm_attention_mass_layer_17=0.0482 cross_arm_attention_mass_layer_2=0.0201 cross_arm_attention_mass_layer_3=0.0064 cross_arm_attention_mass_layer_4=0.0234 cross_arm_attention_mass_layer_5=0.0308 cross_arm_attention_mass_layer_6=0.0131 cross_arm_attention_mass_layer_7=0.0312 cross_arm_attention_mass_layer_8=0.0206 cross_arm_attention_mass_layer_9=0.0180 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=7.8420 grad_cross_arm_comm=0.1508 grad_left_action_in=0.2907 grad_left_expert=7.9865 grad_right_action_in=0.5407 grad_right_expert=5.3887 grad_shared_backbone=18.0209 (22938:train_pytorch.py:882)
130
+
131
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
132
+ 20:12:10.423 [I] step=7 loss=1.0335 smoothed_loss=2.9886 lr=6.97e-07 grad_norm=8.7208 step_time=0.4962s data_time=0.1999s it/s=1.439 eta_to_20=9.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0014 cross_arm_attention_mass_layer_10=0.0066 cross_arm_attention_mass_layer_11=0.0060 cross_arm_attention_mass_layer_12=0.0024 cross_arm_attention_mass_layer_13=0.0015 cross_arm_attention_mass_layer_14=0.0062 cross_arm_attention_mass_layer_15=0.0146 cross_arm_attention_mass_layer_16=0.0319 cross_arm_attention_mass_layer_17=0.0417 cross_arm_attention_mass_layer_2=0.0105 cross_arm_attention_mass_layer_3=0.0022 cross_arm_attention_mass_layer_4=0.0130 cross_arm_attention_mass_layer_5=0.0188 cross_arm_attention_mass_layer_6=0.0045 cross_arm_attention_mass_layer_7=0.0127 cross_arm_attention_mass_layer_8=0.0097 cross_arm_attention_mass_layer_9=0.0097 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.0753 grad_cross_arm_comm=0.0098 grad_left_action_in=0.1514 grad_left_expert=2.5886 grad_right_action_in=0.0879 grad_right_expert=1.9729 grad_shared_backbone=6.8576 (22938:train_pytorch.py:882)
133
+
134
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
135
+ 20:12:11.020 [I] step=8 loss=2.0034 smoothed_loss=2.8901 lr=7.97e-07 grad_norm=15.7969 step_time=0.4407s data_time=0.1564s it/s=1.677 eta_to_20=7.2s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0027 cross_arm_attention_mass_layer_10=0.0129 cross_arm_attention_mass_layer_11=0.0269 cross_arm_attention_mass_layer_12=0.0032 cross_arm_attention_mass_layer_13=0.0177 cross_arm_attention_mass_layer_14=0.0074 cross_arm_attention_mass_layer_15=0.0309 cross_arm_attention_mass_layer_16=0.0446 cross_arm_attention_mass_layer_17=0.0503 cross_arm_attention_mass_layer_2=0.0196 cross_arm_attention_mass_layer_3=0.0046 cross_arm_attention_mass_layer_4=0.0227 cross_arm_attention_mass_layer_5=0.0319 cross_arm_attention_mass_layer_6=0.0114 cross_arm_attention_mass_layer_7=0.0298 cross_arm_attention_mass_layer_8=0.0194 cross_arm_attention_mass_layer_9=0.0117 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=6.6005 grad_cross_arm_comm=0.1531 grad_left_action_in=0.1726 grad_left_expert=4.6426 grad_right_action_in=0.4530 grad_right_expert=3.8705 grad_shared_backbone=12.4324 (22938:train_pytorch.py:882)
136
+
137
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
138
+ 20:12:11.571 [I] step=9 loss=0.4132 smoothed_loss=2.6424 lr=8.96e-07 grad_norm=3.3497 step_time=0.4161s data_time=0.1347s it/s=1.820 eta_to_20=6.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0008 cross_arm_attention_mass_layer_10=0.0014 cross_arm_attention_mass_layer_11=0.0006 cross_arm_attention_mass_layer_12=0.0028 cross_arm_attention_mass_layer_13=0.0018 cross_arm_attention_mass_layer_14=0.0059 cross_arm_attention_mass_layer_15=0.0078 cross_arm_attention_mass_layer_16=0.0337 cross_arm_attention_mass_layer_17=0.0442 cross_arm_attention_mass_layer_2=0.0015 cross_arm_attention_mass_layer_3=0.0012 cross_arm_attention_mass_layer_4=0.0019 cross_arm_attention_mass_layer_5=0.0036 cross_arm_attention_mass_layer_6=0.0013 cross_arm_attention_mass_layer_7=0.0022 cross_arm_attention_mass_layer_8=0.0006 cross_arm_attention_mass_layer_9=0.0052 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=1.7915 grad_cross_arm_comm=0.0012 grad_left_action_in=0.0692 grad_left_expert=1.0033 grad_right_action_in=0.0554 grad_right_expert=0.7293 grad_shared_backbone=2.5249 (22938:train_pytorch.py:882)
139
+
140
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
141
+ 20:12:12.422 [I] step=10 loss=0.6162 smoothed_loss=2.4397 lr=9.96e-07 grad_norm=5.5674 step_time=0.6599s data_time=0.1905s it/s=1.178 eta_to_20=8.5s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0014 cross_arm_attention_mass_layer_10=0.0024 cross_arm_attention_mass_layer_11=0.0047 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0018 cross_arm_attention_mass_layer_14=0.0062 cross_arm_attention_mass_layer_15=0.0094 cross_arm_attention_mass_layer_16=0.0283 cross_arm_attention_mass_layer_17=0.0357 cross_arm_attention_mass_layer_2=0.0074 cross_arm_attention_mass_layer_3=0.0016 cross_arm_attention_mass_layer_4=0.0081 cross_arm_attention_mass_layer_5=0.0156 cross_arm_attention_mass_layer_6=0.0028 cross_arm_attention_mass_layer_7=0.0050 cross_arm_attention_mass_layer_8=0.0040 cross_arm_attention_mass_layer_9=0.0045 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.2079 grad_cross_arm_comm=0.0071 grad_left_action_in=0.0841 grad_left_expert=1.2018 grad_right_action_in=0.0868 grad_right_expert=1.2814 grad_shared_backbone=4.7763 (22938:train_pytorch.py:882)
142
+
143
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
144
+ 20:12:12.957 [I] step=11 loss=0.9030 smoothed_loss=2.2861 lr=1.10e-06 grad_norm=7.2282 step_time=0.4104s data_time=0.1251s it/s=1.872 eta_to_20=4.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0064 cross_arm_attention_mass_layer_11=0.0098 cross_arm_attention_mass_layer_12=0.0013 cross_arm_attention_mass_layer_13=0.0031 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0208 cross_arm_attention_mass_layer_16=0.0355 cross_arm_attention_mass_layer_17=0.0421 cross_arm_attention_mass_layer_2=0.0136 cross_arm_attention_mass_layer_3=0.0023 cross_arm_attention_mass_layer_4=0.0152 cross_arm_attention_mass_layer_5=0.0219 cross_arm_attention_mass_layer_6=0.0054 cross_arm_attention_mass_layer_7=0.0144 cross_arm_attention_mass_layer_8=0.0131 cross_arm_attention_mass_layer_9=0.0082 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=3.3357 grad_cross_arm_comm=0.0099 grad_left_action_in=0.1355 grad_left_expert=2.0379 grad_right_action_in=0.0836 grad_right_expert=1.1722 grad_shared_backbone=5.8293 (22938:train_pytorch.py:882)
145
+
146
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
147
+ 20:12:13.628 [I] step=12 loss=0.7531 smoothed_loss=2.1328 lr=1.20e-06 grad_norm=6.0473 step_time=0.4968s data_time=0.1739s it/s=1.493 eta_to_20=5.4s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0012 cross_arm_attention_mass_layer_10=0.0078 cross_arm_attention_mass_layer_11=0.0121 cross_arm_attention_mass_layer_12=0.0032 cross_arm_attention_mass_layer_13=0.0032 cross_arm_attention_mass_layer_14=0.0048 cross_arm_attention_mass_layer_15=0.0136 cross_arm_attention_mass_layer_16=0.0331 cross_arm_attention_mass_layer_17=0.0404 cross_arm_attention_mass_layer_2=0.0127 cross_arm_attention_mass_layer_3=0.0020 cross_arm_attention_mass_layer_4=0.0138 cross_arm_attention_mass_layer_5=0.0221 cross_arm_attention_mass_layer_6=0.0055 cross_arm_attention_mass_layer_7=0.0174 cross_arm_attention_mass_layer_8=0.0100 cross_arm_attention_mass_layer_9=0.0094 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.8673 grad_cross_arm_comm=0.0090 grad_left_action_in=0.1128 grad_left_expert=1.8561 grad_right_action_in=0.0739 grad_right_expert=1.0243 grad_shared_backbone=4.8443 (22938:train_pytorch.py:882)
148
+
149
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
150
+ 20:12:14.427 [I] step=13 loss=3.7746 smoothed_loss=2.2970 lr=1.29e-06 grad_norm=206.8044 step_time=0.5601s data_time=0.2394s it/s=1.252 eta_to_20=5.6s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0128 cross_arm_attention_mass_layer_10=0.0240 cross_arm_attention_mass_layer_11=0.0241 cross_arm_attention_mass_layer_12=0.0213 cross_arm_attention_mass_layer_13=0.0213 cross_arm_attention_mass_layer_14=0.0164 cross_arm_attention_mass_layer_15=0.0265 cross_arm_attention_mass_layer_16=0.0367 cross_arm_attention_mass_layer_17=0.0289 cross_arm_attention_mass_layer_2=0.0282 cross_arm_attention_mass_layer_3=0.0184 cross_arm_attention_mass_layer_4=0.0365 cross_arm_attention_mass_layer_5=0.0441 cross_arm_attention_mass_layer_6=0.0238 cross_arm_attention_mass_layer_7=0.0371 cross_arm_attention_mass_layer_8=0.0137 cross_arm_attention_mass_layer_9=0.0293 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=15.4957 grad_cross_arm_comm=2.1022 grad_left_action_in=2.3745 grad_left_expert=37.1536 grad_right_action_in=5.2568 grad_right_expert=138.8291 grad_shared_backbone=127.7336 (22938:train_pytorch.py:882)
151
+
152
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
153
+ 20:12:15.255 [I] step=14 loss=1.2933 smoothed_loss=2.1966 lr=1.39e-06 grad_norm=7.9182 step_time=0.5738s data_time=0.2541s it/s=1.210 eta_to_20=5.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0079 cross_arm_attention_mass_layer_11=0.0120 cross_arm_attention_mass_layer_12=0.0016 cross_arm_attention_mass_layer_13=0.0036 cross_arm_attention_mass_layer_14=0.0047 cross_arm_attention_mass_layer_15=0.0131 cross_arm_attention_mass_layer_16=0.0244 cross_arm_attention_mass_layer_17=0.0419 cross_arm_attention_mass_layer_2=0.0129 cross_arm_attention_mass_layer_3=0.0022 cross_arm_attention_mass_layer_4=0.0152 cross_arm_attention_mass_layer_5=0.0233 cross_arm_attention_mass_layer_6=0.0067 cross_arm_attention_mass_layer_7=0.0161 cross_arm_attention_mass_layer_8=0.0092 cross_arm_attention_mass_layer_9=0.0097 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.2052 grad_cross_arm_comm=0.0107 grad_left_action_in=0.1570 grad_left_expert=2.3411 grad_right_action_in=0.1025 grad_right_expert=1.1691 grad_shared_backbone=6.0836 (22938:train_pytorch.py:882)
154
+
155
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
156
+ 20:12:16.034 [I] step=15 loss=3.1068 smoothed_loss=2.2876 lr=1.49e-06 grad_norm=24.4182 step_time=0.5474s data_time=0.2314s it/s=1.286 eta_to_20=3.9s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0033 cross_arm_attention_mass_layer_10=0.0154 cross_arm_attention_mass_layer_11=0.0284 cross_arm_attention_mass_layer_12=0.0046 cross_arm_attention_mass_layer_13=0.0187 cross_arm_attention_mass_layer_14=0.0121 cross_arm_attention_mass_layer_15=0.0370 cross_arm_attention_mass_layer_16=0.0460 cross_arm_attention_mass_layer_17=0.0516 cross_arm_attention_mass_layer_2=0.0206 cross_arm_attention_mass_layer_3=0.0064 cross_arm_attention_mass_layer_4=0.0239 cross_arm_attention_mass_layer_5=0.0299 cross_arm_attention_mass_layer_6=0.0143 cross_arm_attention_mass_layer_7=0.0349 cross_arm_attention_mass_layer_8=0.0213 cross_arm_attention_mass_layer_9=0.0171 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=9.3484 grad_cross_arm_comm=0.3843 grad_left_action_in=0.3015 grad_left_expert=7.0086 grad_right_action_in=0.6660 grad_right_expert=6.4185 grad_shared_backbone=18.8039 (22938:train_pytorch.py:882)
157
+
158
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
159
+ 20:12:16.810 [I] step=16 loss=0.8710 smoothed_loss=2.1460 lr=1.59e-06 grad_norm=7.5162 step_time=0.5638s data_time=0.2117s it/s=1.292 eta_to_20=3.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0016 cross_arm_attention_mass_layer_10=0.0051 cross_arm_attention_mass_layer_11=0.0114 cross_arm_attention_mass_layer_12=0.0017 cross_arm_attention_mass_layer_13=0.0062 cross_arm_attention_mass_layer_14=0.0073 cross_arm_attention_mass_layer_15=0.0221 cross_arm_attention_mass_layer_16=0.0370 cross_arm_attention_mass_layer_17=0.0436 cross_arm_attention_mass_layer_2=0.0138 cross_arm_attention_mass_layer_3=0.0022 cross_arm_attention_mass_layer_4=0.0152 cross_arm_attention_mass_layer_5=0.0195 cross_arm_attention_mass_layer_6=0.0056 cross_arm_attention_mass_layer_7=0.0154 cross_arm_attention_mass_layer_8=0.0132 cross_arm_attention_mass_layer_9=0.0103 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.7344 grad_cross_arm_comm=0.0228 grad_left_action_in=0.1118 grad_left_expert=2.2761 grad_right_action_in=0.1234 grad_right_expert=1.1808 grad_shared_backbone=6.4124 (22938:train_pytorch.py:882)
160
+
161
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
162
+ 20:12:17.396 [I] step=17 loss=1.7002 smoothed_loss=2.1014 lr=1.69e-06 grad_norm=14.0785 step_time=0.4252s data_time=0.1614s it/s=1.708 eta_to_20=1.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0021 cross_arm_attention_mass_layer_10=0.0103 cross_arm_attention_mass_layer_11=0.0197 cross_arm_attention_mass_layer_12=0.0029 cross_arm_attention_mass_layer_13=0.0103 cross_arm_attention_mass_layer_14=0.0067 cross_arm_attention_mass_layer_15=0.0163 cross_arm_attention_mass_layer_16=0.0436 cross_arm_attention_mass_layer_17=0.0446 cross_arm_attention_mass_layer_2=0.0162 cross_arm_attention_mass_layer_3=0.0030 cross_arm_attention_mass_layer_4=0.0192 cross_arm_attention_mass_layer_5=0.0268 cross_arm_attention_mass_layer_6=0.0092 cross_arm_attention_mass_layer_7=0.0242 cross_arm_attention_mass_layer_8=0.0146 cross_arm_attention_mass_layer_9=0.0095 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.2605 grad_cross_arm_comm=0.0625 grad_left_action_in=0.1989 grad_left_expert=4.9518 grad_right_action_in=0.2156 grad_right_expert=2.1764 grad_shared_backbone=12.0796 (22938:train_pytorch.py:882)
163
+
164
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
165
+ 20:12:18.392 [I] step=18 loss=0.4844 smoothed_loss=1.9397 lr=1.79e-06 grad_norm=3.3459 step_time=0.6297s data_time=0.3660s it/s=1.005 eta_to_20=2.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0008 cross_arm_attention_mass_layer_10=0.0016 cross_arm_attention_mass_layer_11=0.0014 cross_arm_attention_mass_layer_12=0.0034 cross_arm_attention_mass_layer_13=0.0007 cross_arm_attention_mass_layer_14=0.0054 cross_arm_attention_mass_layer_15=0.0063 cross_arm_attention_mass_layer_16=0.0319 cross_arm_attention_mass_layer_17=0.0418 cross_arm_attention_mass_layer_2=0.0027 cross_arm_attention_mass_layer_3=0.0013 cross_arm_attention_mass_layer_4=0.0035 cross_arm_attention_mass_layer_5=0.0058 cross_arm_attention_mass_layer_6=0.0015 cross_arm_attention_mass_layer_7=0.0028 cross_arm_attention_mass_layer_8=0.0019 cross_arm_attention_mass_layer_9=0.0049 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=-0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=1.9561 grad_cross_arm_comm=0.0017 grad_left_action_in=0.0746 grad_left_expert=1.1140 grad_right_action_in=0.0388 grad_right_expert=0.5290 grad_shared_backbone=2.3985 (22938:train_pytorch.py:882)
166
+
167
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
168
+ 20:12:19.239 [I] step=19 loss=0.7633 smoothed_loss=1.8220 lr=1.89e-06 grad_norm=7.1468 step_time=0.5757s data_time=0.2714s it/s=1.182 eta_to_20=0.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0069 cross_arm_attention_mass_layer_11=0.0093 cross_arm_attention_mass_layer_12=0.0016 cross_arm_attention_mass_layer_13=0.0034 cross_arm_attention_mass_layer_14=0.0046 cross_arm_attention_mass_layer_15=0.0166 cross_arm_attention_mass_layer_16=0.0297 cross_arm_attention_mass_layer_17=0.0418 cross_arm_attention_mass_layer_2=0.0130 cross_arm_attention_mass_layer_3=0.0026 cross_arm_attention_mass_layer_4=0.0156 cross_arm_attention_mass_layer_5=0.0208 cross_arm_attention_mass_layer_6=0.0062 cross_arm_attention_mass_layer_7=0.0164 cross_arm_attention_mass_layer_8=0.0124 cross_arm_attention_mass_layer_9=0.0115 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=-0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=3.7548 grad_cross_arm_comm=0.0125 grad_left_action_in=0.1160 grad_left_expert=2.3520 grad_right_action_in=0.0799 grad_right_expert=1.3128 grad_shared_backbone=5.3838 (22938:train_pytorch.py:882)
169
+
170
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
171
+ 20:12:19.905 [I] step=20 loss=0.5943 smoothed_loss=1.6993 lr=1.99e-06 grad_norm=6.2792 step_time=0.4954s data_time=0.1707s it/s=1.504 eta_to_20=0.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0012 cross_arm_attention_mass_layer_10=0.0050 cross_arm_attention_mass_layer_11=0.0026 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0009 cross_arm_attention_mass_layer_14=0.0060 cross_arm_attention_mass_layer_15=0.0119 cross_arm_attention_mass_layer_16=0.0308 cross_arm_attention_mass_layer_17=0.0412 cross_arm_attention_mass_layer_2=0.0054 cross_arm_attention_mass_layer_3=0.0020 cross_arm_attention_mass_layer_4=0.0084 cross_arm_attention_mass_layer_5=0.0116 cross_arm_attention_mass_layer_6=0.0029 cross_arm_attention_mass_layer_7=0.0066 cross_arm_attention_mass_layer_8=0.0051 cross_arm_attention_mass_layer_9=0.0094 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=-0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.7963 grad_cross_arm_comm=0.0044 grad_left_action_in=0.0861 grad_left_expert=1.9694 grad_right_action_in=0.0578 grad_right_expert=1.3971 grad_shared_backbone=5.0478 (22938:train_pytorch.py:882)
172
+ 20:19:41.020 [I] Saved checkpoint at step 20 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20 (22938:train_pytorch.py:378)
173
+
artifacts/twin_split_expert_bringup_20260310/run_logs/split_independent_real_smoke3_r2.log ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 19:45:11.253 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2 (20567:train_pytorch.py:533)
2
+ 19:45:11.254 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (20567:train_pytorch.py:552)
3
+ 19:45:11.330 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train (20567:config.py:234)
4
+ 19:45:11.331 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x79458ad85b50>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (20567:data_loader.py:284)
29
+ 19:45:16.791 [I] JAX version 0.5.3 available. (20567:config.py:125)
30
+ 19:45:40.542 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (20567:data_loader.py:148)
31
+ 19:45:40.654 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (20567:video_utils.py:36)
32
+ 19:46:47.372 [I] local_batch_size: 1 (20567:data_loader.py:365)
33
+ 19:50:09.799 [I] Enabled gradient checkpointing for PI0Pytorch model (20567:pi0_pytorch.py:138)
34
+ 19:50:09.802 [I] Enabled gradient checkpointing for memory optimization (20567:train_pytorch.py:624)
35
+ 19:50:09.803 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (20567:train_pytorch.py:493)
36
+ 19:50:09.804 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (20567:train_pytorch.py:653)
37
+ 19:50:13.559 [I] Weight loading missing key count: 0 (20567:train_pytorch.py:657)
38
+ 19:50:13.560 [I] Weight loading missing keys: set() (20567:train_pytorch.py:658)
39
+ 19:50:13.560 [I] Weight loading unexpected key count: 0 (20567:train_pytorch.py:659)
40
+ 19:50:13.560 [I] Weight loading unexpected keys: [] (20567:train_pytorch.py:660)
41
+ 19:50:13.560 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_independent_packed_from_single (20567:train_pytorch.py:661)
42
+ 19:50:13.565 [I] Running on: 963c158043aa | world_size=1 (20567:train_pytorch.py:701)
43
+ 19:50:13.565 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=3 (20567:train_pytorch.py:702)
44
+ 19:50:13.565 [I] Memory optimizations: gradient_checkpointing=True (20567:train_pytorch.py:705)
45
+ 19:50:13.566 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (20567:train_pytorch.py:706)
46
+ 19:50:13.566 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (20567:train_pytorch.py:707)
47
+ 19:50:13.567 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (20567:train_pytorch.py:710)
48
+ 19:50:13.567 [I] EMA is not supported for PyTorch training (20567:train_pytorch.py:713)
49
+ 19:50:13.567 [I] Training precision: float32 (20567:train_pytorch.py:714)
50
+ 19:50:13.576 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k (20567:train_pytorch.py:308)
51
+ 19:50:13.576 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (20567:train_pytorch.py:309)
52
+ 19:50:13.577 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (20567:train_pytorch.py:310)
53
+ 19:50:13.577 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (20567:train_pytorch.py:311)
54
+ 19:50:13.578 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (20567:train_pytorch.py:312)
55
+ 19:50:13.578 [I] Model type: split_independent (20567:train_pytorch.py:313)
56
+ 19:50:13.578 [I] Packed transforms active: True (20567:train_pytorch.py:314)
57
+ 19:50:13.579 [I] World size: 1 (20567:train_pytorch.py:315)
58
+ 19:50:13.579 [I] Batch size: local=1, global=1 (20567:train_pytorch.py:316)
59
+ 19:50:13.580 [I] num_workers: 0 (20567:train_pytorch.py:317)
60
+ 19:50:13.580 [I] Precision: float32 (20567:train_pytorch.py:318)
61
+ 19:50:13.580 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (20567:train_pytorch.py:319)
62
+ 19:50:13.581 [I] Save/log intervals: save_interval=3, log_interval=1 (20567:train_pytorch.py:326)
63
+ 19:50:13.581 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (20567:train_pytorch.py:327)
64
+ 19:50:13.581 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (20567:train_pytorch.py:328)
65
+ 19:50:13.582 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (20567:train_pytorch.py:329)
66
+ 19:50:13.582 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (20567:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 19:50:15.125 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (20567:train_pytorch.py:831)
70
+ 19:50:15.126 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (20567:train_pytorch.py:835)
71
+ 19:50:15.126 [I] debug_step=1 prompt_token_lengths=[75] (20567:train_pytorch.py:838)
72
+ 19:50:15.127 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (20567:train_pytorch.py:839)
73
+ 19:50:15.127 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (20567:train_pytorch.py:842)
74
+ 19:50:15.128 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (20567:train_pytorch.py:845)
75
+ 19:50:15.143 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (20567:train_pytorch.py:849)
76
+ 19:50:15.143 [I] debug_step=1 lr=9.96e-08 grad_norm=31.4779 data_time=0.2101s step_time=1.2943s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (20567:train_pytorch.py:854)
77
+ 19:50:15.144 [I] debug_step=1 grad_shared_backbone=25.5606 grad_left_action_in=0.2318 grad_right_action_in=0.9885 grad_left_expert=5.5978 grad_right_expert=12.3518 grad_action_out=9.6154 (20567:train_pytorch.py:862)
78
+ 19:50:15.144 [I] step=1 loss=2.6238 smoothed_loss=2.6238 lr=9.96e-08 grad_norm=31.4779 step_time=1.2943s data_time=0.2101s it/s=0.633 eta_to_3=3.2s max_cuda_memory=76.13GB grad_action_out=9.6154 grad_left_action_in=0.2318 grad_left_expert=5.5978 grad_right_action_in=0.9885 grad_right_expert=12.3518 grad_shared_backbone=25.5606 (20567:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 19:50:16.012 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (20567:train_pytorch.py:831)
82
+ 19:50:16.013 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (20567:train_pytorch.py:835)
83
+ 19:50:16.013 [I] debug_step=2 prompt_token_lengths=[76] (20567:train_pytorch.py:838)
84
+ 19:50:16.014 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (20567:train_pytorch.py:839)
85
+ 19:50:16.015 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (20567:train_pytorch.py:842)
86
+ 19:50:16.016 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (20567:train_pytorch.py:845)
87
+ 19:50:16.016 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (20567:train_pytorch.py:849)
88
+ 19:50:16.017 [I] debug_step=2 lr=1.99e-07 grad_norm=12.2770 data_time=0.2123s step_time=0.6695s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (20567:train_pytorch.py:854)
89
+ 19:50:16.017 [I] debug_step=2 grad_shared_backbone=10.3527 grad_left_action_in=0.1586 grad_right_action_in=0.1584 grad_left_expert=2.8415 grad_right_expert=4.0156 grad_action_out=4.1478 (20567:train_pytorch.py:862)
90
+ 19:50:16.018 [I] step=2 loss=1.1717 smoothed_loss=2.4786 lr=1.99e-07 grad_norm=12.2770 step_time=0.6695s data_time=0.2123s it/s=1.146 eta_to_3=0.9s max_cuda_memory=76.13GB grad_action_out=4.1478 grad_left_action_in=0.1586 grad_left_expert=2.8415 grad_right_action_in=0.1584 grad_right_expert=4.0156 grad_shared_backbone=10.3527 (20567:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 19:50:16.906 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (20567:train_pytorch.py:831)
94
+ 19:50:16.907 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (20567:train_pytorch.py:835)
95
+ 19:50:16.908 [I] debug_step=3 prompt_token_lengths=[75] (20567:train_pytorch.py:838)
96
+ 19:50:16.908 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (20567:train_pytorch.py:839)
97
+ 19:50:16.908 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (20567:train_pytorch.py:842)
98
+ 19:50:16.909 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (20567:train_pytorch.py:845)
99
+ 19:50:16.910 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (20567:train_pytorch.py:849)
100
+ 19:50:16.910 [I] debug_step=3 lr=2.99e-07 grad_norm=15.1079 data_time=0.2612s step_time=0.6330s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (20567:train_pytorch.py:854)
101
+ 19:50:16.911 [I] debug_step=3 grad_shared_backbone=8.6850 grad_left_action_in=0.2570 grad_right_action_in=0.3869 grad_left_expert=4.4422 grad_right_expert=10.5777 grad_action_out=3.5502 (20567:train_pytorch.py:862)
102
+ 19:50:16.911 [I] step=3 loss=0.9128 smoothed_loss=2.3220 lr=2.99e-07 grad_norm=15.1079 step_time=0.6330s data_time=0.2612s it/s=1.120 eta_to_3=0.0s max_cuda_memory=76.13GB grad_action_out=3.5502 grad_left_action_in=0.2570 grad_left_expert=4.4422 grad_right_action_in=0.3869 grad_right_expert=10.5777 grad_shared_backbone=8.6850 (20567:train_pytorch.py:882)
103
+ 19:53:54.052 [I] Saved checkpoint at step 3 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3 (20567:train_pytorch.py:378)
104
+
artifacts/twin_split_expert_bringup_20260310/run_logs/split_independent_real_train20.log ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20:03:03.080 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20 (22934:train_pytorch.py:533)
2
+ 20:03:03.082 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (22934:train_pytorch.py:552)
3
+ 20:03:03.183 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train (22934:config.py:234)
4
+ 20:03:03.185 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x721cdf0dd610>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (22934:data_loader.py:393)
29
+ 20:03:13.494 [I] JAX version 0.5.3 available. (22934:config.py:125)
30
+ 20:04:17.801 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (22934:data_loader.py:148)
31
+ 20:04:17.904 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (22934:video_utils.py:36)
32
+ 20:09:04.645 [I] local_batch_size: 1 (22934:data_loader.py:474)
33
+ 20:11:56.606 [I] Enabled gradient checkpointing for PI0Pytorch model (22934:pi0_pytorch.py:138)
34
+ 20:11:56.607 [I] Enabled gradient checkpointing for memory optimization (22934:train_pytorch.py:624)
35
+ 20:11:56.608 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (22934:train_pytorch.py:493)
36
+ 20:11:56.609 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (22934:train_pytorch.py:653)
37
+ 20:12:01.374 [I] Weight loading missing key count: 0 (22934:train_pytorch.py:657)
38
+ 20:12:01.375 [I] Weight loading missing keys: set() (22934:train_pytorch.py:658)
39
+ 20:12:01.375 [I] Weight loading unexpected key count: 0 (22934:train_pytorch.py:659)
40
+ 20:12:01.375 [I] Weight loading unexpected keys: [] (22934:train_pytorch.py:660)
41
+ 20:12:01.376 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_independent_packed_from_single (22934:train_pytorch.py:661)
42
+ 20:12:01.380 [I] Running on: 963c158043aa | world_size=1 (22934:train_pytorch.py:701)
43
+ 20:12:01.381 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=20 (22934:train_pytorch.py:702)
44
+ 20:12:01.381 [I] Memory optimizations: gradient_checkpointing=True (22934:train_pytorch.py:705)
45
+ 20:12:01.381 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (22934:train_pytorch.py:706)
46
+ 20:12:01.382 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (22934:train_pytorch.py:707)
47
+ 20:12:01.382 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (22934:train_pytorch.py:710)
48
+ 20:12:01.382 [I] EMA is not supported for PyTorch training (22934:train_pytorch.py:713)
49
+ 20:12:01.383 [I] Training precision: float32 (22934:train_pytorch.py:714)
50
+ 20:12:01.410 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k (22934:train_pytorch.py:308)
51
+ 20:12:01.410 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (22934:train_pytorch.py:309)
52
+ 20:12:01.411 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (22934:train_pytorch.py:310)
53
+ 20:12:01.411 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (22934:train_pytorch.py:311)
54
+ 20:12:01.412 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (22934:train_pytorch.py:312)
55
+ 20:12:01.412 [I] Model type: split_independent (22934:train_pytorch.py:313)
56
+ 20:12:01.412 [I] Packed transforms active: True (22934:train_pytorch.py:314)
57
+ 20:12:01.413 [I] World size: 1 (22934:train_pytorch.py:315)
58
+ 20:12:01.413 [I] Batch size: local=1, global=1 (22934:train_pytorch.py:316)
59
+ 20:12:01.414 [I] num_workers: 0 (22934:train_pytorch.py:317)
60
+ 20:12:01.414 [I] Precision: float32 (22934:train_pytorch.py:318)
61
+ 20:12:01.414 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (22934:train_pytorch.py:319)
62
+ 20:12:01.415 [I] Save/log intervals: save_interval=20, log_interval=1 (22934:train_pytorch.py:326)
63
+ 20:12:01.415 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (22934:train_pytorch.py:327)
64
+ 20:12:01.415 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (22934:train_pytorch.py:328)
65
+ 20:12:01.416 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (22934:train_pytorch.py:329)
66
+ 20:12:01.416 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (22934:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 20:12:03.701 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
70
+ 20:12:03.702 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
71
+ 20:12:03.702 [I] debug_step=1 prompt_token_lengths=[75] (22934:train_pytorch.py:838)
72
+ 20:12:03.702 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (22934:train_pytorch.py:839)
73
+ 20:12:03.702 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (22934:train_pytorch.py:842)
74
+ 20:12:03.703 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
75
+ 20:12:03.729 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
76
+ 20:12:03.730 [I] debug_step=1 lr=9.96e-08 grad_norm=31.4779 data_time=0.5472s step_time=1.7166s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (22934:train_pytorch.py:854)
77
+ 20:12:03.730 [I] debug_step=1 grad_shared_backbone=25.5606 grad_left_action_in=0.2318 grad_right_action_in=0.9885 grad_left_expert=5.5978 grad_right_expert=12.3518 grad_action_out=9.6154 (22934:train_pytorch.py:862)
78
+ 20:12:03.731 [I] step=1 loss=2.6238 smoothed_loss=2.6238 lr=9.96e-08 grad_norm=31.4779 step_time=1.7166s data_time=0.5472s it/s=0.425 eta_to_20=44.7s max_cuda_memory=76.13GB grad_action_out=9.6154 grad_left_action_in=0.2318 grad_left_expert=5.5978 grad_right_action_in=0.9885 grad_right_expert=12.3518 grad_shared_backbone=25.5606 (22934:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 20:12:05.012 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
82
+ 20:12:05.013 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
83
+ 20:12:05.014 [I] debug_step=2 prompt_token_lengths=[76] (22934:train_pytorch.py:838)
84
+ 20:12:05.014 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (22934:train_pytorch.py:839)
85
+ 20:12:05.015 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (22934:train_pytorch.py:842)
86
+ 20:12:05.016 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
87
+ 20:12:05.016 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
88
+ 20:12:05.017 [I] debug_step=2 lr=1.99e-07 grad_norm=12.2749 data_time=0.5381s step_time=0.7692s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
89
+ 20:12:05.017 [I] debug_step=2 grad_shared_backbone=10.3515 grad_left_action_in=0.1585 grad_right_action_in=0.1584 grad_left_expert=2.8412 grad_right_expert=4.0131 grad_action_out=4.1470 (22934:train_pytorch.py:862)
90
+ 20:12:05.018 [I] step=2 loss=1.1715 smoothed_loss=2.4786 lr=1.99e-07 grad_norm=12.2749 step_time=0.7692s data_time=0.5381s it/s=0.777 eta_to_20=23.2s max_cuda_memory=76.13GB grad_action_out=4.1470 grad_left_action_in=0.1585 grad_left_expert=2.8412 grad_right_action_in=0.1584 grad_right_expert=4.0131 grad_shared_backbone=10.3515 (22934:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 20:12:05.585 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
94
+ 20:12:05.586 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
95
+ 20:12:05.586 [I] debug_step=3 prompt_token_lengths=[75] (22934:train_pytorch.py:838)
96
+ 20:12:05.586 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (22934:train_pytorch.py:839)
97
+ 20:12:05.587 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (22934:train_pytorch.py:842)
98
+ 20:12:05.588 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
99
+ 20:12:05.588 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
100
+ 20:12:05.589 [I] debug_step=3 lr=2.99e-07 grad_norm=15.1205 data_time=0.1545s step_time=0.4182s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
101
+ 20:12:05.589 [I] debug_step=3 grad_shared_backbone=8.6946 grad_left_action_in=0.2568 grad_right_action_in=0.3873 grad_left_expert=4.4408 grad_right_expert=10.5877 grad_action_out=3.5507 (22934:train_pytorch.py:862)
102
+ 20:12:05.590 [I] step=3 loss=0.9126 smoothed_loss=2.3220 lr=2.99e-07 grad_norm=15.1205 step_time=0.4182s data_time=0.1545s it/s=1.751 eta_to_20=9.7s max_cuda_memory=76.13GB grad_action_out=3.5507 grad_left_action_in=0.2568 grad_left_expert=4.4408 grad_right_action_in=0.3873 grad_right_expert=10.5877 grad_shared_backbone=8.6946 (22934:train_pytorch.py:882)
103
+
104
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
105
+ 20:12:06.414 [I] debug_step=4 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
106
+ 20:12:06.415 [I] debug_step=4 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
107
+ 20:12:06.416 [I] debug_step=4 prompt_token_lengths=[78] (22934:train_pytorch.py:838)
108
+ 20:12:06.416 [I] debug_step=4 state_stats min=-0.7017 max=1.0004 mean=0.0553 std=0.3507 (22934:train_pytorch.py:839)
109
+ 20:12:06.417 [I] debug_step=4 action_stats min=-1.0014 max=1.0004 mean=-0.0683 std=0.4561 (22934:train_pytorch.py:842)
110
+ 20:12:06.417 [I] debug_step=4 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
111
+ 20:12:06.418 [I] debug_step=4 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
112
+ 20:12:06.419 [I] debug_step=4 lr=3.98e-07 grad_norm=9.2670 data_time=0.2679s step_time=0.5621s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
113
+ 20:12:06.419 [I] debug_step=4 grad_shared_backbone=7.8629 grad_left_action_in=0.1341 grad_right_action_in=0.0877 grad_left_expert=3.2369 grad_right_expert=1.0658 grad_action_out=3.4116 (22934:train_pytorch.py:862)
114
+ 20:12:06.420 [I] step=4 loss=1.1718 smoothed_loss=2.2070 lr=3.98e-07 grad_norm=9.2670 step_time=0.5621s data_time=0.2679s it/s=1.206 eta_to_20=13.3s max_cuda_memory=76.13GB grad_action_out=3.4116 grad_left_action_in=0.1341 grad_left_expert=3.2369 grad_right_action_in=0.0877 grad_right_expert=1.0658 grad_shared_backbone=7.8629 (22934:train_pytorch.py:882)
115
+
116
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
117
+ 20:12:07.218 [I] debug_step=5 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
118
+ 20:12:07.219 [I] debug_step=5 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
119
+ 20:12:07.219 [I] debug_step=5 prompt_token_lengths=[73] (22934:train_pytorch.py:838)
120
+ 20:12:07.219 [I] debug_step=5 state_stats min=-0.9599 max=1.0004 mean=0.0170 std=0.5364 (22934:train_pytorch.py:839)
121
+ 20:12:07.220 [I] debug_step=5 action_stats min=-1.0392 max=1.0004 mean=-0.0159 std=0.4488 (22934:train_pytorch.py:842)
122
+ 20:12:07.220 [I] debug_step=5 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
123
+ 20:12:07.221 [I] debug_step=5 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
124
+ 20:12:07.221 [I] debug_step=5 lr=4.98e-07 grad_norm=18.8576 data_time=0.2330s step_time=0.5704s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
125
+ 20:12:07.222 [I] debug_step=5 grad_shared_backbone=15.0420 grad_left_action_in=0.2664 grad_right_action_in=0.2257 grad_left_expert=7.9881 grad_right_expert=3.7966 grad_action_out=6.1884 (22934:train_pytorch.py:862)
126
+ 20:12:07.223 [I] step=5 loss=1.6473 smoothed_loss=2.1510 lr=4.98e-07 grad_norm=18.8576 step_time=0.5704s data_time=0.2330s it/s=1.246 eta_to_20=12.0s max_cuda_memory=76.13GB grad_action_out=6.1884 grad_left_action_in=0.2664 grad_left_expert=7.9881 grad_right_action_in=0.2257 grad_right_expert=3.7966 grad_shared_backbone=15.0420 (22934:train_pytorch.py:882)
127
+
128
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
129
+ 20:12:07.822 [I] step=6 loss=1.6098 smoothed_loss=2.0969 lr=5.98e-07 grad_norm=20.9772 step_time=0.4435s data_time=0.1600s it/s=1.671 eta_to_20=8.4s max_cuda_memory=76.13GB grad_action_out=6.0592 grad_left_action_in=0.2873 grad_left_expert=8.8574 grad_right_action_in=0.4264 grad_right_expert=6.3071 grad_shared_backbone=16.1173 (22934:train_pytorch.py:882)
130
+
131
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
132
+ 20:12:08.395 [I] step=7 loss=1.0401 smoothed_loss=1.9912 lr=6.97e-07 grad_norm=9.5173 step_time=0.4240s data_time=0.1495s it/s=1.747 eta_to_20=7.4s max_cuda_memory=76.13GB grad_action_out=4.1689 grad_left_action_in=0.1489 grad_left_expert=3.1996 grad_right_action_in=0.0904 grad_right_expert=2.4983 grad_shared_backbone=7.4213 (22934:train_pytorch.py:882)
133
+
134
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
135
+ 20:12:08.914 [I] step=8 loss=1.7539 smoothed_loss=1.9675 lr=7.97e-07 grad_norm=12.9701 step_time=0.3829s data_time=0.1362s it/s=1.931 eta_to_20=6.2s max_cuda_memory=76.13GB grad_action_out=5.3617 grad_left_action_in=0.1890 grad_left_expert=3.6536 grad_right_action_in=0.3790 grad_right_expert=2.7904 grad_shared_backbone=10.5667 (22934:train_pytorch.py:882)
136
+
137
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
138
+ 20:12:09.692 [I] step=9 loss=0.4114 smoothed_loss=1.8119 lr=8.96e-07 grad_norm=3.5873 step_time=0.5166s data_time=0.2609s it/s=1.288 eta_to_20=8.5s max_cuda_memory=76.13GB grad_action_out=1.8283 grad_left_action_in=0.0689 grad_left_expert=1.3656 grad_right_action_in=0.0549 grad_right_expert=0.7330 grad_shared_backbone=2.6507 (22934:train_pytorch.py:882)
139
+
140
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
141
+ 20:12:10.646 [I] step=10 loss=0.6228 smoothed_loss=1.6930 lr=9.96e-07 grad_norm=6.7396 step_time=0.7100s data_time=0.2450s it/s=1.049 eta_to_20=9.5s max_cuda_memory=76.13GB grad_action_out=2.2553 grad_left_action_in=0.0813 grad_left_expert=1.3495 grad_right_action_in=0.0919 grad_right_expert=2.0906 grad_shared_backbone=5.8179 (22934:train_pytorch.py:882)
142
+
143
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
144
+ 20:12:11.288 [I] step=11 loss=0.8688 smoothed_loss=1.6105 lr=1.10e-06 grad_norm=7.2182 step_time=0.4823s data_time=0.1593s it/s=1.561 eta_to_20=5.8s max_cuda_memory=76.13GB grad_action_out=3.3031 grad_left_action_in=0.1262 grad_left_expert=2.5456 grad_right_action_in=0.0809 grad_right_expert=0.9216 grad_shared_backbone=5.7177 (22934:train_pytorch.py:882)
145
+
146
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
147
+ 20:12:11.903 [I] step=12 loss=0.7319 smoothed_loss=1.5227 lr=1.20e-06 grad_norm=6.1848 step_time=0.4468s data_time=0.1681s it/s=1.629 eta_to_20=4.9s max_cuda_memory=76.13GB grad_action_out=2.7925 grad_left_action_in=0.1038 grad_left_expert=2.4508 grad_right_action_in=0.0680 grad_right_expert=0.8716 grad_shared_backbone=4.8333 (22934:train_pytorch.py:882)
148
+
149
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
150
+ 20:12:12.684 [I] step=13 loss=0.8788 smoothed_loss=1.4583 lr=1.29e-06 grad_norm=20.2227 step_time=0.5649s data_time=0.2162s it/s=1.282 eta_to_20=5.5s max_cuda_memory=76.13GB grad_action_out=3.0176 grad_left_action_in=0.1300 grad_left_expert=2.8276 grad_right_action_in=0.4691 grad_right_expert=12.9156 grad_shared_backbone=11.2157 (22934:train_pytorch.py:882)
151
+
152
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
153
+ 20:12:13.370 [I] step=14 loss=1.2741 smoothed_loss=1.4399 lr=1.39e-06 grad_norm=7.8620 step_time=0.5100s data_time=0.1755s it/s=1.461 eta_to_20=4.1s max_cuda_memory=76.13GB grad_action_out=4.2194 grad_left_action_in=0.1433 grad_left_expert=2.8949 grad_right_action_in=0.0958 grad_right_expert=1.0096 grad_shared_backbone=5.8070 (22934:train_pytorch.py:882)
154
+
155
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
156
+ 20:12:14.027 [I] step=15 loss=2.3729 smoothed_loss=1.5332 lr=1.49e-06 grad_norm=19.3589 step_time=0.4678s data_time=0.1899s it/s=1.523 eta_to_20=3.3s max_cuda_memory=76.13GB grad_action_out=7.2135 grad_left_action_in=0.2665 grad_left_expert=7.5354 grad_right_action_in=0.5496 grad_right_expert=4.5295 grad_shared_backbone=15.2257 (22934:train_pytorch.py:882)
157
+
158
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
159
+ 20:12:14.874 [I] step=16 loss=0.8147 smoothed_loss=1.4613 lr=1.59e-06 grad_norm=7.7365 step_time=0.5547s data_time=0.2919s it/s=1.183 eta_to_20=3.4s max_cuda_memory=76.13GB grad_action_out=2.7237 grad_left_action_in=0.1192 grad_left_expert=2.8822 grad_right_action_in=0.0900 grad_right_expert=0.8615 grad_shared_backbone=6.4500 (22934:train_pytorch.py:882)
160
+
161
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
162
+ 20:12:15.664 [I] step=17 loss=1.4318 smoothed_loss=1.4584 lr=1.69e-06 grad_norm=19.5452 step_time=0.5511s data_time=0.2382s it/s=1.268 eta_to_20=2.4s max_cuda_memory=76.13GB grad_action_out=3.9684 grad_left_action_in=0.3767 grad_left_expert=7.8636 grad_right_action_in=0.1317 grad_right_expert=1.6847 grad_shared_backbone=16.9059 (22934:train_pytorch.py:882)
163
+
164
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
165
+ 20:12:16.588 [I] step=18 loss=0.4858 smoothed_loss=1.3611 lr=1.79e-06 grad_norm=3.4382 step_time=0.6846s data_time=0.2403s it/s=1.082 eta_to_20=1.8s max_cuda_memory=76.13GB grad_action_out=1.9985 grad_left_action_in=0.0749 grad_left_expert=1.4156 grad_right_action_in=0.0390 grad_right_expert=0.5210 grad_shared_backbone=2.3369 (22934:train_pytorch.py:882)
166
+
167
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
168
+ 20:12:17.216 [I] step=19 loss=0.7492 smoothed_loss=1.2999 lr=1.89e-06 grad_norm=6.9377 step_time=0.4815s data_time=0.1459s it/s=1.596 eta_to_20=0.6s max_cuda_memory=76.13GB grad_action_out=3.7478 grad_left_action_in=0.1113 grad_left_expert=2.8716 grad_right_action_in=0.0729 grad_right_expert=1.0784 grad_shared_backbone=4.9024 (22934:train_pytorch.py:882)
169
+
170
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
171
+ 20:12:18.186 [I] step=20 loss=0.6038 smoothed_loss=1.2303 lr=1.99e-06 grad_norm=7.0090 step_time=0.7175s data_time=0.2525s it/s=1.032 eta_to_20=0.0s max_cuda_memory=76.13GB grad_action_out=2.8786 grad_left_action_in=0.0890 grad_left_expert=2.7778 grad_right_action_in=0.0549 grad_right_expert=1.4578 grad_shared_backbone=5.5395 (22934:train_pytorch.py:882)
172
+ 20:19:39.399 [I] Saved checkpoint at step 20 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20 (22934:train_pytorch.py:378)
173
+
artifacts/twin_split_expert_bringup_20260310/sanity_checks/split_communicating_invariants.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ weight_loading_missing_keys: []
2
+ weight_loading_unexpected_keys: []
3
+ identical_branch_suffix_max_abs_diff: 0.00000000
4
+ identical_branch_suffix_match: True
5
+ left_branch_invariance_check: skipped_for_split_communicating
6
+ right_branch_invariance_check: skipped_for_split_communicating
artifacts/twin_split_expert_bringup_20260310/sanity_checks/split_independent_invariants.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ weight_loading_missing_keys: []
2
+ weight_loading_unexpected_keys: []
3
+ identical_branch_suffix_max_abs_diff: 0.00000000
4
+ identical_branch_suffix_match: True
5
+ left_branch_invariance_max_abs_diff: 0.00000000
6
+ right_branch_invariance_max_abs_diff: 0.00000000
7
+ left_branch_invariant: True
8
+ right_branch_invariant: True
openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ 0.1060400903224945,
6
+ 0.20956481993198395,
7
+ 0.09184283018112183,
8
+ -1.9880156517028809,
9
+ -0.0493016354739666,
10
+ 2.200657844543457,
11
+ 1.0759528875350952,
12
+ 0.5274205207824707,
13
+ 0.015858052298426628,
14
+ 0.08288046717643738,
15
+ -0.06887393444776535,
16
+ -1.9063940048217773,
17
+ 0.048101384192705154,
18
+ 2.0108680725097656,
19
+ -0.9290279746055603,
20
+ 0.8440811038017273
21
+ ],
22
+ "std": [
23
+ 0.09207697212696075,
24
+ 0.31317394971847534,
25
+ 0.08127228915691376,
26
+ 0.5381271243095398,
27
+ 0.060932666063308716,
28
+ 0.5120578408241272,
29
+ 0.2252715528011322,
30
+ 0.49924755096435547,
31
+ 0.20230208337306976,
32
+ 0.3140813112258911,
33
+ 0.2166559249162674,
34
+ 0.5264315009117126,
35
+ 0.2017098367214203,
36
+ 0.47457119822502136,
37
+ 1.1786143779754639,
38
+ 0.3627784252166748
39
+ ],
40
+ "q01": [
41
+ -5.003212208976038e-6,
42
+ -0.3880260119378567,
43
+ -0.000022378295398084447,
44
+ -2.9896268159270285,
45
+ -0.23859235458523037,
46
+ 1.2214620113372803,
47
+ 0.7853838205337524,
48
+ 0.0,
49
+ -0.6156159267425537,
50
+ -0.4149419299602509,
51
+ -0.9436963497161865,
52
+ -2.8839772893309594,
53
+ -0.9050835555553436,
54
+ 1.2214889526367188,
55
+ -2.795644993972778,
56
+ 0.0
57
+ ],
58
+ "q99": [
59
+ 0.31251292623596166,
60
+ 0.8654691616654395,
61
+ 0.35174238551805614,
62
+ -0.8763489654541017,
63
+ 0.052121943226456635,
64
+ 2.9720811741352082,
65
+ 1.6446517068386077,
66
+ 0.9998,
67
+ 0.7670312994003294,
68
+ 0.9607345881462095,
69
+ 0.6871046730995181,
70
+ -0.874981226503849,
71
+ 0.35838486022949234,
72
+ 2.977322695541382,
73
+ 0.7847790859222412,
74
+ 0.9998
75
+ ]
76
+ },
77
+ "actions": {
78
+ "mean": [
79
+ 0.03630240634083748,
80
+ 0.09624441713094711,
81
+ 0.01367407850921154,
82
+ -0.2224988043308258,
83
+ -0.027621738612651825,
84
+ 0.27498844265937805,
85
+ 0.08921869844198227,
86
+ 0.4565052390098572,
87
+ -0.0037808618508279324,
88
+ 0.09113847464323044,
89
+ -0.0037622663658112288,
90
+ -0.22537092864513397,
91
+ 0.008262325078248978,
92
+ 0.2679949402809143,
93
+ -0.574528694152832,
94
+ 0.7731654047966003
95
+ ],
96
+ "std": [
97
+ 0.049951743334531784,
98
+ 0.29268014430999756,
99
+ 0.06852161139249802,
100
+ 0.3647724986076355,
101
+ 0.07012807577848434,
102
+ 0.27129024267196655,
103
+ 0.11329206824302673,
104
+ 0.49810460209846497,
105
+ 0.09174609929323196,
106
+ 0.22704003751277924,
107
+ 0.10693909972906113,
108
+ 0.2572591006755829,
109
+ 0.11801817268133163,
110
+ 0.12355879694223404,
111
+ 0.35835781693458557,
112
+ 0.4187847375869751
113
+ ],
114
+ "q01": [
115
+ -0.0005862064361572272,
116
+ -0.38811749875545504,
117
+ -0.255800724029541,
118
+ -0.8347694625854493,
119
+ -0.35145472717285153,
120
+ -0.0015478792190551753,
121
+ -0.0005817413330078125,
122
+ 0.0,
123
+ -0.2644369697570801,
124
+ -0.351582763671875,
125
+ -0.3696939945220947,
126
+ -0.7309195489883423,
127
+ -0.3354415845870973,
128
+ -0.000662303924560547,
129
+ -0.934731125831604,
130
+ 0.0
131
+ ],
132
+ "q99": [
133
+ 0.20790743064880374,
134
+ 0.811985669732094,
135
+ 0.19612836360931396,
136
+ 0.3395817384719848,
137
+ 0.05568643188476563,
138
+ 0.7526534500122071,
139
+ 0.4252559995651245,
140
+ 0.9998,
141
+ 0.2558236026763916,
142
+ 0.5890134544372558,
143
+ 0.35822071075439466,
144
+ 0.18567593073844912,
145
+ 0.44035053730010976,
146
+ 0.4996662902832031,
147
+ 0.1265523338317871,
148
+ 0.9998
149
+ ]
150
+ }
151
+ }
152
+ }
openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ 0.1060400903224945,
6
+ 0.20956481993198395,
7
+ 0.09184283018112183,
8
+ -1.9880156517028809,
9
+ -0.0493016354739666,
10
+ 2.200657844543457,
11
+ 1.0759528875350952,
12
+ 0.5274205207824707,
13
+ 0.015858052298426628,
14
+ 0.08288046717643738,
15
+ -0.06887393444776535,
16
+ -1.9063940048217773,
17
+ 0.048101384192705154,
18
+ 2.0108680725097656,
19
+ -0.9290279746055603,
20
+ 0.8440811038017273
21
+ ],
22
+ "std": [
23
+ 0.09207697212696075,
24
+ 0.31317394971847534,
25
+ 0.08127228915691376,
26
+ 0.5381271243095398,
27
+ 0.060932666063308716,
28
+ 0.5120578408241272,
29
+ 0.2252715528011322,
30
+ 0.49924755096435547,
31
+ 0.20230208337306976,
32
+ 0.3140813112258911,
33
+ 0.2166559249162674,
34
+ 0.5264315009117126,
35
+ 0.2017098367214203,
36
+ 0.47457119822502136,
37
+ 1.1786143779754639,
38
+ 0.3627784252166748
39
+ ],
40
+ "q01": [
41
+ -5.003212208976038e-6,
42
+ -0.3880260119378567,
43
+ -0.000022378295398084447,
44
+ -2.9896268159270285,
45
+ -0.23859235458523037,
46
+ 1.2214620113372803,
47
+ 0.7853838205337524,
48
+ 0.0,
49
+ -0.6156159267425537,
50
+ -0.4149419299602509,
51
+ -0.9436963497161865,
52
+ -2.8839772893309594,
53
+ -0.9050835555553436,
54
+ 1.2214889526367188,
55
+ -2.795644993972778,
56
+ 0.0
57
+ ],
58
+ "q99": [
59
+ 0.31251292623596166,
60
+ 0.8654691616654395,
61
+ 0.35174238551805614,
62
+ -0.8763489654541017,
63
+ 0.052121943226456635,
64
+ 2.9720811741352082,
65
+ 1.6446517068386077,
66
+ 0.9998,
67
+ 0.7670312994003294,
68
+ 0.9607345881462095,
69
+ 0.6871046730995181,
70
+ -0.874981226503849,
71
+ 0.35838486022949234,
72
+ 2.977322695541382,
73
+ 0.7847790859222412,
74
+ 0.9998
75
+ ]
76
+ },
77
+ "actions": {
78
+ "mean": [
79
+ 0.03630240634083748,
80
+ 0.09624441713094711,
81
+ 0.01367407850921154,
82
+ -0.2224988043308258,
83
+ -0.027621738612651825,
84
+ 0.27498844265937805,
85
+ 0.08921869844198227,
86
+ 0.4565052390098572,
87
+ -0.0037808618508279324,
88
+ 0.09113847464323044,
89
+ -0.0037622663658112288,
90
+ -0.22537092864513397,
91
+ 0.008262325078248978,
92
+ 0.2679949402809143,
93
+ -0.574528694152832,
94
+ 0.7731654047966003
95
+ ],
96
+ "std": [
97
+ 0.049951743334531784,
98
+ 0.29268014430999756,
99
+ 0.06852161139249802,
100
+ 0.3647724986076355,
101
+ 0.07012807577848434,
102
+ 0.27129024267196655,
103
+ 0.11329206824302673,
104
+ 0.49810460209846497,
105
+ 0.09174609929323196,
106
+ 0.22704003751277924,
107
+ 0.10693909972906113,
108
+ 0.2572591006755829,
109
+ 0.11801817268133163,
110
+ 0.12355879694223404,
111
+ 0.35835781693458557,
112
+ 0.4187847375869751
113
+ ],
114
+ "q01": [
115
+ -0.0005862064361572272,
116
+ -0.38811749875545504,
117
+ -0.255800724029541,
118
+ -0.8347694625854493,
119
+ -0.35145472717285153,
120
+ -0.0015478792190551753,
121
+ -0.0005817413330078125,
122
+ 0.0,
123
+ -0.2644369697570801,
124
+ -0.351582763671875,
125
+ -0.3696939945220947,
126
+ -0.7309195489883423,
127
+ -0.3354415845870973,
128
+ -0.000662303924560547,
129
+ -0.934731125831604,
130
+ 0.0
131
+ ],
132
+ "q99": [
133
+ 0.20790743064880374,
134
+ 0.811985669732094,
135
+ 0.19612836360931396,
136
+ 0.3395817384719848,
137
+ 0.05568643188476563,
138
+ 0.7526534500122071,
139
+ 0.4252559995651245,
140
+ 0.9998,
141
+ 0.2558236026763916,
142
+ 0.5890134544372558,
143
+ 0.35822071075439466,
144
+ 0.18567593073844912,
145
+ 0.44035053730010976,
146
+ 0.4996662902832031,
147
+ 0.1265523338317871,
148
+ 0.9998
149
+ ]
150
+ }
151
+ }
152
+ }
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/assets/lsnu/twin_dual_push_128_train/norm_stats.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ 0.1060400903224945,
6
+ 0.20956481993198395,
7
+ 0.09184283018112183,
8
+ -1.9880156517028809,
9
+ -0.0493016354739666,
10
+ 2.200657844543457,
11
+ 1.0759528875350952,
12
+ 0.5274205207824707,
13
+ 0.015858052298426628,
14
+ 0.08288046717643738,
15
+ -0.06887393444776535,
16
+ -1.9063940048217773,
17
+ 0.048101384192705154,
18
+ 2.0108680725097656,
19
+ -0.9290279746055603,
20
+ 0.8440811038017273
21
+ ],
22
+ "std": [
23
+ 0.09207697212696075,
24
+ 0.31317394971847534,
25
+ 0.08127228915691376,
26
+ 0.5381271243095398,
27
+ 0.060932666063308716,
28
+ 0.5120578408241272,
29
+ 0.2252715528011322,
30
+ 0.49924755096435547,
31
+ 0.20230208337306976,
32
+ 0.3140813112258911,
33
+ 0.2166559249162674,
34
+ 0.5264315009117126,
35
+ 0.2017098367214203,
36
+ 0.47457119822502136,
37
+ 1.1786143779754639,
38
+ 0.3627784252166748
39
+ ],
40
+ "q01": [
41
+ -5.003212208976038e-6,
42
+ -0.3880260119378567,
43
+ -0.000022378295398084447,
44
+ -2.9896268159270285,
45
+ -0.23859235458523037,
46
+ 1.2214620113372803,
47
+ 0.7853838205337524,
48
+ 0.0,
49
+ -0.6156159267425537,
50
+ -0.4149419299602509,
51
+ -0.9436963497161865,
52
+ -2.8839772893309594,
53
+ -0.9050835555553436,
54
+ 1.2214889526367188,
55
+ -2.795644993972778,
56
+ 0.0
57
+ ],
58
+ "q99": [
59
+ 0.31251292623596166,
60
+ 0.8654691616654395,
61
+ 0.35174238551805614,
62
+ -0.8763489654541017,
63
+ 0.052121943226456635,
64
+ 2.9720811741352082,
65
+ 1.6446517068386077,
66
+ 0.9998,
67
+ 0.7670312994003294,
68
+ 0.9607345881462095,
69
+ 0.6871046730995181,
70
+ -0.874981226503849,
71
+ 0.35838486022949234,
72
+ 2.977322695541382,
73
+ 0.7847790859222412,
74
+ 0.9998
75
+ ]
76
+ },
77
+ "actions": {
78
+ "mean": [
79
+ 0.03630240634083748,
80
+ 0.09624441713094711,
81
+ 0.01367407850921154,
82
+ -0.2224988043308258,
83
+ -0.027621738612651825,
84
+ 0.27498844265937805,
85
+ 0.08921869844198227,
86
+ 0.4565052390098572,
87
+ -0.0037808618508279324,
88
+ 0.09113847464323044,
89
+ -0.0037622663658112288,
90
+ -0.22537092864513397,
91
+ 0.008262325078248978,
92
+ 0.2679949402809143,
93
+ -0.574528694152832,
94
+ 0.7731654047966003
95
+ ],
96
+ "std": [
97
+ 0.049951743334531784,
98
+ 0.29268014430999756,
99
+ 0.06852161139249802,
100
+ 0.3647724986076355,
101
+ 0.07012807577848434,
102
+ 0.27129024267196655,
103
+ 0.11329206824302673,
104
+ 0.49810460209846497,
105
+ 0.09174609929323196,
106
+ 0.22704003751277924,
107
+ 0.10693909972906113,
108
+ 0.2572591006755829,
109
+ 0.11801817268133163,
110
+ 0.12355879694223404,
111
+ 0.35835781693458557,
112
+ 0.4187847375869751
113
+ ],
114
+ "q01": [
115
+ -0.0005862064361572272,
116
+ -0.38811749875545504,
117
+ -0.255800724029541,
118
+ -0.8347694625854493,
119
+ -0.35145472717285153,
120
+ -0.0015478792190551753,
121
+ -0.0005817413330078125,
122
+ 0.0,
123
+ -0.2644369697570801,
124
+ -0.351582763671875,
125
+ -0.3696939945220947,
126
+ -0.7309195489883423,
127
+ -0.3354415845870973,
128
+ -0.000662303924560547,
129
+ -0.934731125831604,
130
+ 0.0
131
+ ],
132
+ "q99": [
133
+ 0.20790743064880374,
134
+ 0.811985669732094,
135
+ 0.19612836360931396,
136
+ 0.3395817384719848,
137
+ 0.05568643188476563,
138
+ 0.7526534500122071,
139
+ 0.4252559995651245,
140
+ 0.9998,
141
+ 0.2558236026763916,
142
+ 0.5890134544372558,
143
+ 0.35822071075439466,
144
+ 0.18567593073844912,
145
+ 0.44035053730010976,
146
+ 0.4996662902832031,
147
+ 0.1265523338317871,
148
+ 0.9998
149
+ ]
150
+ }
151
+ }
152
+ }
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/metadata.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61e6cd3a4c82f532df9754b41acaa1702add6fed61c90bd1e302f1ee902b13cd
3
+ size 3044
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5729412bbd417c49d081aebe4be29d5dc462be0a5e46b00033e942bfaa6f82ba
3
+ size 17232229008
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52e95638a814c7b6be696fa5f2fa5353a7d6c9679d5a7fc88a95569b35ff9bf5
3
+ size 29412931288
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/assets/lsnu/twin_dual_push_128_train/norm_stats.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ 0.1060400903224945,
6
+ 0.20956481993198395,
7
+ 0.09184283018112183,
8
+ -1.9880156517028809,
9
+ -0.0493016354739666,
10
+ 2.200657844543457,
11
+ 1.0759528875350952,
12
+ 0.5274205207824707,
13
+ 0.015858052298426628,
14
+ 0.08288046717643738,
15
+ -0.06887393444776535,
16
+ -1.9063940048217773,
17
+ 0.048101384192705154,
18
+ 2.0108680725097656,
19
+ -0.9290279746055603,
20
+ 0.8440811038017273
21
+ ],
22
+ "std": [
23
+ 0.09207697212696075,
24
+ 0.31317394971847534,
25
+ 0.08127228915691376,
26
+ 0.5381271243095398,
27
+ 0.060932666063308716,
28
+ 0.5120578408241272,
29
+ 0.2252715528011322,
30
+ 0.49924755096435547,
31
+ 0.20230208337306976,
32
+ 0.3140813112258911,
33
+ 0.2166559249162674,
34
+ 0.5264315009117126,
35
+ 0.2017098367214203,
36
+ 0.47457119822502136,
37
+ 1.1786143779754639,
38
+ 0.3627784252166748
39
+ ],
40
+ "q01": [
41
+ -5.003212208976038e-6,
42
+ -0.3880260119378567,
43
+ -0.000022378295398084447,
44
+ -2.9896268159270285,
45
+ -0.23859235458523037,
46
+ 1.2214620113372803,
47
+ 0.7853838205337524,
48
+ 0.0,
49
+ -0.6156159267425537,
50
+ -0.4149419299602509,
51
+ -0.9436963497161865,
52
+ -2.8839772893309594,
53
+ -0.9050835555553436,
54
+ 1.2214889526367188,
55
+ -2.795644993972778,
56
+ 0.0
57
+ ],
58
+ "q99": [
59
+ 0.31251292623596166,
60
+ 0.8654691616654395,
61
+ 0.35174238551805614,
62
+ -0.8763489654541017,
63
+ 0.052121943226456635,
64
+ 2.9720811741352082,
65
+ 1.6446517068386077,
66
+ 0.9998,
67
+ 0.7670312994003294,
68
+ 0.9607345881462095,
69
+ 0.6871046730995181,
70
+ -0.874981226503849,
71
+ 0.35838486022949234,
72
+ 2.977322695541382,
73
+ 0.7847790859222412,
74
+ 0.9998
75
+ ]
76
+ },
77
+ "actions": {
78
+ "mean": [
79
+ 0.03630240634083748,
80
+ 0.09624441713094711,
81
+ 0.01367407850921154,
82
+ -0.2224988043308258,
83
+ -0.027621738612651825,
84
+ 0.27498844265937805,
85
+ 0.08921869844198227,
86
+ 0.4565052390098572,
87
+ -0.0037808618508279324,
88
+ 0.09113847464323044,
89
+ -0.0037622663658112288,
90
+ -0.22537092864513397,
91
+ 0.008262325078248978,
92
+ 0.2679949402809143,
93
+ -0.574528694152832,
94
+ 0.7731654047966003
95
+ ],
96
+ "std": [
97
+ 0.049951743334531784,
98
+ 0.29268014430999756,
99
+ 0.06852161139249802,
100
+ 0.3647724986076355,
101
+ 0.07012807577848434,
102
+ 0.27129024267196655,
103
+ 0.11329206824302673,
104
+ 0.49810460209846497,
105
+ 0.09174609929323196,
106
+ 0.22704003751277924,
107
+ 0.10693909972906113,
108
+ 0.2572591006755829,
109
+ 0.11801817268133163,
110
+ 0.12355879694223404,
111
+ 0.35835781693458557,
112
+ 0.4187847375869751
113
+ ],
114
+ "q01": [
115
+ -0.0005862064361572272,
116
+ -0.38811749875545504,
117
+ -0.255800724029541,
118
+ -0.8347694625854493,
119
+ -0.35145472717285153,
120
+ -0.0015478792190551753,
121
+ -0.0005817413330078125,
122
+ 0.0,
123
+ -0.2644369697570801,
124
+ -0.351582763671875,
125
+ -0.3696939945220947,
126
+ -0.7309195489883423,
127
+ -0.3354415845870973,
128
+ -0.000662303924560547,
129
+ -0.934731125831604,
130
+ 0.0
131
+ ],
132
+ "q99": [
133
+ 0.20790743064880374,
134
+ 0.811985669732094,
135
+ 0.19612836360931396,
136
+ 0.3395817384719848,
137
+ 0.05568643188476563,
138
+ 0.7526534500122071,
139
+ 0.4252559995651245,
140
+ 0.9998,
141
+ 0.2558236026763916,
142
+ 0.5890134544372558,
143
+ 0.35822071075439466,
144
+ 0.18567593073844912,
145
+ 0.44035053730010976,
146
+ 0.4996662902832031,
147
+ 0.1265523338317871,
148
+ 0.9998
149
+ ]
150
+ }
151
+ }
152
+ }
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/metadata.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f437d3d865b5a6c0c8a907f6e68f92b55ee83c0dc09a12dcc47ab74a82dbf6f
3
+ size 3044
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39ffa219f42d07ad472b0c128d2dc4bad0b33bf8750f9d4d13a5afb39debe72b
3
+ size 17232229008
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72592549367cbce929b8af3f0bc47054126facb774bfd4aca067a9f579b29db1
3
+ size 29412931288
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/assets/lsnu/twin_dual_push_128_train/norm_stats.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ 0.1060400903224945,
6
+ 0.20956481993198395,
7
+ 0.09184283018112183,
8
+ -1.9880156517028809,
9
+ -0.0493016354739666,
10
+ 2.200657844543457,
11
+ 1.0759528875350952,
12
+ 0.5274205207824707,
13
+ 0.015858052298426628,
14
+ 0.08288046717643738,
15
+ -0.06887393444776535,
16
+ -1.9063940048217773,
17
+ 0.048101384192705154,
18
+ 2.0108680725097656,
19
+ -0.9290279746055603,
20
+ 0.8440811038017273
21
+ ],
22
+ "std": [
23
+ 0.09207697212696075,
24
+ 0.31317394971847534,
25
+ 0.08127228915691376,
26
+ 0.5381271243095398,
27
+ 0.060932666063308716,
28
+ 0.5120578408241272,
29
+ 0.2252715528011322,
30
+ 0.49924755096435547,
31
+ 0.20230208337306976,
32
+ 0.3140813112258911,
33
+ 0.2166559249162674,
34
+ 0.5264315009117126,
35
+ 0.2017098367214203,
36
+ 0.47457119822502136,
37
+ 1.1786143779754639,
38
+ 0.3627784252166748
39
+ ],
40
+ "q01": [
41
+ -5.003212208976038e-6,
42
+ -0.3880260119378567,
43
+ -0.000022378295398084447,
44
+ -2.9896268159270285,
45
+ -0.23859235458523037,
46
+ 1.2214620113372803,
47
+ 0.7853838205337524,
48
+ 0.0,
49
+ -0.6156159267425537,
50
+ -0.4149419299602509,
51
+ -0.9436963497161865,
52
+ -2.8839772893309594,
53
+ -0.9050835555553436,
54
+ 1.2214889526367188,
55
+ -2.795644993972778,
56
+ 0.0
57
+ ],
58
+ "q99": [
59
+ 0.31251292623596166,
60
+ 0.8654691616654395,
61
+ 0.35174238551805614,
62
+ -0.8763489654541017,
63
+ 0.052121943226456635,
64
+ 2.9720811741352082,
65
+ 1.6446517068386077,
66
+ 0.9998,
67
+ 0.7670312994003294,
68
+ 0.9607345881462095,
69
+ 0.6871046730995181,
70
+ -0.874981226503849,
71
+ 0.35838486022949234,
72
+ 2.977322695541382,
73
+ 0.7847790859222412,
74
+ 0.9998
75
+ ]
76
+ },
77
+ "actions": {
78
+ "mean": [
79
+ 0.03630240634083748,
80
+ 0.09624441713094711,
81
+ 0.01367407850921154,
82
+ -0.2224988043308258,
83
+ -0.027621738612651825,
84
+ 0.27498844265937805,
85
+ 0.08921869844198227,
86
+ 0.4565052390098572,
87
+ -0.0037808618508279324,
88
+ 0.09113847464323044,
89
+ -0.0037622663658112288,
90
+ -0.22537092864513397,
91
+ 0.008262325078248978,
92
+ 0.2679949402809143,
93
+ -0.574528694152832,
94
+ 0.7731654047966003
95
+ ],
96
+ "std": [
97
+ 0.049951743334531784,
98
+ 0.29268014430999756,
99
+ 0.06852161139249802,
100
+ 0.3647724986076355,
101
+ 0.07012807577848434,
102
+ 0.27129024267196655,
103
+ 0.11329206824302673,
104
+ 0.49810460209846497,
105
+ 0.09174609929323196,
106
+ 0.22704003751277924,
107
+ 0.10693909972906113,
108
+ 0.2572591006755829,
109
+ 0.11801817268133163,
110
+ 0.12355879694223404,
111
+ 0.35835781693458557,
112
+ 0.4187847375869751
113
+ ],
114
+ "q01": [
115
+ -0.0005862064361572272,
116
+ -0.38811749875545504,
117
+ -0.255800724029541,
118
+ -0.8347694625854493,
119
+ -0.35145472717285153,
120
+ -0.0015478792190551753,
121
+ -0.0005817413330078125,
122
+ 0.0,
123
+ -0.2644369697570801,
124
+ -0.351582763671875,
125
+ -0.3696939945220947,
126
+ -0.7309195489883423,
127
+ -0.3354415845870973,
128
+ -0.000662303924560547,
129
+ -0.934731125831604,
130
+ 0.0
131
+ ],
132
+ "q99": [
133
+ 0.20790743064880374,
134
+ 0.811985669732094,
135
+ 0.19612836360931396,
136
+ 0.3395817384719848,
137
+ 0.05568643188476563,
138
+ 0.7526534500122071,
139
+ 0.4252559995651245,
140
+ 0.9998,
141
+ 0.2558236026763916,
142
+ 0.5890134544372558,
143
+ 0.35822071075439466,
144
+ 0.18567593073844912,
145
+ 0.44035053730010976,
146
+ 0.4996662902832031,
147
+ 0.1265523338317871,
148
+ 0.9998
149
+ ]
150
+ }
151
+ }
152
+ }
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/metadata.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0ca730216bba32605f1e94e967e0b4b88b237f6c234c668da521bd445bc33ff
3
+ size 3044
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2177e5fc2fdaa16dfa6711a21213c952b2913d36e290759003a58554ef7bd9f9
3
+ size 17232228840
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da191448fd206e9e012789f3a03faf31ff5694800dc621f163c4fa1af0295ee1
3
+ size 29412930337
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/assets/lsnu/twin_dual_push_128_train/norm_stats.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ 0.1060400903224945,
6
+ 0.20956481993198395,
7
+ 0.09184283018112183,
8
+ -1.9880156517028809,
9
+ -0.0493016354739666,
10
+ 2.200657844543457,
11
+ 1.0759528875350952,
12
+ 0.5274205207824707,
13
+ 0.015858052298426628,
14
+ 0.08288046717643738,
15
+ -0.06887393444776535,
16
+ -1.9063940048217773,
17
+ 0.048101384192705154,
18
+ 2.0108680725097656,
19
+ -0.9290279746055603,
20
+ 0.8440811038017273
21
+ ],
22
+ "std": [
23
+ 0.09207697212696075,
24
+ 0.31317394971847534,
25
+ 0.08127228915691376,
26
+ 0.5381271243095398,
27
+ 0.060932666063308716,
28
+ 0.5120578408241272,
29
+ 0.2252715528011322,
30
+ 0.49924755096435547,
31
+ 0.20230208337306976,
32
+ 0.3140813112258911,
33
+ 0.2166559249162674,
34
+ 0.5264315009117126,
35
+ 0.2017098367214203,
36
+ 0.47457119822502136,
37
+ 1.1786143779754639,
38
+ 0.3627784252166748
39
+ ],
40
+ "q01": [
41
+ -5.003212208976038e-6,
42
+ -0.3880260119378567,
43
+ -0.000022378295398084447,
44
+ -2.9896268159270285,
45
+ -0.23859235458523037,
46
+ 1.2214620113372803,
47
+ 0.7853838205337524,
48
+ 0.0,
49
+ -0.6156159267425537,
50
+ -0.4149419299602509,
51
+ -0.9436963497161865,
52
+ -2.8839772893309594,
53
+ -0.9050835555553436,
54
+ 1.2214889526367188,
55
+ -2.795644993972778,
56
+ 0.0
57
+ ],
58
+ "q99": [
59
+ 0.31251292623596166,
60
+ 0.8654691616654395,
61
+ 0.35174238551805614,
62
+ -0.8763489654541017,
63
+ 0.052121943226456635,
64
+ 2.9720811741352082,
65
+ 1.6446517068386077,
66
+ 0.9998,
67
+ 0.7670312994003294,
68
+ 0.9607345881462095,
69
+ 0.6871046730995181,
70
+ -0.874981226503849,
71
+ 0.35838486022949234,
72
+ 2.977322695541382,
73
+ 0.7847790859222412,
74
+ 0.9998
75
+ ]
76
+ },
77
+ "actions": {
78
+ "mean": [
79
+ 0.03630240634083748,
80
+ 0.09624441713094711,
81
+ 0.01367407850921154,
82
+ -0.2224988043308258,
83
+ -0.027621738612651825,
84
+ 0.27498844265937805,
85
+ 0.08921869844198227,
86
+ 0.4565052390098572,
87
+ -0.0037808618508279324,
88
+ 0.09113847464323044,
89
+ -0.0037622663658112288,
90
+ -0.22537092864513397,
91
+ 0.008262325078248978,
92
+ 0.2679949402809143,
93
+ -0.574528694152832,
94
+ 0.7731654047966003
95
+ ],
96
+ "std": [
97
+ 0.049951743334531784,
98
+ 0.29268014430999756,
99
+ 0.06852161139249802,
100
+ 0.3647724986076355,
101
+ 0.07012807577848434,
102
+ 0.27129024267196655,
103
+ 0.11329206824302673,
104
+ 0.49810460209846497,
105
+ 0.09174609929323196,
106
+ 0.22704003751277924,
107
+ 0.10693909972906113,
108
+ 0.2572591006755829,
109
+ 0.11801817268133163,
110
+ 0.12355879694223404,
111
+ 0.35835781693458557,
112
+ 0.4187847375869751
113
+ ],
114
+ "q01": [
115
+ -0.0005862064361572272,
116
+ -0.38811749875545504,
117
+ -0.255800724029541,
118
+ -0.8347694625854493,
119
+ -0.35145472717285153,
120
+ -0.0015478792190551753,
121
+ -0.0005817413330078125,
122
+ 0.0,
123
+ -0.2644369697570801,
124
+ -0.351582763671875,
125
+ -0.3696939945220947,
126
+ -0.7309195489883423,
127
+ -0.3354415845870973,
128
+ -0.000662303924560547,
129
+ -0.934731125831604,
130
+ 0.0
131
+ ],
132
+ "q99": [
133
+ 0.20790743064880374,
134
+ 0.811985669732094,
135
+ 0.19612836360931396,
136
+ 0.3395817384719848,
137
+ 0.05568643188476563,
138
+ 0.7526534500122071,
139
+ 0.4252559995651245,
140
+ 0.9998,
141
+ 0.2558236026763916,
142
+ 0.5890134544372558,
143
+ 0.35822071075439466,
144
+ 0.18567593073844912,
145
+ 0.44035053730010976,
146
+ 0.4996662902832031,
147
+ 0.1265523338317871,
148
+ 0.9998
149
+ ]
150
+ }
151
+ }
152
+ }
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/metadata.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e44c886f451c1d306910d471cc73daf51ad51fd7acd092e2174a44318747ee02
3
+ size 3044
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c8bcb5f7990cf9d6623b6c55f4143be77d95fad7fb9e8edfcd3373f62d63ee8
3
+ size 17232228840
openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2567d911c1956025db067b45ade3aebe0933c0db81550235a5190157aad6a0
3
+ size 29412930337
openpi/run_logs/split_communicating_real_smoke3.log ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 19:55:02.788 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3 (22110:train_pytorch.py:533)
2
+ 19:55:02.789 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (22110:train_pytorch.py:552)
3
+ 19:55:02.865 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train (22110:config.py:234)
4
+ 19:55:02.867 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7ec79fca8910>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (22110:data_loader.py:284)
29
+ 19:55:09.225 [I] JAX version 0.5.3 available. (22110:config.py:125)
30
+ 19:55:34.099 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (22110:data_loader.py:148)
31
+ 19:55:34.205 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (22110:video_utils.py:36)
32
+ 19:56:38.376 [I] local_batch_size: 1 (22110:data_loader.py:365)
33
+ 19:58:25.969 [I] Enabled gradient checkpointing for PI0Pytorch model (22110:pi0_pytorch.py:138)
34
+ 19:58:25.971 [I] Enabled gradient checkpointing for memory optimization (22110:train_pytorch.py:624)
35
+ 19:58:25.972 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (22110:train_pytorch.py:493)
36
+ 19:58:25.972 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22110:train_pytorch.py:653)
37
+ 19:58:29.565 [I] Weight loading missing key count: 0 (22110:train_pytorch.py:657)
38
+ 19:58:29.566 [I] Weight loading missing keys: set() (22110:train_pytorch.py:658)
39
+ 19:58:29.566 [I] Weight loading unexpected key count: 0 (22110:train_pytorch.py:659)
40
+ 19:58:29.566 [I] Weight loading unexpected keys: [] (22110:train_pytorch.py:660)
41
+ 19:58:29.567 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22110:train_pytorch.py:661)
42
+ 19:58:29.571 [I] Running on: 963c158043aa | world_size=1 (22110:train_pytorch.py:701)
43
+ 19:58:29.571 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=3 (22110:train_pytorch.py:702)
44
+ 19:58:29.572 [I] Memory optimizations: gradient_checkpointing=True (22110:train_pytorch.py:705)
45
+ 19:58:29.572 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (22110:train_pytorch.py:706)
46
+ 19:58:29.573 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (22110:train_pytorch.py:707)
47
+ 19:58:29.573 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (22110:train_pytorch.py:710)
48
+ 19:58:29.573 [I] EMA is not supported for PyTorch training (22110:train_pytorch.py:713)
49
+ 19:58:29.574 [I] Training precision: float32 (22110:train_pytorch.py:714)
50
+ 19:58:29.590 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k (22110:train_pytorch.py:308)
51
+ 19:58:29.590 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (22110:train_pytorch.py:309)
52
+ 19:58:29.591 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (22110:train_pytorch.py:310)
53
+ 19:58:29.592 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (22110:train_pytorch.py:311)
54
+ 19:58:29.592 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22110:train_pytorch.py:312)
55
+ 19:58:29.592 [I] Model type: split_communicating (22110:train_pytorch.py:313)
56
+ 19:58:29.593 [I] Packed transforms active: True (22110:train_pytorch.py:314)
57
+ 19:58:29.593 [I] World size: 1 (22110:train_pytorch.py:315)
58
+ 19:58:29.594 [I] Batch size: local=1, global=1 (22110:train_pytorch.py:316)
59
+ 19:58:29.594 [I] num_workers: 0 (22110:train_pytorch.py:317)
60
+ 19:58:29.595 [I] Precision: float32 (22110:train_pytorch.py:318)
61
+ 19:58:29.595 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (22110:train_pytorch.py:319)
62
+ 19:58:29.595 [I] Save/log intervals: save_interval=3, log_interval=1 (22110:train_pytorch.py:326)
63
+ 19:58:29.596 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (22110:train_pytorch.py:327)
64
+ 19:58:29.596 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (22110:train_pytorch.py:328)
65
+ 19:58:29.597 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (22110:train_pytorch.py:329)
66
+ 19:58:29.597 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (22110:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 19:58:31.354 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22110:train_pytorch.py:831)
70
+ 19:58:31.355 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22110:train_pytorch.py:835)
71
+ 19:58:31.356 [I] debug_step=1 prompt_token_lengths=[75] (22110:train_pytorch.py:838)
72
+ 19:58:31.356 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (22110:train_pytorch.py:839)
73
+ 19:58:31.357 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (22110:train_pytorch.py:842)
74
+ 19:58:31.358 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22110:train_pytorch.py:845)
75
+ 19:58:31.372 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22110:train_pytorch.py:849)
76
+ 19:58:31.372 [I] debug_step=1 lr=9.96e-08 grad_norm=60.0472 data_time=0.3311s step_time=1.3966s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (22110:train_pytorch.py:854)
77
+ 19:58:31.373 [I] debug_step=1 grad_shared_backbone=36.9945 grad_left_action_in=2.3769 grad_right_action_in=1.7630 grad_left_expert=31.1244 grad_right_expert=27.8917 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 (22110:train_pytorch.py:862)
78
+ 19:58:31.374 [I] step=1 loss=3.8411 smoothed_loss=3.8411 lr=9.96e-08 grad_norm=60.0472 step_time=1.3966s data_time=0.3311s it/s=0.555 eta_to_3=3.6s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 grad_left_action_in=2.3769 grad_left_expert=31.1244 grad_right_action_in=1.7630 grad_right_expert=27.8917 grad_shared_backbone=36.9945 (22110:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 19:58:32.164 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22110:train_pytorch.py:831)
82
+ 19:58:32.165 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22110:train_pytorch.py:835)
83
+ 19:58:32.166 [I] debug_step=2 prompt_token_lengths=[76] (22110:train_pytorch.py:838)
84
+ 19:58:32.166 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (22110:train_pytorch.py:839)
85
+ 19:58:32.167 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (22110:train_pytorch.py:842)
86
+ 19:58:32.168 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22110:train_pytorch.py:845)
87
+ 19:58:32.168 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22110:train_pytorch.py:849)
88
+ 19:58:32.169 [I] debug_step=2 lr=1.99e-07 grad_norm=10.7300 data_time=0.1812s step_time=0.6234s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22110:train_pytorch.py:854)
89
+ 19:58:32.169 [I] debug_step=2 grad_shared_backbone=9.2018 grad_left_action_in=0.1651 grad_right_action_in=0.1485 grad_left_expert=2.5032 grad_right_expert=2.3988 grad_action_out=4.0772 grad_cross_arm_comm=0.0166 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 (22110:train_pytorch.py:862)
90
+ 19:58:32.170 [I] step=2 loss=1.1389 smoothed_loss=3.5709 lr=1.99e-07 grad_norm=10.7300 step_time=0.6234s data_time=0.1812s it/s=1.257 eta_to_3=0.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.0772 grad_cross_arm_comm=0.0166 grad_left_action_in=0.1651 grad_left_expert=2.5032 grad_right_action_in=0.1485 grad_right_expert=2.3988 grad_shared_backbone=9.2018 (22110:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 19:58:32.708 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22110:train_pytorch.py:831)
94
+ 19:58:32.709 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22110:train_pytorch.py:835)
95
+ 19:58:32.709 [I] debug_step=3 prompt_token_lengths=[75] (22110:train_pytorch.py:838)
96
+ 19:58:32.710 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (22110:train_pytorch.py:839)
97
+ 19:58:32.711 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (22110:train_pytorch.py:842)
98
+ 19:58:32.711 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22110:train_pytorch.py:845)
99
+ 19:58:32.712 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22110:train_pytorch.py:849)
100
+ 19:58:32.712 [I] debug_step=3 lr=2.99e-07 grad_norm=343.7256 data_time=0.1312s step_time=0.4126s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22110:train_pytorch.py:854)
101
+ 19:58:32.713 [I] debug_step=3 grad_shared_backbone=215.2880 grad_left_action_in=4.7981 grad_right_action_in=9.5346 grad_left_expert=72.6437 grad_right_expert=227.6029 grad_action_out=23.7709 grad_cross_arm_comm=3.3555 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 (22110:train_pytorch.py:862)
102
+ 19:58:32.713 [I] step=3 loss=5.0518 smoothed_loss=3.7190 lr=2.99e-07 grad_norm=343.7256 step_time=0.4126s data_time=0.1312s it/s=1.843 eta_to_3=0.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=23.7709 grad_cross_arm_comm=3.3555 grad_left_action_in=4.7981 grad_left_expert=72.6437 grad_right_action_in=9.5346 grad_right_expert=227.6029 grad_shared_backbone=215.2880 (22110:train_pytorch.py:882)
103
+ 20:01:38.475 [I] Saved checkpoint at step 3 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_smoke3/3 (22110:train_pytorch.py:378)
104
+
openpi/run_logs/split_communicating_real_train20.log ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20:03:03.480 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20 (22938:train_pytorch.py:533)
2
+ 20:03:03.486 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (22938:train_pytorch.py:552)
3
+ 20:03:03.634 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train (22938:config.py:234)
4
+ 20:03:03.637 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7303f4ce5b90>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (22938:data_loader.py:393)
29
+ 20:03:15.223 [I] JAX version 0.5.3 available. (22938:config.py:125)
30
+ 20:04:19.283 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (22938:data_loader.py:148)
31
+ 20:04:19.378 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (22938:video_utils.py:36)
32
+ 20:09:10.375 [I] local_batch_size: 1 (22938:data_loader.py:474)
33
+ 20:11:59.735 [I] Enabled gradient checkpointing for PI0Pytorch model (22938:pi0_pytorch.py:138)
34
+ 20:11:59.737 [I] Enabled gradient checkpointing for memory optimization (22938:train_pytorch.py:624)
35
+ 20:11:59.738 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (22938:train_pytorch.py:493)
36
+ 20:11:59.738 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22938:train_pytorch.py:653)
37
+ 20:12:04.492 [I] Weight loading missing key count: 0 (22938:train_pytorch.py:657)
38
+ 20:12:04.492 [I] Weight loading missing keys: set() (22938:train_pytorch.py:658)
39
+ 20:12:04.492 [I] Weight loading unexpected key count: 0 (22938:train_pytorch.py:659)
40
+ 20:12:04.493 [I] Weight loading unexpected keys: [] (22938:train_pytorch.py:660)
41
+ 20:12:04.493 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22938:train_pytorch.py:661)
42
+ 20:12:04.497 [I] Running on: 963c158043aa | world_size=1 (22938:train_pytorch.py:701)
43
+ 20:12:04.498 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=20 (22938:train_pytorch.py:702)
44
+ 20:12:04.498 [I] Memory optimizations: gradient_checkpointing=True (22938:train_pytorch.py:705)
45
+ 20:12:04.499 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (22938:train_pytorch.py:706)
46
+ 20:12:04.499 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (22938:train_pytorch.py:707)
47
+ 20:12:04.499 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (22938:train_pytorch.py:710)
48
+ 20:12:04.500 [I] EMA is not supported for PyTorch training (22938:train_pytorch.py:713)
49
+ 20:12:04.500 [I] Training precision: float32 (22938:train_pytorch.py:714)
50
+ 20:12:04.509 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k (22938:train_pytorch.py:308)
51
+ 20:12:04.509 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (22938:train_pytorch.py:309)
52
+ 20:12:04.510 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (22938:train_pytorch.py:310)
53
+ 20:12:04.510 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (22938:train_pytorch.py:311)
54
+ 20:12:04.511 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_communicating_packed_from_single (22938:train_pytorch.py:312)
55
+ 20:12:04.511 [I] Model type: split_communicating (22938:train_pytorch.py:313)
56
+ 20:12:04.511 [I] Packed transforms active: True (22938:train_pytorch.py:314)
57
+ 20:12:04.512 [I] World size: 1 (22938:train_pytorch.py:315)
58
+ 20:12:04.512 [I] Batch size: local=1, global=1 (22938:train_pytorch.py:316)
59
+ 20:12:04.512 [I] num_workers: 0 (22938:train_pytorch.py:317)
60
+ 20:12:04.513 [I] Precision: float32 (22938:train_pytorch.py:318)
61
+ 20:12:04.513 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (22938:train_pytorch.py:319)
62
+ 20:12:04.513 [I] Save/log intervals: save_interval=20, log_interval=1 (22938:train_pytorch.py:326)
63
+ 20:12:04.514 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (22938:train_pytorch.py:327)
64
+ 20:12:04.514 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (22938:train_pytorch.py:328)
65
+ 20:12:04.515 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (22938:train_pytorch.py:329)
66
+ 20:12:04.515 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (22938:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 20:12:06.079 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
70
+ 20:12:06.080 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
71
+ 20:12:06.080 [I] debug_step=1 prompt_token_lengths=[75] (22938:train_pytorch.py:838)
72
+ 20:12:06.081 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (22938:train_pytorch.py:839)
73
+ 20:12:06.081 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (22938:train_pytorch.py:842)
74
+ 20:12:06.082 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
75
+ 20:12:06.097 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
76
+ 20:12:06.097 [I] debug_step=1 lr=9.96e-08 grad_norm=60.0473 data_time=0.2034s step_time=1.3216s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (22938:train_pytorch.py:854)
77
+ 20:12:06.098 [I] debug_step=1 grad_shared_backbone=36.9946 grad_left_action_in=2.3769 grad_right_action_in=1.7630 grad_left_expert=31.1244 grad_right_expert=27.8917 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 (22938:train_pytorch.py:862)
78
+ 20:12:06.099 [I] step=1 loss=3.8411 smoothed_loss=3.8411 lr=9.96e-08 grad_norm=60.0473 step_time=1.3216s data_time=0.2034s it/s=0.625 eta_to_20=30.4s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0050 cross_arm_attention_mass_layer_10=0.0188 cross_arm_attention_mass_layer_11=0.0278 cross_arm_attention_mass_layer_12=0.0052 cross_arm_attention_mass_layer_13=0.0161 cross_arm_attention_mass_layer_14=0.0091 cross_arm_attention_mass_layer_15=0.0342 cross_arm_attention_mass_layer_16=0.0457 cross_arm_attention_mass_layer_17=0.0454 cross_arm_attention_mass_layer_2=0.0217 cross_arm_attention_mass_layer_3=0.0086 cross_arm_attention_mass_layer_4=0.0279 cross_arm_attention_mass_layer_5=0.0355 cross_arm_attention_mass_layer_6=0.0179 cross_arm_attention_mass_layer_7=0.0369 cross_arm_attention_mass_layer_8=0.0183 cross_arm_attention_mass_layer_9=0.0153 cross_arm_comm_gate_layer_0=0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=0.0000 cross_arm_comm_gate_layer_11=0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=0.0000 cross_arm_comm_gate_layer_2=0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=0.0000 cross_arm_comm_gate_layer_5=0.0000 cross_arm_comm_gate_layer_6=0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=0.0000 cross_arm_comm_gate_layer_9=0.0000 grad_action_out=13.0720 grad_cross_arm_comm=3.1067 grad_left_action_in=2.3769 grad_left_expert=31.1244 grad_right_action_in=1.7630 grad_right_expert=27.8917 grad_shared_backbone=36.9946 (22938:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 20:12:07.067 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
82
+ 20:12:07.067 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
83
+ 20:12:07.068 [I] debug_step=2 prompt_token_lengths=[76] (22938:train_pytorch.py:838)
84
+ 20:12:07.069 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (22938:train_pytorch.py:839)
85
+ 20:12:07.069 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (22938:train_pytorch.py:842)
86
+ 20:12:07.070 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
87
+ 20:12:07.070 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
88
+ 20:12:07.071 [I] debug_step=2 lr=1.99e-07 grad_norm=10.7247 data_time=0.2263s step_time=0.7585s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
89
+ 20:12:07.071 [I] debug_step=2 grad_shared_backbone=9.1973 grad_left_action_in=0.1651 grad_right_action_in=0.1484 grad_left_expert=2.5023 grad_right_expert=2.3935 grad_action_out=4.0770 grad_cross_arm_comm=0.0166 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 (22938:train_pytorch.py:862)
90
+ 20:12:07.072 [I] step=2 loss=1.1389 smoothed_loss=3.5709 lr=1.99e-07 grad_norm=10.7247 step_time=0.7585s data_time=0.2263s it/s=1.028 eta_to_20=17.5s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0019 cross_arm_attention_mass_layer_10=0.0094 cross_arm_attention_mass_layer_11=0.0151 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0053 cross_arm_attention_mass_layer_14=0.0056 cross_arm_attention_mass_layer_15=0.0250 cross_arm_attention_mass_layer_16=0.0356 cross_arm_attention_mass_layer_17=0.0413 cross_arm_attention_mass_layer_2=0.0161 cross_arm_attention_mass_layer_3=0.0029 cross_arm_attention_mass_layer_4=0.0175 cross_arm_attention_mass_layer_5=0.0243 cross_arm_attention_mass_layer_6=0.0074 cross_arm_attention_mass_layer_7=0.0232 cross_arm_attention_mass_layer_8=0.0155 cross_arm_attention_mass_layer_9=0.0135 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=-0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.0770 grad_cross_arm_comm=0.0166 grad_left_action_in=0.1651 grad_left_expert=2.5023 grad_right_action_in=0.1484 grad_right_expert=2.3935 grad_shared_backbone=9.1973 (22938:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 20:12:07.689 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
94
+ 20:12:07.690 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
95
+ 20:12:07.690 [I] debug_step=3 prompt_token_lengths=[75] (22938:train_pytorch.py:838)
96
+ 20:12:07.691 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (22938:train_pytorch.py:839)
97
+ 20:12:07.692 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (22938:train_pytorch.py:842)
98
+ 20:12:07.692 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
99
+ 20:12:07.693 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
100
+ 20:12:07.693 [I] debug_step=3 lr=2.99e-07 grad_norm=343.6402 data_time=0.1557s step_time=0.4654s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
101
+ 20:12:07.694 [I] debug_step=3 grad_shared_backbone=215.2410 grad_left_action_in=4.7969 grad_right_action_in=9.5325 grad_left_expert=72.6238 grad_right_expert=227.5470 grad_action_out=23.7695 grad_cross_arm_comm=3.3548 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 (22938:train_pytorch.py:862)
102
+ 20:12:07.694 [I] step=3 loss=5.0512 smoothed_loss=3.7189 lr=2.99e-07 grad_norm=343.6402 step_time=0.4654s data_time=0.1557s it/s=1.609 eta_to_20=10.6s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0127 cross_arm_attention_mass_layer_10=0.0177 cross_arm_attention_mass_layer_11=0.0230 cross_arm_attention_mass_layer_12=0.0134 cross_arm_attention_mass_layer_13=0.0242 cross_arm_attention_mass_layer_14=0.0109 cross_arm_attention_mass_layer_15=0.0285 cross_arm_attention_mass_layer_16=0.0403 cross_arm_attention_mass_layer_17=0.0268 cross_arm_attention_mass_layer_2=0.0275 cross_arm_attention_mass_layer_3=0.0190 cross_arm_attention_mass_layer_4=0.0359 cross_arm_attention_mass_layer_5=0.0454 cross_arm_attention_mass_layer_6=0.0228 cross_arm_attention_mass_layer_7=0.0346 cross_arm_attention_mass_layer_8=0.0149 cross_arm_attention_mass_layer_9=0.0296 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=23.7695 grad_cross_arm_comm=3.3548 grad_left_action_in=4.7969 grad_left_expert=72.6238 grad_right_action_in=9.5325 grad_right_expert=227.5470 grad_shared_backbone=215.2410 (22938:train_pytorch.py:882)
103
+
104
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
105
+ 20:12:08.256 [I] debug_step=4 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
106
+ 20:12:08.257 [I] debug_step=4 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
107
+ 20:12:08.257 [I] debug_step=4 prompt_token_lengths=[78] (22938:train_pytorch.py:838)
108
+ 20:12:08.258 [I] debug_step=4 state_stats min=-0.7017 max=1.0004 mean=0.0553 std=0.3507 (22938:train_pytorch.py:839)
109
+ 20:12:08.258 [I] debug_step=4 action_stats min=-1.0014 max=1.0004 mean=-0.0683 std=0.4561 (22938:train_pytorch.py:842)
110
+ 20:12:08.259 [I] debug_step=4 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
111
+ 20:12:08.259 [I] debug_step=4 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
112
+ 20:12:08.260 [I] debug_step=4 lr=3.98e-07 grad_norm=8.7944 data_time=0.1312s step_time=0.4359s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
113
+ 20:12:08.260 [I] debug_step=4 grad_shared_backbone=7.5903 grad_left_action_in=0.1438 grad_right_action_in=0.1015 grad_left_expert=2.4058 grad_right_expert=1.2982 grad_action_out=3.3839 grad_cross_arm_comm=0.0147 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_2=0.0133 cross_arm_attention_mass_layer_3=0.0026 cross_arm_attention_mass_layer_4=0.0148 cross_arm_attention_mass_layer_5=0.0199 cross_arm_attention_mass_layer_6=0.0062 cross_arm_attention_mass_layer_7=0.0154 cross_arm_attention_mass_layer_8=0.0102 cross_arm_attention_mass_layer_9=0.0086 cross_arm_attention_mass_layer_10=0.0065 cross_arm_attention_mass_layer_11=0.0099 cross_arm_attention_mass_layer_12=0.0010 cross_arm_attention_mass_layer_13=0.0040 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0227 cross_arm_attention_mass_layer_16=0.0351 cross_arm_attention_mass_layer_17=0.0406 (22938:train_pytorch.py:862)
114
+ 20:12:08.261 [I] step=4 loss=1.1860 smoothed_loss=3.4656 lr=3.98e-07 grad_norm=8.7944 step_time=0.4359s data_time=0.1312s it/s=1.768 eta_to_20=9.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0065 cross_arm_attention_mass_layer_11=0.0099 cross_arm_attention_mass_layer_12=0.0010 cross_arm_attention_mass_layer_13=0.0040 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0227 cross_arm_attention_mass_layer_16=0.0351 cross_arm_attention_mass_layer_17=0.0406 cross_arm_attention_mass_layer_2=0.0133 cross_arm_attention_mass_layer_3=0.0026 cross_arm_attention_mass_layer_4=0.0148 cross_arm_attention_mass_layer_5=0.0199 cross_arm_attention_mass_layer_6=0.0062 cross_arm_attention_mass_layer_7=0.0154 cross_arm_attention_mass_layer_8=0.0102 cross_arm_attention_mass_layer_9=0.0086 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=3.3839 grad_cross_arm_comm=0.0147 grad_left_action_in=0.1438 grad_left_expert=2.4058 grad_right_action_in=0.1015 grad_right_expert=1.2982 grad_shared_backbone=7.5903 (22938:train_pytorch.py:882)
115
+
116
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
117
+ 20:12:08.933 [I] debug_step=5 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22938:train_pytorch.py:831)
118
+ 20:12:08.934 [I] debug_step=5 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22938:train_pytorch.py:835)
119
+ 20:12:08.934 [I] debug_step=5 prompt_token_lengths=[73] (22938:train_pytorch.py:838)
120
+ 20:12:08.935 [I] debug_step=5 state_stats min=-0.9599 max=1.0004 mean=0.0170 std=0.5364 (22938:train_pytorch.py:839)
121
+ 20:12:08.935 [I] debug_step=5 action_stats min=-1.0392 max=1.0004 mean=-0.0159 std=0.4488 (22938:train_pytorch.py:842)
122
+ 20:12:08.935 [I] debug_step=5 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22938:train_pytorch.py:845)
123
+ 20:12:08.936 [I] debug_step=5 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22938:train_pytorch.py:849)
124
+ 20:12:08.936 [I] debug_step=5 lr=4.98e-07 grad_norm=20.1429 data_time=0.2048s step_time=0.4721s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.30GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.30GB (22938:train_pytorch.py:854)
125
+ 20:12:08.937 [I] debug_step=5 grad_shared_backbone=16.7899 grad_left_action_in=0.2534 grad_right_action_in=0.3335 grad_left_expert=7.9047 grad_right_expert=3.6853 grad_action_out=6.0934 grad_cross_arm_comm=0.0735 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0020 cross_arm_attention_mass_layer_2=0.0178 cross_arm_attention_mass_layer_3=0.0039 cross_arm_attention_mass_layer_4=0.0203 cross_arm_attention_mass_layer_5=0.0294 cross_arm_attention_mass_layer_6=0.0106 cross_arm_attention_mass_layer_7=0.0286 cross_arm_attention_mass_layer_8=0.0175 cross_arm_attention_mass_layer_9=0.0157 cross_arm_attention_mass_layer_10=0.0148 cross_arm_attention_mass_layer_11=0.0181 cross_arm_attention_mass_layer_12=0.0023 cross_arm_attention_mass_layer_13=0.0128 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0232 cross_arm_attention_mass_layer_16=0.0437 cross_arm_attention_mass_layer_17=0.0451 (22938:train_pytorch.py:862)
126
+ 20:12:08.937 [I] step=5 loss=1.8898 smoothed_loss=3.3081 lr=4.98e-07 grad_norm=20.1429 step_time=0.4721s data_time=0.2048s it/s=1.481 eta_to_20=10.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0020 cross_arm_attention_mass_layer_10=0.0148 cross_arm_attention_mass_layer_11=0.0181 cross_arm_attention_mass_layer_12=0.0023 cross_arm_attention_mass_layer_13=0.0128 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0232 cross_arm_attention_mass_layer_16=0.0437 cross_arm_attention_mass_layer_17=0.0451 cross_arm_attention_mass_layer_2=0.0178 cross_arm_attention_mass_layer_3=0.0039 cross_arm_attention_mass_layer_4=0.0203 cross_arm_attention_mass_layer_5=0.0294 cross_arm_attention_mass_layer_6=0.0106 cross_arm_attention_mass_layer_7=0.0286 cross_arm_attention_mass_layer_8=0.0175 cross_arm_attention_mass_layer_9=0.0157 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=6.0934 grad_cross_arm_comm=0.0735 grad_left_action_in=0.2534 grad_left_expert=7.9047 grad_right_action_in=0.3335 grad_right_expert=3.6853 grad_shared_backbone=16.7899 (22938:train_pytorch.py:882)
127
+
128
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
129
+ 20:12:09.727 [I] step=6 loss=2.2855 smoothed_loss=3.2058 lr=5.98e-07 grad_norm=22.2605 step_time=0.5043s data_time=0.2901s it/s=1.267 eta_to_20=11.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0030 cross_arm_attention_mass_layer_10=0.0179 cross_arm_attention_mass_layer_11=0.0219 cross_arm_attention_mass_layer_12=0.0017 cross_arm_attention_mass_layer_13=0.0164 cross_arm_attention_mass_layer_14=0.0065 cross_arm_attention_mass_layer_15=0.0300 cross_arm_attention_mass_layer_16=0.0448 cross_arm_attention_mass_layer_17=0.0482 cross_arm_attention_mass_layer_2=0.0201 cross_arm_attention_mass_layer_3=0.0064 cross_arm_attention_mass_layer_4=0.0234 cross_arm_attention_mass_layer_5=0.0308 cross_arm_attention_mass_layer_6=0.0131 cross_arm_attention_mass_layer_7=0.0312 cross_arm_attention_mass_layer_8=0.0206 cross_arm_attention_mass_layer_9=0.0180 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=7.8420 grad_cross_arm_comm=0.1508 grad_left_action_in=0.2907 grad_left_expert=7.9865 grad_right_action_in=0.5407 grad_right_expert=5.3887 grad_shared_backbone=18.0209 (22938:train_pytorch.py:882)
130
+
131
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
132
+ 20:12:10.423 [I] step=7 loss=1.0335 smoothed_loss=2.9886 lr=6.97e-07 grad_norm=8.7208 step_time=0.4962s data_time=0.1999s it/s=1.439 eta_to_20=9.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0014 cross_arm_attention_mass_layer_10=0.0066 cross_arm_attention_mass_layer_11=0.0060 cross_arm_attention_mass_layer_12=0.0024 cross_arm_attention_mass_layer_13=0.0015 cross_arm_attention_mass_layer_14=0.0062 cross_arm_attention_mass_layer_15=0.0146 cross_arm_attention_mass_layer_16=0.0319 cross_arm_attention_mass_layer_17=0.0417 cross_arm_attention_mass_layer_2=0.0105 cross_arm_attention_mass_layer_3=0.0022 cross_arm_attention_mass_layer_4=0.0130 cross_arm_attention_mass_layer_5=0.0188 cross_arm_attention_mass_layer_6=0.0045 cross_arm_attention_mass_layer_7=0.0127 cross_arm_attention_mass_layer_8=0.0097 cross_arm_attention_mass_layer_9=0.0097 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.0753 grad_cross_arm_comm=0.0098 grad_left_action_in=0.1514 grad_left_expert=2.5886 grad_right_action_in=0.0879 grad_right_expert=1.9729 grad_shared_backbone=6.8576 (22938:train_pytorch.py:882)
133
+
134
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
135
+ 20:12:11.020 [I] step=8 loss=2.0034 smoothed_loss=2.8901 lr=7.97e-07 grad_norm=15.7969 step_time=0.4407s data_time=0.1564s it/s=1.677 eta_to_20=7.2s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0027 cross_arm_attention_mass_layer_10=0.0129 cross_arm_attention_mass_layer_11=0.0269 cross_arm_attention_mass_layer_12=0.0032 cross_arm_attention_mass_layer_13=0.0177 cross_arm_attention_mass_layer_14=0.0074 cross_arm_attention_mass_layer_15=0.0309 cross_arm_attention_mass_layer_16=0.0446 cross_arm_attention_mass_layer_17=0.0503 cross_arm_attention_mass_layer_2=0.0196 cross_arm_attention_mass_layer_3=0.0046 cross_arm_attention_mass_layer_4=0.0227 cross_arm_attention_mass_layer_5=0.0319 cross_arm_attention_mass_layer_6=0.0114 cross_arm_attention_mass_layer_7=0.0298 cross_arm_attention_mass_layer_8=0.0194 cross_arm_attention_mass_layer_9=0.0117 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=6.6005 grad_cross_arm_comm=0.1531 grad_left_action_in=0.1726 grad_left_expert=4.6426 grad_right_action_in=0.4530 grad_right_expert=3.8705 grad_shared_backbone=12.4324 (22938:train_pytorch.py:882)
136
+
137
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
138
+ 20:12:11.571 [I] step=9 loss=0.4132 smoothed_loss=2.6424 lr=8.96e-07 grad_norm=3.3497 step_time=0.4161s data_time=0.1347s it/s=1.820 eta_to_20=6.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0008 cross_arm_attention_mass_layer_10=0.0014 cross_arm_attention_mass_layer_11=0.0006 cross_arm_attention_mass_layer_12=0.0028 cross_arm_attention_mass_layer_13=0.0018 cross_arm_attention_mass_layer_14=0.0059 cross_arm_attention_mass_layer_15=0.0078 cross_arm_attention_mass_layer_16=0.0337 cross_arm_attention_mass_layer_17=0.0442 cross_arm_attention_mass_layer_2=0.0015 cross_arm_attention_mass_layer_3=0.0012 cross_arm_attention_mass_layer_4=0.0019 cross_arm_attention_mass_layer_5=0.0036 cross_arm_attention_mass_layer_6=0.0013 cross_arm_attention_mass_layer_7=0.0022 cross_arm_attention_mass_layer_8=0.0006 cross_arm_attention_mass_layer_9=0.0052 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=1.7915 grad_cross_arm_comm=0.0012 grad_left_action_in=0.0692 grad_left_expert=1.0033 grad_right_action_in=0.0554 grad_right_expert=0.7293 grad_shared_backbone=2.5249 (22938:train_pytorch.py:882)
139
+
140
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
141
+ 20:12:12.422 [I] step=10 loss=0.6162 smoothed_loss=2.4397 lr=9.96e-07 grad_norm=5.5674 step_time=0.6599s data_time=0.1905s it/s=1.178 eta_to_20=8.5s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0014 cross_arm_attention_mass_layer_10=0.0024 cross_arm_attention_mass_layer_11=0.0047 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0018 cross_arm_attention_mass_layer_14=0.0062 cross_arm_attention_mass_layer_15=0.0094 cross_arm_attention_mass_layer_16=0.0283 cross_arm_attention_mass_layer_17=0.0357 cross_arm_attention_mass_layer_2=0.0074 cross_arm_attention_mass_layer_3=0.0016 cross_arm_attention_mass_layer_4=0.0081 cross_arm_attention_mass_layer_5=0.0156 cross_arm_attention_mass_layer_6=0.0028 cross_arm_attention_mass_layer_7=0.0050 cross_arm_attention_mass_layer_8=0.0040 cross_arm_attention_mass_layer_9=0.0045 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=-0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.2079 grad_cross_arm_comm=0.0071 grad_left_action_in=0.0841 grad_left_expert=1.2018 grad_right_action_in=0.0868 grad_right_expert=1.2814 grad_shared_backbone=4.7763 (22938:train_pytorch.py:882)
142
+
143
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
144
+ 20:12:12.957 [I] step=11 loss=0.9030 smoothed_loss=2.2861 lr=1.10e-06 grad_norm=7.2282 step_time=0.4104s data_time=0.1251s it/s=1.872 eta_to_20=4.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0064 cross_arm_attention_mass_layer_11=0.0098 cross_arm_attention_mass_layer_12=0.0013 cross_arm_attention_mass_layer_13=0.0031 cross_arm_attention_mass_layer_14=0.0072 cross_arm_attention_mass_layer_15=0.0208 cross_arm_attention_mass_layer_16=0.0355 cross_arm_attention_mass_layer_17=0.0421 cross_arm_attention_mass_layer_2=0.0136 cross_arm_attention_mass_layer_3=0.0023 cross_arm_attention_mass_layer_4=0.0152 cross_arm_attention_mass_layer_5=0.0219 cross_arm_attention_mass_layer_6=0.0054 cross_arm_attention_mass_layer_7=0.0144 cross_arm_attention_mass_layer_8=0.0131 cross_arm_attention_mass_layer_9=0.0082 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=-0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=3.3357 grad_cross_arm_comm=0.0099 grad_left_action_in=0.1355 grad_left_expert=2.0379 grad_right_action_in=0.0836 grad_right_expert=1.1722 grad_shared_backbone=5.8293 (22938:train_pytorch.py:882)
145
+
146
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
147
+ 20:12:13.628 [I] step=12 loss=0.7531 smoothed_loss=2.1328 lr=1.20e-06 grad_norm=6.0473 step_time=0.4968s data_time=0.1739s it/s=1.493 eta_to_20=5.4s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0012 cross_arm_attention_mass_layer_10=0.0078 cross_arm_attention_mass_layer_11=0.0121 cross_arm_attention_mass_layer_12=0.0032 cross_arm_attention_mass_layer_13=0.0032 cross_arm_attention_mass_layer_14=0.0048 cross_arm_attention_mass_layer_15=0.0136 cross_arm_attention_mass_layer_16=0.0331 cross_arm_attention_mass_layer_17=0.0404 cross_arm_attention_mass_layer_2=0.0127 cross_arm_attention_mass_layer_3=0.0020 cross_arm_attention_mass_layer_4=0.0138 cross_arm_attention_mass_layer_5=0.0221 cross_arm_attention_mass_layer_6=0.0055 cross_arm_attention_mass_layer_7=0.0174 cross_arm_attention_mass_layer_8=0.0100 cross_arm_attention_mass_layer_9=0.0094 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.8673 grad_cross_arm_comm=0.0090 grad_left_action_in=0.1128 grad_left_expert=1.8561 grad_right_action_in=0.0739 grad_right_expert=1.0243 grad_shared_backbone=4.8443 (22938:train_pytorch.py:882)
148
+
149
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
150
+ 20:12:14.427 [I] step=13 loss=3.7746 smoothed_loss=2.2970 lr=1.29e-06 grad_norm=206.8044 step_time=0.5601s data_time=0.2394s it/s=1.252 eta_to_20=5.6s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0003 cross_arm_attention_mass_layer_1=0.0128 cross_arm_attention_mass_layer_10=0.0240 cross_arm_attention_mass_layer_11=0.0241 cross_arm_attention_mass_layer_12=0.0213 cross_arm_attention_mass_layer_13=0.0213 cross_arm_attention_mass_layer_14=0.0164 cross_arm_attention_mass_layer_15=0.0265 cross_arm_attention_mass_layer_16=0.0367 cross_arm_attention_mass_layer_17=0.0289 cross_arm_attention_mass_layer_2=0.0282 cross_arm_attention_mass_layer_3=0.0184 cross_arm_attention_mass_layer_4=0.0365 cross_arm_attention_mass_layer_5=0.0441 cross_arm_attention_mass_layer_6=0.0238 cross_arm_attention_mass_layer_7=0.0371 cross_arm_attention_mass_layer_8=0.0137 cross_arm_attention_mass_layer_9=0.0293 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=15.4957 grad_cross_arm_comm=2.1022 grad_left_action_in=2.3745 grad_left_expert=37.1536 grad_right_action_in=5.2568 grad_right_expert=138.8291 grad_shared_backbone=127.7336 (22938:train_pytorch.py:882)
151
+
152
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
153
+ 20:12:15.255 [I] step=14 loss=1.2933 smoothed_loss=2.1966 lr=1.39e-06 grad_norm=7.9182 step_time=0.5738s data_time=0.2541s it/s=1.210 eta_to_20=5.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0079 cross_arm_attention_mass_layer_11=0.0120 cross_arm_attention_mass_layer_12=0.0016 cross_arm_attention_mass_layer_13=0.0036 cross_arm_attention_mass_layer_14=0.0047 cross_arm_attention_mass_layer_15=0.0131 cross_arm_attention_mass_layer_16=0.0244 cross_arm_attention_mass_layer_17=0.0419 cross_arm_attention_mass_layer_2=0.0129 cross_arm_attention_mass_layer_3=0.0022 cross_arm_attention_mass_layer_4=0.0152 cross_arm_attention_mass_layer_5=0.0233 cross_arm_attention_mass_layer_6=0.0067 cross_arm_attention_mass_layer_7=0.0161 cross_arm_attention_mass_layer_8=0.0092 cross_arm_attention_mass_layer_9=0.0097 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.2052 grad_cross_arm_comm=0.0107 grad_left_action_in=0.1570 grad_left_expert=2.3411 grad_right_action_in=0.1025 grad_right_expert=1.1691 grad_shared_backbone=6.0836 (22938:train_pytorch.py:882)
154
+
155
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
156
+ 20:12:16.034 [I] step=15 loss=3.1068 smoothed_loss=2.2876 lr=1.49e-06 grad_norm=24.4182 step_time=0.5474s data_time=0.2314s it/s=1.286 eta_to_20=3.9s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0001 cross_arm_attention_mass_layer_1=0.0033 cross_arm_attention_mass_layer_10=0.0154 cross_arm_attention_mass_layer_11=0.0284 cross_arm_attention_mass_layer_12=0.0046 cross_arm_attention_mass_layer_13=0.0187 cross_arm_attention_mass_layer_14=0.0121 cross_arm_attention_mass_layer_15=0.0370 cross_arm_attention_mass_layer_16=0.0460 cross_arm_attention_mass_layer_17=0.0516 cross_arm_attention_mass_layer_2=0.0206 cross_arm_attention_mass_layer_3=0.0064 cross_arm_attention_mass_layer_4=0.0239 cross_arm_attention_mass_layer_5=0.0299 cross_arm_attention_mass_layer_6=0.0143 cross_arm_attention_mass_layer_7=0.0349 cross_arm_attention_mass_layer_8=0.0213 cross_arm_attention_mass_layer_9=0.0171 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=9.3484 grad_cross_arm_comm=0.3843 grad_left_action_in=0.3015 grad_left_expert=7.0086 grad_right_action_in=0.6660 grad_right_expert=6.4185 grad_shared_backbone=18.8039 (22938:train_pytorch.py:882)
157
+
158
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
159
+ 20:12:16.810 [I] step=16 loss=0.8710 smoothed_loss=2.1460 lr=1.59e-06 grad_norm=7.5162 step_time=0.5638s data_time=0.2117s it/s=1.292 eta_to_20=3.1s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0016 cross_arm_attention_mass_layer_10=0.0051 cross_arm_attention_mass_layer_11=0.0114 cross_arm_attention_mass_layer_12=0.0017 cross_arm_attention_mass_layer_13=0.0062 cross_arm_attention_mass_layer_14=0.0073 cross_arm_attention_mass_layer_15=0.0221 cross_arm_attention_mass_layer_16=0.0370 cross_arm_attention_mass_layer_17=0.0436 cross_arm_attention_mass_layer_2=0.0138 cross_arm_attention_mass_layer_3=0.0022 cross_arm_attention_mass_layer_4=0.0152 cross_arm_attention_mass_layer_5=0.0195 cross_arm_attention_mass_layer_6=0.0056 cross_arm_attention_mass_layer_7=0.0154 cross_arm_attention_mass_layer_8=0.0132 cross_arm_attention_mass_layer_9=0.0103 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.7344 grad_cross_arm_comm=0.0228 grad_left_action_in=0.1118 grad_left_expert=2.2761 grad_right_action_in=0.1234 grad_right_expert=1.1808 grad_shared_backbone=6.4124 (22938:train_pytorch.py:882)
160
+
161
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
162
+ 20:12:17.396 [I] step=17 loss=1.7002 smoothed_loss=2.1014 lr=1.69e-06 grad_norm=14.0785 step_time=0.4252s data_time=0.1614s it/s=1.708 eta_to_20=1.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0021 cross_arm_attention_mass_layer_10=0.0103 cross_arm_attention_mass_layer_11=0.0197 cross_arm_attention_mass_layer_12=0.0029 cross_arm_attention_mass_layer_13=0.0103 cross_arm_attention_mass_layer_14=0.0067 cross_arm_attention_mass_layer_15=0.0163 cross_arm_attention_mass_layer_16=0.0436 cross_arm_attention_mass_layer_17=0.0446 cross_arm_attention_mass_layer_2=0.0162 cross_arm_attention_mass_layer_3=0.0030 cross_arm_attention_mass_layer_4=0.0192 cross_arm_attention_mass_layer_5=0.0268 cross_arm_attention_mass_layer_6=0.0092 cross_arm_attention_mass_layer_7=0.0242 cross_arm_attention_mass_layer_8=0.0146 cross_arm_attention_mass_layer_9=0.0095 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=4.2605 grad_cross_arm_comm=0.0625 grad_left_action_in=0.1989 grad_left_expert=4.9518 grad_right_action_in=0.2156 grad_right_expert=2.1764 grad_shared_backbone=12.0796 (22938:train_pytorch.py:882)
163
+
164
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
165
+ 20:12:18.392 [I] step=18 loss=0.4844 smoothed_loss=1.9397 lr=1.79e-06 grad_norm=3.3459 step_time=0.6297s data_time=0.3660s it/s=1.005 eta_to_20=2.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0008 cross_arm_attention_mass_layer_10=0.0016 cross_arm_attention_mass_layer_11=0.0014 cross_arm_attention_mass_layer_12=0.0034 cross_arm_attention_mass_layer_13=0.0007 cross_arm_attention_mass_layer_14=0.0054 cross_arm_attention_mass_layer_15=0.0063 cross_arm_attention_mass_layer_16=0.0319 cross_arm_attention_mass_layer_17=0.0418 cross_arm_attention_mass_layer_2=0.0027 cross_arm_attention_mass_layer_3=0.0013 cross_arm_attention_mass_layer_4=0.0035 cross_arm_attention_mass_layer_5=0.0058 cross_arm_attention_mass_layer_6=0.0015 cross_arm_attention_mass_layer_7=0.0028 cross_arm_attention_mass_layer_8=0.0019 cross_arm_attention_mass_layer_9=0.0049 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=-0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=1.9561 grad_cross_arm_comm=0.0017 grad_left_action_in=0.0746 grad_left_expert=1.1140 grad_right_action_in=0.0388 grad_right_expert=0.5290 grad_shared_backbone=2.3985 (22938:train_pytorch.py:882)
166
+
167
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
168
+ 20:12:19.239 [I] step=19 loss=0.7633 smoothed_loss=1.8220 lr=1.89e-06 grad_norm=7.1468 step_time=0.5757s data_time=0.2714s it/s=1.182 eta_to_20=0.8s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0015 cross_arm_attention_mass_layer_10=0.0069 cross_arm_attention_mass_layer_11=0.0093 cross_arm_attention_mass_layer_12=0.0016 cross_arm_attention_mass_layer_13=0.0034 cross_arm_attention_mass_layer_14=0.0046 cross_arm_attention_mass_layer_15=0.0166 cross_arm_attention_mass_layer_16=0.0297 cross_arm_attention_mass_layer_17=0.0418 cross_arm_attention_mass_layer_2=0.0130 cross_arm_attention_mass_layer_3=0.0026 cross_arm_attention_mass_layer_4=0.0156 cross_arm_attention_mass_layer_5=0.0208 cross_arm_attention_mass_layer_6=0.0062 cross_arm_attention_mass_layer_7=0.0164 cross_arm_attention_mass_layer_8=0.0124 cross_arm_attention_mass_layer_9=0.0115 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=-0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=3.7548 grad_cross_arm_comm=0.0125 grad_left_action_in=0.1160 grad_left_expert=2.3520 grad_right_action_in=0.0799 grad_right_expert=1.3128 grad_shared_backbone=5.3838 (22938:train_pytorch.py:882)
169
+
170
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
171
+ 20:12:19.905 [I] step=20 loss=0.5943 smoothed_loss=1.6993 lr=1.99e-06 grad_norm=6.2792 step_time=0.4954s data_time=0.1707s it/s=1.504 eta_to_20=0.0s max_cuda_memory=76.13GB cross_arm_attention_mass_layer_0=0.0000 cross_arm_attention_mass_layer_1=0.0012 cross_arm_attention_mass_layer_10=0.0050 cross_arm_attention_mass_layer_11=0.0026 cross_arm_attention_mass_layer_12=0.0021 cross_arm_attention_mass_layer_13=0.0009 cross_arm_attention_mass_layer_14=0.0060 cross_arm_attention_mass_layer_15=0.0119 cross_arm_attention_mass_layer_16=0.0308 cross_arm_attention_mass_layer_17=0.0412 cross_arm_attention_mass_layer_2=0.0054 cross_arm_attention_mass_layer_3=0.0020 cross_arm_attention_mass_layer_4=0.0084 cross_arm_attention_mass_layer_5=0.0116 cross_arm_attention_mass_layer_6=0.0029 cross_arm_attention_mass_layer_7=0.0066 cross_arm_attention_mass_layer_8=0.0051 cross_arm_attention_mass_layer_9=0.0094 cross_arm_comm_gate_layer_0=-0.0000 cross_arm_comm_gate_layer_1=-0.0000 cross_arm_comm_gate_layer_10=-0.0000 cross_arm_comm_gate_layer_11=-0.0000 cross_arm_comm_gate_layer_12=0.0000 cross_arm_comm_gate_layer_13=0.0000 cross_arm_comm_gate_layer_14=0.0000 cross_arm_comm_gate_layer_15=0.0000 cross_arm_comm_gate_layer_16=0.0000 cross_arm_comm_gate_layer_17=-0.0000 cross_arm_comm_gate_layer_2=-0.0000 cross_arm_comm_gate_layer_3=-0.0000 cross_arm_comm_gate_layer_4=-0.0000 cross_arm_comm_gate_layer_5=-0.0000 cross_arm_comm_gate_layer_6=-0.0000 cross_arm_comm_gate_layer_7=-0.0000 cross_arm_comm_gate_layer_8=-0.0000 cross_arm_comm_gate_layer_9=-0.0000 grad_action_out=2.7963 grad_cross_arm_comm=0.0044 grad_left_action_in=0.0861 grad_left_expert=1.9694 grad_right_action_in=0.0578 grad_right_expert=1.3971 grad_shared_backbone=5.0478 (22938:train_pytorch.py:882)
172
+ 20:19:41.020 [I] Saved checkpoint at step 20 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k/split_communicating_real_train20/20 (22938:train_pytorch.py:378)
173
+
openpi/run_logs/split_independent_real_smoke3_r2.log ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 19:45:11.253 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2 (20567:train_pytorch.py:533)
2
+ 19:45:11.254 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (20567:train_pytorch.py:552)
3
+ 19:45:11.330 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train (20567:config.py:234)
4
+ 19:45:11.331 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x79458ad85b50>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (20567:data_loader.py:284)
29
+ 19:45:16.791 [I] JAX version 0.5.3 available. (20567:config.py:125)
30
+ 19:45:40.542 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (20567:data_loader.py:148)
31
+ 19:45:40.654 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (20567:video_utils.py:36)
32
+ 19:46:47.372 [I] local_batch_size: 1 (20567:data_loader.py:365)
33
+ 19:50:09.799 [I] Enabled gradient checkpointing for PI0Pytorch model (20567:pi0_pytorch.py:138)
34
+ 19:50:09.802 [I] Enabled gradient checkpointing for memory optimization (20567:train_pytorch.py:624)
35
+ 19:50:09.803 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (20567:train_pytorch.py:493)
36
+ 19:50:09.804 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (20567:train_pytorch.py:653)
37
+ 19:50:13.559 [I] Weight loading missing key count: 0 (20567:train_pytorch.py:657)
38
+ 19:50:13.560 [I] Weight loading missing keys: set() (20567:train_pytorch.py:658)
39
+ 19:50:13.560 [I] Weight loading unexpected key count: 0 (20567:train_pytorch.py:659)
40
+ 19:50:13.560 [I] Weight loading unexpected keys: [] (20567:train_pytorch.py:660)
41
+ 19:50:13.560 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_independent_packed_from_single (20567:train_pytorch.py:661)
42
+ 19:50:13.565 [I] Running on: 963c158043aa | world_size=1 (20567:train_pytorch.py:701)
43
+ 19:50:13.565 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=3 (20567:train_pytorch.py:702)
44
+ 19:50:13.565 [I] Memory optimizations: gradient_checkpointing=True (20567:train_pytorch.py:705)
45
+ 19:50:13.566 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (20567:train_pytorch.py:706)
46
+ 19:50:13.566 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (20567:train_pytorch.py:707)
47
+ 19:50:13.567 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (20567:train_pytorch.py:710)
48
+ 19:50:13.567 [I] EMA is not supported for PyTorch training (20567:train_pytorch.py:713)
49
+ 19:50:13.567 [I] Training precision: float32 (20567:train_pytorch.py:714)
50
+ 19:50:13.576 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k (20567:train_pytorch.py:308)
51
+ 19:50:13.576 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (20567:train_pytorch.py:309)
52
+ 19:50:13.577 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (20567:train_pytorch.py:310)
53
+ 19:50:13.577 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (20567:train_pytorch.py:311)
54
+ 19:50:13.578 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (20567:train_pytorch.py:312)
55
+ 19:50:13.578 [I] Model type: split_independent (20567:train_pytorch.py:313)
56
+ 19:50:13.578 [I] Packed transforms active: True (20567:train_pytorch.py:314)
57
+ 19:50:13.579 [I] World size: 1 (20567:train_pytorch.py:315)
58
+ 19:50:13.579 [I] Batch size: local=1, global=1 (20567:train_pytorch.py:316)
59
+ 19:50:13.580 [I] num_workers: 0 (20567:train_pytorch.py:317)
60
+ 19:50:13.580 [I] Precision: float32 (20567:train_pytorch.py:318)
61
+ 19:50:13.580 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (20567:train_pytorch.py:319)
62
+ 19:50:13.581 [I] Save/log intervals: save_interval=3, log_interval=1 (20567:train_pytorch.py:326)
63
+ 19:50:13.581 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (20567:train_pytorch.py:327)
64
+ 19:50:13.581 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (20567:train_pytorch.py:328)
65
+ 19:50:13.582 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (20567:train_pytorch.py:329)
66
+ 19:50:13.582 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (20567:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 19:50:15.125 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (20567:train_pytorch.py:831)
70
+ 19:50:15.126 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (20567:train_pytorch.py:835)
71
+ 19:50:15.126 [I] debug_step=1 prompt_token_lengths=[75] (20567:train_pytorch.py:838)
72
+ 19:50:15.127 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (20567:train_pytorch.py:839)
73
+ 19:50:15.127 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (20567:train_pytorch.py:842)
74
+ 19:50:15.128 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (20567:train_pytorch.py:845)
75
+ 19:50:15.143 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (20567:train_pytorch.py:849)
76
+ 19:50:15.143 [I] debug_step=1 lr=9.96e-08 grad_norm=31.4779 data_time=0.2101s step_time=1.2943s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (20567:train_pytorch.py:854)
77
+ 19:50:15.144 [I] debug_step=1 grad_shared_backbone=25.5606 grad_left_action_in=0.2318 grad_right_action_in=0.9885 grad_left_expert=5.5978 grad_right_expert=12.3518 grad_action_out=9.6154 (20567:train_pytorch.py:862)
78
+ 19:50:15.144 [I] step=1 loss=2.6238 smoothed_loss=2.6238 lr=9.96e-08 grad_norm=31.4779 step_time=1.2943s data_time=0.2101s it/s=0.633 eta_to_3=3.2s max_cuda_memory=76.13GB grad_action_out=9.6154 grad_left_action_in=0.2318 grad_left_expert=5.5978 grad_right_action_in=0.9885 grad_right_expert=12.3518 grad_shared_backbone=25.5606 (20567:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 19:50:16.012 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (20567:train_pytorch.py:831)
82
+ 19:50:16.013 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (20567:train_pytorch.py:835)
83
+ 19:50:16.013 [I] debug_step=2 prompt_token_lengths=[76] (20567:train_pytorch.py:838)
84
+ 19:50:16.014 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (20567:train_pytorch.py:839)
85
+ 19:50:16.015 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (20567:train_pytorch.py:842)
86
+ 19:50:16.016 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (20567:train_pytorch.py:845)
87
+ 19:50:16.016 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (20567:train_pytorch.py:849)
88
+ 19:50:16.017 [I] debug_step=2 lr=1.99e-07 grad_norm=12.2770 data_time=0.2123s step_time=0.6695s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (20567:train_pytorch.py:854)
89
+ 19:50:16.017 [I] debug_step=2 grad_shared_backbone=10.3527 grad_left_action_in=0.1586 grad_right_action_in=0.1584 grad_left_expert=2.8415 grad_right_expert=4.0156 grad_action_out=4.1478 (20567:train_pytorch.py:862)
90
+ 19:50:16.018 [I] step=2 loss=1.1717 smoothed_loss=2.4786 lr=1.99e-07 grad_norm=12.2770 step_time=0.6695s data_time=0.2123s it/s=1.146 eta_to_3=0.9s max_cuda_memory=76.13GB grad_action_out=4.1478 grad_left_action_in=0.1586 grad_left_expert=2.8415 grad_right_action_in=0.1584 grad_right_expert=4.0156 grad_shared_backbone=10.3527 (20567:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 19:50:16.906 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (20567:train_pytorch.py:831)
94
+ 19:50:16.907 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (20567:train_pytorch.py:835)
95
+ 19:50:16.908 [I] debug_step=3 prompt_token_lengths=[75] (20567:train_pytorch.py:838)
96
+ 19:50:16.908 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (20567:train_pytorch.py:839)
97
+ 19:50:16.908 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (20567:train_pytorch.py:842)
98
+ 19:50:16.909 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (20567:train_pytorch.py:845)
99
+ 19:50:16.910 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (20567:train_pytorch.py:849)
100
+ 19:50:16.910 [I] debug_step=3 lr=2.99e-07 grad_norm=15.1079 data_time=0.2612s step_time=0.6330s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (20567:train_pytorch.py:854)
101
+ 19:50:16.911 [I] debug_step=3 grad_shared_backbone=8.6850 grad_left_action_in=0.2570 grad_right_action_in=0.3869 grad_left_expert=4.4422 grad_right_expert=10.5777 grad_action_out=3.5502 (20567:train_pytorch.py:862)
102
+ 19:50:16.911 [I] step=3 loss=0.9128 smoothed_loss=2.3220 lr=2.99e-07 grad_norm=15.1079 step_time=0.6330s data_time=0.2612s it/s=1.120 eta_to_3=0.0s max_cuda_memory=76.13GB grad_action_out=3.5502 grad_left_action_in=0.2570 grad_left_expert=4.4422 grad_right_action_in=0.3869 grad_right_expert=10.5777 grad_shared_backbone=8.6850 (20567:train_pytorch.py:882)
103
+ 19:53:54.052 [I] Saved checkpoint at step 3 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_smoke3_r2/3 (20567:train_pytorch.py:378)
104
+
openpi/run_logs/split_independent_real_train20.log ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20:03:03.080 [I] Created experiment checkpoint directory: /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20 (22934:train_pytorch.py:533)
2
+ 20:03:03.082 [I] Using batch size per GPU: 1 (total batch size across 1 GPUs: 1) (22934:train_pytorch.py:552)
3
+ 20:03:03.183 [I] Loaded norm stats from /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train (22934:config.py:234)
4
+ 20:03:03.185 [I] data_config: DataConfig(repo_id='lsnu/twin_dual_push_128_train', asset_id='lsnu/twin_dual_push_128_train', norm_stats={'state': NormStats(mean=array([ 0.10604009, 0.20956482, 0.09184283, -1.98801565, -0.04930164,
5
+ 2.20065784, 1.07595289, 0.52742052, 0.01585805, 0.08288047,
6
+ -0.06887393, -1.906394 , 0.04810138, 2.01086807, -0.92902797,
7
+ 0.8440811 ]), std=array([0.09207697, 0.31317395, 0.08127229, 0.53812712, 0.06093267,
8
+ 0.51205784, 0.22527155, 0.49924755, 0.20230208, 0.31408131,
9
+ 0.21665592, 0.5264315 , 0.20170984, 0.4745712 , 1.17861438,
10
+ 0.36277843]), q01=array([-5.00321221e-06, -3.88026012e-01, -2.23782954e-05, -2.98962682e+00,
11
+ -2.38592355e-01, 1.22146201e+00, 7.85383821e-01, 0.00000000e+00,
12
+ -6.15615927e-01, -4.14941930e-01, -9.43696350e-01, -2.88397729e+00,
13
+ -9.05083556e-01, 1.22148895e+00, -2.79564499e+00, 0.00000000e+00]), q99=array([ 0.31251293, 0.86546916, 0.35174239, -0.87634897, 0.05212194,
14
+ 2.97208117, 1.64465171, 0.9998 , 0.7670313 , 0.96073459,
15
+ 0.68710467, -0.87498123, 0.35838486, 2.9773227 , 0.78477909,
16
+ 0.9998 ])), 'actions': NormStats(mean=array([ 0.03630241, 0.09624442, 0.01367408, -0.2224988 , -0.02762174,
17
+ 0.27498844, 0.0892187 , 0.45650524, -0.00378086, 0.09113847,
18
+ -0.00376227, -0.22537093, 0.00826233, 0.26799494, -0.57452869,
19
+ 0.7731654 ]), std=array([0.04995174, 0.29268014, 0.06852161, 0.3647725 , 0.07012808,
20
+ 0.27129024, 0.11329207, 0.4981046 , 0.0917461 , 0.22704004,
21
+ 0.1069391 , 0.2572591 , 0.11801817, 0.1235588 , 0.35835782,
22
+ 0.41878474]), q01=array([-5.86206436e-04, -3.88117499e-01, -2.55800724e-01, -8.34769463e-01,
23
+ -3.51454727e-01, -1.54787922e-03, -5.81741333e-04, 0.00000000e+00,
24
+ -2.64436970e-01, -3.51582764e-01, -3.69693995e-01, -7.30919549e-01,
25
+ -3.35441585e-01, -6.62303925e-04, -9.34731126e-01, 0.00000000e+00]), q99=array([0.20790743, 0.81198567, 0.19612836, 0.33958174, 0.05568643,
26
+ 0.75265345, 0.425256 , 0.9998 , 0.2558236 , 0.58901345,
27
+ 0.35822071, 0.18567593, 0.44035054, 0.49966629, 0.12655233,
28
+ 0.9998 ]))}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'front_image', 'cam_left_wrist': 'wrist_left_image', 'cam_right_wrist': 'wrist_right_image'}, 'state': 'state', 'actions': 'action', 'prompt': 'task'})], outputs=()), data_transforms=Group(inputs=[AlohaInputs(adapt_to_pi=False)], outputs=[]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x721cdf0dd610>, discrete_state_input=True), PackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))], outputs=[UnpackPerArmBlocks(real_arm_dims=(8, 8), block_dims=(16, 16))]), use_quantile_norm=True, action_sequence_keys=('action',), prompt_from_task=False, rlds_data_dir=None, action_space=None, datasets=()) (22934:data_loader.py:393)
29
+ 20:03:13.494 [I] JAX version 0.5.3 available. (22934:config.py:125)
30
+ 20:04:17.801 [I] Using existing local LeRobot dataset mirror for lsnu/twin_dual_push_128_train: /workspace/lerobot/lsnu/twin_dual_push_128_train (22934:data_loader.py:148)
31
+ 20:04:17.904 [W] 'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder (22934:video_utils.py:36)
32
+ 20:09:04.645 [I] local_batch_size: 1 (22934:data_loader.py:474)
33
+ 20:11:56.606 [I] Enabled gradient checkpointing for PI0Pytorch model (22934:pi0_pytorch.py:138)
34
+ 20:11:56.607 [I] Enabled gradient checkpointing for memory optimization (22934:train_pytorch.py:624)
35
+ 20:11:56.608 [I] Step 0 (after_model_creation): GPU memory - allocated: 17.23GB, reserved: 17.23GB, free: 0.00GB, peak_allocated: 17.23GB, peak_reserved: 17.23GB (22934:train_pytorch.py:493)
36
+ 20:11:56.609 [I] Loading weights from: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (22934:train_pytorch.py:653)
37
+ 20:12:01.374 [I] Weight loading missing key count: 0 (22934:train_pytorch.py:657)
38
+ 20:12:01.375 [I] Weight loading missing keys: set() (22934:train_pytorch.py:658)
39
+ 20:12:01.375 [I] Weight loading unexpected key count: 0 (22934:train_pytorch.py:659)
40
+ 20:12:01.375 [I] Weight loading unexpected keys: [] (22934:train_pytorch.py:660)
41
+ 20:12:01.376 [I] Loaded PyTorch weights from /workspace/checkpoints/pi05_base_split_independent_packed_from_single (22934:train_pytorch.py:661)
42
+ 20:12:01.380 [I] Running on: 963c158043aa | world_size=1 (22934:train_pytorch.py:701)
43
+ 20:12:01.381 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=20 (22934:train_pytorch.py:702)
44
+ 20:12:01.381 [I] Memory optimizations: gradient_checkpointing=True (22934:train_pytorch.py:705)
45
+ 20:12:01.381 [I] DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True (22934:train_pytorch.py:706)
46
+ 20:12:01.382 [I] LR schedule: warmup=250, peak_lr=2.50e-05, decay_steps=5000, end_lr=2.50e-06 (22934:train_pytorch.py:707)
47
+ 20:12:01.382 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (22934:train_pytorch.py:710)
48
+ 20:12:01.382 [I] EMA is not supported for PyTorch training (22934:train_pytorch.py:713)
49
+ 20:12:01.383 [I] Training precision: float32 (22934:train_pytorch.py:714)
50
+ 20:12:01.410 [I] Resolved config name: pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k (22934:train_pytorch.py:308)
51
+ 20:12:01.410 [I] Dataset repo_id: lsnu/twin_dual_push_128_train (22934:train_pytorch.py:309)
52
+ 20:12:01.411 [I] Norm-stats file path: /workspace/pi05tests/openpi/assets/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/lsnu/twin_dual_push_128_train/norm_stats.json (22934:train_pytorch.py:310)
53
+ 20:12:01.411 [I] Norm-stats summary: {'keys': ['actions', 'state'], 'state_mean_len': 16, 'state_std_len': 16, 'actions_mean_len': 16, 'actions_std_len': 16} (22934:train_pytorch.py:311)
54
+ 20:12:01.412 [I] Checkpoint source path: /workspace/checkpoints/pi05_base_split_independent_packed_from_single (22934:train_pytorch.py:312)
55
+ 20:12:01.412 [I] Model type: split_independent (22934:train_pytorch.py:313)
56
+ 20:12:01.412 [I] Packed transforms active: True (22934:train_pytorch.py:314)
57
+ 20:12:01.413 [I] World size: 1 (22934:train_pytorch.py:315)
58
+ 20:12:01.413 [I] Batch size: local=1, global=1 (22934:train_pytorch.py:316)
59
+ 20:12:01.414 [I] num_workers: 0 (22934:train_pytorch.py:317)
60
+ 20:12:01.414 [I] Precision: float32 (22934:train_pytorch.py:318)
61
+ 20:12:01.414 [I] LR schedule summary: warmup_steps=250, peak_lr=2.50e-05, decay_steps=5000, decay_lr=2.50e-06 (22934:train_pytorch.py:319)
62
+ 20:12:01.415 [I] Save/log intervals: save_interval=20, log_interval=1 (22934:train_pytorch.py:326)
63
+ 20:12:01.415 [I] Action-loss mask: (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (22934:train_pytorch.py:327)
64
+ 20:12:01.415 [I] Active mask dims: [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] (22934:train_pytorch.py:328)
65
+ 20:12:01.416 [I] Masked dims: [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] (22934:train_pytorch.py:329)
66
+ 20:12:01.416 [I] Gradient bucket diagnostics: left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm (22934:train_pytorch.py:722)
67
+
68
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
69
+ 20:12:03.701 [I] debug_step=1 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
70
+ 20:12:03.702 [I] debug_step=1 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
71
+ 20:12:03.702 [I] debug_step=1 prompt_token_lengths=[75] (22934:train_pytorch.py:838)
72
+ 20:12:03.702 [I] debug_step=1 state_stats min=-1.0000 max=1.0004 mean=0.0112 std=0.3876 (22934:train_pytorch.py:839)
73
+ 20:12:03.702 [I] debug_step=1 action_stats min=-1.0016 max=1.0004 mean=-0.0454 std=0.4716 (22934:train_pytorch.py:842)
74
+ 20:12:03.703 [I] debug_step=1 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
75
+ 20:12:03.729 [I] debug_step=1 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
76
+ 20:12:03.730 [I] debug_step=1 lr=9.96e-08 grad_norm=31.4779 data_time=0.5472s step_time=1.7166s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.25GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.25GB (22934:train_pytorch.py:854)
77
+ 20:12:03.730 [I] debug_step=1 grad_shared_backbone=25.5606 grad_left_action_in=0.2318 grad_right_action_in=0.9885 grad_left_expert=5.5978 grad_right_expert=12.3518 grad_action_out=9.6154 (22934:train_pytorch.py:862)
78
+ 20:12:03.731 [I] step=1 loss=2.6238 smoothed_loss=2.6238 lr=9.96e-08 grad_norm=31.4779 step_time=1.7166s data_time=0.5472s it/s=0.425 eta_to_20=44.7s max_cuda_memory=76.13GB grad_action_out=9.6154 grad_left_action_in=0.2318 grad_left_expert=5.5978 grad_right_action_in=0.9885 grad_right_expert=12.3518 grad_shared_backbone=25.5606 (22934:train_pytorch.py:882)
79
+
80
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
81
+ 20:12:05.012 [I] debug_step=2 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
82
+ 20:12:05.013 [I] debug_step=2 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
83
+ 20:12:05.014 [I] debug_step=2 prompt_token_lengths=[76] (22934:train_pytorch.py:838)
84
+ 20:12:05.014 [I] debug_step=2 state_stats min=-0.9415 max=1.0004 mean=-0.0010 std=0.4295 (22934:train_pytorch.py:839)
85
+ 20:12:05.015 [I] debug_step=2 action_stats min=-1.0000 max=1.1367 mean=0.0272 std=0.4576 (22934:train_pytorch.py:842)
86
+ 20:12:05.016 [I] debug_step=2 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
87
+ 20:12:05.016 [I] debug_step=2 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
88
+ 20:12:05.017 [I] debug_step=2 lr=1.99e-07 grad_norm=12.2749 data_time=0.5381s step_time=0.7692s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
89
+ 20:12:05.017 [I] debug_step=2 grad_shared_backbone=10.3515 grad_left_action_in=0.1585 grad_right_action_in=0.1584 grad_left_expert=2.8412 grad_right_expert=4.0131 grad_action_out=4.1470 (22934:train_pytorch.py:862)
90
+ 20:12:05.018 [I] step=2 loss=1.1715 smoothed_loss=2.4786 lr=1.99e-07 grad_norm=12.2749 step_time=0.7692s data_time=0.5381s it/s=0.777 eta_to_20=23.2s max_cuda_memory=76.13GB grad_action_out=4.1470 grad_left_action_in=0.1585 grad_left_expert=2.8412 grad_right_action_in=0.1584 grad_right_expert=4.0131 grad_shared_backbone=10.3515 (22934:train_pytorch.py:882)
91
+
92
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
93
+ 20:12:05.585 [I] debug_step=3 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
94
+ 20:12:05.586 [I] debug_step=3 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
95
+ 20:12:05.586 [I] debug_step=3 prompt_token_lengths=[75] (22934:train_pytorch.py:838)
96
+ 20:12:05.586 [I] debug_step=3 state_stats min=-1.0000 max=1.0004 mean=0.0558 std=0.4300 (22934:train_pytorch.py:839)
97
+ 20:12:05.587 [I] debug_step=3 action_stats min=-1.0033 max=1.0004 mean=-0.0658 std=0.4704 (22934:train_pytorch.py:842)
98
+ 20:12:05.588 [I] debug_step=3 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
99
+ 20:12:05.588 [I] debug_step=3 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
100
+ 20:12:05.589 [I] debug_step=3 lr=2.99e-07 grad_norm=15.1205 data_time=0.1545s step_time=0.4182s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
101
+ 20:12:05.589 [I] debug_step=3 grad_shared_backbone=8.6946 grad_left_action_in=0.2568 grad_right_action_in=0.3873 grad_left_expert=4.4408 grad_right_expert=10.5877 grad_action_out=3.5507 (22934:train_pytorch.py:862)
102
+ 20:12:05.590 [I] step=3 loss=0.9126 smoothed_loss=2.3220 lr=2.99e-07 grad_norm=15.1205 step_time=0.4182s data_time=0.1545s it/s=1.751 eta_to_20=9.7s max_cuda_memory=76.13GB grad_action_out=3.5507 grad_left_action_in=0.2568 grad_left_expert=4.4408 grad_right_action_in=0.3873 grad_right_expert=10.5877 grad_shared_backbone=8.6946 (22934:train_pytorch.py:882)
103
+
104
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
105
+ 20:12:06.414 [I] debug_step=4 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
106
+ 20:12:06.415 [I] debug_step=4 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
107
+ 20:12:06.416 [I] debug_step=4 prompt_token_lengths=[78] (22934:train_pytorch.py:838)
108
+ 20:12:06.416 [I] debug_step=4 state_stats min=-0.7017 max=1.0004 mean=0.0553 std=0.3507 (22934:train_pytorch.py:839)
109
+ 20:12:06.417 [I] debug_step=4 action_stats min=-1.0014 max=1.0004 mean=-0.0683 std=0.4561 (22934:train_pytorch.py:842)
110
+ 20:12:06.417 [I] debug_step=4 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
111
+ 20:12:06.418 [I] debug_step=4 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
112
+ 20:12:06.419 [I] debug_step=4 lr=3.98e-07 grad_norm=9.2670 data_time=0.2679s step_time=0.5621s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
113
+ 20:12:06.419 [I] debug_step=4 grad_shared_backbone=7.8629 grad_left_action_in=0.1341 grad_right_action_in=0.0877 grad_left_expert=3.2369 grad_right_expert=1.0658 grad_action_out=3.4116 (22934:train_pytorch.py:862)
114
+ 20:12:06.420 [I] step=4 loss=1.1718 smoothed_loss=2.2070 lr=3.98e-07 grad_norm=9.2670 step_time=0.5621s data_time=0.2679s it/s=1.206 eta_to_20=13.3s max_cuda_memory=76.13GB grad_action_out=3.4116 grad_left_action_in=0.1341 grad_left_expert=3.2369 grad_right_action_in=0.0877 grad_right_expert=1.0658 grad_shared_backbone=7.8629 (22934:train_pytorch.py:882)
115
+
116
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
117
+ 20:12:07.218 [I] debug_step=5 observation.state shape=(1, 32) dtype=torch.float64 actions shape=(1, 16, 32) dtype=torch.float32 (22934:train_pytorch.py:831)
118
+ 20:12:07.219 [I] debug_step=5 image_keys=['base_0_rgb', 'left_wrist_0_rgb', 'right_wrist_0_rgb'] image_shapes={'base_0_rgb': (1, 3, 224, 224), 'left_wrist_0_rgb': (1, 3, 224, 224), 'right_wrist_0_rgb': (1, 3, 224, 224)} (22934:train_pytorch.py:835)
119
+ 20:12:07.219 [I] debug_step=5 prompt_token_lengths=[73] (22934:train_pytorch.py:838)
120
+ 20:12:07.219 [I] debug_step=5 state_stats min=-0.9599 max=1.0004 mean=0.0170 std=0.5364 (22934:train_pytorch.py:839)
121
+ 20:12:07.220 [I] debug_step=5 action_stats min=-1.0392 max=1.0004 mean=-0.0159 std=0.4488 (22934:train_pytorch.py:842)
122
+ 20:12:07.220 [I] debug_step=5 state_nonzero_counts_8d_blocks=[8, 0, 8, 0] action_nonzero_counts_8d_blocks=[128, 0, 128, 0] (22934:train_pytorch.py:845)
123
+ 20:12:07.221 [I] debug_step=5 masked_dims=[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] active_dims=[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] masked_zero_counts state=16 actions=256 (22934:train_pytorch.py:849)
124
+ 20:12:07.221 [I] debug_step=5 lr=4.98e-07 grad_norm=18.8576 data_time=0.2330s step_time=0.5704s gpu_mem_allocated=46.71GB gpu_mem_reserved=76.34GB gpu_mem_max_allocated=76.13GB gpu_mem_max_reserved=76.34GB (22934:train_pytorch.py:854)
125
+ 20:12:07.222 [I] debug_step=5 grad_shared_backbone=15.0420 grad_left_action_in=0.2664 grad_right_action_in=0.2257 grad_left_expert=7.9881 grad_right_expert=3.7966 grad_action_out=6.1884 (22934:train_pytorch.py:862)
126
+ 20:12:07.223 [I] step=5 loss=1.6473 smoothed_loss=2.1510 lr=4.98e-07 grad_norm=18.8576 step_time=0.5704s data_time=0.2330s it/s=1.246 eta_to_20=12.0s max_cuda_memory=76.13GB grad_action_out=6.1884 grad_left_action_in=0.2664 grad_left_expert=7.9881 grad_right_action_in=0.2257 grad_right_expert=3.7966 grad_shared_backbone=15.0420 (22934:train_pytorch.py:882)
127
+
128
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
129
+ 20:12:07.822 [I] step=6 loss=1.6098 smoothed_loss=2.0969 lr=5.98e-07 grad_norm=20.9772 step_time=0.4435s data_time=0.1600s it/s=1.671 eta_to_20=8.4s max_cuda_memory=76.13GB grad_action_out=6.0592 grad_left_action_in=0.2873 grad_left_expert=8.8574 grad_right_action_in=0.4264 grad_right_expert=6.3071 grad_shared_backbone=16.1173 (22934:train_pytorch.py:882)
130
+
131
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
132
+ 20:12:08.395 [I] step=7 loss=1.0401 smoothed_loss=1.9912 lr=6.97e-07 grad_norm=9.5173 step_time=0.4240s data_time=0.1495s it/s=1.747 eta_to_20=7.4s max_cuda_memory=76.13GB grad_action_out=4.1689 grad_left_action_in=0.1489 grad_left_expert=3.1996 grad_right_action_in=0.0904 grad_right_expert=2.4983 grad_shared_backbone=7.4213 (22934:train_pytorch.py:882)
133
+
134
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
135
+ 20:12:08.914 [I] step=8 loss=1.7539 smoothed_loss=1.9675 lr=7.97e-07 grad_norm=12.9701 step_time=0.3829s data_time=0.1362s it/s=1.931 eta_to_20=6.2s max_cuda_memory=76.13GB grad_action_out=5.3617 grad_left_action_in=0.1890 grad_left_expert=3.6536 grad_right_action_in=0.3790 grad_right_expert=2.7904 grad_shared_backbone=10.5667 (22934:train_pytorch.py:882)
136
+
137
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
138
+ 20:12:09.692 [I] step=9 loss=0.4114 smoothed_loss=1.8119 lr=8.96e-07 grad_norm=3.5873 step_time=0.5166s data_time=0.2609s it/s=1.288 eta_to_20=8.5s max_cuda_memory=76.13GB grad_action_out=1.8283 grad_left_action_in=0.0689 grad_left_expert=1.3656 grad_right_action_in=0.0549 grad_right_expert=0.7330 grad_shared_backbone=2.6507 (22934:train_pytorch.py:882)
139
+
140
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
141
+ 20:12:10.646 [I] step=10 loss=0.6228 smoothed_loss=1.6930 lr=9.96e-07 grad_norm=6.7396 step_time=0.7100s data_time=0.2450s it/s=1.049 eta_to_20=9.5s max_cuda_memory=76.13GB grad_action_out=2.2553 grad_left_action_in=0.0813 grad_left_expert=1.3495 grad_right_action_in=0.0919 grad_right_expert=2.0906 grad_shared_backbone=5.8179 (22934:train_pytorch.py:882)
142
+
143
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
144
+ 20:12:11.288 [I] step=11 loss=0.8688 smoothed_loss=1.6105 lr=1.10e-06 grad_norm=7.2182 step_time=0.4823s data_time=0.1593s it/s=1.561 eta_to_20=5.8s max_cuda_memory=76.13GB grad_action_out=3.3031 grad_left_action_in=0.1262 grad_left_expert=2.5456 grad_right_action_in=0.0809 grad_right_expert=0.9216 grad_shared_backbone=5.7177 (22934:train_pytorch.py:882)
145
+
146
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
147
+ 20:12:11.903 [I] step=12 loss=0.7319 smoothed_loss=1.5227 lr=1.20e-06 grad_norm=6.1848 step_time=0.4468s data_time=0.1681s it/s=1.629 eta_to_20=4.9s max_cuda_memory=76.13GB grad_action_out=2.7925 grad_left_action_in=0.1038 grad_left_expert=2.4508 grad_right_action_in=0.0680 grad_right_expert=0.8716 grad_shared_backbone=4.8333 (22934:train_pytorch.py:882)
148
+
149
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
150
+ 20:12:12.684 [I] step=13 loss=0.8788 smoothed_loss=1.4583 lr=1.29e-06 grad_norm=20.2227 step_time=0.5649s data_time=0.2162s it/s=1.282 eta_to_20=5.5s max_cuda_memory=76.13GB grad_action_out=3.0176 grad_left_action_in=0.1300 grad_left_expert=2.8276 grad_right_action_in=0.4691 grad_right_expert=12.9156 grad_shared_backbone=11.2157 (22934:train_pytorch.py:882)
151
+
152
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
153
+ 20:12:13.370 [I] step=14 loss=1.2741 smoothed_loss=1.4399 lr=1.39e-06 grad_norm=7.8620 step_time=0.5100s data_time=0.1755s it/s=1.461 eta_to_20=4.1s max_cuda_memory=76.13GB grad_action_out=4.2194 grad_left_action_in=0.1433 grad_left_expert=2.8949 grad_right_action_in=0.0958 grad_right_expert=1.0096 grad_shared_backbone=5.8070 (22934:train_pytorch.py:882)
154
+
155
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
156
+ 20:12:14.027 [I] step=15 loss=2.3729 smoothed_loss=1.5332 lr=1.49e-06 grad_norm=19.3589 step_time=0.4678s data_time=0.1899s it/s=1.523 eta_to_20=3.3s max_cuda_memory=76.13GB grad_action_out=7.2135 grad_left_action_in=0.2665 grad_left_expert=7.5354 grad_right_action_in=0.5496 grad_right_expert=4.5295 grad_shared_backbone=15.2257 (22934:train_pytorch.py:882)
157
+
158
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
159
+ 20:12:14.874 [I] step=16 loss=0.8147 smoothed_loss=1.4613 lr=1.59e-06 grad_norm=7.7365 step_time=0.5547s data_time=0.2919s it/s=1.183 eta_to_20=3.4s max_cuda_memory=76.13GB grad_action_out=2.7237 grad_left_action_in=0.1192 grad_left_expert=2.8822 grad_right_action_in=0.0900 grad_right_expert=0.8615 grad_shared_backbone=6.4500 (22934:train_pytorch.py:882)
160
+
161
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
162
+ 20:12:15.664 [I] step=17 loss=1.4318 smoothed_loss=1.4584 lr=1.69e-06 grad_norm=19.5452 step_time=0.5511s data_time=0.2382s it/s=1.268 eta_to_20=2.4s max_cuda_memory=76.13GB grad_action_out=3.9684 grad_left_action_in=0.3767 grad_left_expert=7.8636 grad_right_action_in=0.1317 grad_right_expert=1.6847 grad_shared_backbone=16.9059 (22934:train_pytorch.py:882)
163
+
164
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
165
+ 20:12:16.588 [I] step=18 loss=0.4858 smoothed_loss=1.3611 lr=1.79e-06 grad_norm=3.4382 step_time=0.6846s data_time=0.2403s it/s=1.082 eta_to_20=1.8s max_cuda_memory=76.13GB grad_action_out=1.9985 grad_left_action_in=0.0749 grad_left_expert=1.4156 grad_right_action_in=0.0390 grad_right_expert=0.5210 grad_shared_backbone=2.3369 (22934:train_pytorch.py:882)
166
+
167
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
168
+ 20:12:17.216 [I] step=19 loss=0.7492 smoothed_loss=1.2999 lr=1.89e-06 grad_norm=6.9377 step_time=0.4815s data_time=0.1459s it/s=1.596 eta_to_20=0.6s max_cuda_memory=76.13GB grad_action_out=3.7478 grad_left_action_in=0.1113 grad_left_expert=2.8716 grad_right_action_in=0.0729 grad_right_expert=1.0784 grad_shared_backbone=4.9024 (22934:train_pytorch.py:882)
169
+
170
+ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
171
+ 20:12:18.186 [I] step=20 loss=0.6038 smoothed_loss=1.2303 lr=1.99e-06 grad_norm=7.0090 step_time=0.7175s data_time=0.2525s it/s=1.032 eta_to_20=0.0s max_cuda_memory=76.13GB grad_action_out=2.8786 grad_left_action_in=0.0890 grad_left_expert=2.7778 grad_right_action_in=0.0549 grad_right_expert=1.4578 grad_shared_backbone=5.5395 (22934:train_pytorch.py:882)
172
+ 20:19:39.399 [I] Saved checkpoint at step 20 -> /workspace/pi05tests/openpi/checkpoints/pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k/split_independent_real_train20/20 (22934:train_pytorch.py:378)
173
+
openpi/scripts/check_parallel_warmstart_equivalence.py CHANGED
@@ -76,6 +76,13 @@ def main() -> None:
76
  )
77
  baseline_config = _config.get_config(args.baseline_config_name)
78
  parallel_config = _config.get_config(args.parallel_config_name)
 
 
 
 
 
 
 
79
 
80
  data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model)
81
  data_config = dataclasses.replace(data_config, repo_id=args.repo_id)
 
76
  )
77
  baseline_config = _config.get_config(args.baseline_config_name)
78
  parallel_config = _config.get_config(args.parallel_config_name)
79
+ parallel_model_cfg = build_model_config(parallel_config)
80
+ if parallel_model_cfg.use_split_action_expert:
81
+ raise ValueError(
82
+ "Exact end-to-end warm-start equivalence is not expected for split action experts. "
83
+ "Use init_parallel_pi05_from_single_pytorch.py for branch copy checks and "
84
+ "check_split_expert_invariants.py for branch-local invariants."
85
+ )
86
 
87
  data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model)
88
  data_config = dataclasses.replace(data_config, repo_id=args.repo_id)
openpi/scripts/check_split_expert_invariants.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import dataclasses
4
+
5
+ import safetensors.torch
6
+ import torch
7
+ import tyro
8
+
9
+ import openpi.models.pi0_config
10
+ import openpi.training.config as _config
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Args:
15
+ config_name: str
16
+ checkpoint_dir: str
17
+ tolerance: float = 1e-6
18
+ batch_size: int = 2
19
+ prefix_len: int = 12
20
+ seed: int = 123
21
+
22
+
23
+ def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config:
24
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
25
+ return openpi.models.pi0_config.Pi0Config(
26
+ dtype="float32",
27
+ action_dim=config.model.action_dim,
28
+ action_horizon=config.model.action_horizon,
29
+ max_token_len=config.model.max_token_len,
30
+ paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
31
+ action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
32
+ pi05=getattr(config.model, "pi05", False),
33
+ arm_action_dims=getattr(config.model, "arm_action_dims", None),
34
+ action_expert_mode=getattr(config.model, "action_expert_mode", None),
35
+ )
36
+
37
+ model_cfg = dataclasses.replace(config.model)
38
+ object.__setattr__(model_cfg, "dtype", "float32")
39
+ return model_cfg
40
+
41
+
42
+ def _random_prefix_context(model, batch_size: int, prefix_len: int, seed: int):
43
+ generator = torch.Generator(device="cpu")
44
+ generator.manual_seed(seed)
45
+ prefix_width = model.paligemma_with_expert.paligemma.config.text_config.hidden_size
46
+ prefix_embs = torch.randn(batch_size, prefix_len, prefix_width, generator=generator, dtype=torch.float32)
47
+ prefix_pad_masks = torch.ones(batch_size, prefix_len, dtype=torch.bool)
48
+ prefix_att_masks = torch.zeros(batch_size, prefix_len, dtype=torch.bool)
49
+ return prefix_embs, prefix_pad_masks, prefix_att_masks
50
+
51
+
52
+ def _run_model(model, prefix_context, x_t, timestep):
53
+ prefix_embs, prefix_pad_masks, prefix_att_masks = prefix_context
54
+ state = torch.zeros(x_t.shape[0], model.config.action_dim, dtype=torch.float32)
55
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = model.embed_suffix(state, x_t, timestep)
56
+ suffix_outputs = model._run_joint_action_expert( # noqa: SLF001
57
+ prefix_embs,
58
+ prefix_pad_masks,
59
+ prefix_att_masks,
60
+ suffix_embs,
61
+ suffix_pad_masks,
62
+ suffix_att_masks,
63
+ adarms_cond,
64
+ )
65
+ suffix_outputs = [output[:, -model.config.action_horizon :].to(dtype=torch.float32) for output in suffix_outputs]
66
+ projected_actions = model._project_action_outputs(suffix_outputs) # noqa: SLF001
67
+ return suffix_outputs, projected_actions
68
+
69
+
70
+ def _run_identical_branch_inputs(model, prefix_context, timestep, seed: int):
71
+ generator = torch.Generator(device="cpu")
72
+ generator.manual_seed(seed)
73
+ width = model.action_expert_width
74
+ horizon = model.config.action_horizon
75
+ batch_size = prefix_context[0].shape[0]
76
+
77
+ shared_suffix = torch.randn(batch_size, horizon, width, generator=generator, dtype=torch.float32)
78
+ shared_cond = torch.randn(batch_size, width, generator=generator, dtype=torch.float32)
79
+ suffix_pad_masks = [torch.ones(batch_size, horizon, dtype=torch.bool) for _ in range(2)]
80
+ suffix_att_masks = [model._action_att_mask(batch_size, torch.device("cpu"), torch.float32) for _ in range(2)] # noqa: SLF001
81
+
82
+ suffix_outputs = model._run_joint_action_expert( # noqa: SLF001
83
+ prefix_context[0],
84
+ prefix_context[1],
85
+ prefix_context[2],
86
+ [shared_suffix.clone(), shared_suffix.clone()],
87
+ suffix_pad_masks,
88
+ suffix_att_masks,
89
+ [shared_cond.clone(), shared_cond.clone()],
90
+ )
91
+ return suffix_outputs
92
+
93
+
94
+ def main() -> None:
95
+ args = tyro.cli(Args)
96
+ config = _config.get_config(args.config_name)
97
+ model_cfg = _build_model_config(config)
98
+ if not model_cfg.use_split_action_expert:
99
+ raise ValueError(f"Config {args.config_name} is not a split-expert config.")
100
+
101
+ import openpi.models_pytorch.pi0_pytorch as pi0_pytorch
102
+
103
+ torch.manual_seed(args.seed)
104
+ model = pi0_pytorch.PI0Pytorch(model_cfg)
105
+ missing, unexpected = safetensors.torch.load_model(model, f"{args.checkpoint_dir}/model.safetensors", strict=False)
106
+ model.eval()
107
+
108
+ prefix_context = _random_prefix_context(model, args.batch_size, args.prefix_len, args.seed + 1)
109
+ x_t = torch.randn(args.batch_size, model.config.action_horizon, model.config.action_dim, dtype=torch.float32)
110
+ timestep = torch.full((args.batch_size,), 0.5, dtype=torch.float32)
111
+
112
+ identical_suffix_outputs = _run_identical_branch_inputs(model, prefix_context, timestep, args.seed + 2)
113
+ identical_branch_suffix_max_abs_diff = float(
114
+ (identical_suffix_outputs[0] - identical_suffix_outputs[1]).abs().max().item()
115
+ )
116
+
117
+ left_suffix_outputs, left_projected_actions = _run_model(model, prefix_context, x_t, timestep)
118
+ x_t_right_perturbed = x_t.clone()
119
+ x_t_right_perturbed[:, :, 16:32] += 0.5 * torch.randn_like(x_t_right_perturbed[:, :, 16:32])
120
+ _, right_perturbed_actions = _run_model(model, prefix_context, x_t_right_perturbed, timestep)
121
+ left_branch_invariance_max_abs_diff = float(
122
+ (left_projected_actions[:, :, 0:16] - right_perturbed_actions[:, :, 0:16]).abs().max().item()
123
+ )
124
+
125
+ x_t_left_perturbed = x_t.clone()
126
+ x_t_left_perturbed[:, :, 0:16] += 0.5 * torch.randn_like(x_t_left_perturbed[:, :, 0:16])
127
+ _, left_perturbed_actions = _run_model(model, prefix_context, x_t_left_perturbed, timestep)
128
+ right_branch_invariance_max_abs_diff = float(
129
+ (left_projected_actions[:, :, 16:32] - left_perturbed_actions[:, :, 16:32]).abs().max().item()
130
+ )
131
+
132
+ print(f"config_name: {args.config_name}")
133
+ print(f"checkpoint_dir: {args.checkpoint_dir}")
134
+ print(f"action_expert_mode: {model_cfg.action_expert_mode}")
135
+ print(f"weight_loading_missing_keys: {list(missing)}")
136
+ print(f"weight_loading_unexpected_keys: {list(unexpected)}")
137
+ print(f"identical_branch_suffix_max_abs_diff: {identical_branch_suffix_max_abs_diff:.8f}")
138
+ print(
139
+ f"identical_branch_suffix_match: "
140
+ f"{identical_branch_suffix_max_abs_diff <= args.tolerance}"
141
+ )
142
+
143
+ if model_cfg.action_expert_mode == "split_independent":
144
+ print(f"left_branch_invariance_max_abs_diff: {left_branch_invariance_max_abs_diff:.8f}")
145
+ print(f"right_branch_invariance_max_abs_diff: {right_branch_invariance_max_abs_diff:.8f}")
146
+ print(f"left_branch_invariant: {left_branch_invariance_max_abs_diff <= args.tolerance}")
147
+ print(f"right_branch_invariant: {right_branch_invariance_max_abs_diff <= args.tolerance}")
148
+ else:
149
+ print("left_branch_invariance_max_abs_diff: skipped_for_split_communicating")
150
+ print("right_branch_invariance_max_abs_diff: skipped_for_split_communicating")
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
openpi/scripts/eval_twin_val_loss_pytorch.py CHANGED
@@ -44,7 +44,7 @@ class Args:
44
  eval_seed: int = 123
45
  sample_num_batches: int = 0
46
  sample_batch_size: int | None = None
47
- sample_num_steps: str = "4,10"
48
  sample_seed: int = 321
49
 
50
 
 
44
  eval_seed: int = 123
45
  sample_num_batches: int = 0
46
  sample_batch_size: int | None = None
47
+ sample_num_steps: str = "1,2,4,8,16"
48
  sample_seed: int = 321
49
 
50
 
openpi/scripts/init_parallel_pi05_from_single_pytorch.py CHANGED
@@ -33,6 +33,7 @@ def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config
33
  action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
34
  pi05=getattr(config.model, "pi05", False),
35
  arm_action_dims=getattr(config.model, "arm_action_dims", None),
 
36
  )
37
 
38
  model_cfg = config.model
@@ -40,12 +41,60 @@ def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config
40
  return model_cfg
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def main() -> None:
44
  args = tyro.cli(Args)
45
  config = _config.get_config(args.config_name)
46
  model_cfg = _build_model_config(config)
47
  if not model_cfg.use_parallel_action_heads:
48
- raise ValueError(f"Config {args.config_name} is not a parallel-head config.")
49
  if tuple(model_cfg.arm_action_dims) != (16, 16):
50
  raise ValueError(f"Expected arm_action_dims=(16, 16), got {model_cfg.arm_action_dims}.")
51
 
@@ -65,74 +114,139 @@ def main() -> None:
65
  f"Expected single-head checkpoint with packed 32-dim actions, got in={tuple(weight_in.shape)} out={tuple(weight_out.shape)}."
66
  )
67
 
68
- with torch.no_grad():
69
- parallel_model.action_in_proj_arms[0].weight.copy_(weight_in[:, 0:16])
70
- parallel_model.action_in_proj_arms[0].bias.zero_()
71
- parallel_model.action_in_proj_arms[1].weight.copy_(weight_in[:, 16:32])
72
- parallel_model.action_in_proj_arms[1].bias.zero_()
73
-
74
- fuse_weight = torch.zeros_like(parallel_model.arm_token_fuse.weight)
75
- identity = torch.eye(hidden_width, dtype=fuse_weight.dtype)
76
- fuse_weight[:, 0:hidden_width] = identity
77
- fuse_weight[:, hidden_width : 2 * hidden_width] = identity
78
- parallel_model.arm_token_fuse.weight.copy_(fuse_weight)
79
- parallel_model.arm_token_fuse.bias.copy_(bias_in)
80
-
81
- parallel_model.action_out_proj_arms[0].weight.copy_(weight_out[0:16, :])
82
- parallel_model.action_out_proj_arms[0].bias.copy_(bias_out[0:16])
83
- parallel_model.action_out_proj_arms[1].weight.copy_(weight_out[16:32, :])
84
- parallel_model.action_out_proj_arms[1].bias.copy_(bias_out[16:32])
85
 
86
  proj_in_dtype = parallel_model.action_in_proj_arms[0].weight.dtype
87
  proj_out_dtype = parallel_model.action_out_proj_arms[0].weight.dtype
88
  x = torch.randn(2, model_cfg.action_horizon, model_cfg.action_dim, dtype=proj_in_dtype)
 
 
89
  suffix = torch.randn(2, model_cfg.action_horizon, hidden_width, dtype=proj_out_dtype)
 
 
 
 
 
 
 
 
 
 
90
  with torch.no_grad():
91
- input_max_abs_diff = float(
 
 
 
 
 
 
 
 
 
92
  (
93
- F.linear(x, weight_in.to(proj_in_dtype), bias_in.to(proj_in_dtype))
94
- - parallel_model._project_action_inputs(x)
95
  )
96
  .abs()
97
  .max()
98
  .item()
99
  )
100
- output_max_abs_diff = float(
101
  (
102
- F.linear(suffix, weight_out.to(proj_out_dtype), bias_out.to(proj_out_dtype))
103
- - parallel_model._project_action_outputs(suffix)
 
 
 
 
 
 
 
 
 
104
  )
105
  .abs()
106
  .max()
107
  .item()
108
  )
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  output_dir = Path(args.output_path)
111
  output_dir.mkdir(parents=True, exist_ok=True)
112
  safetensors.torch.save_model(parallel_model, output_dir / "model.safetensors")
113
  (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(model_cfg), indent=2, sort_keys=True))
114
- metadata = {
115
- "config_name": args.config_name,
116
- "single_ckpt": args.single_ckpt,
117
- "output_path": args.output_path,
118
- "load_state_missing_keys": list(missing),
119
- "load_state_unexpected_keys": list(unexpected),
120
- "input_projection_max_abs_diff": input_max_abs_diff,
121
- "output_projection_max_abs_diff": output_max_abs_diff,
122
- "warm_start_exact": input_max_abs_diff == 0.0 and output_max_abs_diff == 0.0,
123
- }
124
  (output_dir / "init_parallel_metadata.json").write_text(json.dumps(metadata, indent=2, sort_keys=True))
125
 
126
  print(f"config_name: {args.config_name}")
 
127
  print(f"single_ckpt: {args.single_ckpt}")
128
  print(f"output_path: {args.output_path}")
129
  print(f"load_state_missing_keys_count: {len(missing)}")
130
  print(f"load_state_missing_keys: {list(missing)}")
131
  print(f"load_state_unexpected_keys_count: {len(unexpected)}")
132
  print(f"load_state_unexpected_keys: {list(unexpected)}")
133
- print(f"input_projection_max_abs_diff: {input_max_abs_diff:.8f}")
134
- print(f"output_projection_max_abs_diff: {output_max_abs_diff:.8f}")
135
- print(f"warm_start_exact: {metadata['warm_start_exact']}")
 
136
 
137
 
138
  if __name__ == "__main__":
 
33
  action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
34
  pi05=getattr(config.model, "pi05", False),
35
  arm_action_dims=getattr(config.model, "arm_action_dims", None),
36
+ action_expert_mode=getattr(config.model, "action_expert_mode", None),
37
  )
38
 
39
  model_cfg = config.model
 
41
  return model_cfg
42
 
43
 
44
+ def _copy_factorized_heads(model, weight_in, bias_in, weight_out, bias_out) -> None:
45
+ hidden_width = weight_in.shape[0]
46
+ with torch.no_grad():
47
+ model.action_in_proj_arms[0].weight.copy_(weight_in[:, 0:16])
48
+ model.action_in_proj_arms[0].bias.zero_()
49
+ model.action_in_proj_arms[1].weight.copy_(weight_in[:, 16:32])
50
+ model.action_in_proj_arms[1].bias.zero_()
51
+
52
+ if hasattr(model, "arm_token_fuse"):
53
+ fuse_weight = torch.zeros_like(model.arm_token_fuse.weight)
54
+ identity = torch.eye(hidden_width, dtype=fuse_weight.dtype)
55
+ fuse_weight[:, 0:hidden_width] = identity
56
+ fuse_weight[:, hidden_width : 2 * hidden_width] = identity
57
+ model.arm_token_fuse.weight.copy_(fuse_weight)
58
+ model.arm_token_fuse.bias.copy_(bias_in)
59
+
60
+ model.action_out_proj_arms[0].weight.copy_(weight_out[0:16, :])
61
+ model.action_out_proj_arms[0].bias.copy_(bias_out[0:16])
62
+ model.action_out_proj_arms[1].weight.copy_(weight_out[16:32, :])
63
+ model.action_out_proj_arms[1].bias.copy_(bias_out[16:32])
64
+
65
+
66
+ def _copy_split_expert_weights(model, single_state) -> None:
67
+ model_state = model.state_dict()
68
+ with torch.no_grad():
69
+ for key, value in single_state.items():
70
+ if not key.startswith("paligemma_with_expert.gemma_expert."):
71
+ continue
72
+ suffix = key.removeprefix("paligemma_with_expert.gemma_expert.")
73
+ left_key = f"paligemma_with_expert.left_gemma_expert.{suffix}"
74
+ right_key = f"paligemma_with_expert.right_gemma_expert.{suffix}"
75
+ model_state[left_key].copy_(value.to(dtype=model_state[left_key].dtype))
76
+ model_state[right_key].copy_(value.to(dtype=model_state[right_key].dtype))
77
+
78
+
79
+ def _expert_copy_max_abs_diff(model, single_state, target_prefix: str) -> float:
80
+ model_state = model.state_dict()
81
+ max_abs_diff = 0.0
82
+ for key, value in single_state.items():
83
+ if not key.startswith("paligemma_with_expert.gemma_expert."):
84
+ continue
85
+ suffix = key.removeprefix("paligemma_with_expert.gemma_expert.")
86
+ target_key = f"{target_prefix}{suffix}"
87
+ diff = (model_state[target_key].to(torch.float32) - value.to(torch.float32)).abs().max().item()
88
+ max_abs_diff = max(max_abs_diff, float(diff))
89
+ return max_abs_diff
90
+
91
+
92
  def main() -> None:
93
  args = tyro.cli(Args)
94
  config = _config.get_config(args.config_name)
95
  model_cfg = _build_model_config(config)
96
  if not model_cfg.use_parallel_action_heads:
97
+ raise ValueError(f"Config {args.config_name} does not use factorized or split action heads.")
98
  if tuple(model_cfg.arm_action_dims) != (16, 16):
99
  raise ValueError(f"Expected arm_action_dims=(16, 16), got {model_cfg.arm_action_dims}.")
100
 
 
114
  f"Expected single-head checkpoint with packed 32-dim actions, got in={tuple(weight_in.shape)} out={tuple(weight_out.shape)}."
115
  )
116
 
117
+ _copy_factorized_heads(parallel_model, weight_in, bias_in, weight_out, bias_out)
118
+ if model_cfg.use_split_action_expert:
119
+ _copy_split_expert_weights(parallel_model, single_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  proj_in_dtype = parallel_model.action_in_proj_arms[0].weight.dtype
122
  proj_out_dtype = parallel_model.action_out_proj_arms[0].weight.dtype
123
  x = torch.randn(2, model_cfg.action_horizon, model_cfg.action_dim, dtype=proj_in_dtype)
124
+ x_left = x[:, :, 0:16]
125
+ x_right = x[:, :, 16:32]
126
  suffix = torch.randn(2, model_cfg.action_horizon, hidden_width, dtype=proj_out_dtype)
127
+
128
+ metadata = {
129
+ "config_name": args.config_name,
130
+ "action_expert_mode": model_cfg.action_expert_mode,
131
+ "single_ckpt": args.single_ckpt,
132
+ "output_path": args.output_path,
133
+ "load_state_missing_keys": list(missing),
134
+ "load_state_unexpected_keys": list(unexpected),
135
+ }
136
+
137
  with torch.no_grad():
138
+ left_input_projection_max_abs_diff = float(
139
+ (
140
+ F.linear(x_left, weight_in[:, 0:16].to(proj_in_dtype), None)
141
+ - parallel_model.action_in_proj_arms[0](x_left)
142
+ )
143
+ .abs()
144
+ .max()
145
+ .item()
146
+ )
147
+ right_input_projection_max_abs_diff = float(
148
  (
149
+ F.linear(x_right, weight_in[:, 16:32].to(proj_in_dtype), None)
150
+ - parallel_model.action_in_proj_arms[1](x_right)
151
  )
152
  .abs()
153
  .max()
154
  .item()
155
  )
156
+ left_output_projection_max_abs_diff = float(
157
  (
158
+ F.linear(suffix, weight_out[0:16, :].to(proj_out_dtype), bias_out[0:16].to(proj_out_dtype))
159
+ - parallel_model.action_out_proj_arms[0](suffix)
160
+ )
161
+ .abs()
162
+ .max()
163
+ .item()
164
+ )
165
+ right_output_projection_max_abs_diff = float(
166
+ (
167
+ F.linear(suffix, weight_out[16:32, :].to(proj_out_dtype), bias_out[16:32].to(proj_out_dtype))
168
+ - parallel_model.action_out_proj_arms[1](suffix)
169
  )
170
  .abs()
171
  .max()
172
  .item()
173
  )
174
 
175
+ metadata.update(
176
+ {
177
+ "left_input_projection_max_abs_diff": left_input_projection_max_abs_diff,
178
+ "right_input_projection_max_abs_diff": right_input_projection_max_abs_diff,
179
+ "left_output_projection_max_abs_diff": left_output_projection_max_abs_diff,
180
+ "right_output_projection_max_abs_diff": right_output_projection_max_abs_diff,
181
+ }
182
+ )
183
+
184
+ if model_cfg.action_expert_mode == "head_only_parallel":
185
+ input_max_abs_diff = float(
186
+ (
187
+ F.linear(x, weight_in.to(proj_in_dtype), bias_in.to(proj_in_dtype))
188
+ - parallel_model._project_action_inputs(x)
189
+ )
190
+ .abs()
191
+ .max()
192
+ .item()
193
+ )
194
+ output_max_abs_diff = float(
195
+ (
196
+ F.linear(suffix, weight_out.to(proj_out_dtype), bias_out.to(proj_out_dtype))
197
+ - parallel_model._project_action_outputs(suffix)
198
+ )
199
+ .abs()
200
+ .max()
201
+ .item()
202
+ )
203
+ metadata["input_projection_max_abs_diff"] = input_max_abs_diff
204
+ metadata["output_projection_max_abs_diff"] = output_max_abs_diff
205
+ metadata["warm_start_exact"] = input_max_abs_diff == 0.0 and output_max_abs_diff == 0.0
206
+ else:
207
+ left_expert_max_abs_diff = _expert_copy_max_abs_diff(
208
+ parallel_model,
209
+ single_state,
210
+ "paligemma_with_expert.left_gemma_expert.",
211
+ )
212
+ right_expert_max_abs_diff = _expert_copy_max_abs_diff(
213
+ parallel_model,
214
+ single_state,
215
+ "paligemma_with_expert.right_gemma_expert.",
216
+ )
217
+ metadata["left_expert_max_abs_diff"] = left_expert_max_abs_diff
218
+ metadata["right_expert_max_abs_diff"] = right_expert_max_abs_diff
219
+ if parallel_model.paligemma_with_expert.cross_arm_comm is not None:
220
+ metadata["cross_arm_comm_init"] = [
221
+ float(value) for value in parallel_model.paligemma_with_expert.cross_arm_comm.detach().cpu().tolist()
222
+ ]
223
+ metadata["warm_start_exact"] = (
224
+ left_input_projection_max_abs_diff == 0.0
225
+ and right_input_projection_max_abs_diff == 0.0
226
+ and left_output_projection_max_abs_diff == 0.0
227
+ and right_output_projection_max_abs_diff == 0.0
228
+ and left_expert_max_abs_diff == 0.0
229
+ and right_expert_max_abs_diff == 0.0
230
+ )
231
+
232
  output_dir = Path(args.output_path)
233
  output_dir.mkdir(parents=True, exist_ok=True)
234
  safetensors.torch.save_model(parallel_model, output_dir / "model.safetensors")
235
  (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(model_cfg), indent=2, sort_keys=True))
 
 
 
 
 
 
 
 
 
 
236
  (output_dir / "init_parallel_metadata.json").write_text(json.dumps(metadata, indent=2, sort_keys=True))
237
 
238
  print(f"config_name: {args.config_name}")
239
+ print(f"action_expert_mode: {model_cfg.action_expert_mode}")
240
  print(f"single_ckpt: {args.single_ckpt}")
241
  print(f"output_path: {args.output_path}")
242
  print(f"load_state_missing_keys_count: {len(missing)}")
243
  print(f"load_state_missing_keys: {list(missing)}")
244
  print(f"load_state_unexpected_keys_count: {len(unexpected)}")
245
  print(f"load_state_unexpected_keys: {list(unexpected)}")
246
+ for key in sorted(metadata):
247
+ if key in {"config_name", "action_expert_mode", "single_ckpt", "output_path", "load_state_missing_keys", "load_state_unexpected_keys"}:
248
+ continue
249
+ print(f"{key}: {metadata[key]}")
250
 
251
 
252
  if __name__ == "__main__":
openpi/scripts/run_twin_dual_push_128_packed_5k.sh CHANGED
@@ -26,8 +26,8 @@ PARALLEL_EXP="dual_push_128_packed_parallel_5k"
26
  VAL_REPO="lsnu/twin_dual_push_128_val"
27
  INTERMEDIATE_VAL_BATCHES=50
28
  FINAL_VAL_BATCHES=100
29
- SAMPLE_VAL_BATCHES=16
30
- SAMPLE_NUM_STEPS="4,10"
31
  RUN_WARMSTART_CHECK="${RUN_WARMSTART_CHECK:-0}"
32
 
33
  BASELINE_CKPT_ROOT="$ROOT/checkpoints/$BASELINE_CONFIG/$BASELINE_EXP"
 
26
  VAL_REPO="lsnu/twin_dual_push_128_val"
27
  INTERMEDIATE_VAL_BATCHES=50
28
  FINAL_VAL_BATCHES=100
29
+ SAMPLE_VAL_BATCHES=64
30
+ SAMPLE_NUM_STEPS="1,2,4,8,16"
31
  RUN_WARMSTART_CHECK="${RUN_WARMSTART_CHECK:-0}"
32
 
33
  BASELINE_CKPT_ROOT="$ROOT/checkpoints/$BASELINE_CONFIG/$BASELINE_EXP"
openpi/scripts/run_twin_handover_packed_10k.sh CHANGED
@@ -26,8 +26,8 @@ PARALLEL_EXP="handover_packed_parallel_10k"
26
  VAL_REPO="lsnu/twin_handover_256_val"
27
  INTERMEDIATE_VAL_BATCHES=50
28
  FINAL_VAL_BATCHES=100
29
- SAMPLE_VAL_BATCHES=16
30
- SAMPLE_NUM_STEPS="4,10"
31
 
32
  BASELINE_CKPT_ROOT="$ROOT/checkpoints/$BASELINE_CONFIG/$BASELINE_EXP"
33
  PARALLEL_CKPT_ROOT="$ROOT/checkpoints/$PARALLEL_CONFIG/$PARALLEL_EXP"
 
26
  VAL_REPO="lsnu/twin_handover_256_val"
27
  INTERMEDIATE_VAL_BATCHES=50
28
  FINAL_VAL_BATCHES=100
29
+ SAMPLE_VAL_BATCHES=64
30
+ SAMPLE_NUM_STEPS="1,2,4,8,16"
31
 
32
  BASELINE_CKPT_ROOT="$ROOT/checkpoints/$BASELINE_CONFIG/$BASELINE_EXP"
33
  PARALLEL_CKPT_ROOT="$ROOT/checkpoints/$PARALLEL_CONFIG/$PARALLEL_EXP"
openpi/scripts/train_pytorch.py CHANGED
@@ -216,6 +216,34 @@ def grad_norm_for_parameters(parameters) -> float:
216
 
217
  def collect_gradient_bucket_norms(model: torch.nn.Module) -> dict[str, float]:
218
  model_for_logging = unwrap_model(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  metrics = {"grad_shared_expert": grad_norm_for_parameters(model_for_logging.paligemma_with_expert.parameters())}
220
  if model_for_logging.use_parallel_action_heads:
221
  metrics["grad_action_in_proj_arms"] = grad_norm_for_parameters(model_for_logging.action_in_proj_arms.parameters())
@@ -669,7 +697,7 @@ def train_loop(config: _config.TrainConfig):
669
  last_step_end = time.perf_counter()
670
  smoothed_loss = None
671
  if is_main:
672
- model_kind = "parallel" if model_cfg.use_parallel_action_heads else "baseline"
673
  logging.info(f"Running on: {platform.node()} | world_size={world_size}")
674
  logging.info(
675
  f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
@@ -696,7 +724,11 @@ def train_loop(config: _config.TrainConfig):
696
  + (
697
  "action_in_proj, action_out_proj, shared_expert"
698
  if not model_cfg.use_parallel_action_heads
699
- else "action_in_proj_arms, arm_token_fuse, action_out_proj_arms, shared_expert"
 
 
 
 
700
  )
701
  )
702
 
 
216
 
217
  def collect_gradient_bucket_norms(model: torch.nn.Module) -> dict[str, float]:
218
  model_for_logging = unwrap_model(model)
219
+ if model_for_logging.use_split_action_expert:
220
+ metrics = {
221
+ "grad_shared_backbone": grad_norm_for_parameters(
222
+ model_for_logging.paligemma_with_expert.paligemma.parameters()
223
+ ),
224
+ "grad_left_action_in": grad_norm_for_parameters(model_for_logging.action_in_proj_arms[0].parameters()),
225
+ "grad_right_action_in": grad_norm_for_parameters(model_for_logging.action_in_proj_arms[1].parameters()),
226
+ "grad_left_expert": grad_norm_for_parameters(
227
+ model_for_logging.paligemma_with_expert.left_gemma_expert.parameters()
228
+ ),
229
+ "grad_right_expert": grad_norm_for_parameters(
230
+ model_for_logging.paligemma_with_expert.right_gemma_expert.parameters()
231
+ ),
232
+ "grad_action_out": grad_norm_for_parameters(model_for_logging.action_out_proj_arms.parameters()),
233
+ }
234
+ if model_for_logging.use_communicating_action_expert:
235
+ metrics["grad_cross_arm_comm"] = grad_norm_for_parameters(
236
+ [model_for_logging.paligemma_with_expert.cross_arm_comm]
237
+ )
238
+ for layer_idx, gate_value in enumerate(model_for_logging.paligemma_with_expert.cross_arm_comm.detach().cpu()):
239
+ metrics[f"cross_arm_comm_gate_layer_{layer_idx}"] = float(gate_value.item())
240
+ if model_for_logging.paligemma_with_expert.latest_cross_arm_attention_mass is not None:
241
+ for layer_idx, attn_mass in enumerate(
242
+ model_for_logging.paligemma_with_expert.latest_cross_arm_attention_mass.detach().cpu()
243
+ ):
244
+ metrics[f"cross_arm_attention_mass_layer_{layer_idx}"] = float(attn_mass.item())
245
+ return metrics
246
+
247
  metrics = {"grad_shared_expert": grad_norm_for_parameters(model_for_logging.paligemma_with_expert.parameters())}
248
  if model_for_logging.use_parallel_action_heads:
249
  metrics["grad_action_in_proj_arms"] = grad_norm_for_parameters(model_for_logging.action_in_proj_arms.parameters())
 
697
  last_step_end = time.perf_counter()
698
  smoothed_loss = None
699
  if is_main:
700
+ model_kind = model_cfg.action_expert_mode
701
  logging.info(f"Running on: {platform.node()} | world_size={world_size}")
702
  logging.info(
703
  f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
 
724
  + (
725
  "action_in_proj, action_out_proj, shared_expert"
726
  if not model_cfg.use_parallel_action_heads
727
+ else (
728
+ "left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm"
729
+ if model_cfg.use_split_action_expert
730
+ else "action_in_proj_arms, arm_token_fuse, action_out_proj_arms, shared_expert"
731
+ )
732
  )
733
  )
734
 
openpi/src/openpi/models/pi0_config.py CHANGED
@@ -1,5 +1,5 @@
1
  import dataclasses
2
- from typing import TYPE_CHECKING
3
 
4
  import flax.nnx as nnx
5
  import jax
@@ -15,6 +15,9 @@ if TYPE_CHECKING:
15
  from openpi.models.pi0 import Pi0
16
 
17
 
 
 
 
18
  @dataclasses.dataclass(frozen=True)
19
  class Pi0Config(_model.BaseModelConfig):
20
  dtype: str = "bfloat16"
@@ -32,6 +35,8 @@ class Pi0Config(_model.BaseModelConfig):
32
  # Per-arm action dimensions for parallel arm heads. For a dual-arm setup with 32-dim actions,
33
  # this could be `(16, 16)`. The sum must equal `action_dim`.
34
  arm_action_dims: tuple[int, ...] | None = None
 
 
35
  # This config option is not used directly by the model, but it is read by the ModelTransformFactory.
36
  discrete_state_input: bool = None # type: ignore
37
 
@@ -44,6 +49,9 @@ class Pi0Config(_model.BaseModelConfig):
44
  object.__setattr__(self, "arm_action_dims", (self.action_dim,))
45
  else:
46
  object.__setattr__(self, "arm_action_dims", tuple(self.arm_action_dims))
 
 
 
47
 
48
  if not self.arm_action_dims:
49
  raise ValueError("arm_action_dims must contain at least one arm.")
@@ -55,10 +63,27 @@ class Pi0Config(_model.BaseModelConfig):
55
  )
56
  if len(self.arm_action_dims) > 1 and not self.pi05:
57
  raise ValueError("Parallel arm heads are only supported for pi05 models.")
 
 
 
 
 
 
 
 
 
58
 
59
  @property
60
  def use_parallel_action_heads(self) -> bool:
61
- return len(self.arm_action_dims) > 1
 
 
 
 
 
 
 
 
62
 
63
  @property
64
  @override
@@ -71,6 +96,8 @@ class Pi0Config(_model.BaseModelConfig):
71
  def create(self, rng: at.KeyArrayLike) -> "Pi0":
72
  from openpi.models.pi0 import Pi0
73
 
 
 
74
  return Pi0(self, rngs=nnx.Rngs(rng))
75
 
76
  @override
 
1
  import dataclasses
2
+ from typing import TYPE_CHECKING, Literal
3
 
4
  import flax.nnx as nnx
5
  import jax
 
15
  from openpi.models.pi0 import Pi0
16
 
17
 
18
+ ActionExpertMode = Literal["shared", "head_only_parallel", "split_independent", "split_communicating"]
19
+
20
+
21
  @dataclasses.dataclass(frozen=True)
22
  class Pi0Config(_model.BaseModelConfig):
23
  dtype: str = "bfloat16"
 
35
  # Per-arm action dimensions for parallel arm heads. For a dual-arm setup with 32-dim actions,
36
  # this could be `(16, 16)`. The sum must equal `action_dim`.
37
  arm_action_dims: tuple[int, ...] | None = None
38
+ # Defines whether the action expert is shared, split only at the factorized heads, or duplicated per arm.
39
+ action_expert_mode: ActionExpertMode | None = None
40
  # This config option is not used directly by the model, but it is read by the ModelTransformFactory.
41
  discrete_state_input: bool = None # type: ignore
42
 
 
49
  object.__setattr__(self, "arm_action_dims", (self.action_dim,))
50
  else:
51
  object.__setattr__(self, "arm_action_dims", tuple(self.arm_action_dims))
52
+ if self.action_expert_mode is None:
53
+ default_mode: ActionExpertMode = "head_only_parallel" if len(self.arm_action_dims) > 1 else "shared"
54
+ object.__setattr__(self, "action_expert_mode", default_mode)
55
 
56
  if not self.arm_action_dims:
57
  raise ValueError("arm_action_dims must contain at least one arm.")
 
63
  )
64
  if len(self.arm_action_dims) > 1 and not self.pi05:
65
  raise ValueError("Parallel arm heads are only supported for pi05 models.")
66
+ if self.action_expert_mode != "shared" and len(self.arm_action_dims) < 2:
67
+ raise ValueError(
68
+ f"action_expert_mode={self.action_expert_mode!r} requires at least two arm_action_dims, got {self.arm_action_dims}."
69
+ )
70
+ if self.action_expert_mode in ("split_independent", "split_communicating") and len(self.arm_action_dims) != 2:
71
+ raise ValueError(
72
+ "split action expert modes currently require exactly two arm_action_dims "
73
+ f"(left/right), got {self.arm_action_dims}."
74
+ )
75
 
76
  @property
77
  def use_parallel_action_heads(self) -> bool:
78
+ return self.action_expert_mode != "shared"
79
+
80
+ @property
81
+ def use_split_action_expert(self) -> bool:
82
+ return self.action_expert_mode in ("split_independent", "split_communicating")
83
+
84
+ @property
85
+ def use_communicating_action_expert(self) -> bool:
86
+ return self.action_expert_mode == "split_communicating"
87
 
88
  @property
89
  @override
 
96
  def create(self, rng: at.KeyArrayLike) -> "Pi0":
97
  from openpi.models.pi0 import Pi0
98
 
99
+ if self.use_split_action_expert:
100
+ raise NotImplementedError("Split action expert modes are currently supported only in the PyTorch model.")
101
  return Pi0(self, rngs=nnx.Rngs(rng))
102
 
103
  @override
openpi/src/openpi/models/utils/fsq_tokenizer.py CHANGED
@@ -1,7 +1,6 @@
1
  import math
2
  from typing import Any, Literal
3
 
4
- import chex
5
  from einops import einops
6
  from flax import linen as nn
7
  from flax.linen.module import Module
@@ -12,6 +11,20 @@ import jax
12
  import jax.numpy as jnp
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class FsqCodebook(nn.Module):
16
  input_dim: int
17
  target_codebook_size: int
@@ -109,7 +122,7 @@ class FsqCodebook(nn.Module):
109
  z_q = digits / (bases - 1) * 2 - 1
110
 
111
  if z_grad is not None:
112
- chex.assert_equal_shape([z_q, z_grad])
113
  z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
114
 
115
  return self.proj_up(z_q)
@@ -216,7 +229,7 @@ class LookupFreeQuantization(nn.Module):
216
  + token_bit_log_probs[..., 1] @ token_bit_expansions
217
  ) # (batch_size, num_tokens, 2 ** num_dims)
218
  token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
219
- chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
220
 
221
  z_q = self.codebook[tokens]
222
  commit_loss = jnp.square(z - z_q).mean()
@@ -361,7 +374,7 @@ class TokenizerEncoderDecoder(nn.Module):
361
 
362
  if mask is not None:
363
  # mask is (batch_dims..., num_cross_tokens)
364
- chex.assert_equal_shape([y[..., 0], mask])
365
  attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
366
  else:
367
  attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens))
 
1
  import math
2
  from typing import Any, Literal
3
 
 
4
  from einops import einops
5
  from flax import linen as nn
6
  from flax.linen.module import Module
 
11
  import jax.numpy as jnp
12
 
13
 
14
+ def _assert_equal_shape(arrays: list[jnp.ndarray]) -> None:
15
+ if not arrays:
16
+ return
17
+ expected_shape = tuple(arrays[0].shape)
18
+ for array in arrays[1:]:
19
+ if tuple(array.shape) != expected_shape:
20
+ raise ValueError(f"Expected equal shapes, got {expected_shape} and {tuple(array.shape)}.")
21
+
22
+
23
+ def _assert_shape(array: jnp.ndarray, expected_shape: tuple[int, ...]) -> None:
24
+ if tuple(array.shape) != expected_shape:
25
+ raise ValueError(f"Expected shape {expected_shape}, got {tuple(array.shape)}.")
26
+
27
+
28
  class FsqCodebook(nn.Module):
29
  input_dim: int
30
  target_codebook_size: int
 
122
  z_q = digits / (bases - 1) * 2 - 1
123
 
124
  if z_grad is not None:
125
+ _assert_equal_shape([z_q, z_grad])
126
  z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
127
 
128
  return self.proj_up(z_q)
 
229
  + token_bit_log_probs[..., 1] @ token_bit_expansions
230
  ) # (batch_size, num_tokens, 2 ** num_dims)
231
  token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
232
+ _assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
233
 
234
  z_q = self.codebook[tokens]
235
  commit_loss = jnp.square(z - z_q).mean()
 
374
 
375
  if mask is not None:
376
  # mask is (batch_dims..., num_cross_tokens)
377
+ _assert_equal_shape([y[..., 0], mask])
378
  attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
379
  else:
380
  attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens))
openpi/src/openpi/models_pytorch/gemma_pytorch.py CHANGED
@@ -1,6 +1,6 @@
 
1
  from typing import Literal
2
 
3
- import pytest
4
  import torch
5
  from torch import nn
6
  from transformers import GemmaForCausalLM
@@ -16,11 +16,33 @@ class PaliGemmaWithExpertModel(nn.Module):
16
  action_expert_config,
17
  use_adarms=None,
18
  precision: Literal["bfloat16", "float32"] = "bfloat16",
 
 
 
19
  ):
20
  if use_adarms is None:
21
  use_adarms = [False, False]
22
  super().__init__()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  vlm_config_hf = CONFIG_MAPPING["paligemma"]()
25
  vlm_config_hf._vocab_size = 257152 # noqa: SLF001
26
  vlm_config_hf.image_token_index = 257152
@@ -36,7 +58,6 @@ class PaliGemmaWithExpertModel(nn.Module):
36
  vlm_config_hf.text_config.use_adarms = use_adarms[0]
37
  vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
38
  vlm_config_hf.vision_config.intermediate_size = 4304
39
- # Keep image and language embedding dimensions aligned for all variants, including dummy.
40
  vlm_config_hf.vision_config.projection_dim = vlm_config.width
41
  vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
42
  vlm_config_hf.vision_config.torch_dtype = "float32"
@@ -51,16 +72,41 @@ class PaliGemmaWithExpertModel(nn.Module):
51
  vocab_size=257152,
52
  hidden_activation="gelu_pytorch_tanh",
53
  torch_dtype="float32",
54
- use_adarms=use_adarms[1],
55
- adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
56
  )
57
 
58
  self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
59
- self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
60
- self.gemma_expert.model.embed_tokens = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  self.to_bfloat16_for_selected_params(precision)
63
 
 
 
 
 
 
 
64
  def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
65
  if precision == "bfloat16":
66
  self.to(dtype=torch.bfloat16)
@@ -89,194 +135,214 @@ class PaliGemmaWithExpertModel(nn.Module):
89
  def embed_language_tokens(self, tokens: torch.Tensor):
90
  return self.paligemma.language_model.embed_tokens(tokens)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def forward(
93
  self,
94
  attention_mask: torch.Tensor | None = None,
95
  position_ids: torch.LongTensor | None = None,
96
- past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
97
- inputs_embeds: list[torch.FloatTensor] | None = None,
98
  use_cache: bool | None = None,
99
- adarms_cond: list[torch.Tensor] | None = None,
 
100
  ):
 
 
 
 
 
 
101
  if adarms_cond is None:
102
- adarms_cond = [None, None]
103
- if inputs_embeds[1] is None:
 
 
 
 
104
  prefix_output = self.paligemma.language_model.forward(
105
  inputs_embeds=inputs_embeds[0],
106
  attention_mask=attention_mask,
107
  position_ids=position_ids,
108
  past_key_values=past_key_values,
109
  use_cache=use_cache,
110
- adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
111
  )
112
- prefix_past_key_values = prefix_output.past_key_values
113
- prefix_output = prefix_output.last_hidden_state
114
- suffix_output = None
115
- elif inputs_embeds[0] is None:
116
- suffix_output = self.gemma_expert.model.forward(
117
- inputs_embeds=inputs_embeds[1],
 
 
118
  attention_mask=attention_mask,
119
  position_ids=position_ids,
120
  past_key_values=past_key_values,
121
  use_cache=use_cache,
122
- adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
123
  )
124
- suffix_output = suffix_output.last_hidden_state
125
- prefix_output = None
126
- prefix_past_key_values = None
127
- else:
128
- models = [self.paligemma.language_model, self.gemma_expert.model]
129
- num_layers = self.paligemma.config.text_config.num_hidden_layers
130
-
131
- # Check if gradient checkpointing is enabled for any of the models
132
- use_gradient_checkpointing = (
133
- hasattr(self.gemma_expert.model, "gradient_checkpointing")
134
- and self.gemma_expert.model.gradient_checkpointing
135
- and self.training
136
- ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
137
-
138
- # Force enable gradient checkpointing if we're in training mode and the model supports it
139
- if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
140
- if not self.gemma_expert.model.gradient_checkpointing:
141
- print("Forcing gradient checkpointing to be enabled for Gemma expert model")
142
- self.gemma_expert.model.gradient_checkpointing = True
143
- use_gradient_checkpointing = True
144
-
145
- # Debug gradient checkpointing status
146
- if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
147
- print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
148
- print(f"Model training mode: {self.training}")
149
- print(
150
- f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
151
- )
152
- if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
153
- print(
154
- f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
155
- )
156
- self._debug_gc_printed = True
157
-
158
- # Define the complete layer computation function for gradient checkpointing
159
- def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
160
- models = [self.paligemma.language_model, self.gemma_expert.model]
161
-
162
- query_states = []
163
- key_states = []
164
- value_states = []
165
- gates = []
166
- for i, hidden_states in enumerate(inputs_embeds):
167
- layer = models[i].layers[layer_idx]
168
- hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
169
- gates.append(gate)
170
-
171
- input_shape = hidden_states.shape[:-1]
172
- hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
173
- query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
174
- key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
175
- value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
176
-
177
- query_states.append(query_state)
178
- key_states.append(key_state)
179
- value_states.append(value_state)
180
-
181
- # Concatenate and process attention
182
- query_states = torch.cat(query_states, dim=2)
183
- key_states = torch.cat(key_states, dim=2)
184
- value_states = torch.cat(value_states, dim=2)
185
-
186
- dummy_tensor = torch.zeros(
187
- query_states.shape[0],
188
- query_states.shape[2],
189
- query_states.shape[-1],
190
- device=query_states.device,
191
- dtype=query_states.dtype,
192
- )
193
- cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
194
- query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
195
- query_states, key_states, cos, sin, unsqueeze_dim=1
196
- )
197
 
198
- batch_size = query_states.shape[0]
199
- scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # Attention computation
202
- att_output, _ = modeling_gemma.eager_attention_forward(
203
- self.paligemma.language_model.layers[layer_idx].self_attn,
204
- query_states,
205
- key_states,
206
- value_states,
207
- attention_mask,
208
- scaling,
209
  )
210
- # Get head_dim from the current layer, not from the model
211
- head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
212
- att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
213
-
214
- # Process layer outputs
215
- outputs_embeds = []
216
- start_pos = 0
217
- for i, hidden_states in enumerate(inputs_embeds):
218
- layer = models[i].layers[layer_idx]
219
- end_pos = start_pos + hidden_states.shape[1]
220
-
221
- if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
222
- att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
223
- out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
224
-
225
- # first residual
226
- out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
227
- after_first_residual = out_emb.clone()
228
- out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
229
- # Convert to bfloat16 if the next layer (mlp) uses bfloat16
230
- if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
231
- out_emb = out_emb.to(dtype=torch.bfloat16)
232
-
233
- out_emb = layer.mlp(out_emb)
234
- # second residual
235
- out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
236
- outputs_embeds.append(out_emb)
237
- start_pos = end_pos
238
-
239
- return outputs_embeds
240
-
241
- # Process all layers with gradient checkpointing if enabled
242
- for layer_idx in range(num_layers):
243
- if use_gradient_checkpointing:
244
- inputs_embeds = torch.utils.checkpoint.checkpoint(
245
- compute_layer_complete,
246
- layer_idx,
247
- inputs_embeds,
248
- attention_mask,
249
- position_ids,
250
- adarms_cond,
251
- use_reentrant=False,
252
- preserve_rng_state=False,
253
- )
254
- else:
255
- inputs_embeds = compute_layer_complete(
256
- layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
257
- )
258
-
259
- # Old code removed - now using compute_layer_complete function above
260
-
261
- # final norm
262
- # Define final norm computation function for gradient checkpointing
263
- def compute_final_norms(inputs_embeds, adarms_cond):
264
- outputs_embeds = []
265
- for i, hidden_states in enumerate(inputs_embeds):
266
- out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
267
- outputs_embeds.append(out_emb)
268
- return outputs_embeds
269
-
270
- # Apply gradient checkpointing to final norm if enabled
271
  if use_gradient_checkpointing:
272
- outputs_embeds = torch.utils.checkpoint.checkpoint(
273
- compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False
 
 
 
 
 
 
 
 
274
  )
275
  else:
276
- outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
 
 
 
 
 
 
 
277
 
278
- prefix_output = outputs_embeds[0]
279
- suffix_output = outputs_embeds[1]
280
- prefix_past_key_values = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- return [prefix_output, suffix_output], prefix_past_key_values
 
1
+ from typing import Any
2
  from typing import Literal
3
 
 
4
  import torch
5
  from torch import nn
6
  from transformers import GemmaForCausalLM
 
16
  action_expert_config,
17
  use_adarms=None,
18
  precision: Literal["bfloat16", "float32"] = "bfloat16",
19
+ *,
20
+ num_action_experts: int = 1,
21
+ enable_cross_arm_communication: bool = False,
22
  ):
23
  if use_adarms is None:
24
  use_adarms = [False, False]
25
  super().__init__()
26
 
27
+ if num_action_experts < 1:
28
+ raise ValueError(f"num_action_experts must be positive, got {num_action_experts}.")
29
+ if enable_cross_arm_communication and num_action_experts < 2:
30
+ raise ValueError("Cross-arm communication requires at least two action experts.")
31
+
32
+ if len(use_adarms) == 2 and num_action_experts > 1:
33
+ use_adarms = [use_adarms[0], *([use_adarms[1]] * num_action_experts)]
34
+ if len(use_adarms) != num_action_experts + 1:
35
+ raise ValueError(
36
+ f"use_adarms must have one entry per stream, got {len(use_adarms)} for {num_action_experts + 1} streams."
37
+ )
38
+
39
+ expert_use_adarms = use_adarms[1]
40
+ if any(expert_flag != expert_use_adarms for expert_flag in use_adarms[1:]):
41
+ raise ValueError(f"All action expert streams must agree on use_adarms, got {use_adarms[1:]}.")
42
+
43
+ self.num_action_experts = num_action_experts
44
+ self.enable_cross_arm_communication = enable_cross_arm_communication
45
+
46
  vlm_config_hf = CONFIG_MAPPING["paligemma"]()
47
  vlm_config_hf._vocab_size = 257152 # noqa: SLF001
48
  vlm_config_hf.image_token_index = 257152
 
58
  vlm_config_hf.text_config.use_adarms = use_adarms[0]
59
  vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
60
  vlm_config_hf.vision_config.intermediate_size = 4304
 
61
  vlm_config_hf.vision_config.projection_dim = vlm_config.width
62
  vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
63
  vlm_config_hf.vision_config.torch_dtype = "float32"
 
72
  vocab_size=257152,
73
  hidden_activation="gelu_pytorch_tanh",
74
  torch_dtype="float32",
75
+ use_adarms=expert_use_adarms,
76
+ adarms_cond_dim=action_expert_config.width if expert_use_adarms else None,
77
  )
78
 
79
  self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
80
+ if num_action_experts == 1:
81
+ self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
82
+ self.gemma_expert.model.embed_tokens = None
83
+ self._action_expert_names = ("gemma_expert",)
84
+ else:
85
+ self.left_gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
86
+ self.right_gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
87
+ self.left_gemma_expert.model.embed_tokens = None
88
+ self.right_gemma_expert.model.embed_tokens = None
89
+ self._action_expert_names = ("left_gemma_expert", "right_gemma_expert")
90
+
91
+ if self.enable_cross_arm_communication:
92
+ self.cross_arm_comm = nn.Parameter(torch.zeros(action_expert_config.depth, dtype=torch.float32))
93
+ self.register_buffer(
94
+ "latest_cross_arm_attention_mass",
95
+ torch.zeros(action_expert_config.depth, dtype=torch.float32),
96
+ persistent=False,
97
+ )
98
+ else:
99
+ self.cross_arm_comm = None
100
+ self.latest_cross_arm_attention_mass = None
101
 
102
  self.to_bfloat16_for_selected_params(precision)
103
 
104
+ def _get_action_experts(self) -> list[GemmaForCausalLM]:
105
+ return [getattr(self, name) for name in self._action_expert_names]
106
+
107
+ def _get_action_expert_models(self) -> list[nn.Module]:
108
+ return [expert.model for expert in self._get_action_experts()]
109
+
110
  def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
111
  if precision == "bfloat16":
112
  self.to(dtype=torch.bfloat16)
 
135
  def embed_language_tokens(self, tokens: torch.Tensor):
136
  return self.paligemma.language_model.embed_tokens(tokens)
137
 
138
+ def _make_outputs(self, prefix_output, suffix_outputs: list[torch.Tensor | None]) -> list[torch.Tensor | None]:
139
+ return [prefix_output, *suffix_outputs]
140
+
141
+ def _compute_cross_arm_attention_mass(
142
+ self,
143
+ layer_idx: int,
144
+ att_weights: torch.Tensor,
145
+ cross_attention_selector: torch.Tensor | None,
146
+ ) -> None:
147
+ if self.latest_cross_arm_attention_mass is None or cross_attention_selector is None:
148
+ return
149
+ selector = cross_attention_selector.to(device=att_weights.device, dtype=att_weights.dtype)
150
+ denom = selector.sum() * att_weights.shape[0] * att_weights.shape[1]
151
+ if float(denom.item()) <= 0:
152
+ self.latest_cross_arm_attention_mass[layer_idx].zero_()
153
+ return
154
+ mass = (att_weights * selector).sum() / denom
155
+ self.latest_cross_arm_attention_mass[layer_idx].copy_(mass.detach().to(torch.float32))
156
+
157
  def forward(
158
  self,
159
  attention_mask: torch.Tensor | None = None,
160
  position_ids: torch.LongTensor | None = None,
161
+ past_key_values: list[torch.FloatTensor] | Any | None = None,
162
+ inputs_embeds: list[torch.FloatTensor | None] | None = None,
163
  use_cache: bool | None = None,
164
+ adarms_cond: list[torch.Tensor | None] | None = None,
165
+ cross_attention_selector: torch.Tensor | None = None,
166
  ):
167
+ if inputs_embeds is None:
168
+ raise ValueError("inputs_embeds is required.")
169
+ if len(inputs_embeds) != self.num_action_experts + 1:
170
+ raise ValueError(
171
+ f"Expected {self.num_action_experts + 1} input streams, got {len(inputs_embeds)}."
172
+ )
173
  if adarms_cond is None:
174
+ adarms_cond = [None] * len(inputs_embeds)
175
+ if len(adarms_cond) != len(inputs_embeds):
176
+ raise ValueError(f"Expected {len(inputs_embeds)} adarms_cond entries, got {len(adarms_cond)}.")
177
+
178
+ suffix_inputs = inputs_embeds[1:]
179
+ if inputs_embeds[0] is not None and all(suffix is None for suffix in suffix_inputs):
180
  prefix_output = self.paligemma.language_model.forward(
181
  inputs_embeds=inputs_embeds[0],
182
  attention_mask=attention_mask,
183
  position_ids=position_ids,
184
  past_key_values=past_key_values,
185
  use_cache=use_cache,
186
+ adarms_cond=adarms_cond[0],
187
  )
188
+ outputs = self._make_outputs(prefix_output.last_hidden_state, [None] * self.num_action_experts)
189
+ return outputs, prefix_output.past_key_values
190
+
191
+ active_suffix_indices = [i for i, suffix in enumerate(suffix_inputs) if suffix is not None]
192
+ if inputs_embeds[0] is None and len(active_suffix_indices) == 1:
193
+ expert_idx = active_suffix_indices[0]
194
+ suffix_output = self._get_action_expert_models()[expert_idx].forward(
195
+ inputs_embeds=suffix_inputs[expert_idx],
196
  attention_mask=attention_mask,
197
  position_ids=position_ids,
198
  past_key_values=past_key_values,
199
  use_cache=use_cache,
200
+ adarms_cond=adarms_cond[expert_idx + 1],
201
  )
202
+ outputs = [None] * len(inputs_embeds)
203
+ outputs[expert_idx + 1] = suffix_output.last_hidden_state
204
+ return outputs, None
205
+
206
+ if inputs_embeds[0] is None:
207
+ raise NotImplementedError("Multi-stream suffix-only forward is not implemented.")
208
+ if any(suffix is None for suffix in suffix_inputs):
209
+ raise ValueError("Joint forward requires all suffix streams to be present.")
210
+
211
+ models = [self.paligemma.language_model, *self._get_action_expert_models()]
212
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
213
+ use_gradient_checkpointing = self.training and any(
214
+ getattr(model, "gradient_checkpointing", False) for model in models
215
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ def compute_layer_complete(
218
+ layer_idx,
219
+ stream_inputs,
220
+ attention_mask,
221
+ position_ids,
222
+ adarms_cond,
223
+ cross_attention_selector,
224
+ ):
225
+ query_states = []
226
+ key_states = []
227
+ value_states = []
228
+ gates = []
229
+ for stream_idx, hidden_states in enumerate(stream_inputs):
230
+ layer = models[stream_idx].layers[layer_idx]
231
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[stream_idx]) # noqa: PLW2901
232
+ gates.append(gate)
233
+
234
+ input_shape = hidden_states.shape[:-1]
235
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
236
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
237
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
238
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
239
+
240
+ query_states.append(query_state)
241
+ key_states.append(key_state)
242
+ value_states.append(value_state)
243
+
244
+ query_states = torch.cat(query_states, dim=2)
245
+ key_states = torch.cat(key_states, dim=2)
246
+ value_states = torch.cat(value_states, dim=2)
247
+
248
+ dummy_tensor = torch.zeros(
249
+ query_states.shape[0],
250
+ query_states.shape[2],
251
+ query_states.shape[-1],
252
+ device=query_states.device,
253
+ dtype=query_states.dtype,
254
+ )
255
+ cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
256
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
257
+ query_states, key_states, cos, sin, unsqueeze_dim=1
258
+ )
259
 
260
+ layer_attention_mask = attention_mask
261
+ if self.cross_arm_comm is not None and cross_attention_selector is not None:
262
+ cross_bias = self.cross_arm_comm[layer_idx].to(device=attention_mask.device, dtype=attention_mask.dtype)
263
+ layer_attention_mask = attention_mask + cross_bias * cross_attention_selector.to(
264
+ device=attention_mask.device,
265
+ dtype=attention_mask.dtype,
 
 
266
  )
267
+
268
+ batch_size = query_states.shape[0]
269
+ scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
270
+ att_output, att_weights = modeling_gemma.eager_attention_forward(
271
+ self.paligemma.language_model.layers[layer_idx].self_attn,
272
+ query_states,
273
+ key_states,
274
+ value_states,
275
+ layer_attention_mask,
276
+ scaling,
277
+ )
278
+ self._compute_cross_arm_attention_mass(layer_idx, att_weights, cross_attention_selector)
279
+
280
+ proj_dim = self.paligemma.language_model.layers[layer_idx].self_attn.o_proj.in_features
281
+ att_output = att_output.reshape(batch_size, -1, proj_dim)
282
+
283
+ outputs_embeds = []
284
+ start_pos = 0
285
+ for stream_idx, hidden_states in enumerate(stream_inputs):
286
+ layer = models[stream_idx].layers[layer_idx]
287
+ end_pos = start_pos + hidden_states.shape[1]
288
+
289
+ stream_att_output = att_output[:, start_pos:end_pos]
290
+ if stream_att_output.dtype != layer.self_attn.o_proj.weight.dtype:
291
+ stream_att_output = stream_att_output.to(layer.self_attn.o_proj.weight.dtype)
292
+ out_emb = layer.self_attn.o_proj(stream_att_output)
293
+
294
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[stream_idx]) # noqa: SLF001
295
+ after_first_residual = out_emb.clone()
296
+ out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[stream_idx])
297
+ if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
298
+ out_emb = out_emb.to(dtype=torch.bfloat16)
299
+
300
+ out_emb = layer.mlp(out_emb)
301
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
302
+ outputs_embeds.append(out_emb)
303
+ start_pos = end_pos
304
+
305
+ return outputs_embeds
306
+
307
+ for layer_idx in range(num_layers):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  if use_gradient_checkpointing:
309
+ inputs_embeds = torch.utils.checkpoint.checkpoint(
310
+ compute_layer_complete,
311
+ layer_idx,
312
+ inputs_embeds,
313
+ attention_mask,
314
+ position_ids,
315
+ adarms_cond,
316
+ cross_attention_selector,
317
+ use_reentrant=False,
318
+ preserve_rng_state=False,
319
  )
320
  else:
321
+ inputs_embeds = compute_layer_complete(
322
+ layer_idx,
323
+ inputs_embeds,
324
+ attention_mask,
325
+ position_ids,
326
+ adarms_cond,
327
+ cross_attention_selector,
328
+ )
329
 
330
+ def compute_final_norms(stream_inputs, adarms_cond):
331
+ outputs_embeds = []
332
+ for stream_idx, hidden_states in enumerate(stream_inputs):
333
+ out_emb, _ = models[stream_idx].norm(hidden_states, cond=adarms_cond[stream_idx])
334
+ outputs_embeds.append(out_emb)
335
+ return outputs_embeds
336
+
337
+ if use_gradient_checkpointing:
338
+ outputs_embeds = torch.utils.checkpoint.checkpoint(
339
+ compute_final_norms,
340
+ inputs_embeds,
341
+ adarms_cond,
342
+ use_reentrant=False,
343
+ preserve_rng_state=False,
344
+ )
345
+ else:
346
+ outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
347
 
348
+ return outputs_embeds, None
openpi/src/openpi/models_pytorch/pi0_pytorch.py CHANGED
@@ -15,7 +15,6 @@ import openpi.models_pytorch.preprocessing_pytorch as _preprocessing
15
  def get_safe_dtype(target_dtype, device_type):
16
  """Get a safe dtype for the given device type."""
17
  if device_type == "cpu":
18
- # CPU doesn't support bfloat16, use float32 instead
19
  if target_dtype == torch.bfloat16:
20
  return torch.float32
21
  if target_dtype == torch.float64:
@@ -29,15 +28,12 @@ def create_sinusoidal_pos_embedding(
29
  """Computes sine-cosine positional embedding vectors for scalar positions."""
30
  if dimension % 2 != 0:
31
  raise ValueError(f"dimension ({dimension}) must be divisible by 2")
32
-
33
  if time.ndim != 1:
34
  raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
35
 
36
  dtype = get_safe_dtype(torch.float64, device.type)
37
  fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
38
  period = min_period * (max_period / min_period) ** fraction
39
-
40
- # Compute the outer product
41
  scaling_factor = 1.0 / period * 2 * math.pi
42
  sin_input = scaling_factor[None, :] * time[:, None]
43
  return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
@@ -51,26 +47,7 @@ def sample_beta(alpha, beta, bsize, device):
51
 
52
 
53
  def make_att_2d_masks(pad_masks, att_masks):
54
- """Copied from big_vision.
55
-
56
- Tokens can attend to valid inputs tokens which have a cumulative mask_ar
57
- smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
58
- setup several types of attention, for example:
59
-
60
- [[1 1 1 1 1 1]]: pure causal attention.
61
-
62
- [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
63
- themselves and the last 3 tokens have a causal attention. The first
64
- entry could also be a 1 without changing behaviour.
65
-
66
- [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
67
- block can attend all previous blocks and all tokens on the same block.
68
-
69
- Args:
70
- input_mask: bool[B, N] true if its part of the input, false if padding.
71
- mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
72
- it and 0 where it shares the same attention mask as the previous token.
73
- """
74
  if att_masks.ndim != 2:
75
  raise ValueError(att_masks.ndim)
76
  if pad_masks.ndim != 2:
@@ -89,26 +66,36 @@ class PI0Pytorch(nn.Module):
89
  self.pi05 = config.pi05
90
  self.arm_action_dims = tuple(config.arm_action_dims)
91
  self.use_parallel_action_heads = config.use_parallel_action_heads
 
 
 
92
  self.action_split_dims = list(self.arm_action_dims)
93
 
 
 
 
94
  paligemma_config = _gemma.get_config(config.paligemma_variant)
95
  action_expert_config = _gemma.get_config(config.action_expert_variant)
96
  self.action_expert_width = action_expert_config.width
97
 
 
98
  self.paligemma_with_expert = PaliGemmaWithExpertModel(
99
  paligemma_config,
100
  action_expert_config,
101
  use_adarms=[False, True] if self.pi05 else [False, False],
102
  precision=config.dtype,
 
 
103
  )
104
 
105
  if self.use_parallel_action_heads:
106
  self.action_in_proj_arms = nn.ModuleList(
107
  [nn.Linear(arm_dim, action_expert_config.width) for arm_dim in self.arm_action_dims]
108
  )
109
- self.arm_token_fuse = nn.Linear(
110
- len(self.arm_action_dims) * action_expert_config.width, action_expert_config.width
111
- )
 
112
  self.action_out_proj_arms = nn.ModuleList(
113
  [nn.Linear(action_expert_config.width, arm_dim) for arm_dim in self.arm_action_dims]
114
  )
@@ -128,7 +115,6 @@ class PI0Pytorch(nn.Module):
128
  if os.environ.get("OPENPI_TORCH_COMPILE_SAMPLE_ACTIONS", "0") == "1":
129
  self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
130
 
131
- # Initialize gradient checkpointing flag
132
  self.gradient_checkpointing_enabled = False
133
 
134
  msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`."
@@ -140,30 +126,29 @@ class PI0Pytorch(nn.Module):
140
  except ImportError:
141
  raise ValueError(msg) from None
142
 
 
 
 
143
  def gradient_checkpointing_enable(self):
144
- """Enable gradient checkpointing for memory optimization."""
145
  self.gradient_checkpointing_enabled = True
146
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
147
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
148
- self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
149
-
150
  logging.info("Enabled gradient checkpointing for PI0Pytorch model")
151
 
152
  def gradient_checkpointing_disable(self):
153
- """Disable gradient checkpointing."""
154
  self.gradient_checkpointing_enabled = False
155
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
156
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
157
- self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
158
-
159
  logging.info("Disabled gradient checkpointing for PI0Pytorch model")
160
 
161
  def is_gradient_checkpointing_enabled(self):
162
- """Check if gradient checkpointing is enabled."""
163
  return self.gradient_checkpointing_enabled
164
 
165
  def _apply_checkpoint(self, func, *args, **kwargs):
166
- """Helper method to apply gradient checkpointing if enabled."""
167
  if self.gradient_checkpointing_enabled and self.training:
168
  return torch.utils.checkpoint.checkpoint(
169
  func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
@@ -171,12 +156,10 @@ class PI0Pytorch(nn.Module):
171
  return func(*args, **kwargs)
172
 
173
  def _prepare_attention_masks_4d(self, att_2d_masks):
174
- """Helper method to prepare 4D attention masks for transformer."""
175
  att_2d_masks_4d = att_2d_masks[:, None, :, :]
176
  return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
177
 
178
  def _preprocess_observation(self, observation, *, train=True):
179
- """Helper method to preprocess observation."""
180
  observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
181
  return (
182
  list(observation.images.values()),
@@ -187,13 +170,7 @@ class PI0Pytorch(nn.Module):
187
  )
188
 
189
  def sample_noise(self, shape, device):
190
- return torch.normal(
191
- mean=0.0,
192
- std=1.0,
193
- size=shape,
194
- dtype=torch.float32,
195
- device=device,
196
- )
197
 
198
  def sample_time(self, bsize, device):
199
  time_beta = sample_beta(1.5, 1.0, bsize, device)
@@ -217,6 +194,9 @@ class PI0Pytorch(nn.Module):
217
 
218
  per_arm_embeddings.append(self._apply_checkpoint(arm_proj_func, arm_actions))
219
 
 
 
 
220
  fused_inputs = torch.cat(per_arm_embeddings, dim=-1)
221
 
222
  def fuse_func(fused_inputs):
@@ -232,68 +212,69 @@ class PI0Pytorch(nn.Module):
232
 
233
  return self._apply_checkpoint(action_out_proj_func, suffix_out)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  per_arm_outputs = []
236
- for arm_head in self.action_out_proj_arms:
237
 
238
- def arm_out_func(suffix_out, arm_head=arm_head):
239
- return arm_head(suffix_out)
240
 
241
- per_arm_outputs.append(self._apply_checkpoint(arm_out_func, suffix_out))
242
  return torch.cat(per_arm_outputs, dim=-1)
243
 
244
  def embed_prefix(
245
  self, images, img_masks, lang_tokens, lang_masks
246
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
247
- """Embed images with SigLIP and language tokens with embedding layer to prepare
248
- for PaliGemma transformer processing.
249
- """
250
  embs = []
251
  pad_masks = []
252
  att_masks = []
253
 
254
- # Process images
255
  for img, img_mask in zip(images, img_masks, strict=True):
256
 
257
  def image_embed_func(img):
258
  return self.paligemma_with_expert.embed_image(img)
259
 
260
  img_emb = self._apply_checkpoint(image_embed_func, img)
261
-
262
  bsize, num_img_embs = img_emb.shape[:2]
263
-
264
  embs.append(img_emb)
265
  pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
266
-
267
- # Create attention masks so that image tokens attend to each other
268
  att_masks += [0] * num_img_embs
269
 
270
- # Process language tokens
271
  def lang_embed_func(lang_tokens):
272
  lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
273
- lang_emb_dim = lang_emb.shape[-1]
274
- return lang_emb * math.sqrt(lang_emb_dim)
275
 
276
  lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
277
-
278
  embs.append(lang_emb)
279
  pad_masks.append(lang_masks)
280
-
281
- # full attention between image and language inputs
282
- num_lang_embs = lang_emb.shape[1]
283
- att_masks += [0] * num_lang_embs
284
 
285
  embs = torch.cat(embs, dim=1)
286
  pad_masks = torch.cat(pad_masks, dim=1)
287
  att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
288
-
289
- # Get batch size from the first dimension of the concatenated tensors
290
- bsize = pad_masks.shape[0]
291
- att_masks = att_masks[None, :].expand(bsize, len(att_masks))
292
-
293
  return embs, pad_masks, att_masks
294
 
 
 
 
 
295
  def embed_suffix(self, state, noisy_actions, timestep):
296
- """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
 
 
297
  embs = []
298
  pad_masks = []
299
  att_masks = []
@@ -302,79 +283,186 @@ class PI0Pytorch(nn.Module):
302
  if self.state_proj.weight.dtype == torch.float32:
303
  state = state.to(torch.float32)
304
 
305
- # Embed state
306
  def state_proj_func(state):
307
  return self.state_proj(state)
308
 
309
  state_emb = self._apply_checkpoint(state_proj_func, state)
310
-
311
  embs.append(state_emb[:, None, :])
312
  bsize = state_emb.shape[0]
313
  device = state_emb.device
314
-
315
- state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
316
- pad_masks.append(state_mask)
317
-
318
- # Set attention masks so that image and language inputs do not attend to state or actions
319
  att_masks += [1]
320
 
321
- # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
322
  time_emb = create_sinusoidal_pos_embedding(
323
  timestep, self.action_expert_width, min_period=4e-3, max_period=4.0, device=timestep.device
324
  )
325
  time_emb = time_emb.type(dtype=timestep.dtype)
326
-
327
- # Fuse timestep + action information using an MLP
328
  action_emb = self._project_action_inputs(noisy_actions)
329
 
330
  if not self.pi05:
331
  time_emb = time_emb[:, None, :].expand_as(action_emb)
332
  action_time_emb = torch.cat([action_emb, time_emb], dim=2)
333
 
334
- # Apply MLP layers
335
  def mlp_func(action_time_emb):
336
  x = self.action_time_mlp_in(action_time_emb)
337
- x = F.silu(x) # swish == silu
338
  return self.action_time_mlp_out(x)
339
 
340
  action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
341
  adarms_cond = None
342
  else:
343
- # time MLP (for adaRMS)
344
  def time_mlp_func(time_emb):
345
  x = self.time_mlp_in(time_emb)
346
- x = F.silu(x) # swish == silu
347
  x = self.time_mlp_out(x)
348
  return F.silu(x)
349
 
350
  time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
351
  action_time_emb = action_emb
352
- adarms_cond = time_emb
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- # Add to input tokens
355
  embs.append(action_time_emb)
356
-
357
  bsize, action_time_dim = action_time_emb.shape[:2]
358
- action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
359
- pad_masks.append(action_time_mask)
360
-
361
- # Set attention masks so that image, language and state inputs do not attend to action tokens
362
  att_masks += [1] + ([0] * (self.config.action_horizon - 1))
363
 
364
  embs = torch.cat(embs, dim=1)
365
  pad_masks = torch.cat(pad_masks, dim=1)
366
  att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
367
  att_masks = att_masks[None, :].expand(bsize, len(att_masks))
368
-
369
  return embs, pad_masks, att_masks, adarms_cond
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  def forward(self, observation, actions, noise=None, time=None) -> Tensor:
372
- """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
373
  images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True)
374
 
375
  if noise is None:
376
  noise = self.sample_noise(actions.shape, actions.device)
377
-
378
  if time is None:
379
  time = self.sample_time(actions.shape[0], actions.device)
380
 
@@ -384,60 +472,55 @@ class PI0Pytorch(nn.Module):
384
 
385
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
386
  suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
387
- if (
388
- self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
389
- == torch.bfloat16
390
- ):
391
- suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
392
- prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
393
-
394
- pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
395
- att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
396
-
397
- att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
398
- position_ids = torch.cumsum(pad_masks, dim=1) - 1
399
-
400
- # Prepare attention masks
401
- att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
402
-
403
- # Apply gradient checkpointing if enabled
404
- def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
405
- (_, suffix_out), _ = self.paligemma_with_expert.forward(
406
- attention_mask=att_2d_masks_4d,
407
- position_ids=position_ids,
408
- past_key_values=None,
409
- inputs_embeds=[prefix_embs, suffix_embs],
410
- use_cache=False,
411
- adarms_cond=[None, adarms_cond],
412
- )
413
- return suffix_out
414
-
415
- suffix_out = self._apply_checkpoint(
416
- forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
417
  )
418
 
419
- suffix_out = suffix_out[:, -self.config.action_horizon :]
420
- suffix_out = suffix_out.to(dtype=torch.float32)
 
 
421
 
422
  v_t = self._project_action_outputs(suffix_out)
423
-
424
  return F.mse_loss(u_t, v_t, reduction="none")
425
 
426
  @torch.no_grad()
427
  def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
428
- """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
429
  bsize = observation.state.shape[0]
430
  if noise is None:
431
  actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
432
  noise = self.sample_noise(actions_shape, device)
433
 
434
  images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
435
-
436
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
438
  prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
439
-
440
- # Compute image and language key value cache
441
  prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
442
  self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
443
 
@@ -449,26 +532,35 @@ class PI0Pytorch(nn.Module):
449
  use_cache=True,
450
  )
451
 
452
- dt = -1.0 / num_steps
453
- dt = torch.tensor(dt, dtype=torch.float32, device=device)
454
-
455
- x_t = noise
456
- time = torch.tensor(1.0, dtype=torch.float32, device=device)
457
  while time >= -dt / 2:
458
  expanded_time = time.expand(bsize)
459
- v_t = self.denoise_step(
460
- state,
461
- prefix_pad_masks,
462
- past_key_values,
463
- x_t,
464
- expanded_time,
465
- )
466
-
467
- # Euler step - use new tensor assignment instead of in-place operation
468
  x_t = x_t + dt * v_t
469
  time += dt
470
  return x_t
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  def denoise_step(
473
  self,
474
  state,
@@ -477,7 +569,6 @@ class PI0Pytorch(nn.Module):
477
  x_t,
478
  timestep,
479
  ):
480
- """Apply one denoising step of the noise `x_t` at a given timestep."""
481
  suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
482
 
483
  suffix_len = suffix_pad_masks.shape[1]
@@ -485,15 +576,12 @@ class PI0Pytorch(nn.Module):
485
  prefix_len = prefix_pad_masks.shape[1]
486
 
487
  prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
488
-
489
  suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
490
-
491
  full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
492
 
493
  prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
494
  position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
495
 
496
- # Prepare attention masks
497
  full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
498
  self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
499
 
 
15
  def get_safe_dtype(target_dtype, device_type):
16
  """Get a safe dtype for the given device type."""
17
  if device_type == "cpu":
 
18
  if target_dtype == torch.bfloat16:
19
  return torch.float32
20
  if target_dtype == torch.float64:
 
28
  """Computes sine-cosine positional embedding vectors for scalar positions."""
29
  if dimension % 2 != 0:
30
  raise ValueError(f"dimension ({dimension}) must be divisible by 2")
 
31
  if time.ndim != 1:
32
  raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
33
 
34
  dtype = get_safe_dtype(torch.float64, device.type)
35
  fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
36
  period = min_period * (max_period / min_period) ** fraction
 
 
37
  scaling_factor = 1.0 / period * 2 * math.pi
38
  sin_input = scaling_factor[None, :] * time[:, None]
39
  return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
 
47
 
48
 
49
  def make_att_2d_masks(pad_masks, att_masks):
50
+ """Copied from big_vision."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if att_masks.ndim != 2:
52
  raise ValueError(att_masks.ndim)
53
  if pad_masks.ndim != 2:
 
66
  self.pi05 = config.pi05
67
  self.arm_action_dims = tuple(config.arm_action_dims)
68
  self.use_parallel_action_heads = config.use_parallel_action_heads
69
+ self.use_split_action_expert = config.use_split_action_expert
70
+ self.use_communicating_action_expert = config.use_communicating_action_expert
71
+ self.action_expert_mode = config.action_expert_mode
72
  self.action_split_dims = list(self.arm_action_dims)
73
 
74
+ if self.use_split_action_expert and not self.pi05:
75
+ raise NotImplementedError("Split action experts are currently implemented only for pi0.5 models.")
76
+
77
  paligemma_config = _gemma.get_config(config.paligemma_variant)
78
  action_expert_config = _gemma.get_config(config.action_expert_variant)
79
  self.action_expert_width = action_expert_config.width
80
 
81
+ num_action_experts = len(self.arm_action_dims) if self.use_split_action_expert else 1
82
  self.paligemma_with_expert = PaliGemmaWithExpertModel(
83
  paligemma_config,
84
  action_expert_config,
85
  use_adarms=[False, True] if self.pi05 else [False, False],
86
  precision=config.dtype,
87
+ num_action_experts=num_action_experts,
88
+ enable_cross_arm_communication=self.use_communicating_action_expert,
89
  )
90
 
91
  if self.use_parallel_action_heads:
92
  self.action_in_proj_arms = nn.ModuleList(
93
  [nn.Linear(arm_dim, action_expert_config.width) for arm_dim in self.arm_action_dims]
94
  )
95
+ if self.action_expert_mode == "head_only_parallel":
96
+ self.arm_token_fuse = nn.Linear(
97
+ len(self.arm_action_dims) * action_expert_config.width, action_expert_config.width
98
+ )
99
  self.action_out_proj_arms = nn.ModuleList(
100
  [nn.Linear(action_expert_config.width, arm_dim) for arm_dim in self.arm_action_dims]
101
  )
 
115
  if os.environ.get("OPENPI_TORCH_COMPILE_SAMPLE_ACTIONS", "0") == "1":
116
  self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
117
 
 
118
  self.gradient_checkpointing_enabled = False
119
 
120
  msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`."
 
126
  except ImportError:
127
  raise ValueError(msg) from None
128
 
129
+ def _expert_models(self) -> list[nn.Module]:
130
+ return self.paligemma_with_expert._get_action_expert_models()
131
+
132
  def gradient_checkpointing_enable(self):
 
133
  self.gradient_checkpointing_enabled = True
134
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
135
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
136
+ for expert_model in self._expert_models():
137
+ expert_model.gradient_checkpointing = True
138
  logging.info("Enabled gradient checkpointing for PI0Pytorch model")
139
 
140
  def gradient_checkpointing_disable(self):
 
141
  self.gradient_checkpointing_enabled = False
142
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
143
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
144
+ for expert_model in self._expert_models():
145
+ expert_model.gradient_checkpointing = False
146
  logging.info("Disabled gradient checkpointing for PI0Pytorch model")
147
 
148
  def is_gradient_checkpointing_enabled(self):
 
149
  return self.gradient_checkpointing_enabled
150
 
151
  def _apply_checkpoint(self, func, *args, **kwargs):
 
152
  if self.gradient_checkpointing_enabled and self.training:
153
  return torch.utils.checkpoint.checkpoint(
154
  func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
 
156
  return func(*args, **kwargs)
157
 
158
  def _prepare_attention_masks_4d(self, att_2d_masks):
 
159
  att_2d_masks_4d = att_2d_masks[:, None, :, :]
160
  return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
161
 
162
  def _preprocess_observation(self, observation, *, train=True):
 
163
  observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
164
  return (
165
  list(observation.images.values()),
 
170
  )
171
 
172
  def sample_noise(self, shape, device):
173
+ return torch.normal(mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device)
 
 
 
 
 
 
174
 
175
  def sample_time(self, bsize, device):
176
  time_beta = sample_beta(1.5, 1.0, bsize, device)
 
194
 
195
  per_arm_embeddings.append(self._apply_checkpoint(arm_proj_func, arm_actions))
196
 
197
+ if self.use_split_action_expert:
198
+ return per_arm_embeddings
199
+
200
  fused_inputs = torch.cat(per_arm_embeddings, dim=-1)
201
 
202
  def fuse_func(fused_inputs):
 
212
 
213
  return self._apply_checkpoint(action_out_proj_func, suffix_out)
214
 
215
+ if not self.use_split_action_expert:
216
+ per_arm_outputs = []
217
+ for arm_head in self.action_out_proj_arms:
218
+
219
+ def arm_out_func(suffix_out, arm_head=arm_head):
220
+ return arm_head(suffix_out)
221
+
222
+ per_arm_outputs.append(self._apply_checkpoint(arm_out_func, suffix_out))
223
+ return torch.cat(per_arm_outputs, dim=-1)
224
+
225
+ if len(suffix_out) != len(self.action_out_proj_arms):
226
+ raise ValueError(f"Expected {len(self.action_out_proj_arms)} arm outputs, got {len(suffix_out)}.")
227
+
228
  per_arm_outputs = []
229
+ for arm_head, arm_suffix_out in zip(self.action_out_proj_arms, suffix_out, strict=True):
230
 
231
+ def arm_out_func(arm_suffix_out, arm_head=arm_head):
232
+ return arm_head(arm_suffix_out)
233
 
234
+ per_arm_outputs.append(self._apply_checkpoint(arm_out_func, arm_suffix_out))
235
  return torch.cat(per_arm_outputs, dim=-1)
236
 
237
  def embed_prefix(
238
  self, images, img_masks, lang_tokens, lang_masks
239
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
240
  embs = []
241
  pad_masks = []
242
  att_masks = []
243
 
 
244
  for img, img_mask in zip(images, img_masks, strict=True):
245
 
246
  def image_embed_func(img):
247
  return self.paligemma_with_expert.embed_image(img)
248
 
249
  img_emb = self._apply_checkpoint(image_embed_func, img)
 
250
  bsize, num_img_embs = img_emb.shape[:2]
 
251
  embs.append(img_emb)
252
  pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
 
 
253
  att_masks += [0] * num_img_embs
254
 
 
255
  def lang_embed_func(lang_tokens):
256
  lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
257
+ return lang_emb * math.sqrt(lang_emb.shape[-1])
 
258
 
259
  lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
 
260
  embs.append(lang_emb)
261
  pad_masks.append(lang_masks)
262
+ att_masks += [0] * lang_emb.shape[1]
 
 
 
263
 
264
  embs = torch.cat(embs, dim=1)
265
  pad_masks = torch.cat(pad_masks, dim=1)
266
  att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
267
+ att_masks = att_masks[None, :].expand(pad_masks.shape[0], len(att_masks))
 
 
 
 
268
  return embs, pad_masks, att_masks
269
 
270
+ def _action_att_mask(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
271
+ att_mask = torch.tensor([1] + ([0] * (self.config.action_horizon - 1)), dtype=dtype, device=device)
272
+ return att_mask[None, :].expand(batch_size, -1)
273
+
274
  def embed_suffix(self, state, noisy_actions, timestep):
275
+ if self.use_split_action_expert and not self.pi05:
276
+ raise NotImplementedError("Split action experts are currently implemented only for pi0.5 models.")
277
+
278
  embs = []
279
  pad_masks = []
280
  att_masks = []
 
283
  if self.state_proj.weight.dtype == torch.float32:
284
  state = state.to(torch.float32)
285
 
 
286
  def state_proj_func(state):
287
  return self.state_proj(state)
288
 
289
  state_emb = self._apply_checkpoint(state_proj_func, state)
 
290
  embs.append(state_emb[:, None, :])
291
  bsize = state_emb.shape[0]
292
  device = state_emb.device
293
+ pad_masks.append(torch.ones(bsize, 1, dtype=torch.bool, device=device))
 
 
 
 
294
  att_masks += [1]
295
 
 
296
  time_emb = create_sinusoidal_pos_embedding(
297
  timestep, self.action_expert_width, min_period=4e-3, max_period=4.0, device=timestep.device
298
  )
299
  time_emb = time_emb.type(dtype=timestep.dtype)
 
 
300
  action_emb = self._project_action_inputs(noisy_actions)
301
 
302
  if not self.pi05:
303
  time_emb = time_emb[:, None, :].expand_as(action_emb)
304
  action_time_emb = torch.cat([action_emb, time_emb], dim=2)
305
 
 
306
  def mlp_func(action_time_emb):
307
  x = self.action_time_mlp_in(action_time_emb)
308
+ x = F.silu(x)
309
  return self.action_time_mlp_out(x)
310
 
311
  action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
312
  adarms_cond = None
313
  else:
314
+
315
  def time_mlp_func(time_emb):
316
  x = self.time_mlp_in(time_emb)
317
+ x = F.silu(x)
318
  x = self.time_mlp_out(x)
319
  return F.silu(x)
320
 
321
  time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
322
  action_time_emb = action_emb
323
+ adarms_cond = [time_emb] * len(action_time_emb) if self.use_split_action_expert else time_emb
324
+
325
+ if self.use_split_action_expert:
326
+ suffix_embs = []
327
+ suffix_pad_masks = []
328
+ suffix_att_masks = []
329
+ for arm_emb in action_time_emb:
330
+ bsize, action_time_dim = arm_emb.shape[:2]
331
+ suffix_embs.append(arm_emb)
332
+ suffix_pad_masks.append(torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device))
333
+ suffix_att_masks.append(self._action_att_mask(bsize, timestep.device, arm_emb.dtype))
334
+ return suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond
335
 
 
336
  embs.append(action_time_emb)
 
337
  bsize, action_time_dim = action_time_emb.shape[:2]
338
+ pad_masks.append(torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device))
 
 
 
339
  att_masks += [1] + ([0] * (self.config.action_horizon - 1))
340
 
341
  embs = torch.cat(embs, dim=1)
342
  pad_masks = torch.cat(pad_masks, dim=1)
343
  att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
344
  att_masks = att_masks[None, :].expand(bsize, len(att_masks))
 
345
  return embs, pad_masks, att_masks, adarms_cond
346
 
347
+ def _cast_joint_embs(self, prefix_embs, suffix_embs):
348
+ if (
349
+ self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
350
+ == torch.bfloat16
351
+ ):
352
+ prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
353
+ if isinstance(suffix_embs, (list, tuple)):
354
+ suffix_embs = [suffix_emb.to(dtype=torch.bfloat16) for suffix_emb in suffix_embs]
355
+ else:
356
+ suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
357
+ return prefix_embs, suffix_embs
358
+
359
+ def _build_split_joint_attention(self, prefix_pad_masks, prefix_att_masks, suffix_pad_masks):
360
+ batch_size = prefix_pad_masks.shape[0]
361
+ prefix_len = prefix_pad_masks.shape[1]
362
+ branch_lengths = [branch_pad.shape[1] for branch_pad in suffix_pad_masks]
363
+ total_len = prefix_len + sum(branch_lengths)
364
+ device = prefix_pad_masks.device
365
+
366
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
367
+ full_att_2d_masks = torch.zeros((batch_size, total_len, total_len), dtype=torch.bool, device=device)
368
+ full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_2d_masks
369
+
370
+ cross_attention_selector = None
371
+ if self.use_communicating_action_expert:
372
+ cross_attention_selector = torch.zeros((1, 1, total_len, total_len), dtype=torch.float32, device=device)
373
+
374
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
375
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1, keepdim=True)
376
+ branch_position_ids = []
377
+
378
+ branch_starts = []
379
+ start = prefix_len
380
+ for branch_len in branch_lengths:
381
+ branch_starts.append(start)
382
+ start += branch_len
383
+
384
+ for branch_idx, branch_pad_masks in enumerate(suffix_pad_masks):
385
+ branch_start = branch_starts[branch_idx]
386
+ branch_len = branch_pad_masks.shape[1]
387
+ branch_position_ids.append(prefix_offsets + torch.cumsum(branch_pad_masks, dim=1) - 1)
388
+ full_att_2d_masks[:, branch_start : branch_start + branch_len, :prefix_len] = (
389
+ branch_pad_masks[:, :, None] & prefix_pad_masks[:, None, :]
390
+ )
391
+
392
+ q_positions = torch.arange(branch_len, device=device)[:, None]
393
+ k_positions = torch.arange(branch_len, device=device)[None, :]
394
+ same_branch_causal = k_positions <= q_positions
395
+ full_att_2d_masks[:, branch_start : branch_start + branch_len, branch_start : branch_start + branch_len] = (
396
+ branch_pad_masks[:, :, None] & branch_pad_masks[:, None, :] & same_branch_causal[None, :, :]
397
+ )
398
+
399
+ if self.use_communicating_action_expert:
400
+ for query_idx, query_pad_masks in enumerate(suffix_pad_masks):
401
+ query_start = branch_starts[query_idx]
402
+ query_len = query_pad_masks.shape[1]
403
+ query_positions = torch.arange(query_len, device=device)[:, None]
404
+ for key_idx, key_pad_masks in enumerate(suffix_pad_masks):
405
+ if query_idx == key_idx:
406
+ continue
407
+ key_start = branch_starts[key_idx]
408
+ key_len = key_pad_masks.shape[1]
409
+ key_positions = torch.arange(key_len, device=device)[None, :]
410
+ cross_causal = key_positions <= query_positions
411
+ cross_block = query_pad_masks[:, :, None] & key_pad_masks[:, None, :] & cross_causal[None, :, :]
412
+ full_att_2d_masks[:, query_start : query_start + query_len, key_start : key_start + key_len] = (
413
+ cross_block
414
+ )
415
+ cross_attention_selector[
416
+ :, :, query_start : query_start + query_len, key_start : key_start + key_len
417
+ ] = cross_causal[None, None, :, :].to(dtype=torch.float32)
418
+
419
+ position_ids = torch.cat([prefix_position_ids, *branch_position_ids], dim=1)
420
+ return full_att_2d_masks, position_ids, cross_attention_selector
421
+
422
+ def _run_joint_action_expert(
423
+ self,
424
+ prefix_embs,
425
+ prefix_pad_masks,
426
+ prefix_att_masks,
427
+ suffix_embs,
428
+ suffix_pad_masks,
429
+ suffix_att_masks,
430
+ adarms_cond,
431
+ ):
432
+ prefix_embs, suffix_embs = self._cast_joint_embs(prefix_embs, suffix_embs)
433
+
434
+ if self.use_split_action_expert:
435
+ att_2d_masks, position_ids, cross_attention_selector = self._build_split_joint_attention(
436
+ prefix_pad_masks, prefix_att_masks, suffix_pad_masks
437
+ )
438
+ inputs_embeds = [prefix_embs, *suffix_embs]
439
+ adarms_cond = [None, *adarms_cond]
440
+ else:
441
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
442
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
443
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
444
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
445
+ cross_attention_selector = None
446
+ inputs_embeds = [prefix_embs, suffix_embs]
447
+ adarms_cond = [None, adarms_cond]
448
+
449
+ att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
450
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
451
+ attention_mask=att_2d_masks_4d,
452
+ position_ids=position_ids,
453
+ past_key_values=None,
454
+ inputs_embeds=inputs_embeds,
455
+ use_cache=False,
456
+ adarms_cond=adarms_cond,
457
+ cross_attention_selector=cross_attention_selector,
458
+ )
459
+ return outputs_embeds[1:]
460
+
461
  def forward(self, observation, actions, noise=None, time=None) -> Tensor:
 
462
  images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True)
463
 
464
  if noise is None:
465
  noise = self.sample_noise(actions.shape, actions.device)
 
466
  if time is None:
467
  time = self.sample_time(actions.shape[0], actions.device)
468
 
 
472
 
473
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
474
  suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
475
+ suffix_out = self._run_joint_action_expert(
476
+ prefix_embs,
477
+ prefix_pad_masks,
478
+ prefix_att_masks,
479
+ suffix_embs,
480
+ suffix_pad_masks,
481
+ suffix_att_masks,
482
+ adarms_cond,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  )
484
 
485
+ if self.use_split_action_expert:
486
+ suffix_out = [output[:, -self.config.action_horizon :].to(dtype=torch.float32) for output in suffix_out]
487
+ else:
488
+ suffix_out = suffix_out[0][:, -self.config.action_horizon :].to(dtype=torch.float32)
489
 
490
  v_t = self._project_action_outputs(suffix_out)
 
491
  return F.mse_loss(u_t, v_t, reduction="none")
492
 
493
  @torch.no_grad()
494
  def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
 
495
  bsize = observation.state.shape[0]
496
  if noise is None:
497
  actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
498
  noise = self.sample_noise(actions_shape, device)
499
 
500
  images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
 
501
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
502
+
503
+ dt = torch.tensor(-1.0 / num_steps, dtype=torch.float32, device=device)
504
+ x_t = noise
505
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
506
+
507
+ if self.use_split_action_expert:
508
+ while time >= -dt / 2:
509
+ expanded_time = time.expand(bsize)
510
+ v_t = self._denoise_step_split(
511
+ prefix_embs,
512
+ prefix_pad_masks,
513
+ prefix_att_masks,
514
+ state,
515
+ x_t,
516
+ expanded_time,
517
+ )
518
+ x_t = x_t + dt * v_t
519
+ time += dt
520
+ return x_t
521
+
522
  prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
523
  prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
 
 
524
  prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
525
  self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
526
 
 
532
  use_cache=True,
533
  )
534
 
 
 
 
 
 
535
  while time >= -dt / 2:
536
  expanded_time = time.expand(bsize)
537
+ v_t = self.denoise_step(state, prefix_pad_masks, past_key_values, x_t, expanded_time)
 
 
 
 
 
 
 
 
538
  x_t = x_t + dt * v_t
539
  time += dt
540
  return x_t
541
 
542
+ def _denoise_step_split(
543
+ self,
544
+ prefix_embs,
545
+ prefix_pad_masks,
546
+ prefix_att_masks,
547
+ state,
548
+ x_t,
549
+ timestep,
550
+ ):
551
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
552
+ outputs = self._run_joint_action_expert(
553
+ prefix_embs,
554
+ prefix_pad_masks,
555
+ prefix_att_masks,
556
+ suffix_embs,
557
+ suffix_pad_masks,
558
+ suffix_att_masks,
559
+ adarms_cond,
560
+ )
561
+ outputs = [output[:, -self.config.action_horizon :].to(dtype=torch.float32) for output in outputs]
562
+ return self._project_action_outputs(outputs)
563
+
564
  def denoise_step(
565
  self,
566
  state,
 
569
  x_t,
570
  timestep,
571
  ):
 
572
  suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
573
 
574
  suffix_len = suffix_pad_masks.shape[1]
 
576
  prefix_len = prefix_pad_masks.shape[1]
577
 
578
  prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
 
579
  suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
 
580
  full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
581
 
582
  prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
583
  position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
584
 
 
585
  full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
586
  self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
587
 
openpi/src/openpi/training/config.py CHANGED
@@ -1050,6 +1050,7 @@ _CONFIGS = [
1050
  action_dim=32,
1051
  action_horizon=16,
1052
  arm_action_dims=(16, 16),
 
1053
  ),
1054
  data=LeRobotDROIDDataConfig(
1055
  repo_id="your_hf_username/my_multiarm_droid_dataset",
@@ -1087,6 +1088,7 @@ _CONFIGS = [
1087
  action_dim=32,
1088
  action_horizon=16,
1089
  arm_action_dims=(16, 16),
 
1090
  ),
1091
  data=LeRobotTWINBimanualDataConfig(
1092
  repo_id="your_hf_username/twin_bimanual_lerobot_train",
@@ -1131,6 +1133,7 @@ _CONFIGS = [
1131
  action_dim=32,
1132
  action_horizon=16,
1133
  arm_action_dims=(16, 16),
 
1134
  ),
1135
  data=LeRobotTWINBimanualPackedDataConfig(
1136
  repo_id="lsnu/twin_handover_256_train",
@@ -1153,6 +1156,66 @@ _CONFIGS = [
1153
  overwrite=True,
1154
  wandb_enabled=False,
1155
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1156
  TrainConfig(
1157
  name="pi05_twin_handover_256_packed_baseline_pytorch_10k",
1158
  model=pi0_config.Pi0Config(
@@ -1188,6 +1251,7 @@ _CONFIGS = [
1188
  action_dim=32,
1189
  action_horizon=16,
1190
  arm_action_dims=(16, 16),
 
1191
  ),
1192
  data=LeRobotTWINBimanualPackedDataConfig(
1193
  repo_id="lsnu/twin_handover_256_train",
@@ -1210,6 +1274,66 @@ _CONFIGS = [
1210
  overwrite=True,
1211
  wandb_enabled=False,
1212
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1213
  TrainConfig(
1214
  name="pi05_twin_dual_push_128_packed_baseline_pytorch_5k",
1215
  model=pi0_config.Pi0Config(
@@ -1245,6 +1369,7 @@ _CONFIGS = [
1245
  action_dim=32,
1246
  action_horizon=16,
1247
  arm_action_dims=(16, 16),
 
1248
  ),
1249
  data=LeRobotTWINBimanualPackedDataConfig(
1250
  repo_id="lsnu/twin_dual_push_128_train",
@@ -1267,6 +1392,66 @@ _CONFIGS = [
1267
  overwrite=True,
1268
  wandb_enabled=False,
1269
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1270
  #
1271
  # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
1272
  #
@@ -1327,6 +1512,7 @@ _CONFIGS = [
1327
  action_horizon=8,
1328
  max_token_len=32,
1329
  arm_action_dims=(16, 16),
 
1330
  ),
1331
  data=FakeDataConfig(),
1332
  batch_size=1,
@@ -1339,6 +1525,52 @@ _CONFIGS = [
1339
  wandb_enabled=False,
1340
  pytorch_training_precision="float32",
1341
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1342
  TrainConfig(
1343
  # Local smoke-test for converted TWIN LeRobot data.
1344
  name="debug_pi05_twin_bimanual_parallel_local_smoke",
@@ -1350,6 +1582,7 @@ _CONFIGS = [
1350
  action_horizon=8,
1351
  max_token_len=64,
1352
  arm_action_dims=(16, 16),
 
1353
  ),
1354
  data=LeRobotTWINBimanualDataConfig(
1355
  # This repo id is produced by scripts/convert_twin_squashfs_to_lerobot.py in local smoke mode.
 
1050
  action_dim=32,
1051
  action_horizon=16,
1052
  arm_action_dims=(16, 16),
1053
+ action_expert_mode="head_only_parallel",
1054
  ),
1055
  data=LeRobotDROIDDataConfig(
1056
  repo_id="your_hf_username/my_multiarm_droid_dataset",
 
1088
  action_dim=32,
1089
  action_horizon=16,
1090
  arm_action_dims=(16, 16),
1091
+ action_expert_mode="head_only_parallel",
1092
  ),
1093
  data=LeRobotTWINBimanualDataConfig(
1094
  repo_id="your_hf_username/twin_bimanual_lerobot_train",
 
1133
  action_dim=32,
1134
  action_horizon=16,
1135
  arm_action_dims=(16, 16),
1136
+ action_expert_mode="head_only_parallel",
1137
  ),
1138
  data=LeRobotTWINBimanualPackedDataConfig(
1139
  repo_id="lsnu/twin_handover_256_train",
 
1156
  overwrite=True,
1157
  wandb_enabled=False,
1158
  ),
1159
+ TrainConfig(
1160
+ name="pi05_twin_handover_256_packed_split_expert_independent_pytorch_2k",
1161
+ model=pi0_config.Pi0Config(
1162
+ pi05=True,
1163
+ action_dim=32,
1164
+ action_horizon=16,
1165
+ arm_action_dims=(16, 16),
1166
+ action_expert_mode="split_independent",
1167
+ ),
1168
+ data=LeRobotTWINBimanualPackedDataConfig(
1169
+ repo_id="lsnu/twin_handover_256_train",
1170
+ base_config=DataConfig(prompt_from_task=False),
1171
+ ),
1172
+ pytorch_weight_path="/workspace/checkpoints/pi05_base_split_independent_packed_from_single",
1173
+ pytorch_training_precision="bfloat16",
1174
+ action_loss_mask=(1.0,) * 8 + (0.0,) * 8 + (1.0,) * 8 + (0.0,) * 8,
1175
+ lr_schedule=_optimizer.CosineDecaySchedule(
1176
+ warmup_steps=200,
1177
+ peak_lr=2.5e-5,
1178
+ decay_steps=2_000,
1179
+ decay_lr=2.5e-6,
1180
+ ),
1181
+ batch_size=16,
1182
+ num_workers=8,
1183
+ num_train_steps=2_000,
1184
+ log_interval=10,
1185
+ save_interval=250,
1186
+ overwrite=True,
1187
+ wandb_enabled=False,
1188
+ ),
1189
+ TrainConfig(
1190
+ name="pi05_twin_handover_256_packed_split_expert_communicating_pytorch_2k",
1191
+ model=pi0_config.Pi0Config(
1192
+ pi05=True,
1193
+ action_dim=32,
1194
+ action_horizon=16,
1195
+ arm_action_dims=(16, 16),
1196
+ action_expert_mode="split_communicating",
1197
+ ),
1198
+ data=LeRobotTWINBimanualPackedDataConfig(
1199
+ repo_id="lsnu/twin_handover_256_train",
1200
+ base_config=DataConfig(prompt_from_task=False),
1201
+ ),
1202
+ pytorch_weight_path="/workspace/checkpoints/pi05_base_split_communicating_packed_from_single",
1203
+ pytorch_training_precision="bfloat16",
1204
+ action_loss_mask=(1.0,) * 8 + (0.0,) * 8 + (1.0,) * 8 + (0.0,) * 8,
1205
+ lr_schedule=_optimizer.CosineDecaySchedule(
1206
+ warmup_steps=200,
1207
+ peak_lr=2.5e-5,
1208
+ decay_steps=2_000,
1209
+ decay_lr=2.5e-6,
1210
+ ),
1211
+ batch_size=16,
1212
+ num_workers=8,
1213
+ num_train_steps=2_000,
1214
+ log_interval=10,
1215
+ save_interval=250,
1216
+ overwrite=True,
1217
+ wandb_enabled=False,
1218
+ ),
1219
  TrainConfig(
1220
  name="pi05_twin_handover_256_packed_baseline_pytorch_10k",
1221
  model=pi0_config.Pi0Config(
 
1251
  action_dim=32,
1252
  action_horizon=16,
1253
  arm_action_dims=(16, 16),
1254
+ action_expert_mode="head_only_parallel",
1255
  ),
1256
  data=LeRobotTWINBimanualPackedDataConfig(
1257
  repo_id="lsnu/twin_handover_256_train",
 
1274
  overwrite=True,
1275
  wandb_enabled=False,
1276
  ),
1277
+ TrainConfig(
1278
+ name="pi05_twin_handover_256_packed_split_expert_independent_pytorch_10k",
1279
+ model=pi0_config.Pi0Config(
1280
+ pi05=True,
1281
+ action_dim=32,
1282
+ action_horizon=16,
1283
+ arm_action_dims=(16, 16),
1284
+ action_expert_mode="split_independent",
1285
+ ),
1286
+ data=LeRobotTWINBimanualPackedDataConfig(
1287
+ repo_id="lsnu/twin_handover_256_train",
1288
+ base_config=DataConfig(prompt_from_task=False),
1289
+ ),
1290
+ pytorch_weight_path="/workspace/checkpoints/pi05_base_split_independent_packed_from_single",
1291
+ pytorch_training_precision="bfloat16",
1292
+ action_loss_mask=(1.0,) * 8 + (0.0,) * 8 + (1.0,) * 8 + (0.0,) * 8,
1293
+ lr_schedule=_optimizer.CosineDecaySchedule(
1294
+ warmup_steps=500,
1295
+ peak_lr=2.5e-5,
1296
+ decay_steps=10_000,
1297
+ decay_lr=2.5e-6,
1298
+ ),
1299
+ batch_size=16,
1300
+ num_workers=8,
1301
+ num_train_steps=10_000,
1302
+ log_interval=10,
1303
+ save_interval=1_000,
1304
+ overwrite=True,
1305
+ wandb_enabled=False,
1306
+ ),
1307
+ TrainConfig(
1308
+ name="pi05_twin_handover_256_packed_split_expert_communicating_pytorch_10k",
1309
+ model=pi0_config.Pi0Config(
1310
+ pi05=True,
1311
+ action_dim=32,
1312
+ action_horizon=16,
1313
+ arm_action_dims=(16, 16),
1314
+ action_expert_mode="split_communicating",
1315
+ ),
1316
+ data=LeRobotTWINBimanualPackedDataConfig(
1317
+ repo_id="lsnu/twin_handover_256_train",
1318
+ base_config=DataConfig(prompt_from_task=False),
1319
+ ),
1320
+ pytorch_weight_path="/workspace/checkpoints/pi05_base_split_communicating_packed_from_single",
1321
+ pytorch_training_precision="bfloat16",
1322
+ action_loss_mask=(1.0,) * 8 + (0.0,) * 8 + (1.0,) * 8 + (0.0,) * 8,
1323
+ lr_schedule=_optimizer.CosineDecaySchedule(
1324
+ warmup_steps=500,
1325
+ peak_lr=2.5e-5,
1326
+ decay_steps=10_000,
1327
+ decay_lr=2.5e-6,
1328
+ ),
1329
+ batch_size=16,
1330
+ num_workers=8,
1331
+ num_train_steps=10_000,
1332
+ log_interval=10,
1333
+ save_interval=1_000,
1334
+ overwrite=True,
1335
+ wandb_enabled=False,
1336
+ ),
1337
  TrainConfig(
1338
  name="pi05_twin_dual_push_128_packed_baseline_pytorch_5k",
1339
  model=pi0_config.Pi0Config(
 
1369
  action_dim=32,
1370
  action_horizon=16,
1371
  arm_action_dims=(16, 16),
1372
+ action_expert_mode="head_only_parallel",
1373
  ),
1374
  data=LeRobotTWINBimanualPackedDataConfig(
1375
  repo_id="lsnu/twin_dual_push_128_train",
 
1392
  overwrite=True,
1393
  wandb_enabled=False,
1394
  ),
1395
+ TrainConfig(
1396
+ name="pi05_twin_dual_push_128_packed_split_expert_independent_pytorch_5k",
1397
+ model=pi0_config.Pi0Config(
1398
+ pi05=True,
1399
+ action_dim=32,
1400
+ action_horizon=16,
1401
+ arm_action_dims=(16, 16),
1402
+ action_expert_mode="split_independent",
1403
+ ),
1404
+ data=LeRobotTWINBimanualPackedDataConfig(
1405
+ repo_id="lsnu/twin_dual_push_128_train",
1406
+ base_config=DataConfig(prompt_from_task=False),
1407
+ ),
1408
+ pytorch_weight_path="/workspace/checkpoints/pi05_base_split_independent_packed_from_single",
1409
+ pytorch_training_precision="bfloat16",
1410
+ action_loss_mask=(1.0,) * 8 + (0.0,) * 8 + (1.0,) * 8 + (0.0,) * 8,
1411
+ lr_schedule=_optimizer.CosineDecaySchedule(
1412
+ warmup_steps=250,
1413
+ peak_lr=2.5e-5,
1414
+ decay_steps=5_000,
1415
+ decay_lr=2.5e-6,
1416
+ ),
1417
+ batch_size=16,
1418
+ num_workers=8,
1419
+ num_train_steps=5_000,
1420
+ log_interval=10,
1421
+ save_interval=1_000,
1422
+ overwrite=True,
1423
+ wandb_enabled=False,
1424
+ ),
1425
+ TrainConfig(
1426
+ name="pi05_twin_dual_push_128_packed_split_expert_communicating_pytorch_5k",
1427
+ model=pi0_config.Pi0Config(
1428
+ pi05=True,
1429
+ action_dim=32,
1430
+ action_horizon=16,
1431
+ arm_action_dims=(16, 16),
1432
+ action_expert_mode="split_communicating",
1433
+ ),
1434
+ data=LeRobotTWINBimanualPackedDataConfig(
1435
+ repo_id="lsnu/twin_dual_push_128_train",
1436
+ base_config=DataConfig(prompt_from_task=False),
1437
+ ),
1438
+ pytorch_weight_path="/workspace/checkpoints/pi05_base_split_communicating_packed_from_single",
1439
+ pytorch_training_precision="bfloat16",
1440
+ action_loss_mask=(1.0,) * 8 + (0.0,) * 8 + (1.0,) * 8 + (0.0,) * 8,
1441
+ lr_schedule=_optimizer.CosineDecaySchedule(
1442
+ warmup_steps=250,
1443
+ peak_lr=2.5e-5,
1444
+ decay_steps=5_000,
1445
+ decay_lr=2.5e-6,
1446
+ ),
1447
+ batch_size=16,
1448
+ num_workers=8,
1449
+ num_train_steps=5_000,
1450
+ log_interval=10,
1451
+ save_interval=1_000,
1452
+ overwrite=True,
1453
+ wandb_enabled=False,
1454
+ ),
1455
  #
1456
  # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
1457
  #
 
1512
  action_horizon=8,
1513
  max_token_len=32,
1514
  arm_action_dims=(16, 16),
1515
+ action_expert_mode="head_only_parallel",
1516
  ),
1517
  data=FakeDataConfig(),
1518
  batch_size=1,
 
1525
  wandb_enabled=False,
1526
  pytorch_training_precision="float32",
1527
  ),
1528
+ TrainConfig(
1529
+ name="debug_pi05_split_independent_pytorch_smoke",
1530
+ model=pi0_config.Pi0Config(
1531
+ pi05=True,
1532
+ paligemma_variant="dummy",
1533
+ action_expert_variant="dummy",
1534
+ action_dim=32,
1535
+ action_horizon=8,
1536
+ max_token_len=32,
1537
+ arm_action_dims=(16, 16),
1538
+ action_expert_mode="split_independent",
1539
+ ),
1540
+ data=FakeDataConfig(),
1541
+ batch_size=1,
1542
+ num_workers=0,
1543
+ num_train_steps=2,
1544
+ log_interval=1,
1545
+ save_interval=1,
1546
+ overwrite=True,
1547
+ exp_name="debug_pi05_split_independent_pytorch_smoke",
1548
+ wandb_enabled=False,
1549
+ pytorch_training_precision="float32",
1550
+ ),
1551
+ TrainConfig(
1552
+ name="debug_pi05_split_communicating_pytorch_smoke",
1553
+ model=pi0_config.Pi0Config(
1554
+ pi05=True,
1555
+ paligemma_variant="dummy",
1556
+ action_expert_variant="dummy",
1557
+ action_dim=32,
1558
+ action_horizon=8,
1559
+ max_token_len=32,
1560
+ arm_action_dims=(16, 16),
1561
+ action_expert_mode="split_communicating",
1562
+ ),
1563
+ data=FakeDataConfig(),
1564
+ batch_size=1,
1565
+ num_workers=0,
1566
+ num_train_steps=2,
1567
+ log_interval=1,
1568
+ save_interval=1,
1569
+ overwrite=True,
1570
+ exp_name="debug_pi05_split_communicating_pytorch_smoke",
1571
+ wandb_enabled=False,
1572
+ pytorch_training_precision="float32",
1573
+ ),
1574
  TrainConfig(
1575
  # Local smoke-test for converted TWIN LeRobot data.
1576
  name="debug_pi05_twin_bimanual_parallel_local_smoke",
 
1582
  action_horizon=8,
1583
  max_token_len=64,
1584
  arm_action_dims=(16, 16),
1585
+ action_expert_mode="head_only_parallel",
1586
  ),
1587
  data=LeRobotTWINBimanualDataConfig(
1588
  # This repo id is produced by scripts/convert_twin_squashfs_to_lerobot.py in local smoke mode.
openpi/src/openpi/training/data_loader.py CHANGED
@@ -10,7 +10,6 @@ from typing import Literal, Protocol, SupportsIndex, TypeVar
10
  from huggingface_hub import snapshot_download
11
  import jax
12
  import jax.numpy as jnp
13
- import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
14
  import numpy as np
15
  import torch
16
 
@@ -164,6 +163,114 @@ def _ensure_local_lerobot_dataset(repo_id: str) -> Path:
164
  return root
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def create_torch_dataset(
168
  data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
169
  ) -> Dataset:
@@ -174,6 +281,9 @@ def create_torch_dataset(
174
  if repo_id == "fake":
175
  return FakeDataset(model_config, num_samples=1024)
176
 
 
 
 
177
  dataset_root = _ensure_local_lerobot_dataset(repo_id)
178
  dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id, root=dataset_root, revision="main")
179
  dataset = lerobot_dataset.LeRobotDataset(
 
10
  from huggingface_hub import snapshot_download
11
  import jax
12
  import jax.numpy as jnp
 
13
  import numpy as np
14
  import torch
15
 
 
163
  return root
164
 
165
 
166
+ def _patch_lerobot_column_compat(lerobot_dataset) -> None:
167
+ if getattr(lerobot_dataset, "_openpi_column_compat_patched", False):
168
+ return
169
+
170
+ def _hf_column_to_numpy(column) -> np.ndarray:
171
+ if isinstance(column, torch.Tensor):
172
+ return column.detach().cpu().numpy()
173
+ if hasattr(column, "to_pylist"):
174
+ return np.asarray(column.to_pylist())
175
+ values = list(column)
176
+ if values and isinstance(values[0], torch.Tensor):
177
+ return torch.stack(values).detach().cpu().numpy()
178
+ return np.asarray(values)
179
+
180
+ def _hf_column_to_tensor(column) -> torch.Tensor:
181
+ if isinstance(column, torch.Tensor):
182
+ return column
183
+ values = column.to_pylist() if hasattr(column, "to_pylist") else list(column)
184
+ if values and isinstance(values[0], torch.Tensor):
185
+ return torch.stack(values)
186
+ return torch.as_tensor(values)
187
+
188
+ def _patched_init(
189
+ self,
190
+ repo_id: str,
191
+ root: str | Path | None = None,
192
+ episodes: list[int] | None = None,
193
+ image_transforms=None,
194
+ delta_timestamps: dict[list[float]] | None = None,
195
+ tolerance_s: float = 1e-4,
196
+ revision: str | None = None,
197
+ force_cache_sync: bool = False,
198
+ download_videos: bool = True,
199
+ video_backend: str | None = None,
200
+ ):
201
+ self.repo_id = repo_id
202
+ self.root = Path(root) if root else lerobot_dataset.HF_LEROBOT_HOME / repo_id
203
+ self.image_transforms = image_transforms
204
+ self.delta_timestamps = delta_timestamps
205
+ self.episodes = episodes
206
+ self.tolerance_s = tolerance_s
207
+ self.revision = revision if revision else lerobot_dataset.CODEBASE_VERSION
208
+ self.video_backend = video_backend if video_backend else lerobot_dataset.get_safe_default_codec()
209
+ self.delta_indices = None
210
+ self.image_writer = None
211
+ self.episode_buffer = None
212
+
213
+ self.root.mkdir(exist_ok=True, parents=True)
214
+
215
+ self.meta = lerobot_dataset.LeRobotDatasetMetadata(
216
+ self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
217
+ )
218
+ if self.episodes is not None and self.meta._version >= lerobot_dataset.packaging.version.parse("v2.1"):
219
+ episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
220
+ self.stats = lerobot_dataset.aggregate_stats(episodes_stats)
221
+
222
+ try:
223
+ if force_cache_sync:
224
+ raise FileNotFoundError
225
+ assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
226
+ self.hf_dataset = self.load_hf_dataset()
227
+ except (AssertionError, FileNotFoundError, NotADirectoryError):
228
+ self.revision = lerobot_dataset.get_safe_version(self.repo_id, self.revision)
229
+ self.download_episodes(download_videos)
230
+ self.hf_dataset = self.load_hf_dataset()
231
+
232
+ self.episode_data_index = lerobot_dataset.get_episode_data_index(self.meta.episodes, self.episodes)
233
+
234
+ timestamps = _hf_column_to_numpy(self.hf_dataset["timestamp"])
235
+ episode_indices = _hf_column_to_numpy(self.hf_dataset["episode_index"])
236
+ ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
237
+ lerobot_dataset.check_timestamps_sync(
238
+ timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
239
+ )
240
+
241
+ if self.delta_timestamps is not None:
242
+ lerobot_dataset.check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
243
+ self.delta_indices = lerobot_dataset.get_delta_indices(self.delta_timestamps, self.fps)
244
+
245
+ def _patched_get_query_timestamps(
246
+ self,
247
+ current_ts: float,
248
+ query_indices: dict[str, list[int]] | None = None,
249
+ ) -> dict[str, list[float]]:
250
+ query_timestamps = {}
251
+ for key in self.meta.video_keys:
252
+ if query_indices is not None and key in query_indices:
253
+ timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
254
+ query_timestamps[key] = _hf_column_to_tensor(timestamps).tolist()
255
+ else:
256
+ query_timestamps[key] = [current_ts]
257
+ return query_timestamps
258
+
259
+ def _patched_query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
260
+ return {
261
+ key: _hf_column_to_tensor(self.hf_dataset.select(q_idx)[key])
262
+ for key, q_idx in query_indices.items()
263
+ if key not in self.meta.video_keys
264
+ }
265
+
266
+ lerobot_dataset._hf_column_to_numpy = _hf_column_to_numpy
267
+ lerobot_dataset._hf_column_to_tensor = _hf_column_to_tensor
268
+ lerobot_dataset.LeRobotDataset.__init__ = _patched_init
269
+ lerobot_dataset.LeRobotDataset._get_query_timestamps = _patched_get_query_timestamps
270
+ lerobot_dataset.LeRobotDataset._query_hf_dataset = _patched_query_hf_dataset
271
+ lerobot_dataset._openpi_column_compat_patched = True
272
+
273
+
274
  def create_torch_dataset(
275
  data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
276
  ) -> Dataset:
 
281
  if repo_id == "fake":
282
  return FakeDataset(model_config, num_samples=1024)
283
 
284
+ import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
285
+
286
+ _patch_lerobot_column_compat(lerobot_dataset)
287
  dataset_root = _ensure_local_lerobot_dataset(repo_id)
288
  dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id, root=dataset_root, revision="main")
289
  dataset = lerobot_dataset.LeRobotDataset(