Kevin Black commited on
Commit
8999e55
·
1 Parent(s): aa7853c

Fix lint errors

Browse files
examples/convert_jax_model_to_pytorch.py CHANGED
@@ -10,13 +10,13 @@ Usage:
10
  # Just inspect keys:
11
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
12
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
13
-
14
  # Convert to PyTorch:
15
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
16
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
17
 
18
- Example:
19
- # pi0_droid
20
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
21
 
22
  # pi0_aloha_sim
@@ -33,44 +33,45 @@ import pathlib
33
  import shutil
34
  import traceback
35
 
 
36
  import jax
37
  import jax.numpy as jnp
38
  import jax.sharding
39
  import numpy as np
40
  import orbax.checkpoint as ocp
41
- import torch
42
  import safetensors
43
- from flax.nnx.traversals import flatten_mapping
 
 
 
 
44
 
45
  # Import our modules
46
  import openpi.models_pytorch.pi0_pytorch
47
- import openpi.models.pi0_config
48
- import openpi.models.gemma
49
  import openpi.shared.download
50
- import openpi.models.model
51
 
52
 
53
  def flatten_for_inspection(tree, separator="/"):
54
  """
55
  Flatten a nested dictionary for easy inspection of keys using flax.nnx.traversals.flatten_mapping.
56
-
57
  Args:
58
  tree: The nested dictionary (JAX pytree)
59
  separator: Separator to use between key levels
60
-
61
  Returns:
62
  Dictionary with flattened keys and array shapes as values
63
  """
64
  flattened = flatten_mapping(tree, separator=separator)
65
-
66
  # Convert values to shape/dtype information for inspection
67
  result = {}
68
  for key, value in flattened.items():
69
- if hasattr(value, 'shape') and hasattr(value, 'dtype'):
70
  result[key] = f"shape: {value.shape}, dtype: {value.dtype}"
71
  else:
72
  result[key] = f"type: {type(value)}"
73
-
74
  return result
75
 
76
 
@@ -90,19 +91,15 @@ def slice_paligemma_state_dict(state_dict, config):
90
  """Convert PaliGemma JAX parameters to PyTorch format."""
91
  suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
92
 
93
-
94
  # patch embeddings
95
  jax_key = f"img/embedding/kernel{suffix}"
96
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
97
  state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
98
-
99
-
100
  jax_key = f"img/embedding/bias{suffix}"
101
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
102
  state_dict[pytorch_key] = state_dict.pop(jax_key)
103
 
104
-
105
-
106
  # positional embeddings
107
  jax_key = f"img/pos_embedding{suffix}"
108
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
@@ -114,54 +111,101 @@ def slice_paligemma_state_dict(state_dict, config):
114
  encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
115
  encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
116
 
117
- encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
118
- encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
119
- encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
120
- encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
121
 
122
- encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
123
- encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
124
- encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
125
- encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
126
- encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
127
- encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
128
- encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
129
- encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  for i in range(config.vision_config.num_hidden_layers):
132
-
133
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
134
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
135
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
136
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
137
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
138
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
139
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
140
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
141
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
142
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
143
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
144
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
145
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
146
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
147
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
148
- state_dict[f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
151
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
152
  state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
153
-
154
  jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
155
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
156
  state_dict[pytorch_key] = state_dict.pop(jax_key)
157
 
158
  # multimodal projector
159
  jax_key = f"img/head/kernel{suffix}"
160
- pytorch_key = 'paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight'
161
  state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
162
-
163
  jax_key = f"img/head/bias{suffix}"
164
- pytorch_key = 'paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias'
165
  state_dict[pytorch_key] = state_dict.pop(jax_key)
166
 
167
  # text decoder (gemma)
@@ -181,24 +225,54 @@ def slice_paligemma_state_dict(state_dict, config):
181
  llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
182
 
183
  for i in range(config.text_config.num_hidden_layers):
184
- q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
185
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
 
 
 
 
 
 
 
 
186
 
187
  k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
188
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
 
 
189
  v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
190
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
193
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
194
-
195
  gate_proj_weight = llm_mlp_gating_einsum[i, 0]
196
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
 
 
197
  up_proj_weight = llm_mlp_gating_einsum[i, 1]
198
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
199
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
200
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
201
- state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
 
 
 
 
 
 
 
 
202
 
203
  jax_key = f"llm/final_norm/scale{suffix}"
204
  pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
@@ -206,7 +280,7 @@ def slice_paligemma_state_dict(state_dict, config):
206
 
207
  expert_dict = {}
208
  final_state_dict = {}
209
-
210
  # Expert-related keys to extract (including pi05 Dense layer parameters)
211
  expert_keys = [
212
  f"llm/final_norm_1/scale{suffix}",
@@ -224,7 +298,7 @@ def slice_paligemma_state_dict(state_dict, config):
224
  f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
225
  f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
226
  ]
227
-
228
  for key, value in state_dict.items():
229
  if key not in expert_keys:
230
  final_state_dict[key] = torch.from_numpy(value)
@@ -237,13 +311,13 @@ def slice_paligemma_state_dict(state_dict, config):
237
  def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None):
238
  """Convert Gemma JAX parameters to PyTorch format."""
239
  # Add missing attributes to config if they don't exist
240
- if not hasattr(config, 'vocab_size'):
241
  config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
242
- if not hasattr(config, 'hidden_size'):
243
  config.hidden_size = config.width
244
- if not hasattr(config, 'num_hidden_layers'):
245
  config.num_hidden_layers = config.depth
246
- if not hasattr(config, 'num_attention_heads'):
247
  config.num_attention_heads = config.num_heads
248
 
249
  suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
@@ -260,42 +334,79 @@ def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None
260
  # Pi05 with adaptive normalization
261
  llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
262
  llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
263
- llm_input_layernorm_kernel = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}")
264
- llm_post_attention_layernorm_kernel = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}")
 
 
 
 
265
  else:
266
  # Regular pi0 with standard RMSNorm
267
  llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
268
  llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
269
 
270
-
271
  for i in range(config.num_hidden_layers):
272
- q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
273
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
 
 
 
 
 
 
274
 
275
  k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
276
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
 
 
277
  v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
278
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
 
 
 
 
 
 
 
 
 
 
 
279
 
280
- o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)
281
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
282
-
283
  gate_proj_weight = llm_mlp_gating_einsum[i, 0]
284
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
 
 
285
  up_proj_weight = llm_mlp_gating_einsum[i, 1]
286
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
287
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
 
 
 
 
288
 
289
  if "pi05" in checkpoint_dir:
290
  # Pi05 with adaptive normalization - use Dense layer parameters directly
291
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = llm_input_layernorm_bias[i]
292
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = llm_post_attention_layernorm_bias[i]
293
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = llm_input_layernorm_kernel[i].transpose()
294
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = llm_post_attention_layernorm_kernel[i].transpose()
 
 
 
 
 
 
 
 
295
  else:
296
  # Regular pi0 with standard RMSNorm
297
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
298
- state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
 
 
 
 
299
 
300
  # Handle final norm layer
301
  if "pi05" in checkpoint_dir:
@@ -306,9 +417,11 @@ def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None
306
  state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
307
  else:
308
  # Regular pi0 with standard RMSNorm
309
- state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
310
-
311
- #state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
 
 
312
 
313
  final_state_dict = {}
314
  for key, value in state_dict.items():
@@ -316,7 +429,6 @@ def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None
316
  final_state_dict[key] = torch.from_numpy(value)
317
  else:
318
  final_state_dict[key] = value
319
-
320
 
321
  return final_state_dict
322
 
@@ -339,11 +451,13 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str |
339
  restore_dtype = dtype_map.get(restore_precision) if restore_precision else None
340
 
341
  # Use CPU sharding to avoid GPU memory issues during checkpoint loading
342
- cpu_device = jax.devices('cpu')[0]
343
  cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device)
344
-
345
  # Use repository restore utility to load a pure dict of params (value suffix removed)
346
- params = openpi.models.model.restore_params(params_dir, restore_type=jax.Array, dtype=restore_dtype, sharding=cpu_sharding)
 
 
347
 
348
  # get params for PaliGemma
349
  pali_params = params["PaliGemma"]
@@ -355,43 +469,43 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str |
355
  def load_jax_model_and_print_keys(checkpoint_dir: str):
356
  """
357
  Load JAX model from checkpoint and print all parameter keys.
358
-
359
  Args:
360
  checkpoint_dir: Path to the checkpoint directory
361
  """
362
  params_path = pathlib.Path(checkpoint_dir).resolve()
363
-
364
  if not params_path.exists():
365
  print(f"Error: Checkpoint directory does not exist: {params_path}")
366
  return
367
-
368
  try:
369
  # Initialize checkpointer
370
  checkpointer = ocp.PyTreeCheckpointer()
371
-
372
  # Load metadata to see available keys
373
  metadata = checkpointer.metadata(params_path)
374
  print("Available top-level keys in checkpoint:")
375
- for key in metadata.keys():
376
  print(f" - {key}")
377
  print()
378
-
379
  # Restore the parameters
380
  params_name = "params"
381
  if params_name not in metadata:
382
  print(f"Warning: '{params_name}' not found in metadata. Available keys: {list(metadata.keys())}")
383
  if metadata.keys():
384
- params_name = list(metadata.keys())[0]
385
  print(f"Using '{params_name}' instead.")
386
  else:
387
  print("No keys found in metadata!")
388
  return
389
-
390
  item = {params_name: metadata[params_name]}
391
  # Use CPU device to avoid GPU memory issues
392
- device = jax.devices('cpu')[0]
393
  sharding = jax.sharding.SingleDeviceSharding(device)
394
-
395
  restored = checkpointer.restore(
396
  params_path,
397
  ocp.args.PyTreeRestore(
@@ -406,33 +520,33 @@ def load_jax_model_and_print_keys(checkpoint_dir: str):
406
  transforms={},
407
  ),
408
  )
409
-
410
  params = restored[params_name]
411
-
412
  # Flatten and print all keys
413
  flat_params = flatten_for_inspection(params)
414
-
415
  print(f"All parameter keys with shapes and dtypes ({len(flat_params)} total):")
416
  print("=" * 80)
417
-
418
  # Sort keys for better readability
419
  sorted_keys = sorted(flat_params.keys())
420
-
421
  for key in sorted_keys:
422
  print(f"{key:<60} -> {flat_params[key]}")
423
-
424
  print()
425
  print("=" * 80)
426
  print(f"Summary: Found {len(flat_params)} parameters")
427
-
428
  # Print some high-level structure information
429
  top_level_keys = set()
430
  for key in sorted_keys:
431
- top_level_key = key.split('/')[0]
432
  top_level_keys.add(top_level_key)
433
-
434
- print(f"Top-level parameter groups: {sorted(list(top_level_keys))}")
435
-
436
  except Exception as e:
437
  print(f"Error loading checkpoint: {e}")
438
  traceback.print_exc()
@@ -441,29 +555,29 @@ def load_jax_model_and_print_keys(checkpoint_dir: str):
441
  def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str):
442
  """
443
  Convert PI0 JAX checkpoint to PyTorch format.
444
-
445
  Args:
446
  checkpoint_dir: Path to the JAX checkpoint
447
  precision: Model precision (float32, bfloat16, float16)
448
  output_path: Path to save the converted PyTorch model
449
  """
450
  print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
451
-
452
  # Break down orbax ckpts by restoring via JAX to respect dtype
453
- initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision='float32')
454
-
455
  # Process projection params
456
  if "pi05" in checkpoint_dir:
457
  keys = [
458
- "action_in_proj",
459
  "action_out_proj",
460
- "time_mlp_in",
461
  "time_mlp_out",
462
  ]
463
  else:
464
  keys = [
465
  "state_proj",
466
- "action_in_proj",
467
  "action_out_proj",
468
  "action_time_mlp_in",
469
  "action_time_mlp_out",
@@ -479,10 +593,10 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
479
  else:
480
  weight = kernel_params
481
  bias = bias_params
482
-
483
  pytorch_weight_key = f"{key}.weight"
484
  pytorch_bias_key = f"{key}.bias"
485
-
486
  projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
487
  projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
488
 
@@ -490,22 +604,30 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
490
  # All models use the same PaliGemma config structure
491
  class PaliGemmaConfig:
492
  def __init__(self):
493
- self.vision_config = type('obj', (object,), {
494
- 'hidden_size': 1152,
495
- 'num_hidden_layers': 27,
496
- 'num_attention_heads': 16,
497
- 'intermediate_size': 4304,
498
- 'patch_size': 14,
499
- 'projection_dim': 2048
500
- })()
501
- self.text_config = type('obj', (object,), {
502
- 'hidden_size': 2048,
503
- 'num_hidden_layers': 18,
504
- 'num_attention_heads': 8,
505
- 'head_dim': 256,
506
- 'intermediate_size': 16384
507
- })()
508
-
 
 
 
 
 
 
 
 
509
  paligemma_config = PaliGemmaConfig()
510
  action_expert_config = openpi.models.gemma.get_config("gemma_300m")
511
 
@@ -513,27 +635,24 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
513
  paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
514
 
515
  # Process Gemma weights from expert_params
516
- gemma_params = slice_gemma_state_dict(expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir)
 
 
517
 
518
  # Create Pi0Config based on checkpoint path
519
- if "pi0_aloha_sim" in checkpoint_dir:
520
- pi0_config = openpi.models.pi0_config.Pi0Config(
521
- action_dim=14, # ALOHA has 14 action dimensions
522
- action_horizon=50,
523
- )
524
- elif "pi0_aloha_towel" in checkpoint_dir:
525
  pi0_config = openpi.models.pi0_config.Pi0Config(
526
  action_dim=14, # ALOHA has 14 action dimensions
527
  action_horizon=50,
528
  )
529
  elif "pi0_base" in checkpoint_dir:
530
  pi0_config = openpi.models.pi0_config.Pi0Config(
531
- action_dim=8, # Base droid has 8 action dimensions
532
  action_horizon=10,
533
  )
534
  elif "pi05_droid" in checkpoint_dir:
535
  pi0_config = openpi.models.pi0_config.Pi0Config(
536
- action_dim=8, # Base droid has 8 action dimensions
537
  action_horizon=10,
538
  pi05=True,
539
  )
@@ -560,10 +679,10 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
560
 
561
  # Combine all parameters (no prefix needed for our model structure)
562
  all_params = {**paligemma_params, **gemma_params, **projection_params}
563
-
564
  # Load state dict
565
  pi0_model.load_state_dict(all_params, strict=False)
566
-
567
  if precision == "float32":
568
  pi0_model = pi0_model.to(torch.float32)
569
  elif precision == "bfloat16":
@@ -573,10 +692,10 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
573
 
574
  # Save the converted model using safetensors
575
  os.makedirs(output_path, exist_ok=True)
576
-
577
  # Save model weights as SafeTensors using save_model to handle tied weights
578
  safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
579
-
580
  # Copy assets folder if it exists
581
  assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
582
  if assets_source.exists():
@@ -584,7 +703,7 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
584
  if assets_dest.exists():
585
  shutil.rmtree(assets_dest)
586
  shutil.copytree(assets_source, assets_dest)
587
-
588
  # Save config as JSON for reference
589
  config_dict = {
590
  "action_dim": pi0_config.action_dim,
@@ -595,37 +714,26 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
595
  }
596
  with open(os.path.join(output_path, "config.json"), "w") as f:
597
  json.dump(config_dict, f, indent=2)
598
-
599
- print(f"Model conversion completed successfully!")
600
  print(f"Model saved to {output_path}")
601
 
602
 
603
  def main():
604
  parser = argparse.ArgumentParser(description="Load JAX model and optionally convert to PyTorch")
 
605
  parser.add_argument(
606
- "--checkpoint_dir",
607
- type=str,
608
- required=True,
609
- help="Path to the JAX checkpoint directory"
610
- )
611
- parser.add_argument(
612
- "--output_path",
613
- type=str,
614
- help="Path to save converted PyTorch model (required for conversion)"
615
  )
616
  parser.add_argument(
617
  "--precision",
618
  choices=["float32", "bfloat16", "float16"],
619
  default="bfloat16",
620
  type=str,
621
- help="Precision for model conversion"
622
- )
623
- parser.add_argument(
624
- "--inspect_only",
625
- action="store_true",
626
- help="Only inspect parameter keys, don't convert"
627
  )
628
-
 
629
  args = parser.parse_args()
630
 
631
  if not os.path.exists(args.checkpoint_dir):
@@ -633,7 +741,7 @@ def main():
633
  checkpoint_dir = openpi.shared.download.maybe_download(f"gs://openpi-assets/checkpoints/{model_name}")
634
  else:
635
  checkpoint_dir = args.checkpoint_dir
636
-
637
  if args.inspect_only:
638
  load_jax_model_and_print_keys(args.checkpoint_dir)
639
  else:
 
10
  # Just inspect keys:
11
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
12
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
13
+
14
  # Convert to PyTorch:
15
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
16
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
17
 
18
+ Example:
19
+ # pi0_droid
20
  python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
21
 
22
  # pi0_aloha_sim
 
33
  import shutil
34
  import traceback
35
 
36
+ from flax.nnx.traversals import flatten_mapping
37
  import jax
38
  import jax.numpy as jnp
39
  import jax.sharding
40
  import numpy as np
41
  import orbax.checkpoint as ocp
 
42
  import safetensors
43
+ import torch
44
+
45
+ import openpi.models.gemma
46
+ import openpi.models.model
47
+ import openpi.models.pi0_config
48
 
49
  # Import our modules
50
  import openpi.models_pytorch.pi0_pytorch
 
 
51
  import openpi.shared.download
 
52
 
53
 
54
  def flatten_for_inspection(tree, separator="/"):
55
  """
56
  Flatten a nested dictionary for easy inspection of keys using flax.nnx.traversals.flatten_mapping.
57
+
58
  Args:
59
  tree: The nested dictionary (JAX pytree)
60
  separator: Separator to use between key levels
61
+
62
  Returns:
63
  Dictionary with flattened keys and array shapes as values
64
  """
65
  flattened = flatten_mapping(tree, separator=separator)
66
+
67
  # Convert values to shape/dtype information for inspection
68
  result = {}
69
  for key, value in flattened.items():
70
+ if hasattr(value, "shape") and hasattr(value, "dtype"):
71
  result[key] = f"shape: {value.shape}, dtype: {value.dtype}"
72
  else:
73
  result[key] = f"type: {type(value)}"
74
+
75
  return result
76
 
77
 
 
91
  """Convert PaliGemma JAX parameters to PyTorch format."""
92
  suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
93
 
 
94
  # patch embeddings
95
  jax_key = f"img/embedding/kernel{suffix}"
96
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
97
  state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
98
+
 
99
  jax_key = f"img/embedding/bias{suffix}"
100
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
101
  state_dict[pytorch_key] = state_dict.pop(jax_key)
102
 
 
 
103
  # positional embeddings
104
  jax_key = f"img/pos_embedding{suffix}"
105
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
 
111
  encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
112
  encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
113
 
114
+ encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
115
+ encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
116
+ encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
117
+ encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
118
 
119
+ encoderblock_attention_0_key_kernel = state_dict.pop(
120
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
121
+ )
122
+ encoderblock_attention_0_key_bias = state_dict.pop(
123
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
124
+ )
125
+ encoderblock_attention_0_value_kernel = state_dict.pop(
126
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
127
+ )
128
+ encoderblock_attention_0_value_bias = state_dict.pop(
129
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
130
+ )
131
+ encoderblock_attention_0_query_kernel = state_dict.pop(
132
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
133
+ )
134
+ encoderblock_attention_0_query_bias = state_dict.pop(
135
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
136
+ )
137
+ encoderblock_attention_0_out_kernel = state_dict.pop(
138
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
139
+ )
140
+ encoderblock_attention_0_out_bias = state_dict.pop(
141
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
142
+ )
143
 
144
  for i in range(config.vision_config.num_hidden_layers):
145
+ state_dict[
146
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
147
+ ] = encoderblock_layernorm0_scale[i].transpose()
148
+ state_dict[
149
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
150
+ ] = encoderblock_layernorm0_bias[i]
151
+ state_dict[
152
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
153
+ ] = encoderblock_layernorm1_scale[i].transpose()
154
+ state_dict[
155
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
156
+ ] = encoderblock_layernorm1_bias[i]
157
+ state_dict[
158
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
159
+ ] = encoderblock_mlp_dense0_kernel[i].transpose()
160
+ state_dict[
161
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
162
+ ] = encoderblock_mlp_dense0_bias[i]
163
+ state_dict[
164
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
165
+ ] = encoderblock_mlp_dense1_kernel[i].transpose()
166
+ state_dict[
167
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
168
+ ] = encoderblock_mlp_dense1_bias[i]
169
+ state_dict[
170
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
171
+ ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
172
+ state_dict[
173
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
174
+ ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
175
+ state_dict[
176
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
177
+ ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
178
+ state_dict[
179
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
180
+ ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
181
+ state_dict[
182
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
183
+ ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
184
+ state_dict[
185
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
186
+ ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
187
+ state_dict[
188
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
189
+ ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
190
+ state_dict[
191
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
192
+ ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
193
 
194
  jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
195
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
196
  state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
197
+
198
  jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
199
  pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
200
  state_dict[pytorch_key] = state_dict.pop(jax_key)
201
 
202
  # multimodal projector
203
  jax_key = f"img/head/kernel{suffix}"
204
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
205
  state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
206
+
207
  jax_key = f"img/head/bias{suffix}"
208
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
209
  state_dict[pytorch_key] = state_dict.pop(jax_key)
210
 
211
  # text decoder (gemma)
 
225
  llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
226
 
227
  for i in range(config.text_config.num_hidden_layers):
228
+ q_proj_weight_reshaped = (
229
+ llm_attention_q_einsum[i]
230
+ .transpose(0, 2, 1)
231
+ .reshape(
232
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
233
+ )
234
+ )
235
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
236
+ q_proj_weight_reshaped
237
+ )
238
 
239
  k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
240
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
241
+ k_proj_weight_reshaped
242
+ )
243
  v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
244
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
245
+ v_proj_weight_reshaped
246
+ )
247
+
248
+ o_proj_weight_reshaped = (
249
+ llm_attention_attn_vec_einsum[i]
250
+ .transpose(2, 0, 1)
251
+ .reshape(
252
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
253
+ )
254
+ )
255
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
256
+ o_proj_weight_reshaped
257
+ )
258
 
 
 
 
259
  gate_proj_weight = llm_mlp_gating_einsum[i, 0]
260
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
261
+ gate_proj_weight.transpose()
262
+ )
263
  up_proj_weight = llm_mlp_gating_einsum[i, 1]
264
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
265
+ up_proj_weight.transpose()
266
+ )
267
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
268
+ llm_mlp_linear[i].transpose()
269
+ )
270
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
271
+ llm_input_layernorm[i]
272
+ )
273
+ state_dict[
274
+ f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
275
+ ] = llm_post_attention_layernorm[i]
276
 
277
  jax_key = f"llm/final_norm/scale{suffix}"
278
  pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
 
280
 
281
  expert_dict = {}
282
  final_state_dict = {}
283
+
284
  # Expert-related keys to extract (including pi05 Dense layer parameters)
285
  expert_keys = [
286
  f"llm/final_norm_1/scale{suffix}",
 
298
  f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
299
  f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
300
  ]
301
+
302
  for key, value in state_dict.items():
303
  if key not in expert_keys:
304
  final_state_dict[key] = torch.from_numpy(value)
 
311
  def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None):
312
  """Convert Gemma JAX parameters to PyTorch format."""
313
  # Add missing attributes to config if they don't exist
314
+ if not hasattr(config, "vocab_size"):
315
  config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
316
+ if not hasattr(config, "hidden_size"):
317
  config.hidden_size = config.width
318
+ if not hasattr(config, "num_hidden_layers"):
319
  config.num_hidden_layers = config.depth
320
+ if not hasattr(config, "num_attention_heads"):
321
  config.num_attention_heads = config.num_heads
322
 
323
  suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
 
334
  # Pi05 with adaptive normalization
335
  llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
336
  llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
337
+ llm_input_layernorm_kernel = state_dict.pop(
338
+ f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
339
+ )
340
+ llm_post_attention_layernorm_kernel = state_dict.pop(
341
+ f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
342
+ )
343
  else:
344
  # Regular pi0 with standard RMSNorm
345
  llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
346
  llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
347
 
 
348
  for i in range(config.num_hidden_layers):
349
+ q_proj_weight_reshaped = (
350
+ llm_attention_q_einsum[i]
351
+ .transpose(0, 2, 1)
352
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
353
+ )
354
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
355
+ q_proj_weight_reshaped
356
+ )
357
 
358
  k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
359
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
360
+ k_proj_weight_reshaped
361
+ )
362
  v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
363
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
364
+ v_proj_weight_reshaped
365
+ )
366
+
367
+ o_proj_weight_reshaped = (
368
+ llm_attention_attn_vec_einsum[i]
369
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
370
+ .transpose(1, 0)
371
+ )
372
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
373
+ o_proj_weight_reshaped
374
+ )
375
 
 
 
 
376
  gate_proj_weight = llm_mlp_gating_einsum[i, 0]
377
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
378
+ gate_proj_weight.transpose()
379
+ )
380
  up_proj_weight = llm_mlp_gating_einsum[i, 1]
381
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
382
+ up_proj_weight.transpose()
383
+ )
384
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
385
+ i
386
+ ].transpose()
387
 
388
  if "pi05" in checkpoint_dir:
389
  # Pi05 with adaptive normalization - use Dense layer parameters directly
390
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
391
+ llm_input_layernorm_bias[i]
392
+ )
393
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
394
+ llm_post_attention_layernorm_bias[i]
395
+ )
396
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
397
+ llm_input_layernorm_kernel[i].transpose()
398
+ )
399
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
400
+ llm_post_attention_layernorm_kernel[i].transpose()
401
+ )
402
  else:
403
  # Regular pi0 with standard RMSNorm
404
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
405
+ llm_input_layernorm[i]
406
+ )
407
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
408
+ llm_post_attention_layernorm[i]
409
+ )
410
 
411
  # Handle final norm layer
412
  if "pi05" in checkpoint_dir:
 
417
  state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
418
  else:
419
  # Regular pi0 with standard RMSNorm
420
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
421
+ f"llm/final_norm_{num_expert}/scale{suffix}"
422
+ )
423
+
424
+ # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
425
 
426
  final_state_dict = {}
427
  for key, value in state_dict.items():
 
429
  final_state_dict[key] = torch.from_numpy(value)
430
  else:
431
  final_state_dict[key] = value
 
432
 
433
  return final_state_dict
434
 
 
451
  restore_dtype = dtype_map.get(restore_precision) if restore_precision else None
452
 
453
  # Use CPU sharding to avoid GPU memory issues during checkpoint loading
454
+ cpu_device = jax.devices("cpu")[0]
455
  cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device)
456
+
457
  # Use repository restore utility to load a pure dict of params (value suffix removed)
458
+ params = openpi.models.model.restore_params(
459
+ params_dir, restore_type=jax.Array, dtype=restore_dtype, sharding=cpu_sharding
460
+ )
461
 
462
  # get params for PaliGemma
463
  pali_params = params["PaliGemma"]
 
469
  def load_jax_model_and_print_keys(checkpoint_dir: str):
470
  """
471
  Load JAX model from checkpoint and print all parameter keys.
472
+
473
  Args:
474
  checkpoint_dir: Path to the checkpoint directory
475
  """
476
  params_path = pathlib.Path(checkpoint_dir).resolve()
477
+
478
  if not params_path.exists():
479
  print(f"Error: Checkpoint directory does not exist: {params_path}")
480
  return
481
+
482
  try:
483
  # Initialize checkpointer
484
  checkpointer = ocp.PyTreeCheckpointer()
485
+
486
  # Load metadata to see available keys
487
  metadata = checkpointer.metadata(params_path)
488
  print("Available top-level keys in checkpoint:")
489
+ for key in metadata:
490
  print(f" - {key}")
491
  print()
492
+
493
  # Restore the parameters
494
  params_name = "params"
495
  if params_name not in metadata:
496
  print(f"Warning: '{params_name}' not found in metadata. Available keys: {list(metadata.keys())}")
497
  if metadata.keys():
498
+ params_name = next(iter(metadata.keys()))
499
  print(f"Using '{params_name}' instead.")
500
  else:
501
  print("No keys found in metadata!")
502
  return
503
+
504
  item = {params_name: metadata[params_name]}
505
  # Use CPU device to avoid GPU memory issues
506
+ device = jax.devices("cpu")[0]
507
  sharding = jax.sharding.SingleDeviceSharding(device)
508
+
509
  restored = checkpointer.restore(
510
  params_path,
511
  ocp.args.PyTreeRestore(
 
520
  transforms={},
521
  ),
522
  )
523
+
524
  params = restored[params_name]
525
+
526
  # Flatten and print all keys
527
  flat_params = flatten_for_inspection(params)
528
+
529
  print(f"All parameter keys with shapes and dtypes ({len(flat_params)} total):")
530
  print("=" * 80)
531
+
532
  # Sort keys for better readability
533
  sorted_keys = sorted(flat_params.keys())
534
+
535
  for key in sorted_keys:
536
  print(f"{key:<60} -> {flat_params[key]}")
537
+
538
  print()
539
  print("=" * 80)
540
  print(f"Summary: Found {len(flat_params)} parameters")
541
+
542
  # Print some high-level structure information
543
  top_level_keys = set()
544
  for key in sorted_keys:
545
+ top_level_key = key.split("/")[0]
546
  top_level_keys.add(top_level_key)
547
+
548
+ print(f"Top-level parameter groups: {sorted(top_level_keys)}")
549
+
550
  except Exception as e:
551
  print(f"Error loading checkpoint: {e}")
552
  traceback.print_exc()
 
555
  def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str):
556
  """
557
  Convert PI0 JAX checkpoint to PyTorch format.
558
+
559
  Args:
560
  checkpoint_dir: Path to the JAX checkpoint
561
  precision: Model precision (float32, bfloat16, float16)
562
  output_path: Path to save the converted PyTorch model
563
  """
564
  print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
565
+
566
  # Break down orbax ckpts by restoring via JAX to respect dtype
567
+ initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
568
+
569
  # Process projection params
570
  if "pi05" in checkpoint_dir:
571
  keys = [
572
+ "action_in_proj",
573
  "action_out_proj",
574
+ "time_mlp_in",
575
  "time_mlp_out",
576
  ]
577
  else:
578
  keys = [
579
  "state_proj",
580
+ "action_in_proj",
581
  "action_out_proj",
582
  "action_time_mlp_in",
583
  "action_time_mlp_out",
 
593
  else:
594
  weight = kernel_params
595
  bias = bias_params
596
+
597
  pytorch_weight_key = f"{key}.weight"
598
  pytorch_bias_key = f"{key}.bias"
599
+
600
  projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
601
  projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
602
 
 
604
  # All models use the same PaliGemma config structure
605
  class PaliGemmaConfig:
606
  def __init__(self):
607
+ self.vision_config = type(
608
+ "obj",
609
+ (object,),
610
+ {
611
+ "hidden_size": 1152,
612
+ "num_hidden_layers": 27,
613
+ "num_attention_heads": 16,
614
+ "intermediate_size": 4304,
615
+ "patch_size": 14,
616
+ "projection_dim": 2048,
617
+ },
618
+ )()
619
+ self.text_config = type(
620
+ "obj",
621
+ (object,),
622
+ {
623
+ "hidden_size": 2048,
624
+ "num_hidden_layers": 18,
625
+ "num_attention_heads": 8,
626
+ "head_dim": 256,
627
+ "intermediate_size": 16384,
628
+ },
629
+ )()
630
+
631
  paligemma_config = PaliGemmaConfig()
632
  action_expert_config = openpi.models.gemma.get_config("gemma_300m")
633
 
 
635
  paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
636
 
637
  # Process Gemma weights from expert_params
638
+ gemma_params = slice_gemma_state_dict(
639
+ expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir
640
+ )
641
 
642
  # Create Pi0Config based on checkpoint path
643
+ if "pi0_aloha_sim" in checkpoint_dir or "pi0_aloha_towel" in checkpoint_dir:
 
 
 
 
 
644
  pi0_config = openpi.models.pi0_config.Pi0Config(
645
  action_dim=14, # ALOHA has 14 action dimensions
646
  action_horizon=50,
647
  )
648
  elif "pi0_base" in checkpoint_dir:
649
  pi0_config = openpi.models.pi0_config.Pi0Config(
650
+ action_dim=8, # Base droid has 8 action dimensions
651
  action_horizon=10,
652
  )
653
  elif "pi05_droid" in checkpoint_dir:
654
  pi0_config = openpi.models.pi0_config.Pi0Config(
655
+ action_dim=8, # Base droid has 8 action dimensions
656
  action_horizon=10,
657
  pi05=True,
658
  )
 
679
 
680
  # Combine all parameters (no prefix needed for our model structure)
681
  all_params = {**paligemma_params, **gemma_params, **projection_params}
682
+
683
  # Load state dict
684
  pi0_model.load_state_dict(all_params, strict=False)
685
+
686
  if precision == "float32":
687
  pi0_model = pi0_model.to(torch.float32)
688
  elif precision == "bfloat16":
 
692
 
693
  # Save the converted model using safetensors
694
  os.makedirs(output_path, exist_ok=True)
695
+
696
  # Save model weights as SafeTensors using save_model to handle tied weights
697
  safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
698
+
699
  # Copy assets folder if it exists
700
  assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
701
  if assets_source.exists():
 
703
  if assets_dest.exists():
704
  shutil.rmtree(assets_dest)
705
  shutil.copytree(assets_source, assets_dest)
706
+
707
  # Save config as JSON for reference
708
  config_dict = {
709
  "action_dim": pi0_config.action_dim,
 
714
  }
715
  with open(os.path.join(output_path, "config.json"), "w") as f:
716
  json.dump(config_dict, f, indent=2)
717
+
718
+ print("Model conversion completed successfully!")
719
  print(f"Model saved to {output_path}")
720
 
721
 
722
  def main():
723
  parser = argparse.ArgumentParser(description="Load JAX model and optionally convert to PyTorch")
724
+ parser.add_argument("--checkpoint_dir", type=str, required=True, help="Path to the JAX checkpoint directory")
725
  parser.add_argument(
726
+ "--output_path", type=str, help="Path to save converted PyTorch model (required for conversion)"
 
 
 
 
 
 
 
 
727
  )
728
  parser.add_argument(
729
  "--precision",
730
  choices=["float32", "bfloat16", "float16"],
731
  default="bfloat16",
732
  type=str,
733
+ help="Precision for model conversion",
 
 
 
 
 
734
  )
735
+ parser.add_argument("--inspect_only", action="store_true", help="Only inspect parameter keys, don't convert")
736
+
737
  args = parser.parse_args()
738
 
739
  if not os.path.exists(args.checkpoint_dir):
 
741
  checkpoint_dir = openpi.shared.download.maybe_download(f"gs://openpi-assets/checkpoints/{model_name}")
742
  else:
743
  checkpoint_dir = args.checkpoint_dir
744
+
745
  if args.inspect_only:
746
  load_jax_model_and_print_keys(args.checkpoint_dir)
747
  else:
examples/droid/convert_droid_data_to_lerobot.py CHANGED
@@ -277,7 +277,7 @@ class RecordedMultiCameraWrapper:
277
  self.camera_kwargs = camera_kwargs
278
 
279
  # Open Camera Readers #
280
- mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4") # noqa: PTH207
281
  all_filepaths = mp4_filepaths
282
 
283
  self.camera_dict = {}
 
277
  self.camera_kwargs = camera_kwargs
278
 
279
  # Open Camera Readers #
280
+ mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
281
  all_filepaths = mp4_filepaths
282
 
283
  self.camera_dict = {}
pyproject.toml CHANGED
@@ -73,7 +73,7 @@ members = ["packages/*"]
73
  [tool.ruff]
74
  line-length = 120
75
  target-version = "py311"
76
- extend-exclude = ["docker", "third_party"]
77
 
78
  [tool.ruff.lint]
79
  # https://docs.astral.sh/ruff/rules/
@@ -101,7 +101,6 @@ select = [
101
  "PLR5",
102
  "PLW",
103
  "PT",
104
- "PTH",
105
  "Q",
106
  "RET",
107
  "RUF",
 
73
  [tool.ruff]
74
  line-length = 120
75
  target-version = "py311"
76
+ extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
77
 
78
  [tool.ruff.lint]
79
  # https://docs.astral.sh/ruff/rules/
 
101
  "PLR5",
102
  "PLW",
103
  "PT",
 
104
  "Q",
105
  "RET",
106
  "RUF",
scripts/train_pytorch.py CHANGED
@@ -23,7 +23,6 @@ Multi-Node Training:
23
 
24
  """
25
 
26
- import argparse
27
  import dataclasses
28
  import gc
29
  import logging
@@ -31,10 +30,10 @@ import os
31
  import platform
32
  import shutil
33
  import time
34
- from typing import Any, Dict
35
 
36
  import jax
37
  import numpy as np
 
38
  import torch
39
  import torch.distributed as dist
40
  import torch.nn.parallel
@@ -42,162 +41,169 @@ import torch.utils.data
42
  import torch.utils.data.distributed
43
  import tqdm
44
  import wandb
45
- import safetensors.torch
46
 
 
 
47
  import openpi.training.config as _config
48
  import openpi.training.data_loader as _data
49
- import openpi.models.model as _model
50
- import openpi.models_pytorch.pi0_pytorch
51
- import openpi.models.pi0_config
52
 
53
 
54
  def init_logging():
55
- level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
56
-
57
- class CustomFormatter(logging.Formatter):
58
- def format(self, record):
59
- record.levelname = level_mapping.get(record.levelname, record.levelname)
60
- return super().format(record)
61
-
62
- formatter = CustomFormatter(
63
- fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
64
- datefmt="%H:%M:%S",
65
- )
66
- logger = logging.getLogger()
67
- logger.setLevel(logging.INFO)
68
- if not logger.handlers:
69
- ch = logging.StreamHandler()
70
- ch.setFormatter(formatter)
71
- logger.addHandler(ch)
72
- else:
73
- logger.handlers[0].setFormatter(formatter)
74
 
75
 
76
  def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
77
- """Initialize wandb logging."""
78
- if not enabled:
79
- wandb.init(mode="disabled")
80
- return
81
-
82
- ckpt_dir = config.checkpoint_dir
83
- if not ckpt_dir.exists():
84
- raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
85
-
86
- if resuming:
87
- run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
88
- wandb.init(id=run_id, resume="must", project=config.project_name)
89
- else:
90
- wandb.init(
91
- name=config.exp_name,
92
- config=dataclasses.asdict(config),
93
- project=config.project_name,
94
- )
95
- (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
96
 
97
 
98
  def setup_ddp():
99
- world_size = int(os.environ.get("WORLD_SIZE", "1"))
100
- use_ddp = world_size > 1
101
- if use_ddp and not torch.distributed.is_initialized():
102
- backend = "nccl" if torch.cuda.is_available() else "gloo"
103
- torch.distributed.init_process_group(backend=backend, init_method="env://")
104
-
105
- # Set up debugging environment variables for DDP issues
106
- if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
107
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
108
-
109
- local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
110
- device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
111
- if torch.cuda.is_available():
112
- torch.cuda.set_device(device)
113
- return use_ddp, local_rank, device
114
 
115
 
116
  def cleanup_ddp():
117
- if torch.distributed.is_initialized():
118
- torch.distributed.barrier()
119
- torch.distributed.destroy_process_group()
120
 
121
 
122
  def set_seed(seed: int, local_rank: int):
123
- torch.manual_seed(seed + local_rank)
124
- np.random.seed(seed + local_rank)
125
- if torch.cuda.is_available():
126
- torch.cuda.manual_seed_all(seed + local_rank)
127
 
128
 
129
  def build_datasets(config: _config.TrainConfig):
130
- # Use the unified data loader with PyTorch framework
131
- data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
132
- return data_loader, data_loader.data_config()
133
 
134
 
135
  def get_model_state_dict(model):
136
- """Get state dict from model, handling DDP wrapper."""
137
- return model.module.state_dict() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.state_dict()
 
 
 
 
138
 
139
 
140
  def get_model_parameters(model):
141
- """Get parameters from model, handling DDP wrapper."""
142
- return model.module.parameters() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.parameters()
 
 
 
 
143
 
144
 
145
  def save_checkpoint(model, optimizer, global_step, config, is_main):
146
- """Save a checkpoint with model state, optimizer state, and metadata."""
147
- if not is_main:
148
- return
149
-
150
- # Only save if it's time to save or if it's the final step
151
- if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
152
- # Create temporary directory for atomic checkpoint saving
153
- final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
154
- tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
155
-
156
- # Remove any existing temp directory and create new one
157
- if tmp_ckpt_dir.exists():
158
- shutil.rmtree(tmp_ckpt_dir)
159
- tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
160
-
161
- # Save model state using safetensors (handle shared tensors)
162
- model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
163
- safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "pytorch_model.safetensors")
164
-
165
- # Save optimizer state using PyTorch format
166
- torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
167
-
168
- # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
169
- metadata = {
170
- "global_step": global_step,
171
- "config": dataclasses.asdict(config),
172
- "timestamp": time.time(),
173
- }
174
- torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
175
-
176
- # Atomically move temp directory to final location
177
- if final_ckpt_dir.exists():
178
- shutil.rmtree(final_ckpt_dir)
179
- tmp_ckpt_dir.rename(final_ckpt_dir)
180
-
181
- logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
182
-
183
- # Log checkpoint to wandb
184
- if config.wandb_enabled:
185
- wandb.log({"checkpoint_step": global_step}, step=global_step)
186
 
187
 
188
  def load_checkpoint(model, optimizer, checkpoint_dir, device):
189
  """Load the latest checkpoint and return the global step."""
190
- checkpoint_steps = []
191
- for d in checkpoint_dir.iterdir():
192
- if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_"):
193
- checkpoint_steps.append(int(d.name))
194
-
 
195
  if not checkpoint_steps:
196
  raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
197
-
198
  latest_step = max(checkpoint_steps)
199
  ckpt_dir = checkpoint_dir / f"{latest_step}"
200
-
201
  # Clear memory before loading checkpoints
202
  if torch.cuda.is_available():
203
  torch.cuda.empty_cache()
@@ -208,35 +214,34 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device):
208
  # Load model state with error handling
209
  logging.info("Loading model state...")
210
  safetensors_path = ckpt_dir / "pytorch_model.safetensors"
211
-
212
  if safetensors_path.exists():
213
  model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
214
  safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
215
  logging.info("Loaded model state from safetensors format")
216
  else:
217
  raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
218
-
219
  torch.cuda.empty_cache()
220
  gc.collect()
221
  log_memory_usage(device, latest_step, "after_loading_model")
222
-
223
  # Load optimizer state with error handling
224
  logging.info("Loading optimizer state...")
225
  optimizer_path = ckpt_dir / "optimizer.pt"
226
-
227
  if optimizer_path.exists():
228
  optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
229
  logging.info("Loaded optimizer state from pt format")
230
  else:
231
  raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
232
-
233
  optimizer.load_state_dict(optimizer_state_dict)
234
  del optimizer_state_dict
235
  torch.cuda.empty_cache()
236
  gc.collect()
237
  log_memory_usage(device, latest_step, "after_loading_optimizer")
238
-
239
-
240
  # Load metadata
241
  logging.info("Loading metadata...")
242
  metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
@@ -245,355 +250,379 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device):
245
  torch.cuda.empty_cache()
246
  gc.collect()
247
  log_memory_usage(device, latest_step, "after_loading_metadata")
248
-
249
  logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
250
  return global_step
251
-
252
  except RuntimeError as e:
253
  if "out of memory" in str(e):
254
  # Clear memory and provide detailed error message
255
  torch.cuda.empty_cache()
256
  gc.collect()
257
- logging.error(f"Out of memory error while loading checkpoint: {str(e)}")
258
  log_memory_usage(device, latest_step, "after_oom_error")
259
- raise RuntimeError(f"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True") from e
 
 
260
  raise
261
 
262
 
263
  def get_latest_checkpoint_step(checkpoint_dir):
264
- """Get the latest checkpoint step number from a checkpoint directory."""
265
- checkpoint_steps = []
266
- for d in checkpoint_dir.iterdir():
267
- if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_"):
268
- checkpoint_steps.append(int(d.name))
269
-
270
- return max(checkpoint_steps) if checkpoint_steps else None
271
 
272
 
273
  def log_memory_usage(device, step, phase="unknown"):
274
- """Log detailed memory usage information."""
275
- if not torch.cuda.is_available():
276
- return
277
-
278
- memory_allocated = torch.cuda.memory_allocated(device) / 1e9
279
- memory_reserved = torch.cuda.memory_reserved(device) / 1e9
280
- memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
281
- memory_free = memory_free / 1e9
282
-
283
- # Get more detailed memory info
284
- memory_stats = torch.cuda.memory_stats(device)
285
- max_memory_allocated = memory_stats.get('allocated_bytes.all.peak', 0) / 1e9
286
- max_memory_reserved = memory_stats.get('reserved_bytes.all.peak', 0) / 1e9
287
-
288
- # Get DDP info if available
289
- ddp_info = ""
290
- if dist.is_initialized():
291
- ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
292
-
293
- logging.info(f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}")
 
 
294
 
295
 
296
  def train_loop(config: _config.TrainConfig):
297
- use_ddp, local_rank, device = setup_ddp()
298
- is_main = (not use_ddp) or (dist.get_rank() == 0)
299
- set_seed(config.seed, local_rank)
300
-
301
- # Initialize checkpoint directory and wandb
302
- resuming = False
303
- if config.resume:
304
- # Find checkpoint directory based on experiment name
305
- exp_checkpoint_dir = config.checkpoint_dir
306
- if exp_checkpoint_dir.exists():
307
- # Use validation to find the latest working checkpoint
308
- latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
309
- if latest_step is not None:
310
- resuming = True
311
- logging.info(f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}")
312
- else:
313
- raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
314
- else:
315
- raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
316
- elif config.overwrite and config.checkpoint_dir.exists():
317
- shutil.rmtree(config.checkpoint_dir)
318
- logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
319
-
320
- # Create checkpoint directory with experiment name
321
- if not resuming:
322
- # For new runs, create experiment-specific checkpoint directory
323
- exp_checkpoint_dir = config.checkpoint_dir
324
- exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
325
- logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
326
- else:
327
- # For resume, checkpoint_dir is already set to the experiment directory
328
- logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
329
-
330
- # Initialize wandb (only on main process)
331
- if is_main:
332
- init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
333
-
334
- # Build data loader using the unified data loader
335
- # Calculate effective batch size per GPU for DDP
336
- # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
337
- world_size = torch.distributed.get_world_size() if use_ddp else 1
338
- effective_batch_size = config.batch_size // world_size
339
- logging.info(f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})")
340
-
341
- # Pass the original batch size to data loader - it will handle DDP splitting internally
342
- loader, _ = build_datasets(config)
343
-
344
- # Log sample images to wandb on first batch
345
- if is_main and config.wandb_enabled and not resuming:
346
- # Create a separate data loader for sample batch to avoid consuming the main loader
347
- sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
348
- sample_batch = next(iter(sample_data_loader))
349
- # Convert observation and actions to torch tensors
350
- observation, actions = sample_batch
351
- sample_batch = observation.to_dict()
352
- sample_batch["actions"] = actions
353
-
354
- # Create sample images for wandb
355
- images_to_log = []
356
- # Get batch size from the first image tensor
357
- batch_size = next(iter(sample_batch['image'].values())).shape[0]
358
- for i in range(min(5, batch_size)):
359
- # Concatenate all camera views horizontally for this batch item
360
- # Convert from NCHW to NHWC format for wandb
361
- img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch['image'].values()], axis=1)
362
- img_concatenated = img_concatenated.cpu().numpy()
363
- images_to_log.append(wandb.Image(img_concatenated))
364
-
365
- wandb.log({"camera_views": images_to_log}, step=0)
366
-
367
- # Clear sample batch from memory aggressively
368
- del sample_batch, observation, actions, images_to_log, img_concatenated
369
- del sample_data_loader # Also delete the sample data loader
370
- gc.collect()
371
- if torch.cuda.is_available():
372
- torch.cuda.empty_cache()
373
- logging.info("Cleared sample batch and data loader from memory")
374
-
375
- # Build model
376
- if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
377
- # Convert dataclass to Pi0Config if needed
378
- model_cfg = openpi.models.pi0_config.Pi0Config(
379
- dtype=config.pytorch_training_precision,
380
- action_dim=config.model.action_dim,
381
- action_horizon=config.model.action_horizon,
382
- max_token_len=config.model.max_token_len,
383
- paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
384
- action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
385
- pi05=getattr(config.model, "pi05", False),
386
- )
387
- else:
388
- model_cfg = config.model
389
- # Update dtype to match pytorch_training_precision
390
- object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
391
-
392
- model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
393
-
394
-
395
- if hasattr(model, 'gradient_checkpointing_enable'):
396
- enable_gradient_checkpointing = True
397
- model.gradient_checkpointing_enable()
398
- logging.info("Enabled gradient checkpointing for memory optimization")
399
- else:
400
- enable_gradient_checkpointing = False
401
- logging.info("Gradient checkpointing is not supported for this model")
402
-
403
- # Log initial memory usage after model creation
404
- if is_main and torch.cuda.is_available():
405
- log_memory_usage(device, 0, "after_model_creation")
406
-
407
- # Enable memory optimizations for large-scale training
408
- if world_size >= 8:
409
- torch.backends.cudnn.benchmark = True
410
- torch.backends.cuda.matmul.allow_tf32 = True
411
- torch.backends.cudnn.allow_tf32 = True
412
- # Set memory allocation configuration
413
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
414
- logging.info("Enabled memory optimizations for 8+ GPU training")
415
-
416
- if use_ddp:
417
- model = torch.nn.parallel.DistributedDataParallel(
418
- model,
419
- device_ids=[device.index] if device.type == "cuda" else None,
420
- find_unused_parameters=True, # Disable for memory efficiency
421
- gradient_as_bucket_view=True, # Enable for memory efficiency
422
- static_graph=True if world_size >= 8 else False, # Enable for 8+ GPUs
423
- )
424
-
425
- # Load weights from weight_loader if specified (for fine-tuning)
426
- if config.pytorch_weight_path is not None:
427
- logging.info(f"Loading weights from: {config.pytorch_weight_path}")
428
-
429
- model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
430
- safetensors.torch.load_model((model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path)
431
- logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
432
-
433
- # Optimizer + learning rate schedule from config
434
- warmup_steps = config.lr_schedule.warmup_steps
435
- peak_lr = config.lr_schedule.peak_lr
436
- decay_steps = config.lr_schedule.decay_steps
437
- end_lr = config.lr_schedule.decay_lr
438
-
439
- # Create optimizer with config parameters
440
- optim = torch.optim.AdamW(
441
- model.parameters(),
442
- lr=peak_lr,
443
- betas=(config.optimizer.b1, config.optimizer.b2),
444
- eps=config.optimizer.eps,
445
- weight_decay=config.optimizer.weight_decay
446
- )
447
-
448
-
449
- # Load checkpoint if resuming
450
- global_step = 0
451
- if resuming:
452
- global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
453
- logging.info(f"Resumed training from step {global_step}")
454
-
455
- def lr_schedule(step: int):
456
- if step < warmup_steps:
457
- # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
458
- init_lr = peak_lr / (warmup_steps + 1)
459
- return init_lr + (peak_lr - init_lr) * step / warmup_steps
460
- # cosine decay
461
- progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
462
- cos = 0.5 * (1 + np.cos(np.pi * progress))
463
- return end_lr + (peak_lr - end_lr) * cos
464
-
465
- model.train()
466
- start_time = time.time()
467
- infos = [] # Collect stats over log interval
468
- if is_main:
469
- logging.info(f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}")
470
- logging.info(f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}")
471
- logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
472
- logging.info(f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}")
473
- logging.info(f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}")
474
- logging.info(f"EMA is not supported for PyTorch training")
475
- logging.info(f"Training precision: {model_cfg.dtype}")
476
-
477
- # Training loop - iterate until we reach num_train_steps
478
- pbar = tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None
479
-
480
- while global_step < config.num_train_steps:
481
- # Set epoch for distributed training
482
- if use_ddp and hasattr(loader, 'set_epoch'):
483
- loader.set_epoch(global_step // len(loader))
484
-
485
- for observation, actions in loader:
486
- # Check if we've reached the target number of steps
487
- if global_step >= config.num_train_steps:
488
- break
489
-
490
- # The unified data loader returns (observation, actions) tuple
491
- observation = jax.tree.map(lambda x: x.to(device), observation)
492
- actions = actions.to(torch.float32)
493
- actions = actions.to(device)
494
-
495
- # Update LR
496
- for pg in optim.param_groups:
497
- pg["lr"] = lr_schedule(global_step)
498
-
499
- # Forward pass
500
- losses = model(observation, actions)
501
- # Ensure losses is a tensor and handle different return types
502
- if isinstance(losses, (list, tuple)):
503
- losses = torch.stack(losses)
504
- elif not isinstance(losses, torch.Tensor):
505
- losses = torch.tensor(losses, device=device, dtype=torch.float32)
506
-
507
- loss = losses.mean()
508
-
509
- # Backward pass
510
- loss.backward()
511
-
512
- # Log memory usage after backward pass
513
- if global_step < 5 and is_main:
514
- if torch.cuda.is_available():
515
- log_memory_usage(device, global_step, "after_backward")
516
-
517
- # Gradient clipping
518
- grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
519
-
520
- # Optimizer step
521
- optim.step()
522
- optim.zero_grad(set_to_none=True)
523
-
524
- # Clear gradients more aggressively
525
- for param in model.parameters():
526
- if param.grad is not None:
527
- param.grad.detach_()
528
- param.grad = None
529
-
530
-
531
- # Collect stats
532
- if is_main:
533
- infos.append({
534
- "loss": loss.item(),
535
- "learning_rate": optim.param_groups[0]['lr'],
536
- "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
537
- })
538
-
539
- if is_main and (global_step % config.log_interval == 0):
540
- elapsed = time.time() - start_time
541
-
542
- # Average stats over log interval
543
- avg_loss = sum(info["loss"] for info in infos) / len(infos)
544
- avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
545
-
546
- avg_grad_norm = None
547
- if any('grad_norm' in info for info in infos):
548
- vals = [info['grad_norm'] for info in infos if 'grad_norm' in info and info['grad_norm'] is not None]
549
- if len(vals) > 0:
550
- avg_grad_norm = sum(vals) / len(vals)
551
- logging.info(f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" if avg_grad_norm is not None else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s")
552
-
553
- # Log to wandb
554
- if config.wandb_enabled and len(infos) > 0:
555
- log_payload = {
556
- "loss": avg_loss,
557
- "learning_rate": avg_lr,
558
- "step": global_step,
559
- "time_per_step": elapsed / config.log_interval,
560
- }
561
- if avg_grad_norm is not None:
562
- log_payload["grad_norm"] = avg_grad_norm
563
- wandb.log(log_payload, step=global_step)
564
-
565
- start_time = time.time()
566
- infos = [] # Reset stats collection
567
-
568
- global_step += 1
569
- # Save checkpoint using the new mechanism
570
- save_checkpoint(model, optim, global_step, config, is_main)
571
-
572
- # Update progress bar
573
- if pbar is not None:
574
- pbar.update(1)
575
- pbar.set_postfix({
576
- 'loss': f'{loss.item():.4f}',
577
- 'lr': f'{optim.param_groups[0]["lr"]:.2e}',
578
- 'step': global_step
579
- })
580
-
581
- # Close progress bar
582
- if pbar is not None:
583
- pbar.close()
584
-
585
- # Finish wandb run
586
- if is_main and config.wandb_enabled:
587
- wandb.finish()
588
-
589
- cleanup_ddp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
 
592
  def main():
593
- init_logging()
594
- config = _config.cli()
595
- train_loop(config)
596
 
597
 
598
  if __name__ == "__main__":
599
- main()
 
23
 
24
  """
25
 
 
26
  import dataclasses
27
  import gc
28
  import logging
 
30
  import platform
31
  import shutil
32
  import time
 
33
 
34
  import jax
35
  import numpy as np
36
+ import safetensors.torch
37
  import torch
38
  import torch.distributed as dist
39
  import torch.nn.parallel
 
41
  import torch.utils.data.distributed
42
  import tqdm
43
  import wandb
 
44
 
45
+ import openpi.models.pi0_config
46
+ import openpi.models_pytorch.pi0_pytorch
47
  import openpi.training.config as _config
48
  import openpi.training.data_loader as _data
 
 
 
49
 
50
 
51
  def init_logging():
52
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
53
+
54
+ class CustomFormatter(logging.Formatter):
55
+ def format(self, record):
56
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
57
+ return super().format(record)
58
+
59
+ formatter = CustomFormatter(
60
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
61
+ datefmt="%H:%M:%S",
62
+ )
63
+ logger = logging.getLogger()
64
+ logger.setLevel(logging.INFO)
65
+ if not logger.handlers:
66
+ ch = logging.StreamHandler()
67
+ ch.setFormatter(formatter)
68
+ logger.addHandler(ch)
69
+ else:
70
+ logger.handlers[0].setFormatter(formatter)
71
 
72
 
73
  def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
74
+ """Initialize wandb logging."""
75
+ if not enabled:
76
+ wandb.init(mode="disabled")
77
+ return
78
+
79
+ ckpt_dir = config.checkpoint_dir
80
+ if not ckpt_dir.exists():
81
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
82
+
83
+ if resuming:
84
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
85
+ wandb.init(id=run_id, resume="must", project=config.project_name)
86
+ else:
87
+ wandb.init(
88
+ name=config.exp_name,
89
+ config=dataclasses.asdict(config),
90
+ project=config.project_name,
91
+ )
92
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
93
 
94
 
95
  def setup_ddp():
96
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
97
+ use_ddp = world_size > 1
98
+ if use_ddp and not torch.distributed.is_initialized():
99
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
100
+ torch.distributed.init_process_group(backend=backend, init_method="env://")
101
+
102
+ # Set up debugging environment variables for DDP issues
103
+ if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
104
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
105
+
106
+ local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
107
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
108
+ if torch.cuda.is_available():
109
+ torch.cuda.set_device(device)
110
+ return use_ddp, local_rank, device
111
 
112
 
113
  def cleanup_ddp():
114
+ if torch.distributed.is_initialized():
115
+ torch.distributed.barrier()
116
+ torch.distributed.destroy_process_group()
117
 
118
 
119
  def set_seed(seed: int, local_rank: int):
120
+ torch.manual_seed(seed + local_rank)
121
+ np.random.seed(seed + local_rank)
122
+ if torch.cuda.is_available():
123
+ torch.cuda.manual_seed_all(seed + local_rank)
124
 
125
 
126
  def build_datasets(config: _config.TrainConfig):
127
+ # Use the unified data loader with PyTorch framework
128
+ data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
129
+ return data_loader, data_loader.data_config()
130
 
131
 
132
  def get_model_state_dict(model):
133
+ """Get state dict from model, handling DDP wrapper."""
134
+ return (
135
+ model.module.state_dict()
136
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
137
+ else model.state_dict()
138
+ )
139
 
140
 
141
  def get_model_parameters(model):
142
+ """Get parameters from model, handling DDP wrapper."""
143
+ return (
144
+ model.module.parameters()
145
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
146
+ else model.parameters()
147
+ )
148
 
149
 
150
  def save_checkpoint(model, optimizer, global_step, config, is_main):
151
+ """Save a checkpoint with model state, optimizer state, and metadata."""
152
+ if not is_main:
153
+ return
154
+
155
+ # Only save if it's time to save or if it's the final step
156
+ if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
157
+ # Create temporary directory for atomic checkpoint saving
158
+ final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
159
+ tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
160
+
161
+ # Remove any existing temp directory and create new one
162
+ if tmp_ckpt_dir.exists():
163
+ shutil.rmtree(tmp_ckpt_dir)
164
+ tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
165
+
166
+ # Save model state using safetensors (handle shared tensors)
167
+ model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
168
+ safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "pytorch_model.safetensors")
169
+
170
+ # Save optimizer state using PyTorch format
171
+ torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
172
+
173
+ # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
174
+ metadata = {
175
+ "global_step": global_step,
176
+ "config": dataclasses.asdict(config),
177
+ "timestamp": time.time(),
178
+ }
179
+ torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
180
+
181
+ # Atomically move temp directory to final location
182
+ if final_ckpt_dir.exists():
183
+ shutil.rmtree(final_ckpt_dir)
184
+ tmp_ckpt_dir.rename(final_ckpt_dir)
185
+
186
+ logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
187
+
188
+ # Log checkpoint to wandb
189
+ if config.wandb_enabled:
190
+ wandb.log({"checkpoint_step": global_step}, step=global_step)
191
 
192
 
193
  def load_checkpoint(model, optimizer, checkpoint_dir, device):
194
  """Load the latest checkpoint and return the global step."""
195
+ checkpoint_steps = [
196
+ int(d.name)
197
+ for d in checkpoint_dir.iterdir()
198
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
199
+ ]
200
+
201
  if not checkpoint_steps:
202
  raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
203
+
204
  latest_step = max(checkpoint_steps)
205
  ckpt_dir = checkpoint_dir / f"{latest_step}"
206
+
207
  # Clear memory before loading checkpoints
208
  if torch.cuda.is_available():
209
  torch.cuda.empty_cache()
 
214
  # Load model state with error handling
215
  logging.info("Loading model state...")
216
  safetensors_path = ckpt_dir / "pytorch_model.safetensors"
217
+
218
  if safetensors_path.exists():
219
  model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
220
  safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
221
  logging.info("Loaded model state from safetensors format")
222
  else:
223
  raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
224
+
225
  torch.cuda.empty_cache()
226
  gc.collect()
227
  log_memory_usage(device, latest_step, "after_loading_model")
228
+
229
  # Load optimizer state with error handling
230
  logging.info("Loading optimizer state...")
231
  optimizer_path = ckpt_dir / "optimizer.pt"
232
+
233
  if optimizer_path.exists():
234
  optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
235
  logging.info("Loaded optimizer state from pt format")
236
  else:
237
  raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
238
+
239
  optimizer.load_state_dict(optimizer_state_dict)
240
  del optimizer_state_dict
241
  torch.cuda.empty_cache()
242
  gc.collect()
243
  log_memory_usage(device, latest_step, "after_loading_optimizer")
244
+
 
245
  # Load metadata
246
  logging.info("Loading metadata...")
247
  metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
 
250
  torch.cuda.empty_cache()
251
  gc.collect()
252
  log_memory_usage(device, latest_step, "after_loading_metadata")
253
+
254
  logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
255
  return global_step
256
+
257
  except RuntimeError as e:
258
  if "out of memory" in str(e):
259
  # Clear memory and provide detailed error message
260
  torch.cuda.empty_cache()
261
  gc.collect()
262
+ logging.error(f"Out of memory error while loading checkpoint: {e!s}")
263
  log_memory_usage(device, latest_step, "after_oom_error")
264
+ raise RuntimeError(
265
+ "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
266
+ ) from e
267
  raise
268
 
269
 
270
  def get_latest_checkpoint_step(checkpoint_dir):
271
+ """Get the latest checkpoint step number from a checkpoint directory."""
272
+ checkpoint_steps = [
273
+ int(d.name)
274
+ for d in checkpoint_dir.iterdir()
275
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
276
+ ]
277
+ return max(checkpoint_steps) if checkpoint_steps else None
278
 
279
 
280
  def log_memory_usage(device, step, phase="unknown"):
281
+ """Log detailed memory usage information."""
282
+ if not torch.cuda.is_available():
283
+ return
284
+
285
+ memory_allocated = torch.cuda.memory_allocated(device) / 1e9
286
+ memory_reserved = torch.cuda.memory_reserved(device) / 1e9
287
+ memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
288
+ memory_free = memory_free / 1e9
289
+
290
+ # Get more detailed memory info
291
+ memory_stats = torch.cuda.memory_stats(device)
292
+ max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
293
+ max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
294
+
295
+ # Get DDP info if available
296
+ ddp_info = ""
297
+ if dist.is_initialized():
298
+ ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
299
+
300
+ logging.info(
301
+ f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
302
+ )
303
 
304
 
305
  def train_loop(config: _config.TrainConfig):
306
+ use_ddp, local_rank, device = setup_ddp()
307
+ is_main = (not use_ddp) or (dist.get_rank() == 0)
308
+ set_seed(config.seed, local_rank)
309
+
310
+ # Initialize checkpoint directory and wandb
311
+ resuming = False
312
+ if config.resume:
313
+ # Find checkpoint directory based on experiment name
314
+ exp_checkpoint_dir = config.checkpoint_dir
315
+ if exp_checkpoint_dir.exists():
316
+ # Use validation to find the latest working checkpoint
317
+ latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
318
+ if latest_step is not None:
319
+ resuming = True
320
+ logging.info(
321
+ f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
322
+ )
323
+ else:
324
+ raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
325
+ else:
326
+ raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
327
+ elif config.overwrite and config.checkpoint_dir.exists():
328
+ shutil.rmtree(config.checkpoint_dir)
329
+ logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
330
+
331
+ # Create checkpoint directory with experiment name
332
+ if not resuming:
333
+ # For new runs, create experiment-specific checkpoint directory
334
+ exp_checkpoint_dir = config.checkpoint_dir
335
+ exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
336
+ logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
337
+ else:
338
+ # For resume, checkpoint_dir is already set to the experiment directory
339
+ logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
340
+
341
+ # Initialize wandb (only on main process)
342
+ if is_main:
343
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
344
+
345
+ # Build data loader using the unified data loader
346
+ # Calculate effective batch size per GPU for DDP
347
+ # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
348
+ world_size = torch.distributed.get_world_size() if use_ddp else 1
349
+ effective_batch_size = config.batch_size // world_size
350
+ logging.info(
351
+ f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
352
+ )
353
+
354
+ # Pass the original batch size to data loader - it will handle DDP splitting internally
355
+ loader, _ = build_datasets(config)
356
+
357
+ # Log sample images to wandb on first batch
358
+ if is_main and config.wandb_enabled and not resuming:
359
+ # Create a separate data loader for sample batch to avoid consuming the main loader
360
+ sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
361
+ sample_batch = next(iter(sample_data_loader))
362
+ # Convert observation and actions to torch tensors
363
+ observation, actions = sample_batch
364
+ sample_batch = observation.to_dict()
365
+ sample_batch["actions"] = actions
366
+
367
+ # Create sample images for wandb
368
+ images_to_log = []
369
+ # Get batch size from the first image tensor
370
+ batch_size = next(iter(sample_batch["image"].values())).shape[0]
371
+ for i in range(min(5, batch_size)):
372
+ # Concatenate all camera views horizontally for this batch item
373
+ # Convert from NCHW to NHWC format for wandb
374
+ img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
375
+ img_concatenated = img_concatenated.cpu().numpy()
376
+ images_to_log.append(wandb.Image(img_concatenated))
377
+
378
+ wandb.log({"camera_views": images_to_log}, step=0)
379
+
380
+ # Clear sample batch from memory aggressively
381
+ del sample_batch, observation, actions, images_to_log, img_concatenated
382
+ del sample_data_loader # Also delete the sample data loader
383
+ gc.collect()
384
+ if torch.cuda.is_available():
385
+ torch.cuda.empty_cache()
386
+ logging.info("Cleared sample batch and data loader from memory")
387
+
388
+ # Build model
389
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
390
+ # Convert dataclass to Pi0Config if needed
391
+ model_cfg = openpi.models.pi0_config.Pi0Config(
392
+ dtype=config.pytorch_training_precision,
393
+ action_dim=config.model.action_dim,
394
+ action_horizon=config.model.action_horizon,
395
+ max_token_len=config.model.max_token_len,
396
+ paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
397
+ action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
398
+ pi05=getattr(config.model, "pi05", False),
399
+ )
400
+ else:
401
+ model_cfg = config.model
402
+ # Update dtype to match pytorch_training_precision
403
+ object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
404
+
405
+ model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
406
+
407
+ if hasattr(model, "gradient_checkpointing_enable"):
408
+ enable_gradient_checkpointing = True
409
+ model.gradient_checkpointing_enable()
410
+ logging.info("Enabled gradient checkpointing for memory optimization")
411
+ else:
412
+ enable_gradient_checkpointing = False
413
+ logging.info("Gradient checkpointing is not supported for this model")
414
+
415
+ # Log initial memory usage after model creation
416
+ if is_main and torch.cuda.is_available():
417
+ log_memory_usage(device, 0, "after_model_creation")
418
+
419
+ # Enable memory optimizations for large-scale training
420
+ if world_size >= 8:
421
+ torch.backends.cudnn.benchmark = True
422
+ torch.backends.cuda.matmul.allow_tf32 = True
423
+ torch.backends.cudnn.allow_tf32 = True
424
+ # Set memory allocation configuration
425
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
426
+ logging.info("Enabled memory optimizations for 8+ GPU training")
427
+
428
+ if use_ddp:
429
+ model = torch.nn.parallel.DistributedDataParallel(
430
+ model,
431
+ device_ids=[device.index] if device.type == "cuda" else None,
432
+ find_unused_parameters=True, # Disable for memory efficiency
433
+ gradient_as_bucket_view=True, # Enable for memory efficiency
434
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
435
+ )
436
+
437
+ # Load weights from weight_loader if specified (for fine-tuning)
438
+ if config.pytorch_weight_path is not None:
439
+ logging.info(f"Loading weights from: {config.pytorch_weight_path}")
440
+
441
+ model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
442
+ safetensors.torch.load_model(
443
+ (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
444
+ )
445
+ logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
446
+
447
+ # Optimizer + learning rate schedule from config
448
+ warmup_steps = config.lr_schedule.warmup_steps
449
+ peak_lr = config.lr_schedule.peak_lr
450
+ decay_steps = config.lr_schedule.decay_steps
451
+ end_lr = config.lr_schedule.decay_lr
452
+
453
+ # Create optimizer with config parameters
454
+ optim = torch.optim.AdamW(
455
+ model.parameters(),
456
+ lr=peak_lr,
457
+ betas=(config.optimizer.b1, config.optimizer.b2),
458
+ eps=config.optimizer.eps,
459
+ weight_decay=config.optimizer.weight_decay,
460
+ )
461
+
462
+ # Load checkpoint if resuming
463
+ global_step = 0
464
+ if resuming:
465
+ global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
466
+ logging.info(f"Resumed training from step {global_step}")
467
+
468
+ def lr_schedule(step: int):
469
+ if step < warmup_steps:
470
+ # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
471
+ init_lr = peak_lr / (warmup_steps + 1)
472
+ return init_lr + (peak_lr - init_lr) * step / warmup_steps
473
+ # cosine decay
474
+ progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
475
+ cos = 0.5 * (1 + np.cos(np.pi * progress))
476
+ return end_lr + (peak_lr - end_lr) * cos
477
+
478
+ model.train()
479
+ start_time = time.time()
480
+ infos = [] # Collect stats over log interval
481
+ if is_main:
482
+ logging.info(
483
+ f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
484
+ )
485
+ logging.info(
486
+ f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
487
+ )
488
+ logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
489
+ logging.info(
490
+ f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
491
+ )
492
+ logging.info(
493
+ f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
494
+ )
495
+ logging.info("EMA is not supported for PyTorch training")
496
+ logging.info(f"Training precision: {model_cfg.dtype}")
497
+
498
+ # Training loop - iterate until we reach num_train_steps
499
+ pbar = (
500
+ tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
501
+ if is_main
502
+ else None
503
+ )
504
+
505
+ while global_step < config.num_train_steps:
506
+ # Set epoch for distributed training
507
+ if use_ddp and hasattr(loader, "set_epoch"):
508
+ loader.set_epoch(global_step // len(loader))
509
+
510
+ for observation, actions in loader:
511
+ # Check if we've reached the target number of steps
512
+ if global_step >= config.num_train_steps:
513
+ break
514
+
515
+ # The unified data loader returns (observation, actions) tuple
516
+ observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
517
+ actions = actions.to(torch.float32) # noqa: PLW2901
518
+ actions = actions.to(device) # noqa: PLW2901
519
+
520
+ # Update LR
521
+ for pg in optim.param_groups:
522
+ pg["lr"] = lr_schedule(global_step)
523
+
524
+ # Forward pass
525
+ losses = model(observation, actions)
526
+ # Ensure losses is a tensor and handle different return types
527
+ if isinstance(losses, list | tuple):
528
+ losses = torch.stack(losses)
529
+ elif not isinstance(losses, torch.Tensor):
530
+ losses = torch.tensor(losses, device=device, dtype=torch.float32)
531
+
532
+ loss = losses.mean()
533
+
534
+ # Backward pass
535
+ loss.backward()
536
+
537
+ # Log memory usage after backward pass
538
+ if global_step < 5 and is_main and torch.cuda.is_available():
539
+ log_memory_usage(device, global_step, "after_backward")
540
+
541
+ # Gradient clipping
542
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
543
+
544
+ # Optimizer step
545
+ optim.step()
546
+ optim.zero_grad(set_to_none=True)
547
+
548
+ # Clear gradients more aggressively
549
+ for param in model.parameters():
550
+ if param.grad is not None:
551
+ param.grad.detach_()
552
+ param.grad = None
553
+
554
+ # Collect stats
555
+ if is_main:
556
+ infos.append(
557
+ {
558
+ "loss": loss.item(),
559
+ "learning_rate": optim.param_groups[0]["lr"],
560
+ "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
561
+ }
562
+ )
563
+
564
+ if is_main and (global_step % config.log_interval == 0):
565
+ elapsed = time.time() - start_time
566
+
567
+ # Average stats over log interval
568
+ avg_loss = sum(info["loss"] for info in infos) / len(infos)
569
+ avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
570
+
571
+ avg_grad_norm = None
572
+ if any("grad_norm" in info for info in infos):
573
+ vals = [
574
+ info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
575
+ ]
576
+ if len(vals) > 0:
577
+ avg_grad_norm = sum(vals) / len(vals)
578
+ logging.info(
579
+ f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
580
+ if avg_grad_norm is not None
581
+ else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
582
+ )
583
+
584
+ # Log to wandb
585
+ if config.wandb_enabled and len(infos) > 0:
586
+ log_payload = {
587
+ "loss": avg_loss,
588
+ "learning_rate": avg_lr,
589
+ "step": global_step,
590
+ "time_per_step": elapsed / config.log_interval,
591
+ }
592
+ if avg_grad_norm is not None:
593
+ log_payload["grad_norm"] = avg_grad_norm
594
+ wandb.log(log_payload, step=global_step)
595
+
596
+ start_time = time.time()
597
+ infos = [] # Reset stats collection
598
+
599
+ global_step += 1
600
+ # Save checkpoint using the new mechanism
601
+ save_checkpoint(model, optim, global_step, config, is_main)
602
+
603
+ # Update progress bar
604
+ if pbar is not None:
605
+ pbar.update(1)
606
+ pbar.set_postfix(
607
+ {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
608
+ )
609
+
610
+ # Close progress bar
611
+ if pbar is not None:
612
+ pbar.close()
613
+
614
+ # Finish wandb run
615
+ if is_main and config.wandb_enabled:
616
+ wandb.finish()
617
+
618
+ cleanup_ddp()
619
 
620
 
621
  def main():
622
+ init_logging()
623
+ config = _config.cli()
624
+ train_loop(config)
625
 
626
 
627
  if __name__ == "__main__":
628
+ main()
src/openpi/models/model.py CHANGED
@@ -4,7 +4,7 @@ import dataclasses
4
  import enum
5
  import logging
6
  import pathlib
7
- from typing import Generic, TypeVar, Union
8
 
9
  import augmax
10
  from flax import nnx
@@ -12,7 +12,6 @@ from flax import struct
12
  from flax import traverse_util
13
  import jax
14
  import jax.numpy as jnp
15
- import logging
16
  import numpy as np
17
  import orbax.checkpoint as ocp
18
  import safetensors
@@ -25,7 +24,7 @@ import openpi.shared.array_typing as at
25
  logger = logging.getLogger("openpi")
26
 
27
  # Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
28
- ArrayT = TypeVar("ArrayT", bound=Union[jax.Array, torch.Tensor, np.ndarray])
29
 
30
 
31
  class ModelType(enum.Enum):
@@ -117,7 +116,7 @@ class Observation(Generic[ArrayT]):
117
  for key in data["image"]:
118
  if data["image"][key].dtype == np.uint8:
119
  data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
120
- elif hasattr(data["image"][key], 'dtype') and data["image"][key].dtype == torch.uint8:
121
  data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
122
  return cls(
123
  images=data["image"],
 
4
  import enum
5
  import logging
6
  import pathlib
7
+ from typing import Generic, TypeVar
8
 
9
  import augmax
10
  from flax import nnx
 
12
  from flax import traverse_util
13
  import jax
14
  import jax.numpy as jnp
 
15
  import numpy as np
16
  import orbax.checkpoint as ocp
17
  import safetensors
 
24
  logger = logging.getLogger("openpi")
25
 
26
  # Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
27
+ ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
28
 
29
 
30
  class ModelType(enum.Enum):
 
116
  for key in data["image"]:
117
  if data["image"][key].dtype == np.uint8:
118
  data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
119
+ elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
120
  data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
121
  return cls(
122
  images=data["image"],
src/openpi/models/pi0_config.py CHANGED
@@ -48,6 +48,7 @@ class Pi0Config(_model.BaseModelConfig):
48
  @override
49
  def create(self, rng: at.KeyArrayLike) -> "Pi0":
50
  from openpi.models.pi0 import Pi0
 
51
  return Pi0(self, rngs=nnx.Rngs(rng))
52
 
53
  @override
@@ -104,4 +105,4 @@ class Pi0Config(_model.BaseModelConfig):
104
  )
105
  if not filters:
106
  return nnx.Nothing
107
- return nnx.All(*filters)
 
48
  @override
49
  def create(self, rng: at.KeyArrayLike) -> "Pi0":
50
  from openpi.models.pi0 import Pi0
51
+
52
  return Pi0(self, rngs=nnx.Rngs(rng))
53
 
54
  @override
 
105
  )
106
  if not filters:
107
  return nnx.Nothing
108
+ return nnx.All(*filters)
src/openpi/models/tokenizer.py CHANGED
@@ -254,7 +254,7 @@ class FSQTokenizer:
254
  assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
255
  # Download tokenizer
256
  path = download.maybe_download(fsq_tokenizer_path)
257
- tok_path = os.path.join(path, os.listdir(path)[0]) # noqa: PTH118
258
 
259
  # Split step from path
260
  step = int(tok_path.split("/")[-1])
 
254
  assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
255
  # Download tokenizer
256
  path = download.maybe_download(fsq_tokenizer_path)
257
+ tok_path = os.path.join(path, os.listdir(path)[0])
258
 
259
  # Split step from path
260
  step = int(tok_path.split("/")[-1])
src/openpi/models_pytorch/gemma_pytorch.py CHANGED
@@ -1,19 +1,28 @@
1
- from pytest import Cache
 
 
2
  import torch
3
  from torch import nn
4
- from transformers import GemmaForCausalLM, PaliGemmaForConditionalGeneration
5
- from transformers.models.gemma import modeling_gemma
6
-
7
  from transformers.models.auto import CONFIG_MAPPING
8
- from typing import Literal
9
 
10
 
11
  class PaliGemmaWithExpertModel(nn.Module):
12
- def __init__(self, vlm_config, action_expert_config, use_adarms=[False, False], precision: Literal["bfloat16", "float32"] = "bfloat16"):
 
 
 
 
 
 
 
 
13
  super().__init__()
14
 
15
  vlm_config_hf = CONFIG_MAPPING["paligemma"]()
16
- vlm_config_hf._vocab_size = 257152
17
  vlm_config_hf.image_token_index = 257152
18
  vlm_config_hf.text_config.hidden_size = vlm_config.width
19
  vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
@@ -53,9 +62,9 @@ class PaliGemmaWithExpertModel(nn.Module):
53
 
54
  def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
55
  if precision == "bfloat16":
56
- self = self.to(dtype=torch.bfloat16)
57
  elif precision == "float32":
58
- self = self.to(dtype=torch.float32)
59
  return
60
  else:
61
  raise ValueError(f"Invalid precision: {precision}")
@@ -83,11 +92,13 @@ class PaliGemmaWithExpertModel(nn.Module):
83
  self,
84
  attention_mask: torch.Tensor | None = None,
85
  position_ids: torch.LongTensor | None = None,
86
- past_key_values: list[torch.FloatTensor] | Cache | None = None,
87
- inputs_embeds: list[torch.FloatTensor] = None,
88
  use_cache: bool | None = None,
89
- adarms_cond: list[torch.Tensor] = [None, None],
90
  ):
 
 
91
  if inputs_embeds[1] is None:
92
  prefix_output = self.paligemma.language_model.forward(
93
  inputs_embeds=inputs_embeds[0],
@@ -115,45 +126,45 @@ class PaliGemmaWithExpertModel(nn.Module):
115
  else:
116
  models = [self.paligemma.language_model, self.gemma_expert.model]
117
  num_layers = self.paligemma.config.text_config.num_hidden_layers
118
-
119
  # Check if gradient checkpointing is enabled for any of the models
120
  use_gradient_checkpointing = (
121
- hasattr(self.gemma_expert.model, 'gradient_checkpointing') and
122
- self.gemma_expert.model.gradient_checkpointing and
123
- self.training
124
- ) or (
125
- hasattr(self, 'gradient_checkpointing') and
126
- self.gradient_checkpointing and
127
- self.training
128
- )
129
-
130
  # Force enable gradient checkpointing if we're in training mode and the model supports it
131
- if self.training and hasattr(self.gemma_expert.model, 'gradient_checkpointing'):
132
  if not self.gemma_expert.model.gradient_checkpointing:
133
  print("Forcing gradient checkpointing to be enabled for Gemma expert model")
134
  self.gemma_expert.model.gradient_checkpointing = True
135
  use_gradient_checkpointing = True
136
-
137
  # Debug gradient checkpointing status
138
- if hasattr(self, '_debug_gc_printed') and not self._debug_gc_printed:
139
  print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
140
  print(f"Model training mode: {self.training}")
141
- print(f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}")
142
- if hasattr(self.gemma_expert.model, 'gradient_checkpointing'):
143
- print(f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}")
 
 
 
 
144
  self._debug_gc_printed = True
145
-
146
  # Define the complete layer computation function for gradient checkpointing
147
  def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
148
  models = [self.paligemma.language_model, self.gemma_expert.model]
149
-
150
  query_states = []
151
  key_states = []
152
  value_states = []
153
  gates = []
154
  for i, hidden_states in enumerate(inputs_embeds):
155
  layer = models[i].layers[layer_idx]
156
- hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
157
  gates.append(gate)
158
 
159
  input_shape = hidden_states.shape[:-1]
@@ -171,16 +182,29 @@ class PaliGemmaWithExpertModel(nn.Module):
171
  key_states = torch.cat(key_states, dim=2)
172
  value_states = torch.cat(value_states, dim=2)
173
 
174
- dummy_tensor = torch.zeros(query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype)
 
 
 
 
 
 
175
  cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
176
- query_states, key_states = modeling_gemma.apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
 
 
177
 
178
  batch_size = query_states.shape[0]
179
  scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
180
-
181
  # Attention computation
182
  att_output, _ = modeling_gemma.eager_attention_forward(
183
- self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling
 
 
 
 
 
184
  )
185
  # Get head_dim from the current layer, not from the model
186
  head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
@@ -195,10 +219,10 @@ class PaliGemmaWithExpertModel(nn.Module):
195
 
196
  if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
197
  att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
198
- out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
199
 
200
  # first residual
201
- out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i])
202
  after_first_residual = out_emb.clone()
203
  out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
204
  # Convert to bfloat16 if the next layer (mlp) uses bfloat16
@@ -207,10 +231,10 @@ class PaliGemmaWithExpertModel(nn.Module):
207
 
208
  out_emb = layer.mlp(out_emb)
209
  # second residual
210
- out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate)
211
  outputs_embeds.append(out_emb)
212
  start_pos = end_pos
213
-
214
  return outputs_embeds
215
 
216
  # Process all layers with gradient checkpointing if enabled
@@ -218,12 +242,18 @@ class PaliGemmaWithExpertModel(nn.Module):
218
  if use_gradient_checkpointing:
219
  inputs_embeds = torch.utils.checkpoint.checkpoint(
220
  compute_layer_complete,
221
- layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
 
 
 
 
222
  use_reentrant=False,
223
- preserve_rng_state=False
224
  )
225
  else:
226
- inputs_embeds = compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond)
 
 
227
 
228
  # Old code removed - now using compute_layer_complete function above
229
 
@@ -235,14 +265,11 @@ class PaliGemmaWithExpertModel(nn.Module):
235
  out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
236
  outputs_embeds.append(out_emb)
237
  return outputs_embeds
238
-
239
  # Apply gradient checkpointing to final norm if enabled
240
  if use_gradient_checkpointing:
241
  outputs_embeds = torch.utils.checkpoint.checkpoint(
242
- compute_final_norms,
243
- inputs_embeds, adarms_cond,
244
- use_reentrant=False,
245
- preserve_rng_state=False
246
  )
247
  else:
248
  outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
@@ -251,4 +278,4 @@ class PaliGemmaWithExpertModel(nn.Module):
251
  suffix_output = outputs_embeds[1]
252
  prefix_past_key_values = None
253
 
254
- return [prefix_output, suffix_output], prefix_past_key_values
 
1
+ from typing import Literal
2
+
3
+ import pytest
4
  import torch
5
  from torch import nn
6
+ from transformers import GemmaForCausalLM
7
+ from transformers import PaliGemmaForConditionalGeneration
 
8
  from transformers.models.auto import CONFIG_MAPPING
9
+ from transformers.models.gemma import modeling_gemma
10
 
11
 
12
  class PaliGemmaWithExpertModel(nn.Module):
13
+ def __init__(
14
+ self,
15
+ vlm_config,
16
+ action_expert_config,
17
+ use_adarms=None,
18
+ precision: Literal["bfloat16", "float32"] = "bfloat16",
19
+ ):
20
+ if use_adarms is None:
21
+ use_adarms = [False, False]
22
  super().__init__()
23
 
24
  vlm_config_hf = CONFIG_MAPPING["paligemma"]()
25
+ vlm_config_hf._vocab_size = 257152 # noqa: SLF001
26
  vlm_config_hf.image_token_index = 257152
27
  vlm_config_hf.text_config.hidden_size = vlm_config.width
28
  vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
 
62
 
63
  def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
64
  if precision == "bfloat16":
65
+ self.to(dtype=torch.bfloat16)
66
  elif precision == "float32":
67
+ self.to(dtype=torch.float32)
68
  return
69
  else:
70
  raise ValueError(f"Invalid precision: {precision}")
 
92
  self,
93
  attention_mask: torch.Tensor | None = None,
94
  position_ids: torch.LongTensor | None = None,
95
+ past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
96
+ inputs_embeds: list[torch.FloatTensor] | None = None,
97
  use_cache: bool | None = None,
98
+ adarms_cond: list[torch.Tensor] | None = None,
99
  ):
100
+ if adarms_cond is None:
101
+ adarms_cond = [None, None]
102
  if inputs_embeds[1] is None:
103
  prefix_output = self.paligemma.language_model.forward(
104
  inputs_embeds=inputs_embeds[0],
 
126
  else:
127
  models = [self.paligemma.language_model, self.gemma_expert.model]
128
  num_layers = self.paligemma.config.text_config.num_hidden_layers
129
+
130
  # Check if gradient checkpointing is enabled for any of the models
131
  use_gradient_checkpointing = (
132
+ hasattr(self.gemma_expert.model, "gradient_checkpointing")
133
+ and self.gemma_expert.model.gradient_checkpointing
134
+ and self.training
135
+ ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
136
+
 
 
 
 
137
  # Force enable gradient checkpointing if we're in training mode and the model supports it
138
+ if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
139
  if not self.gemma_expert.model.gradient_checkpointing:
140
  print("Forcing gradient checkpointing to be enabled for Gemma expert model")
141
  self.gemma_expert.model.gradient_checkpointing = True
142
  use_gradient_checkpointing = True
143
+
144
  # Debug gradient checkpointing status
145
+ if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
146
  print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
147
  print(f"Model training mode: {self.training}")
148
+ print(
149
+ f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
150
+ )
151
+ if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
152
+ print(
153
+ f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
154
+ )
155
  self._debug_gc_printed = True
156
+
157
  # Define the complete layer computation function for gradient checkpointing
158
  def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
159
  models = [self.paligemma.language_model, self.gemma_expert.model]
160
+
161
  query_states = []
162
  key_states = []
163
  value_states = []
164
  gates = []
165
  for i, hidden_states in enumerate(inputs_embeds):
166
  layer = models[i].layers[layer_idx]
167
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
168
  gates.append(gate)
169
 
170
  input_shape = hidden_states.shape[:-1]
 
182
  key_states = torch.cat(key_states, dim=2)
183
  value_states = torch.cat(value_states, dim=2)
184
 
185
+ dummy_tensor = torch.zeros(
186
+ query_states.shape[0],
187
+ query_states.shape[2],
188
+ query_states.shape[-1],
189
+ device=query_states.device,
190
+ dtype=query_states.dtype,
191
+ )
192
  cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
193
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
194
+ query_states, key_states, cos, sin, unsqueeze_dim=1
195
+ )
196
 
197
  batch_size = query_states.shape[0]
198
  scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
199
+
200
  # Attention computation
201
  att_output, _ = modeling_gemma.eager_attention_forward(
202
+ self.paligemma.language_model.layers[layer_idx].self_attn,
203
+ query_states,
204
+ key_states,
205
+ value_states,
206
+ attention_mask,
207
+ scaling,
208
  )
209
  # Get head_dim from the current layer, not from the model
210
  head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
 
219
 
220
  if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
221
  att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
222
+ out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
223
 
224
  # first residual
225
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
226
  after_first_residual = out_emb.clone()
227
  out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
228
  # Convert to bfloat16 if the next layer (mlp) uses bfloat16
 
231
 
232
  out_emb = layer.mlp(out_emb)
233
  # second residual
234
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
235
  outputs_embeds.append(out_emb)
236
  start_pos = end_pos
237
+
238
  return outputs_embeds
239
 
240
  # Process all layers with gradient checkpointing if enabled
 
242
  if use_gradient_checkpointing:
243
  inputs_embeds = torch.utils.checkpoint.checkpoint(
244
  compute_layer_complete,
245
+ layer_idx,
246
+ inputs_embeds,
247
+ attention_mask,
248
+ position_ids,
249
+ adarms_cond,
250
  use_reentrant=False,
251
+ preserve_rng_state=False,
252
  )
253
  else:
254
+ inputs_embeds = compute_layer_complete(
255
+ layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
256
+ )
257
 
258
  # Old code removed - now using compute_layer_complete function above
259
 
 
265
  out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
266
  outputs_embeds.append(out_emb)
267
  return outputs_embeds
268
+
269
  # Apply gradient checkpointing to final norm if enabled
270
  if use_gradient_checkpointing:
271
  outputs_embeds = torch.utils.checkpoint.checkpoint(
272
+ compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False
 
 
 
273
  )
274
  else:
275
  outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
 
278
  suffix_output = outputs_embeds[1]
279
  prefix_past_key_values = None
280
 
281
+ return [prefix_output, suffix_output], prefix_past_key_values
src/openpi/models_pytorch/pi0_pytorch.py CHANGED
@@ -1,10 +1,10 @@
1
- import math
2
  import logging
 
3
 
4
  import torch
5
  from torch import Tensor
6
  from torch import nn
7
- import torch.nn.functional as F
8
 
9
  import openpi.models.gemma as _gemma
10
  from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
@@ -17,7 +17,7 @@ def get_safe_dtype(target_dtype, device_type):
17
  # CPU doesn't support bfloat16, use float32 instead
18
  if target_dtype == torch.bfloat16:
19
  return torch.float32
20
- elif target_dtype == torch.float64:
21
  return torch.float64
22
  return target_dtype
23
 
@@ -39,16 +39,14 @@ def create_sinusoidal_pos_embedding(
39
  # Compute the outer product
40
  scaling_factor = 1.0 / period * 2 * math.pi
41
  sin_input = scaling_factor[None, :] * time[:, None]
42
- pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
43
- return pos_emb
44
 
45
 
46
  def sample_beta(alpha, beta, bsize, device):
47
  alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
48
  beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
49
  dist = torch.distributions.Beta(alpha_t, beta_t)
50
- samples = dist.sample((bsize,))
51
- return samples
52
 
53
 
54
  def make_att_2d_masks(pad_masks, att_masks):
@@ -80,8 +78,7 @@ def make_att_2d_masks(pad_masks, att_masks):
80
  cumsum = torch.cumsum(att_masks, dim=1)
81
  att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
82
  pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
83
- att_2d_masks = att_2d_masks & pad_2d_masks
84
- return att_2d_masks
85
 
86
 
87
  class PI0Pytorch(nn.Module):
@@ -93,7 +90,12 @@ class PI0Pytorch(nn.Module):
93
  paligemma_config = _gemma.get_config(config.paligemma_variant)
94
  action_expert_config = _gemma.get_config(config.action_expert_variant)
95
 
96
- self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_config, action_expert_config, use_adarms=[False, True] if self.pi05 else [False, False], precision=config.dtype)
 
 
 
 
 
97
 
98
  self.action_in_proj = nn.Linear(32, action_expert_config.width)
99
  self.action_out_proj = nn.Linear(action_expert_config.width, 32)
@@ -106,17 +108,20 @@ class PI0Pytorch(nn.Module):
106
  self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
107
  self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
108
 
109
- torch.set_float32_matmul_precision('high')
110
  self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
111
-
112
  # Initialize gradient checkpointing flag
113
  self.gradient_checkpointing_enabled = False
 
 
114
  try:
115
  from transformers.models.siglip import check
 
116
  if not check.check_whether_transformers_replace_is_installed_correctly():
117
- raise ValueError("TransformersReplace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.")
118
  except ImportError:
119
- raise ValueError("TransformersReplace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.")
120
 
121
  def gradient_checkpointing_enable(self):
122
  """Enable gradient checkpointing for memory optimization."""
@@ -124,7 +129,7 @@ class PI0Pytorch(nn.Module):
124
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
125
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
126
  self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
127
-
128
  logging.info("Enabled gradient checkpointing for PI0Pytorch model")
129
 
130
  def gradient_checkpointing_disable(self):
@@ -133,7 +138,7 @@ class PI0Pytorch(nn.Module):
133
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
134
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
135
  self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
136
-
137
  logging.info("Disabled gradient checkpointing for PI0Pytorch model")
138
 
139
  def is_gradient_checkpointing_enabled(self):
@@ -146,15 +151,14 @@ class PI0Pytorch(nn.Module):
146
  return torch.utils.checkpoint.checkpoint(
147
  func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
148
  )
149
- else:
150
- return func(*args, **kwargs)
151
 
152
  def _prepare_attention_masks_4d(self, att_2d_masks):
153
  """Helper method to prepare 4D attention masks for transformer."""
154
  att_2d_masks_4d = att_2d_masks[:, None, :, :]
155
  return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
156
 
157
- def _preprocess_observation(self, observation, train=True):
158
  """Helper method to preprocess observation."""
159
  observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
160
  return (
@@ -162,18 +166,17 @@ class PI0Pytorch(nn.Module):
162
  list(observation.image_masks.values()),
163
  observation.tokenized_prompt,
164
  observation.tokenized_prompt_mask,
165
- observation.state
166
  )
167
 
168
  def sample_noise(self, shape, device):
169
- noise = torch.normal(
170
  mean=0.0,
171
  std=1.0,
172
  size=shape,
173
  dtype=torch.float32,
174
  device=device,
175
  )
176
- return noise
177
 
178
  def sample_time(self, bsize, device):
179
  time_beta = sample_beta(1.5, 1.0, bsize, device)
@@ -189,19 +192,19 @@ class PI0Pytorch(nn.Module):
189
  embs = []
190
  pad_masks = []
191
  att_masks = []
192
-
193
  # Process images
194
- for img, img_mask in zip(images, img_masks):
 
195
  def image_embed_func(img):
196
  return self.paligemma_with_expert.embed_image(img)
197
-
198
  img_emb = self._apply_checkpoint(image_embed_func, img)
199
 
200
  bsize, num_img_embs = img_emb.shape[:2]
201
- img_mask = img_mask[:, None].expand(bsize, num_img_embs)
202
 
203
  embs.append(img_emb)
204
- pad_masks.append(img_mask)
205
 
206
  # Create attention masks so that image tokens attend to each other
207
  att_masks += [0] * num_img_embs
@@ -211,7 +214,7 @@ class PI0Pytorch(nn.Module):
211
  lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
212
  lang_emb_dim = lang_emb.shape[-1]
213
  return lang_emb * math.sqrt(lang_emb_dim)
214
-
215
  lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
216
 
217
  embs.append(lang_emb)
@@ -239,16 +242,16 @@ class PI0Pytorch(nn.Module):
239
 
240
  if not self.pi05:
241
  if self.state_proj.weight.dtype == torch.float32:
242
- state = state.to(torch.float32)
 
243
  # Embed state
244
  def state_proj_func(state):
245
  return self.state_proj(state)
246
-
247
  state_emb = self._apply_checkpoint(state_proj_func, state)
248
-
249
  embs.append(state_emb[:, None, :])
250
  bsize = state_emb.shape[0]
251
- dtype = state_emb.dtype
252
  device = state_emb.device
253
 
254
  state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
@@ -266,20 +269,19 @@ class PI0Pytorch(nn.Module):
266
  # Fuse timestep + action information using an MLP
267
  def action_proj_func(noisy_actions):
268
  return self.action_in_proj(noisy_actions)
269
-
270
  action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
271
 
272
  if not self.pi05:
273
  time_emb = time_emb[:, None, :].expand_as(action_emb)
274
  action_time_emb = torch.cat([action_emb, time_emb], dim=2)
275
-
276
  # Apply MLP layers
277
  def mlp_func(action_time_emb):
278
  x = self.action_time_mlp_in(action_time_emb)
279
  x = F.silu(x) # swish == silu
280
- x = self.action_time_mlp_out(x)
281
- return x
282
-
283
  action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
284
  adarms_cond = None
285
  else:
@@ -288,9 +290,8 @@ class PI0Pytorch(nn.Module):
288
  x = self.time_mlp_in(time_emb)
289
  x = F.silu(x) # swish == silu
290
  x = self.time_mlp_out(x)
291
- x = F.silu(x)
292
- return x
293
-
294
  time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
295
  action_time_emb = action_emb
296
  adarms_cond = time_emb
@@ -328,7 +329,10 @@ class PI0Pytorch(nn.Module):
328
 
329
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
330
  suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
331
- if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
 
 
 
332
  suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
333
  prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
334
 
@@ -349,25 +353,24 @@ class PI0Pytorch(nn.Module):
349
  past_key_values=None,
350
  inputs_embeds=[prefix_embs, suffix_embs],
351
  use_cache=False,
352
- adarms_cond=[None, adarms_cond]
353
  )
354
  return suffix_out
355
-
356
  suffix_out = self._apply_checkpoint(
357
  forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
358
  )
359
-
360
  suffix_out = suffix_out[:, -self.config.action_horizon :]
361
  suffix_out = suffix_out.to(dtype=torch.float32)
362
 
363
  # Apply gradient checkpointing to final action projection if enabled
364
  def action_out_proj_func(suffix_out):
365
  return self.action_out_proj(suffix_out)
366
-
367
  v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
368
 
369
- losses = F.mse_loss(u_t, v_t, reduction="none")
370
- return losses
371
 
372
  @torch.no_grad()
373
  def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
@@ -376,7 +379,7 @@ class PI0Pytorch(nn.Module):
376
  if noise is None:
377
  actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
378
  noise = self.sample_noise(actions_shape, device)
379
-
380
  images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
381
 
382
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
@@ -385,7 +388,7 @@ class PI0Pytorch(nn.Module):
385
 
386
  # Compute image and language key value cache
387
  prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
388
- self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager"
389
 
390
  _, past_key_values = self.paligemma_with_expert.forward(
391
  attention_mask=prefix_att_2d_masks_4d,
@@ -441,7 +444,7 @@ class PI0Pytorch(nn.Module):
441
 
442
  # Prepare attention masks
443
  full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
444
- self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager"
445
 
446
  outputs_embeds, _ = self.paligemma_with_expert.forward(
447
  attention_mask=full_att_2d_masks_4d,
@@ -449,12 +452,10 @@ class PI0Pytorch(nn.Module):
449
  past_key_values=past_key_values,
450
  inputs_embeds=[None, suffix_embs],
451
  use_cache=False,
452
- adarms_cond=[None, adarms_cond]
453
  )
454
 
455
  suffix_out = outputs_embeds[1]
456
  suffix_out = suffix_out[:, -self.config.action_horizon :]
457
  suffix_out = suffix_out.to(dtype=torch.float32)
458
- v_t = self.action_out_proj(suffix_out)
459
-
460
- return v_t
 
 
1
  import logging
2
+ import math
3
 
4
  import torch
5
  from torch import Tensor
6
  from torch import nn
7
+ import torch.nn.functional as F # noqa: N812
8
 
9
  import openpi.models.gemma as _gemma
10
  from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
 
17
  # CPU doesn't support bfloat16, use float32 instead
18
  if target_dtype == torch.bfloat16:
19
  return torch.float32
20
+ if target_dtype == torch.float64:
21
  return torch.float64
22
  return target_dtype
23
 
 
39
  # Compute the outer product
40
  scaling_factor = 1.0 / period * 2 * math.pi
41
  sin_input = scaling_factor[None, :] * time[:, None]
42
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
 
43
 
44
 
45
  def sample_beta(alpha, beta, bsize, device):
46
  alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
47
  beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
48
  dist = torch.distributions.Beta(alpha_t, beta_t)
49
+ return dist.sample((bsize,))
 
50
 
51
 
52
  def make_att_2d_masks(pad_masks, att_masks):
 
78
  cumsum = torch.cumsum(att_masks, dim=1)
79
  att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
80
  pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
81
+ return att_2d_masks & pad_2d_masks
 
82
 
83
 
84
  class PI0Pytorch(nn.Module):
 
90
  paligemma_config = _gemma.get_config(config.paligemma_variant)
91
  action_expert_config = _gemma.get_config(config.action_expert_variant)
92
 
93
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
94
+ paligemma_config,
95
+ action_expert_config,
96
+ use_adarms=[False, True] if self.pi05 else [False, False],
97
+ precision=config.dtype,
98
+ )
99
 
100
  self.action_in_proj = nn.Linear(32, action_expert_config.width)
101
  self.action_out_proj = nn.Linear(action_expert_config.width, 32)
 
108
  self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
109
  self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
110
 
111
+ torch.set_float32_matmul_precision("high")
112
  self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
113
+
114
  # Initialize gradient checkpointing flag
115
  self.gradient_checkpointing_enabled = False
116
+
117
+ msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`."
118
  try:
119
  from transformers.models.siglip import check
120
+
121
  if not check.check_whether_transformers_replace_is_installed_correctly():
122
+ raise ValueError(msg)
123
  except ImportError:
124
+ raise ValueError(msg) from None
125
 
126
  def gradient_checkpointing_enable(self):
127
  """Enable gradient checkpointing for memory optimization."""
 
129
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
130
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
131
  self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
132
+
133
  logging.info("Enabled gradient checkpointing for PI0Pytorch model")
134
 
135
  def gradient_checkpointing_disable(self):
 
138
  self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
139
  self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
140
  self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
141
+
142
  logging.info("Disabled gradient checkpointing for PI0Pytorch model")
143
 
144
  def is_gradient_checkpointing_enabled(self):
 
151
  return torch.utils.checkpoint.checkpoint(
152
  func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
153
  )
154
+ return func(*args, **kwargs)
 
155
 
156
  def _prepare_attention_masks_4d(self, att_2d_masks):
157
  """Helper method to prepare 4D attention masks for transformer."""
158
  att_2d_masks_4d = att_2d_masks[:, None, :, :]
159
  return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
160
 
161
+ def _preprocess_observation(self, observation, *, train=True):
162
  """Helper method to preprocess observation."""
163
  observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
164
  return (
 
166
  list(observation.image_masks.values()),
167
  observation.tokenized_prompt,
168
  observation.tokenized_prompt_mask,
169
+ observation.state,
170
  )
171
 
172
  def sample_noise(self, shape, device):
173
+ return torch.normal(
174
  mean=0.0,
175
  std=1.0,
176
  size=shape,
177
  dtype=torch.float32,
178
  device=device,
179
  )
 
180
 
181
  def sample_time(self, bsize, device):
182
  time_beta = sample_beta(1.5, 1.0, bsize, device)
 
192
  embs = []
193
  pad_masks = []
194
  att_masks = []
195
+
196
  # Process images
197
+ for img, img_mask in zip(images, img_masks, strict=True):
198
+
199
  def image_embed_func(img):
200
  return self.paligemma_with_expert.embed_image(img)
201
+
202
  img_emb = self._apply_checkpoint(image_embed_func, img)
203
 
204
  bsize, num_img_embs = img_emb.shape[:2]
 
205
 
206
  embs.append(img_emb)
207
+ pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
208
 
209
  # Create attention masks so that image tokens attend to each other
210
  att_masks += [0] * num_img_embs
 
214
  lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
215
  lang_emb_dim = lang_emb.shape[-1]
216
  return lang_emb * math.sqrt(lang_emb_dim)
217
+
218
  lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
219
 
220
  embs.append(lang_emb)
 
242
 
243
  if not self.pi05:
244
  if self.state_proj.weight.dtype == torch.float32:
245
+ state = state.to(torch.float32)
246
+
247
  # Embed state
248
  def state_proj_func(state):
249
  return self.state_proj(state)
250
+
251
  state_emb = self._apply_checkpoint(state_proj_func, state)
252
+
253
  embs.append(state_emb[:, None, :])
254
  bsize = state_emb.shape[0]
 
255
  device = state_emb.device
256
 
257
  state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
 
269
  # Fuse timestep + action information using an MLP
270
  def action_proj_func(noisy_actions):
271
  return self.action_in_proj(noisy_actions)
272
+
273
  action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
274
 
275
  if not self.pi05:
276
  time_emb = time_emb[:, None, :].expand_as(action_emb)
277
  action_time_emb = torch.cat([action_emb, time_emb], dim=2)
278
+
279
  # Apply MLP layers
280
  def mlp_func(action_time_emb):
281
  x = self.action_time_mlp_in(action_time_emb)
282
  x = F.silu(x) # swish == silu
283
+ return self.action_time_mlp_out(x)
284
+
 
285
  action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
286
  adarms_cond = None
287
  else:
 
290
  x = self.time_mlp_in(time_emb)
291
  x = F.silu(x) # swish == silu
292
  x = self.time_mlp_out(x)
293
+ return F.silu(x)
294
+
 
295
  time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
296
  action_time_emb = action_emb
297
  adarms_cond = time_emb
 
329
 
330
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
331
  suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
332
+ if (
333
+ self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
334
+ == torch.bfloat16
335
+ ):
336
  suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
337
  prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
338
 
 
353
  past_key_values=None,
354
  inputs_embeds=[prefix_embs, suffix_embs],
355
  use_cache=False,
356
+ adarms_cond=[None, adarms_cond],
357
  )
358
  return suffix_out
359
+
360
  suffix_out = self._apply_checkpoint(
361
  forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
362
  )
363
+
364
  suffix_out = suffix_out[:, -self.config.action_horizon :]
365
  suffix_out = suffix_out.to(dtype=torch.float32)
366
 
367
  # Apply gradient checkpointing to final action projection if enabled
368
  def action_out_proj_func(suffix_out):
369
  return self.action_out_proj(suffix_out)
370
+
371
  v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
372
 
373
+ return F.mse_loss(u_t, v_t, reduction="none")
 
374
 
375
  @torch.no_grad()
376
  def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
 
379
  if noise is None:
380
  actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
381
  noise = self.sample_noise(actions_shape, device)
382
+
383
  images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
384
 
385
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
 
388
 
389
  # Compute image and language key value cache
390
  prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
391
+ self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
392
 
393
  _, past_key_values = self.paligemma_with_expert.forward(
394
  attention_mask=prefix_att_2d_masks_4d,
 
444
 
445
  # Prepare attention masks
446
  full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
447
+ self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
448
 
449
  outputs_embeds, _ = self.paligemma_with_expert.forward(
450
  attention_mask=full_att_2d_masks_4d,
 
452
  past_key_values=past_key_values,
453
  inputs_embeds=[None, suffix_embs],
454
  use_cache=False,
455
+ adarms_cond=[None, adarms_cond],
456
  )
457
 
458
  suffix_out = outputs_embeds[1]
459
  suffix_out = suffix_out[:, -self.config.action_horizon :]
460
  suffix_out = suffix_out.to(dtype=torch.float32)
461
+ return self.action_out_proj(suffix_out)
 
 
src/openpi/models_pytorch/preprocessing_pytorch.py CHANGED
@@ -1,5 +1,6 @@
1
- import logging
2
  from collections.abc import Sequence
 
 
3
  import torch
4
 
5
  from openpi.shared import image_tools
@@ -15,6 +16,7 @@ IMAGE_KEYS = (
15
 
16
  IMAGE_RESOLUTION = (224, 224)
17
 
 
18
  def preprocess_observation_pytorch(
19
  observation,
20
  *,
@@ -23,7 +25,7 @@ def preprocess_observation_pytorch(
23
  image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
24
  ):
25
  """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
26
-
27
  This function avoids complex type annotations that can cause torch.compile issues.
28
  """
29
  if not set(image_keys).issubset(observation.images):
@@ -67,14 +69,14 @@ def preprocess_observation_pytorch(
67
  # Use tensor operations instead of .item() for torch.compile compatibility
68
  start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
69
  start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
70
- image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :]
71
 
72
  # Resize back to original size
73
  image = torch.nn.functional.interpolate(
74
  image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
75
  size=(height, width),
76
- mode='bilinear',
77
- align_corners=False
78
  ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
79
 
80
  # Random rotation (small angles)
@@ -93,7 +95,7 @@ def preprocess_observation_pytorch(
93
  grid_y = torch.linspace(-1, 1, height, device=image.device)
94
 
95
  # Create meshgrid
96
- grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij')
97
 
98
  # Expand to batch dimension
99
  grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
@@ -109,9 +111,9 @@ def preprocess_observation_pytorch(
109
  image = torch.nn.functional.grid_sample(
110
  image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
111
  grid,
112
- mode='bilinear',
113
- padding_mode='zeros',
114
- align_corners=False
115
  ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
116
 
117
  # Color augmentations for all cameras
@@ -159,7 +161,7 @@ def preprocess_observation_pytorch(
159
  def __init__(self, **kwargs):
160
  for key, value in kwargs.items():
161
  setattr(self, key, value)
162
-
163
  return SimpleProcessedObservation(
164
  images=out_images,
165
  image_masks=out_masks,
 
 
1
  from collections.abc import Sequence
2
+ import logging
3
+
4
  import torch
5
 
6
  from openpi.shared import image_tools
 
16
 
17
  IMAGE_RESOLUTION = (224, 224)
18
 
19
+
20
  def preprocess_observation_pytorch(
21
  observation,
22
  *,
 
25
  image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
26
  ):
27
  """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
28
+
29
  This function avoids complex type annotations that can cause torch.compile issues.
30
  """
31
  if not set(image_keys).issubset(observation.images):
 
69
  # Use tensor operations instead of .item() for torch.compile compatibility
70
  start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
71
  start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
72
+ image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
73
 
74
  # Resize back to original size
75
  image = torch.nn.functional.interpolate(
76
  image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
77
  size=(height, width),
78
+ mode="bilinear",
79
+ align_corners=False,
80
  ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
81
 
82
  # Random rotation (small angles)
 
95
  grid_y = torch.linspace(-1, 1, height, device=image.device)
96
 
97
  # Create meshgrid
98
+ grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
99
 
100
  # Expand to batch dimension
101
  grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
 
111
  image = torch.nn.functional.grid_sample(
112
  image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
113
  grid,
114
+ mode="bilinear",
115
+ padding_mode="zeros",
116
+ align_corners=False,
117
  ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
118
 
119
  # Color augmentations for all cameras
 
161
  def __init__(self, **kwargs):
162
  for key, value in kwargs.items():
163
  setattr(self, key, value)
164
+
165
  return SimpleProcessedObservation(
166
  images=out_images,
167
  image_masks=out_masks,
src/openpi/policies/policy.py CHANGED
@@ -35,7 +35,7 @@ class Policy(BasePolicy):
35
  is_pytorch: bool = False,
36
  ):
37
  """Initialize the Policy.
38
-
39
  Args:
40
  model: The model to use for action sampling.
41
  rng: Random number generator key for JAX models. Ignored for PyTorch models.
@@ -43,7 +43,7 @@ class Policy(BasePolicy):
43
  output_transforms: Output data transformations to apply after inference.
44
  sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
45
  metadata: Additional metadata to store with the policy.
46
- pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
47
  Only relevant when is_pytorch=True.
48
  is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
49
  """
@@ -81,10 +81,7 @@ class Policy(BasePolicy):
81
  # Prepare kwargs for sample_actions
82
  sample_kwargs = dict(self._sample_kwargs)
83
  if noise is not None:
84
- if self._is_pytorch_model:
85
- noise = torch.from_numpy(noise).to(self._pytorch_device)
86
- else:
87
- noise = jnp.asarray(noise)
88
 
89
  if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
90
  noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
 
35
  is_pytorch: bool = False,
36
  ):
37
  """Initialize the Policy.
38
+
39
  Args:
40
  model: The model to use for action sampling.
41
  rng: Random number generator key for JAX models. Ignored for PyTorch models.
 
43
  output_transforms: Output data transformations to apply after inference.
44
  sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
45
  metadata: Additional metadata to store with the policy.
46
+ pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
47
  Only relevant when is_pytorch=True.
48
  is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
49
  """
 
81
  # Prepare kwargs for sample_actions
82
  sample_kwargs = dict(self._sample_kwargs)
83
  if noise is not None:
84
+ noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
 
 
 
85
 
86
  if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
87
  noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
src/openpi/policies/policy_config.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
- import pathlib
3
  import os
 
4
  from typing import Any
5
 
6
  import jax.numpy as jnp
@@ -35,9 +35,9 @@ def create_trained_policy(
35
  data if it doesn't already exist.
36
  norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
37
  from the checkpoint directory.
38
- pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
39
  If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
40
-
41
  Note:
42
  The function automatically detects whether the model is PyTorch-based by checking for the
43
  presence of "model.safensors" in the checkpoint directory.
@@ -52,7 +52,7 @@ def create_trained_policy(
52
  logging.info("Loading model...")
53
  if is_pytorch:
54
  model = train_config.model.load_pytorch(train_config, weight_path)
55
- model.paligemma_with_expert.to_bfloat16_for_selected_params('bfloat16')
56
  else:
57
  model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
58
  data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
@@ -67,13 +67,11 @@ def create_trained_policy(
67
  if is_pytorch and pytorch_device is None:
68
  try:
69
  import torch
70
- if torch.cuda.is_available():
71
- pytorch_device = "cuda"
72
- else:
73
- pytorch_device = "cpu"
74
  except ImportError:
75
  pytorch_device = "cpu"
76
-
77
  return _policy.Policy(
78
  model,
79
  transforms=[
 
1
  import logging
 
2
  import os
3
+ import pathlib
4
  from typing import Any
5
 
6
  import jax.numpy as jnp
 
35
  data if it doesn't already exist.
36
  norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
37
  from the checkpoint directory.
38
+ pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
39
  If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
40
+
41
  Note:
42
  The function automatically detects whether the model is PyTorch-based by checking for the
43
  presence of "model.safensors" in the checkpoint directory.
 
52
  logging.info("Loading model...")
53
  if is_pytorch:
54
  model = train_config.model.load_pytorch(train_config, weight_path)
55
+ model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
56
  else:
57
  model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
58
  data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
 
67
  if is_pytorch and pytorch_device is None:
68
  try:
69
  import torch
70
+
71
+ pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
72
  except ImportError:
73
  pytorch_device = "cpu"
74
+
75
  return _policy.Policy(
76
  model,
77
  transforms=[
src/openpi/shared/array_typing.py CHANGED
@@ -7,7 +7,6 @@ import beartype
7
  import jax
8
  import jax._src.tree_util as private_tree_util
9
  import jax.core
10
- from jaxtyping import Array # noqa: F401
11
  from jaxtyping import ArrayLike
12
  from jaxtyping import Bool # noqa: F401
13
  from jaxtyping import DTypeLike # noqa: F401
@@ -31,6 +30,7 @@ _original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_an
31
  # Redefine Array to include both JAX arrays and PyTorch tensors
32
  Array = jax.Array | torch.Tensor
33
 
 
34
  def _check_dataclass_annotations(self, typechecker):
35
  if not any(
36
  frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}
 
7
  import jax
8
  import jax._src.tree_util as private_tree_util
9
  import jax.core
 
10
  from jaxtyping import ArrayLike
11
  from jaxtyping import Bool # noqa: F401
12
  from jaxtyping import DTypeLike # noqa: F401
 
30
  # Redefine Array to include both JAX arrays and PyTorch tensors
31
  Array = jax.Array | torch.Tensor
32
 
33
+
34
  def _check_dataclass_annotations(self, typechecker):
35
  if not any(
36
  frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}
src/openpi/shared/image_tools.py CHANGED
@@ -3,7 +3,7 @@ import functools
3
  import jax
4
  import jax.numpy as jnp
5
  import torch
6
- import torch.nn.functional as F
7
 
8
  import openpi.shared.array_typing as at
9
 
@@ -60,13 +60,13 @@ def resize_with_pad_torch(
60
  ) -> torch.Tensor:
61
  """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
62
  by padding with black. If the image is float32, it must be in the range [-1, 1].
63
-
64
  Args:
65
  images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
66
  height: Target height
67
  width: Target width
68
  mode: Interpolation mode ('bilinear', 'nearest', etc.)
69
-
70
  Returns:
71
  Resized and padded tensor with same shape format as input
72
  """
@@ -91,10 +91,7 @@ def resize_with_pad_torch(
91
 
92
  # Resize
93
  resized_images = F.interpolate(
94
- images,
95
- size=(resized_height, resized_width),
96
- mode=mode,
97
- align_corners=False if mode == "bilinear" else None
98
  )
99
 
100
  # Handle dtype-specific clipping
@@ -116,8 +113,8 @@ def resize_with_pad_torch(
116
  padded_images = F.pad(
117
  resized_images,
118
  (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
119
- mode='constant',
120
- value=constant_value
121
  )
122
 
123
  # Convert back to original format if needed
@@ -126,4 +123,4 @@ def resize_with_pad_torch(
126
  if batch_size == 1 and images.shape[0] == 1:
127
  padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
128
 
129
- return padded_images
 
3
  import jax
4
  import jax.numpy as jnp
5
  import torch
6
+ import torch.nn.functional as F # noqa: N812
7
 
8
  import openpi.shared.array_typing as at
9
 
 
60
  ) -> torch.Tensor:
61
  """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
62
  by padding with black. If the image is float32, it must be in the range [-1, 1].
63
+
64
  Args:
65
  images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
66
  height: Target height
67
  width: Target width
68
  mode: Interpolation mode ('bilinear', 'nearest', etc.)
69
+
70
  Returns:
71
  Resized and padded tensor with same shape format as input
72
  """
 
91
 
92
  # Resize
93
  resized_images = F.interpolate(
94
+ images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
 
 
 
95
  )
96
 
97
  # Handle dtype-specific clipping
 
113
  padded_images = F.pad(
114
  resized_images,
115
  (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
116
+ mode="constant",
117
+ value=constant_value,
118
  )
119
 
120
  # Convert back to original format if needed
 
123
  if batch_size == 1 and images.shape[0] == 1:
124
  padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
125
 
126
+ return padded_images
src/openpi/training/config.py CHANGED
@@ -6,7 +6,7 @@ import dataclasses
6
  import difflib
7
  import logging
8
  import pathlib
9
- from typing import Any, Protocol, TypeAlias, Literal
10
 
11
  import etils.epath as epath
12
  import flax.nnx as nnx
@@ -623,7 +623,7 @@ _CONFIGS = [
623
  data=SimpleDataConfig(
624
  assets=AssetsConfig(asset_id="droid"),
625
  data_transforms=lambda model: _transforms.Group(
626
- inputs=[droid_policy.DroidInputs( model_type=ModelType.PI05)],
627
  outputs=[droid_policy.DroidOutputs()],
628
  ),
629
  base_config=DataConfig(
 
6
  import difflib
7
  import logging
8
  import pathlib
9
+ from typing import Any, Literal, Protocol, TypeAlias
10
 
11
  import etils.epath as epath
12
  import flax.nnx as nnx
 
623
  data=SimpleDataConfig(
624
  assets=AssetsConfig(asset_id="droid"),
625
  data_transforms=lambda model: _transforms.Group(
626
+ inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)],
627
  outputs=[droid_policy.DroidOutputs()],
628
  ),
629
  base_config=DataConfig(
src/openpi/training/data_loader.py CHANGED
@@ -1,14 +1,13 @@
1
  from collections.abc import Iterator, Sequence
2
- from typing import Literal
3
  import multiprocessing
4
  import os
5
  import typing
6
- from typing import Protocol, SupportsIndex, TypeVar
7
 
8
  import jax
9
  import jax.numpy as jnp
10
  import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
11
- import logging
12
  import numpy as np
13
  import torch
14
 
@@ -231,7 +230,7 @@ def create_data_loader(
231
  framework: Literal["jax", "pytorch"],
232
  ) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
233
  """Create a data loader for training.
234
-
235
  Args:
236
  config: The training configuration.
237
  sharding: The sharding to use for the data loader (JAX only).
@@ -367,22 +366,21 @@ def create_rlds_data_loader(
367
  """
368
  if framework == "pytorch":
369
  raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
370
- else:
371
- dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
372
- dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
373
 
374
- data_loader = RLDSDataLoader(
375
- dataset,
376
- sharding=sharding,
377
- num_batches=num_batches,
378
- )
379
 
380
  return DataLoaderImpl(data_config, data_loader)
381
 
382
 
383
  class TorchDataLoader:
384
  """Torch data loader implementation."""
385
-
386
  def __init__(
387
  self,
388
  dataset,
 
1
  from collections.abc import Iterator, Sequence
2
+ import logging
3
  import multiprocessing
4
  import os
5
  import typing
6
+ from typing import Literal, Protocol, SupportsIndex, TypeVar
7
 
8
  import jax
9
  import jax.numpy as jnp
10
  import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
 
11
  import numpy as np
12
  import torch
13
 
 
230
  framework: Literal["jax", "pytorch"],
231
  ) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
232
  """Create a data loader for training.
233
+
234
  Args:
235
  config: The training configuration.
236
  sharding: The sharding to use for the data loader (JAX only).
 
366
  """
367
  if framework == "pytorch":
368
  raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
369
+ dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
370
+ dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
 
371
 
372
+ data_loader = RLDSDataLoader(
373
+ dataset,
374
+ sharding=sharding,
375
+ num_batches=num_batches,
376
+ )
377
 
378
  return DataLoaderImpl(data_config, data_loader)
379
 
380
 
381
  class TorchDataLoader:
382
  """Torch data loader implementation."""
383
+
384
  def __init__(
385
  self,
386
  dataset,