Kevin Black commited on
Commit ·
8999e55
1
Parent(s): aa7853c
Fix lint errors
Browse files- examples/convert_jax_model_to_pytorch.py +282 -174
- examples/droid/convert_droid_data_to_lerobot.py +1 -1
- pyproject.toml +1 -2
- scripts/train_pytorch.py +484 -455
- src/openpi/models/model.py +3 -4
- src/openpi/models/pi0_config.py +2 -1
- src/openpi/models/tokenizer.py +1 -1
- src/openpi/models_pytorch/gemma_pytorch.py +75 -48
- src/openpi/models_pytorch/pi0_pytorch.py +55 -54
- src/openpi/models_pytorch/preprocessing_pytorch.py +12 -10
- src/openpi/policies/policy.py +3 -6
- src/openpi/policies/policy_config.py +7 -9
- src/openpi/shared/array_typing.py +1 -1
- src/openpi/shared/image_tools.py +7 -10
- src/openpi/training/config.py +2 -2
- src/openpi/training/data_loader.py +11 -13
examples/convert_jax_model_to_pytorch.py
CHANGED
|
@@ -10,13 +10,13 @@ Usage:
|
|
| 10 |
# Just inspect keys:
|
| 11 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 12 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 13 |
-
|
| 14 |
# Convert to PyTorch:
|
| 15 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 16 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 17 |
|
| 18 |
-
Example:
|
| 19 |
-
# pi0_droid
|
| 20 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
|
| 21 |
|
| 22 |
# pi0_aloha_sim
|
|
@@ -33,44 +33,45 @@ import pathlib
|
|
| 33 |
import shutil
|
| 34 |
import traceback
|
| 35 |
|
|
|
|
| 36 |
import jax
|
| 37 |
import jax.numpy as jnp
|
| 38 |
import jax.sharding
|
| 39 |
import numpy as np
|
| 40 |
import orbax.checkpoint as ocp
|
| 41 |
-
import torch
|
| 42 |
import safetensors
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Import our modules
|
| 46 |
import openpi.models_pytorch.pi0_pytorch
|
| 47 |
-
import openpi.models.pi0_config
|
| 48 |
-
import openpi.models.gemma
|
| 49 |
import openpi.shared.download
|
| 50 |
-
import openpi.models.model
|
| 51 |
|
| 52 |
|
| 53 |
def flatten_for_inspection(tree, separator="/"):
|
| 54 |
"""
|
| 55 |
Flatten a nested dictionary for easy inspection of keys using flax.nnx.traversals.flatten_mapping.
|
| 56 |
-
|
| 57 |
Args:
|
| 58 |
tree: The nested dictionary (JAX pytree)
|
| 59 |
separator: Separator to use between key levels
|
| 60 |
-
|
| 61 |
Returns:
|
| 62 |
Dictionary with flattened keys and array shapes as values
|
| 63 |
"""
|
| 64 |
flattened = flatten_mapping(tree, separator=separator)
|
| 65 |
-
|
| 66 |
# Convert values to shape/dtype information for inspection
|
| 67 |
result = {}
|
| 68 |
for key, value in flattened.items():
|
| 69 |
-
if hasattr(value,
|
| 70 |
result[key] = f"shape: {value.shape}, dtype: {value.dtype}"
|
| 71 |
else:
|
| 72 |
result[key] = f"type: {type(value)}"
|
| 73 |
-
|
| 74 |
return result
|
| 75 |
|
| 76 |
|
|
@@ -90,19 +91,15 @@ def slice_paligemma_state_dict(state_dict, config):
|
|
| 90 |
"""Convert PaliGemma JAX parameters to PyTorch format."""
|
| 91 |
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
| 92 |
|
| 93 |
-
|
| 94 |
# patch embeddings
|
| 95 |
jax_key = f"img/embedding/kernel{suffix}"
|
| 96 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
|
| 97 |
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
jax_key = f"img/embedding/bias{suffix}"
|
| 101 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
|
| 102 |
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
# positional embeddings
|
| 107 |
jax_key = f"img/pos_embedding{suffix}"
|
| 108 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
|
|
@@ -114,54 +111,101 @@ def slice_paligemma_state_dict(state_dict, config):
|
|
| 114 |
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
| 115 |
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
| 116 |
|
| 117 |
-
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
| 118 |
-
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
| 119 |
-
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
| 120 |
-
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
| 121 |
|
| 122 |
-
encoderblock_attention_0_key_kernel = state_dict.pop(
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
for i in range(config.vision_config.num_hidden_layers):
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
state_dict[
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
state_dict[
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
state_dict[
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
state_dict[
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
state_dict[
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
|
| 151 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
|
| 152 |
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 153 |
-
|
| 154 |
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
|
| 155 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
|
| 156 |
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 157 |
|
| 158 |
# multimodal projector
|
| 159 |
jax_key = f"img/head/kernel{suffix}"
|
| 160 |
-
pytorch_key =
|
| 161 |
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 162 |
-
|
| 163 |
jax_key = f"img/head/bias{suffix}"
|
| 164 |
-
pytorch_key =
|
| 165 |
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 166 |
|
| 167 |
# text decoder (gemma)
|
|
@@ -181,24 +225,54 @@ def slice_paligemma_state_dict(state_dict, config):
|
|
| 181 |
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
| 182 |
|
| 183 |
for i in range(config.text_config.num_hidden_layers):
|
| 184 |
-
q_proj_weight_reshaped =
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 188 |
-
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] =
|
|
|
|
|
|
|
| 189 |
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 190 |
-
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
| 193 |
-
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
| 194 |
-
|
| 195 |
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 196 |
-
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] =
|
|
|
|
|
|
|
| 197 |
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 198 |
-
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] =
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
jax_key = f"llm/final_norm/scale{suffix}"
|
| 204 |
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
|
|
@@ -206,7 +280,7 @@ def slice_paligemma_state_dict(state_dict, config):
|
|
| 206 |
|
| 207 |
expert_dict = {}
|
| 208 |
final_state_dict = {}
|
| 209 |
-
|
| 210 |
# Expert-related keys to extract (including pi05 Dense layer parameters)
|
| 211 |
expert_keys = [
|
| 212 |
f"llm/final_norm_1/scale{suffix}",
|
|
@@ -224,7 +298,7 @@ def slice_paligemma_state_dict(state_dict, config):
|
|
| 224 |
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
|
| 225 |
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
|
| 226 |
]
|
| 227 |
-
|
| 228 |
for key, value in state_dict.items():
|
| 229 |
if key not in expert_keys:
|
| 230 |
final_state_dict[key] = torch.from_numpy(value)
|
|
@@ -237,13 +311,13 @@ def slice_paligemma_state_dict(state_dict, config):
|
|
| 237 |
def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None):
|
| 238 |
"""Convert Gemma JAX parameters to PyTorch format."""
|
| 239 |
# Add missing attributes to config if they don't exist
|
| 240 |
-
if not hasattr(config,
|
| 241 |
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
|
| 242 |
-
if not hasattr(config,
|
| 243 |
config.hidden_size = config.width
|
| 244 |
-
if not hasattr(config,
|
| 245 |
config.num_hidden_layers = config.depth
|
| 246 |
-
if not hasattr(config,
|
| 247 |
config.num_attention_heads = config.num_heads
|
| 248 |
|
| 249 |
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
|
@@ -260,42 +334,79 @@ def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None
|
|
| 260 |
# Pi05 with adaptive normalization
|
| 261 |
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 262 |
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 263 |
-
llm_input_layernorm_kernel = state_dict.pop(
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
else:
|
| 266 |
# Regular pi0 with standard RMSNorm
|
| 267 |
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
| 268 |
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
| 269 |
|
| 270 |
-
|
| 271 |
for i in range(config.num_hidden_layers):
|
| 272 |
-
q_proj_weight_reshaped =
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 276 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] =
|
|
|
|
|
|
|
| 277 |
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 278 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)
|
| 281 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
| 282 |
-
|
| 283 |
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 284 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] =
|
|
|
|
|
|
|
| 285 |
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 286 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] =
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
if "pi05" in checkpoint_dir:
|
| 290 |
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
| 291 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] =
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
else:
|
| 296 |
# Regular pi0 with standard RMSNorm
|
| 297 |
-
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] =
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
# Handle final norm layer
|
| 301 |
if "pi05" in checkpoint_dir:
|
|
@@ -306,9 +417,11 @@ def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None
|
|
| 306 |
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
|
| 307 |
else:
|
| 308 |
# Regular pi0 with standard RMSNorm
|
| 309 |
-
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
|
| 313 |
final_state_dict = {}
|
| 314 |
for key, value in state_dict.items():
|
|
@@ -316,7 +429,6 @@ def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None
|
|
| 316 |
final_state_dict[key] = torch.from_numpy(value)
|
| 317 |
else:
|
| 318 |
final_state_dict[key] = value
|
| 319 |
-
|
| 320 |
|
| 321 |
return final_state_dict
|
| 322 |
|
|
@@ -339,11 +451,13 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str |
|
|
| 339 |
restore_dtype = dtype_map.get(restore_precision) if restore_precision else None
|
| 340 |
|
| 341 |
# Use CPU sharding to avoid GPU memory issues during checkpoint loading
|
| 342 |
-
cpu_device = jax.devices(
|
| 343 |
cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device)
|
| 344 |
-
|
| 345 |
# Use repository restore utility to load a pure dict of params (value suffix removed)
|
| 346 |
-
params = openpi.models.model.restore_params(
|
|
|
|
|
|
|
| 347 |
|
| 348 |
# get params for PaliGemma
|
| 349 |
pali_params = params["PaliGemma"]
|
|
@@ -355,43 +469,43 @@ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str |
|
|
| 355 |
def load_jax_model_and_print_keys(checkpoint_dir: str):
|
| 356 |
"""
|
| 357 |
Load JAX model from checkpoint and print all parameter keys.
|
| 358 |
-
|
| 359 |
Args:
|
| 360 |
checkpoint_dir: Path to the checkpoint directory
|
| 361 |
"""
|
| 362 |
params_path = pathlib.Path(checkpoint_dir).resolve()
|
| 363 |
-
|
| 364 |
if not params_path.exists():
|
| 365 |
print(f"Error: Checkpoint directory does not exist: {params_path}")
|
| 366 |
return
|
| 367 |
-
|
| 368 |
try:
|
| 369 |
# Initialize checkpointer
|
| 370 |
checkpointer = ocp.PyTreeCheckpointer()
|
| 371 |
-
|
| 372 |
# Load metadata to see available keys
|
| 373 |
metadata = checkpointer.metadata(params_path)
|
| 374 |
print("Available top-level keys in checkpoint:")
|
| 375 |
-
for key in metadata
|
| 376 |
print(f" - {key}")
|
| 377 |
print()
|
| 378 |
-
|
| 379 |
# Restore the parameters
|
| 380 |
params_name = "params"
|
| 381 |
if params_name not in metadata:
|
| 382 |
print(f"Warning: '{params_name}' not found in metadata. Available keys: {list(metadata.keys())}")
|
| 383 |
if metadata.keys():
|
| 384 |
-
params_name =
|
| 385 |
print(f"Using '{params_name}' instead.")
|
| 386 |
else:
|
| 387 |
print("No keys found in metadata!")
|
| 388 |
return
|
| 389 |
-
|
| 390 |
item = {params_name: metadata[params_name]}
|
| 391 |
# Use CPU device to avoid GPU memory issues
|
| 392 |
-
device = jax.devices(
|
| 393 |
sharding = jax.sharding.SingleDeviceSharding(device)
|
| 394 |
-
|
| 395 |
restored = checkpointer.restore(
|
| 396 |
params_path,
|
| 397 |
ocp.args.PyTreeRestore(
|
|
@@ -406,33 +520,33 @@ def load_jax_model_and_print_keys(checkpoint_dir: str):
|
|
| 406 |
transforms={},
|
| 407 |
),
|
| 408 |
)
|
| 409 |
-
|
| 410 |
params = restored[params_name]
|
| 411 |
-
|
| 412 |
# Flatten and print all keys
|
| 413 |
flat_params = flatten_for_inspection(params)
|
| 414 |
-
|
| 415 |
print(f"All parameter keys with shapes and dtypes ({len(flat_params)} total):")
|
| 416 |
print("=" * 80)
|
| 417 |
-
|
| 418 |
# Sort keys for better readability
|
| 419 |
sorted_keys = sorted(flat_params.keys())
|
| 420 |
-
|
| 421 |
for key in sorted_keys:
|
| 422 |
print(f"{key:<60} -> {flat_params[key]}")
|
| 423 |
-
|
| 424 |
print()
|
| 425 |
print("=" * 80)
|
| 426 |
print(f"Summary: Found {len(flat_params)} parameters")
|
| 427 |
-
|
| 428 |
# Print some high-level structure information
|
| 429 |
top_level_keys = set()
|
| 430 |
for key in sorted_keys:
|
| 431 |
-
top_level_key = key.split(
|
| 432 |
top_level_keys.add(top_level_key)
|
| 433 |
-
|
| 434 |
-
print(f"Top-level parameter groups: {sorted(
|
| 435 |
-
|
| 436 |
except Exception as e:
|
| 437 |
print(f"Error loading checkpoint: {e}")
|
| 438 |
traceback.print_exc()
|
|
@@ -441,29 +555,29 @@ def load_jax_model_and_print_keys(checkpoint_dir: str):
|
|
| 441 |
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str):
|
| 442 |
"""
|
| 443 |
Convert PI0 JAX checkpoint to PyTorch format.
|
| 444 |
-
|
| 445 |
Args:
|
| 446 |
checkpoint_dir: Path to the JAX checkpoint
|
| 447 |
precision: Model precision (float32, bfloat16, float16)
|
| 448 |
output_path: Path to save the converted PyTorch model
|
| 449 |
"""
|
| 450 |
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
|
| 451 |
-
|
| 452 |
# Break down orbax ckpts by restoring via JAX to respect dtype
|
| 453 |
-
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision=
|
| 454 |
-
|
| 455 |
# Process projection params
|
| 456 |
if "pi05" in checkpoint_dir:
|
| 457 |
keys = [
|
| 458 |
-
"action_in_proj",
|
| 459 |
"action_out_proj",
|
| 460 |
-
"time_mlp_in",
|
| 461 |
"time_mlp_out",
|
| 462 |
]
|
| 463 |
else:
|
| 464 |
keys = [
|
| 465 |
"state_proj",
|
| 466 |
-
"action_in_proj",
|
| 467 |
"action_out_proj",
|
| 468 |
"action_time_mlp_in",
|
| 469 |
"action_time_mlp_out",
|
|
@@ -479,10 +593,10 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 479 |
else:
|
| 480 |
weight = kernel_params
|
| 481 |
bias = bias_params
|
| 482 |
-
|
| 483 |
pytorch_weight_key = f"{key}.weight"
|
| 484 |
pytorch_bias_key = f"{key}.bias"
|
| 485 |
-
|
| 486 |
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
|
| 487 |
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
|
| 488 |
|
|
@@ -490,22 +604,30 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 490 |
# All models use the same PaliGemma config structure
|
| 491 |
class PaliGemmaConfig:
|
| 492 |
def __init__(self):
|
| 493 |
-
self.vision_config = type(
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
paligemma_config = PaliGemmaConfig()
|
| 510 |
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
|
| 511 |
|
|
@@ -513,27 +635,24 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 513 |
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
|
| 514 |
|
| 515 |
# Process Gemma weights from expert_params
|
| 516 |
-
gemma_params = slice_gemma_state_dict(
|
|
|
|
|
|
|
| 517 |
|
| 518 |
# Create Pi0Config based on checkpoint path
|
| 519 |
-
if "pi0_aloha_sim" in checkpoint_dir:
|
| 520 |
-
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 521 |
-
action_dim=14, # ALOHA has 14 action dimensions
|
| 522 |
-
action_horizon=50,
|
| 523 |
-
)
|
| 524 |
-
elif "pi0_aloha_towel" in checkpoint_dir:
|
| 525 |
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 526 |
action_dim=14, # ALOHA has 14 action dimensions
|
| 527 |
action_horizon=50,
|
| 528 |
)
|
| 529 |
elif "pi0_base" in checkpoint_dir:
|
| 530 |
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 531 |
-
action_dim=8,
|
| 532 |
action_horizon=10,
|
| 533 |
)
|
| 534 |
elif "pi05_droid" in checkpoint_dir:
|
| 535 |
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 536 |
-
action_dim=8,
|
| 537 |
action_horizon=10,
|
| 538 |
pi05=True,
|
| 539 |
)
|
|
@@ -560,10 +679,10 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 560 |
|
| 561 |
# Combine all parameters (no prefix needed for our model structure)
|
| 562 |
all_params = {**paligemma_params, **gemma_params, **projection_params}
|
| 563 |
-
|
| 564 |
# Load state dict
|
| 565 |
pi0_model.load_state_dict(all_params, strict=False)
|
| 566 |
-
|
| 567 |
if precision == "float32":
|
| 568 |
pi0_model = pi0_model.to(torch.float32)
|
| 569 |
elif precision == "bfloat16":
|
|
@@ -573,10 +692,10 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 573 |
|
| 574 |
# Save the converted model using safetensors
|
| 575 |
os.makedirs(output_path, exist_ok=True)
|
| 576 |
-
|
| 577 |
# Save model weights as SafeTensors using save_model to handle tied weights
|
| 578 |
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
|
| 579 |
-
|
| 580 |
# Copy assets folder if it exists
|
| 581 |
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
|
| 582 |
if assets_source.exists():
|
|
@@ -584,7 +703,7 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 584 |
if assets_dest.exists():
|
| 585 |
shutil.rmtree(assets_dest)
|
| 586 |
shutil.copytree(assets_source, assets_dest)
|
| 587 |
-
|
| 588 |
# Save config as JSON for reference
|
| 589 |
config_dict = {
|
| 590 |
"action_dim": pi0_config.action_dim,
|
|
@@ -595,37 +714,26 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str
|
|
| 595 |
}
|
| 596 |
with open(os.path.join(output_path, "config.json"), "w") as f:
|
| 597 |
json.dump(config_dict, f, indent=2)
|
| 598 |
-
|
| 599 |
-
print(
|
| 600 |
print(f"Model saved to {output_path}")
|
| 601 |
|
| 602 |
|
| 603 |
def main():
|
| 604 |
parser = argparse.ArgumentParser(description="Load JAX model and optionally convert to PyTorch")
|
|
|
|
| 605 |
parser.add_argument(
|
| 606 |
-
"--
|
| 607 |
-
type=str,
|
| 608 |
-
required=True,
|
| 609 |
-
help="Path to the JAX checkpoint directory"
|
| 610 |
-
)
|
| 611 |
-
parser.add_argument(
|
| 612 |
-
"--output_path",
|
| 613 |
-
type=str,
|
| 614 |
-
help="Path to save converted PyTorch model (required for conversion)"
|
| 615 |
)
|
| 616 |
parser.add_argument(
|
| 617 |
"--precision",
|
| 618 |
choices=["float32", "bfloat16", "float16"],
|
| 619 |
default="bfloat16",
|
| 620 |
type=str,
|
| 621 |
-
help="Precision for model conversion"
|
| 622 |
-
)
|
| 623 |
-
parser.add_argument(
|
| 624 |
-
"--inspect_only",
|
| 625 |
-
action="store_true",
|
| 626 |
-
help="Only inspect parameter keys, don't convert"
|
| 627 |
)
|
| 628 |
-
|
|
|
|
| 629 |
args = parser.parse_args()
|
| 630 |
|
| 631 |
if not os.path.exists(args.checkpoint_dir):
|
|
@@ -633,7 +741,7 @@ def main():
|
|
| 633 |
checkpoint_dir = openpi.shared.download.maybe_download(f"gs://openpi-assets/checkpoints/{model_name}")
|
| 634 |
else:
|
| 635 |
checkpoint_dir = args.checkpoint_dir
|
| 636 |
-
|
| 637 |
if args.inspect_only:
|
| 638 |
load_jax_model_and_print_keys(args.checkpoint_dir)
|
| 639 |
else:
|
|
|
|
| 10 |
# Just inspect keys:
|
| 11 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 12 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 13 |
+
|
| 14 |
# Convert to PyTorch:
|
| 15 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 16 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 17 |
|
| 18 |
+
Example:
|
| 19 |
+
# pi0_droid
|
| 20 |
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
|
| 21 |
|
| 22 |
# pi0_aloha_sim
|
|
|
|
| 33 |
import shutil
|
| 34 |
import traceback
|
| 35 |
|
| 36 |
+
from flax.nnx.traversals import flatten_mapping
|
| 37 |
import jax
|
| 38 |
import jax.numpy as jnp
|
| 39 |
import jax.sharding
|
| 40 |
import numpy as np
|
| 41 |
import orbax.checkpoint as ocp
|
|
|
|
| 42 |
import safetensors
|
| 43 |
+
import torch
|
| 44 |
+
|
| 45 |
+
import openpi.models.gemma
|
| 46 |
+
import openpi.models.model
|
| 47 |
+
import openpi.models.pi0_config
|
| 48 |
|
| 49 |
# Import our modules
|
| 50 |
import openpi.models_pytorch.pi0_pytorch
|
|
|
|
|
|
|
| 51 |
import openpi.shared.download
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def flatten_for_inspection(tree, separator="/"):
|
| 55 |
"""
|
| 56 |
Flatten a nested dictionary for easy inspection of keys using flax.nnx.traversals.flatten_mapping.
|
| 57 |
+
|
| 58 |
Args:
|
| 59 |
tree: The nested dictionary (JAX pytree)
|
| 60 |
separator: Separator to use between key levels
|
| 61 |
+
|
| 62 |
Returns:
|
| 63 |
Dictionary with flattened keys and array shapes as values
|
| 64 |
"""
|
| 65 |
flattened = flatten_mapping(tree, separator=separator)
|
| 66 |
+
|
| 67 |
# Convert values to shape/dtype information for inspection
|
| 68 |
result = {}
|
| 69 |
for key, value in flattened.items():
|
| 70 |
+
if hasattr(value, "shape") and hasattr(value, "dtype"):
|
| 71 |
result[key] = f"shape: {value.shape}, dtype: {value.dtype}"
|
| 72 |
else:
|
| 73 |
result[key] = f"type: {type(value)}"
|
| 74 |
+
|
| 75 |
return result
|
| 76 |
|
| 77 |
|
|
|
|
| 91 |
"""Convert PaliGemma JAX parameters to PyTorch format."""
|
| 92 |
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
| 93 |
|
|
|
|
| 94 |
# patch embeddings
|
| 95 |
jax_key = f"img/embedding/kernel{suffix}"
|
| 96 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
|
| 97 |
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
|
| 98 |
+
|
|
|
|
| 99 |
jax_key = f"img/embedding/bias{suffix}"
|
| 100 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
|
| 101 |
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 102 |
|
|
|
|
|
|
|
| 103 |
# positional embeddings
|
| 104 |
jax_key = f"img/pos_embedding{suffix}"
|
| 105 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
|
|
|
|
| 111 |
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
| 112 |
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
| 113 |
|
| 114 |
+
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
| 115 |
+
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
| 116 |
+
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
| 117 |
+
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
| 118 |
|
| 119 |
+
encoderblock_attention_0_key_kernel = state_dict.pop(
|
| 120 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
|
| 121 |
+
)
|
| 122 |
+
encoderblock_attention_0_key_bias = state_dict.pop(
|
| 123 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
|
| 124 |
+
)
|
| 125 |
+
encoderblock_attention_0_value_kernel = state_dict.pop(
|
| 126 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
|
| 127 |
+
)
|
| 128 |
+
encoderblock_attention_0_value_bias = state_dict.pop(
|
| 129 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
|
| 130 |
+
)
|
| 131 |
+
encoderblock_attention_0_query_kernel = state_dict.pop(
|
| 132 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
|
| 133 |
+
)
|
| 134 |
+
encoderblock_attention_0_query_bias = state_dict.pop(
|
| 135 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
|
| 136 |
+
)
|
| 137 |
+
encoderblock_attention_0_out_kernel = state_dict.pop(
|
| 138 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
|
| 139 |
+
)
|
| 140 |
+
encoderblock_attention_0_out_bias = state_dict.pop(
|
| 141 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
|
| 142 |
+
)
|
| 143 |
|
| 144 |
for i in range(config.vision_config.num_hidden_layers):
|
| 145 |
+
state_dict[
|
| 146 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
|
| 147 |
+
] = encoderblock_layernorm0_scale[i].transpose()
|
| 148 |
+
state_dict[
|
| 149 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
|
| 150 |
+
] = encoderblock_layernorm0_bias[i]
|
| 151 |
+
state_dict[
|
| 152 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
|
| 153 |
+
] = encoderblock_layernorm1_scale[i].transpose()
|
| 154 |
+
state_dict[
|
| 155 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
|
| 156 |
+
] = encoderblock_layernorm1_bias[i]
|
| 157 |
+
state_dict[
|
| 158 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
|
| 159 |
+
] = encoderblock_mlp_dense0_kernel[i].transpose()
|
| 160 |
+
state_dict[
|
| 161 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
|
| 162 |
+
] = encoderblock_mlp_dense0_bias[i]
|
| 163 |
+
state_dict[
|
| 164 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
|
| 165 |
+
] = encoderblock_mlp_dense1_kernel[i].transpose()
|
| 166 |
+
state_dict[
|
| 167 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
|
| 168 |
+
] = encoderblock_mlp_dense1_bias[i]
|
| 169 |
+
state_dict[
|
| 170 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
|
| 171 |
+
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 172 |
+
state_dict[
|
| 173 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
|
| 174 |
+
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 175 |
+
state_dict[
|
| 176 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
|
| 177 |
+
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 178 |
+
state_dict[
|
| 179 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
|
| 180 |
+
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 181 |
+
state_dict[
|
| 182 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
|
| 183 |
+
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 184 |
+
state_dict[
|
| 185 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
|
| 186 |
+
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 187 |
+
state_dict[
|
| 188 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
|
| 189 |
+
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 190 |
+
state_dict[
|
| 191 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
|
| 192 |
+
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 193 |
|
| 194 |
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
|
| 195 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
|
| 196 |
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 197 |
+
|
| 198 |
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
|
| 199 |
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
|
| 200 |
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 201 |
|
| 202 |
# multimodal projector
|
| 203 |
jax_key = f"img/head/kernel{suffix}"
|
| 204 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
|
| 205 |
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 206 |
+
|
| 207 |
jax_key = f"img/head/bias{suffix}"
|
| 208 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
|
| 209 |
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 210 |
|
| 211 |
# text decoder (gemma)
|
|
|
|
| 225 |
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
| 226 |
|
| 227 |
for i in range(config.text_config.num_hidden_layers):
|
| 228 |
+
q_proj_weight_reshaped = (
|
| 229 |
+
llm_attention_q_einsum[i]
|
| 230 |
+
.transpose(0, 2, 1)
|
| 231 |
+
.reshape(
|
| 232 |
+
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
|
| 236 |
+
q_proj_weight_reshaped
|
| 237 |
+
)
|
| 238 |
|
| 239 |
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 240 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
|
| 241 |
+
k_proj_weight_reshaped
|
| 242 |
+
)
|
| 243 |
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 244 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
|
| 245 |
+
v_proj_weight_reshaped
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
o_proj_weight_reshaped = (
|
| 249 |
+
llm_attention_attn_vec_einsum[i]
|
| 250 |
+
.transpose(2, 0, 1)
|
| 251 |
+
.reshape(
|
| 252 |
+
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
|
| 256 |
+
o_proj_weight_reshaped
|
| 257 |
+
)
|
| 258 |
|
|
|
|
|
|
|
|
|
|
| 259 |
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 260 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
|
| 261 |
+
gate_proj_weight.transpose()
|
| 262 |
+
)
|
| 263 |
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 264 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
|
| 265 |
+
up_proj_weight.transpose()
|
| 266 |
+
)
|
| 267 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
|
| 268 |
+
llm_mlp_linear[i].transpose()
|
| 269 |
+
)
|
| 270 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
|
| 271 |
+
llm_input_layernorm[i]
|
| 272 |
+
)
|
| 273 |
+
state_dict[
|
| 274 |
+
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
|
| 275 |
+
] = llm_post_attention_layernorm[i]
|
| 276 |
|
| 277 |
jax_key = f"llm/final_norm/scale{suffix}"
|
| 278 |
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
|
|
|
|
| 280 |
|
| 281 |
expert_dict = {}
|
| 282 |
final_state_dict = {}
|
| 283 |
+
|
| 284 |
# Expert-related keys to extract (including pi05 Dense layer parameters)
|
| 285 |
expert_keys = [
|
| 286 |
f"llm/final_norm_1/scale{suffix}",
|
|
|
|
| 298 |
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
|
| 299 |
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
|
| 300 |
]
|
| 301 |
+
|
| 302 |
for key, value in state_dict.items():
|
| 303 |
if key not in expert_keys:
|
| 304 |
final_state_dict[key] = torch.from_numpy(value)
|
|
|
|
| 311 |
def slice_gemma_state_dict(state_dict, config, num_expert=1, checkpoint_dir=None):
|
| 312 |
"""Convert Gemma JAX parameters to PyTorch format."""
|
| 313 |
# Add missing attributes to config if they don't exist
|
| 314 |
+
if not hasattr(config, "vocab_size"):
|
| 315 |
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
|
| 316 |
+
if not hasattr(config, "hidden_size"):
|
| 317 |
config.hidden_size = config.width
|
| 318 |
+
if not hasattr(config, "num_hidden_layers"):
|
| 319 |
config.num_hidden_layers = config.depth
|
| 320 |
+
if not hasattr(config, "num_attention_heads"):
|
| 321 |
config.num_attention_heads = config.num_heads
|
| 322 |
|
| 323 |
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
|
|
|
| 334 |
# Pi05 with adaptive normalization
|
| 335 |
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 336 |
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 337 |
+
llm_input_layernorm_kernel = state_dict.pop(
|
| 338 |
+
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
|
| 339 |
+
)
|
| 340 |
+
llm_post_attention_layernorm_kernel = state_dict.pop(
|
| 341 |
+
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
|
| 342 |
+
)
|
| 343 |
else:
|
| 344 |
# Regular pi0 with standard RMSNorm
|
| 345 |
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
| 346 |
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
| 347 |
|
|
|
|
| 348 |
for i in range(config.num_hidden_layers):
|
| 349 |
+
q_proj_weight_reshaped = (
|
| 350 |
+
llm_attention_q_einsum[i]
|
| 351 |
+
.transpose(0, 2, 1)
|
| 352 |
+
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
| 353 |
+
)
|
| 354 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
|
| 355 |
+
q_proj_weight_reshaped
|
| 356 |
+
)
|
| 357 |
|
| 358 |
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 359 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
|
| 360 |
+
k_proj_weight_reshaped
|
| 361 |
+
)
|
| 362 |
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 363 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
|
| 364 |
+
v_proj_weight_reshaped
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
o_proj_weight_reshaped = (
|
| 368 |
+
llm_attention_attn_vec_einsum[i]
|
| 369 |
+
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
| 370 |
+
.transpose(1, 0)
|
| 371 |
+
)
|
| 372 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
|
| 373 |
+
o_proj_weight_reshaped
|
| 374 |
+
)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
| 376 |
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 377 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
|
| 378 |
+
gate_proj_weight.transpose()
|
| 379 |
+
)
|
| 380 |
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 381 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
|
| 382 |
+
up_proj_weight.transpose()
|
| 383 |
+
)
|
| 384 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
|
| 385 |
+
i
|
| 386 |
+
].transpose()
|
| 387 |
|
| 388 |
if "pi05" in checkpoint_dir:
|
| 389 |
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
| 390 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
|
| 391 |
+
llm_input_layernorm_bias[i]
|
| 392 |
+
)
|
| 393 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
|
| 394 |
+
llm_post_attention_layernorm_bias[i]
|
| 395 |
+
)
|
| 396 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
|
| 397 |
+
llm_input_layernorm_kernel[i].transpose()
|
| 398 |
+
)
|
| 399 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
|
| 400 |
+
llm_post_attention_layernorm_kernel[i].transpose()
|
| 401 |
+
)
|
| 402 |
else:
|
| 403 |
# Regular pi0 with standard RMSNorm
|
| 404 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
|
| 405 |
+
llm_input_layernorm[i]
|
| 406 |
+
)
|
| 407 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
|
| 408 |
+
llm_post_attention_layernorm[i]
|
| 409 |
+
)
|
| 410 |
|
| 411 |
# Handle final norm layer
|
| 412 |
if "pi05" in checkpoint_dir:
|
|
|
|
| 417 |
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
|
| 418 |
else:
|
| 419 |
# Regular pi0 with standard RMSNorm
|
| 420 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
|
| 421 |
+
f"llm/final_norm_{num_expert}/scale{suffix}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
|
| 425 |
|
| 426 |
final_state_dict = {}
|
| 427 |
for key, value in state_dict.items():
|
|
|
|
| 429 |
final_state_dict[key] = torch.from_numpy(value)
|
| 430 |
else:
|
| 431 |
final_state_dict[key] = value
|
|
|
|
| 432 |
|
| 433 |
return final_state_dict
|
| 434 |
|
|
|
|
| 451 |
restore_dtype = dtype_map.get(restore_precision) if restore_precision else None
|
| 452 |
|
| 453 |
# Use CPU sharding to avoid GPU memory issues during checkpoint loading
|
| 454 |
+
cpu_device = jax.devices("cpu")[0]
|
| 455 |
cpu_sharding = jax.sharding.SingleDeviceSharding(cpu_device)
|
| 456 |
+
|
| 457 |
# Use repository restore utility to load a pure dict of params (value suffix removed)
|
| 458 |
+
params = openpi.models.model.restore_params(
|
| 459 |
+
params_dir, restore_type=jax.Array, dtype=restore_dtype, sharding=cpu_sharding
|
| 460 |
+
)
|
| 461 |
|
| 462 |
# get params for PaliGemma
|
| 463 |
pali_params = params["PaliGemma"]
|
|
|
|
| 469 |
def load_jax_model_and_print_keys(checkpoint_dir: str):
|
| 470 |
"""
|
| 471 |
Load JAX model from checkpoint and print all parameter keys.
|
| 472 |
+
|
| 473 |
Args:
|
| 474 |
checkpoint_dir: Path to the checkpoint directory
|
| 475 |
"""
|
| 476 |
params_path = pathlib.Path(checkpoint_dir).resolve()
|
| 477 |
+
|
| 478 |
if not params_path.exists():
|
| 479 |
print(f"Error: Checkpoint directory does not exist: {params_path}")
|
| 480 |
return
|
| 481 |
+
|
| 482 |
try:
|
| 483 |
# Initialize checkpointer
|
| 484 |
checkpointer = ocp.PyTreeCheckpointer()
|
| 485 |
+
|
| 486 |
# Load metadata to see available keys
|
| 487 |
metadata = checkpointer.metadata(params_path)
|
| 488 |
print("Available top-level keys in checkpoint:")
|
| 489 |
+
for key in metadata:
|
| 490 |
print(f" - {key}")
|
| 491 |
print()
|
| 492 |
+
|
| 493 |
# Restore the parameters
|
| 494 |
params_name = "params"
|
| 495 |
if params_name not in metadata:
|
| 496 |
print(f"Warning: '{params_name}' not found in metadata. Available keys: {list(metadata.keys())}")
|
| 497 |
if metadata.keys():
|
| 498 |
+
params_name = next(iter(metadata.keys()))
|
| 499 |
print(f"Using '{params_name}' instead.")
|
| 500 |
else:
|
| 501 |
print("No keys found in metadata!")
|
| 502 |
return
|
| 503 |
+
|
| 504 |
item = {params_name: metadata[params_name]}
|
| 505 |
# Use CPU device to avoid GPU memory issues
|
| 506 |
+
device = jax.devices("cpu")[0]
|
| 507 |
sharding = jax.sharding.SingleDeviceSharding(device)
|
| 508 |
+
|
| 509 |
restored = checkpointer.restore(
|
| 510 |
params_path,
|
| 511 |
ocp.args.PyTreeRestore(
|
|
|
|
| 520 |
transforms={},
|
| 521 |
),
|
| 522 |
)
|
| 523 |
+
|
| 524 |
params = restored[params_name]
|
| 525 |
+
|
| 526 |
# Flatten and print all keys
|
| 527 |
flat_params = flatten_for_inspection(params)
|
| 528 |
+
|
| 529 |
print(f"All parameter keys with shapes and dtypes ({len(flat_params)} total):")
|
| 530 |
print("=" * 80)
|
| 531 |
+
|
| 532 |
# Sort keys for better readability
|
| 533 |
sorted_keys = sorted(flat_params.keys())
|
| 534 |
+
|
| 535 |
for key in sorted_keys:
|
| 536 |
print(f"{key:<60} -> {flat_params[key]}")
|
| 537 |
+
|
| 538 |
print()
|
| 539 |
print("=" * 80)
|
| 540 |
print(f"Summary: Found {len(flat_params)} parameters")
|
| 541 |
+
|
| 542 |
# Print some high-level structure information
|
| 543 |
top_level_keys = set()
|
| 544 |
for key in sorted_keys:
|
| 545 |
+
top_level_key = key.split("/")[0]
|
| 546 |
top_level_keys.add(top_level_key)
|
| 547 |
+
|
| 548 |
+
print(f"Top-level parameter groups: {sorted(top_level_keys)}")
|
| 549 |
+
|
| 550 |
except Exception as e:
|
| 551 |
print(f"Error loading checkpoint: {e}")
|
| 552 |
traceback.print_exc()
|
|
|
|
| 555 |
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, output_path: str):
|
| 556 |
"""
|
| 557 |
Convert PI0 JAX checkpoint to PyTorch format.
|
| 558 |
+
|
| 559 |
Args:
|
| 560 |
checkpoint_dir: Path to the JAX checkpoint
|
| 561 |
precision: Model precision (float32, bfloat16, float16)
|
| 562 |
output_path: Path to save the converted PyTorch model
|
| 563 |
"""
|
| 564 |
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
|
| 565 |
+
|
| 566 |
# Break down orbax ckpts by restoring via JAX to respect dtype
|
| 567 |
+
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
|
| 568 |
+
|
| 569 |
# Process projection params
|
| 570 |
if "pi05" in checkpoint_dir:
|
| 571 |
keys = [
|
| 572 |
+
"action_in_proj",
|
| 573 |
"action_out_proj",
|
| 574 |
+
"time_mlp_in",
|
| 575 |
"time_mlp_out",
|
| 576 |
]
|
| 577 |
else:
|
| 578 |
keys = [
|
| 579 |
"state_proj",
|
| 580 |
+
"action_in_proj",
|
| 581 |
"action_out_proj",
|
| 582 |
"action_time_mlp_in",
|
| 583 |
"action_time_mlp_out",
|
|
|
|
| 593 |
else:
|
| 594 |
weight = kernel_params
|
| 595 |
bias = bias_params
|
| 596 |
+
|
| 597 |
pytorch_weight_key = f"{key}.weight"
|
| 598 |
pytorch_bias_key = f"{key}.bias"
|
| 599 |
+
|
| 600 |
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
|
| 601 |
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
|
| 602 |
|
|
|
|
| 604 |
# All models use the same PaliGemma config structure
|
| 605 |
class PaliGemmaConfig:
|
| 606 |
def __init__(self):
|
| 607 |
+
self.vision_config = type(
|
| 608 |
+
"obj",
|
| 609 |
+
(object,),
|
| 610 |
+
{
|
| 611 |
+
"hidden_size": 1152,
|
| 612 |
+
"num_hidden_layers": 27,
|
| 613 |
+
"num_attention_heads": 16,
|
| 614 |
+
"intermediate_size": 4304,
|
| 615 |
+
"patch_size": 14,
|
| 616 |
+
"projection_dim": 2048,
|
| 617 |
+
},
|
| 618 |
+
)()
|
| 619 |
+
self.text_config = type(
|
| 620 |
+
"obj",
|
| 621 |
+
(object,),
|
| 622 |
+
{
|
| 623 |
+
"hidden_size": 2048,
|
| 624 |
+
"num_hidden_layers": 18,
|
| 625 |
+
"num_attention_heads": 8,
|
| 626 |
+
"head_dim": 256,
|
| 627 |
+
"intermediate_size": 16384,
|
| 628 |
+
},
|
| 629 |
+
)()
|
| 630 |
+
|
| 631 |
paligemma_config = PaliGemmaConfig()
|
| 632 |
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
|
| 633 |
|
|
|
|
| 635 |
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
|
| 636 |
|
| 637 |
# Process Gemma weights from expert_params
|
| 638 |
+
gemma_params = slice_gemma_state_dict(
|
| 639 |
+
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir
|
| 640 |
+
)
|
| 641 |
|
| 642 |
# Create Pi0Config based on checkpoint path
|
| 643 |
+
if "pi0_aloha_sim" in checkpoint_dir or "pi0_aloha_towel" in checkpoint_dir:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 645 |
action_dim=14, # ALOHA has 14 action dimensions
|
| 646 |
action_horizon=50,
|
| 647 |
)
|
| 648 |
elif "pi0_base" in checkpoint_dir:
|
| 649 |
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 650 |
+
action_dim=8, # Base droid has 8 action dimensions
|
| 651 |
action_horizon=10,
|
| 652 |
)
|
| 653 |
elif "pi05_droid" in checkpoint_dir:
|
| 654 |
pi0_config = openpi.models.pi0_config.Pi0Config(
|
| 655 |
+
action_dim=8, # Base droid has 8 action dimensions
|
| 656 |
action_horizon=10,
|
| 657 |
pi05=True,
|
| 658 |
)
|
|
|
|
| 679 |
|
| 680 |
# Combine all parameters (no prefix needed for our model structure)
|
| 681 |
all_params = {**paligemma_params, **gemma_params, **projection_params}
|
| 682 |
+
|
| 683 |
# Load state dict
|
| 684 |
pi0_model.load_state_dict(all_params, strict=False)
|
| 685 |
+
|
| 686 |
if precision == "float32":
|
| 687 |
pi0_model = pi0_model.to(torch.float32)
|
| 688 |
elif precision == "bfloat16":
|
|
|
|
| 692 |
|
| 693 |
# Save the converted model using safetensors
|
| 694 |
os.makedirs(output_path, exist_ok=True)
|
| 695 |
+
|
| 696 |
# Save model weights as SafeTensors using save_model to handle tied weights
|
| 697 |
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
|
| 698 |
+
|
| 699 |
# Copy assets folder if it exists
|
| 700 |
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
|
| 701 |
if assets_source.exists():
|
|
|
|
| 703 |
if assets_dest.exists():
|
| 704 |
shutil.rmtree(assets_dest)
|
| 705 |
shutil.copytree(assets_source, assets_dest)
|
| 706 |
+
|
| 707 |
# Save config as JSON for reference
|
| 708 |
config_dict = {
|
| 709 |
"action_dim": pi0_config.action_dim,
|
|
|
|
| 714 |
}
|
| 715 |
with open(os.path.join(output_path, "config.json"), "w") as f:
|
| 716 |
json.dump(config_dict, f, indent=2)
|
| 717 |
+
|
| 718 |
+
print("Model conversion completed successfully!")
|
| 719 |
print(f"Model saved to {output_path}")
|
| 720 |
|
| 721 |
|
| 722 |
def main():
|
| 723 |
parser = argparse.ArgumentParser(description="Load JAX model and optionally convert to PyTorch")
|
| 724 |
+
parser.add_argument("--checkpoint_dir", type=str, required=True, help="Path to the JAX checkpoint directory")
|
| 725 |
parser.add_argument(
|
| 726 |
+
"--output_path", type=str, help="Path to save converted PyTorch model (required for conversion)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
)
|
| 728 |
parser.add_argument(
|
| 729 |
"--precision",
|
| 730 |
choices=["float32", "bfloat16", "float16"],
|
| 731 |
default="bfloat16",
|
| 732 |
type=str,
|
| 733 |
+
help="Precision for model conversion",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
)
|
| 735 |
+
parser.add_argument("--inspect_only", action="store_true", help="Only inspect parameter keys, don't convert")
|
| 736 |
+
|
| 737 |
args = parser.parse_args()
|
| 738 |
|
| 739 |
if not os.path.exists(args.checkpoint_dir):
|
|
|
|
| 741 |
checkpoint_dir = openpi.shared.download.maybe_download(f"gs://openpi-assets/checkpoints/{model_name}")
|
| 742 |
else:
|
| 743 |
checkpoint_dir = args.checkpoint_dir
|
| 744 |
+
|
| 745 |
if args.inspect_only:
|
| 746 |
load_jax_model_and_print_keys(args.checkpoint_dir)
|
| 747 |
else:
|
examples/droid/convert_droid_data_to_lerobot.py
CHANGED
|
@@ -277,7 +277,7 @@ class RecordedMultiCameraWrapper:
|
|
| 277 |
self.camera_kwargs = camera_kwargs
|
| 278 |
|
| 279 |
# Open Camera Readers #
|
| 280 |
-
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
|
| 281 |
all_filepaths = mp4_filepaths
|
| 282 |
|
| 283 |
self.camera_dict = {}
|
|
|
|
| 277 |
self.camera_kwargs = camera_kwargs
|
| 278 |
|
| 279 |
# Open Camera Readers #
|
| 280 |
+
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
|
| 281 |
all_filepaths = mp4_filepaths
|
| 282 |
|
| 283 |
self.camera_dict = {}
|
pyproject.toml
CHANGED
|
@@ -73,7 +73,7 @@ members = ["packages/*"]
|
|
| 73 |
[tool.ruff]
|
| 74 |
line-length = 120
|
| 75 |
target-version = "py311"
|
| 76 |
-
extend-exclude = ["docker", "third_party"]
|
| 77 |
|
| 78 |
[tool.ruff.lint]
|
| 79 |
# https://docs.astral.sh/ruff/rules/
|
|
@@ -101,7 +101,6 @@ select = [
|
|
| 101 |
"PLR5",
|
| 102 |
"PLW",
|
| 103 |
"PT",
|
| 104 |
-
"PTH",
|
| 105 |
"Q",
|
| 106 |
"RET",
|
| 107 |
"RUF",
|
|
|
|
| 73 |
[tool.ruff]
|
| 74 |
line-length = 120
|
| 75 |
target-version = "py311"
|
| 76 |
+
extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
|
| 77 |
|
| 78 |
[tool.ruff.lint]
|
| 79 |
# https://docs.astral.sh/ruff/rules/
|
|
|
|
| 101 |
"PLR5",
|
| 102 |
"PLW",
|
| 103 |
"PT",
|
|
|
|
| 104 |
"Q",
|
| 105 |
"RET",
|
| 106 |
"RUF",
|
scripts/train_pytorch.py
CHANGED
|
@@ -23,7 +23,6 @@ Multi-Node Training:
|
|
| 23 |
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
import argparse
|
| 27 |
import dataclasses
|
| 28 |
import gc
|
| 29 |
import logging
|
|
@@ -31,10 +30,10 @@ import os
|
|
| 31 |
import platform
|
| 32 |
import shutil
|
| 33 |
import time
|
| 34 |
-
from typing import Any, Dict
|
| 35 |
|
| 36 |
import jax
|
| 37 |
import numpy as np
|
|
|
|
| 38 |
import torch
|
| 39 |
import torch.distributed as dist
|
| 40 |
import torch.nn.parallel
|
|
@@ -42,162 +41,169 @@ import torch.utils.data
|
|
| 42 |
import torch.utils.data.distributed
|
| 43 |
import tqdm
|
| 44 |
import wandb
|
| 45 |
-
import safetensors.torch
|
| 46 |
|
|
|
|
|
|
|
| 47 |
import openpi.training.config as _config
|
| 48 |
import openpi.training.data_loader as _data
|
| 49 |
-
import openpi.models.model as _model
|
| 50 |
-
import openpi.models_pytorch.pi0_pytorch
|
| 51 |
-
import openpi.models.pi0_config
|
| 52 |
|
| 53 |
|
| 54 |
def init_logging():
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
|
| 76 |
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
|
| 97 |
|
| 98 |
def setup_ddp():
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
|
| 115 |
|
| 116 |
def cleanup_ddp():
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
| 121 |
|
| 122 |
def set_seed(seed: int, local_rank: int):
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
| 128 |
|
| 129 |
def build_datasets(config: _config.TrainConfig):
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
|
| 134 |
|
| 135 |
def get_model_state_dict(model):
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def get_model_parameters(model):
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def save_checkpoint(model, optimizer, global_step, config, is_main):
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
|
| 187 |
|
| 188 |
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
| 189 |
"""Load the latest checkpoint and return the global step."""
|
| 190 |
-
checkpoint_steps = [
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
| 195 |
if not checkpoint_steps:
|
| 196 |
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 197 |
-
|
| 198 |
latest_step = max(checkpoint_steps)
|
| 199 |
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
| 200 |
-
|
| 201 |
# Clear memory before loading checkpoints
|
| 202 |
if torch.cuda.is_available():
|
| 203 |
torch.cuda.empty_cache()
|
|
@@ -208,35 +214,34 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
|
| 208 |
# Load model state with error handling
|
| 209 |
logging.info("Loading model state...")
|
| 210 |
safetensors_path = ckpt_dir / "pytorch_model.safetensors"
|
| 211 |
-
|
| 212 |
if safetensors_path.exists():
|
| 213 |
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 214 |
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
| 215 |
logging.info("Loaded model state from safetensors format")
|
| 216 |
else:
|
| 217 |
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
| 218 |
-
|
| 219 |
torch.cuda.empty_cache()
|
| 220 |
gc.collect()
|
| 221 |
log_memory_usage(device, latest_step, "after_loading_model")
|
| 222 |
-
|
| 223 |
# Load optimizer state with error handling
|
| 224 |
logging.info("Loading optimizer state...")
|
| 225 |
optimizer_path = ckpt_dir / "optimizer.pt"
|
| 226 |
-
|
| 227 |
if optimizer_path.exists():
|
| 228 |
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
| 229 |
logging.info("Loaded optimizer state from pt format")
|
| 230 |
else:
|
| 231 |
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
| 232 |
-
|
| 233 |
optimizer.load_state_dict(optimizer_state_dict)
|
| 234 |
del optimizer_state_dict
|
| 235 |
torch.cuda.empty_cache()
|
| 236 |
gc.collect()
|
| 237 |
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
| 238 |
-
|
| 239 |
-
|
| 240 |
# Load metadata
|
| 241 |
logging.info("Loading metadata...")
|
| 242 |
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
|
@@ -245,355 +250,379 @@ def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
|
| 245 |
torch.cuda.empty_cache()
|
| 246 |
gc.collect()
|
| 247 |
log_memory_usage(device, latest_step, "after_loading_metadata")
|
| 248 |
-
|
| 249 |
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
| 250 |
return global_step
|
| 251 |
-
|
| 252 |
except RuntimeError as e:
|
| 253 |
if "out of memory" in str(e):
|
| 254 |
# Clear memory and provide detailed error message
|
| 255 |
torch.cuda.empty_cache()
|
| 256 |
gc.collect()
|
| 257 |
-
logging.error(f"Out of memory error while loading checkpoint: {
|
| 258 |
log_memory_usage(device, latest_step, "after_oom_error")
|
| 259 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 260 |
raise
|
| 261 |
|
| 262 |
|
| 263 |
def get_latest_checkpoint_step(checkpoint_dir):
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
|
| 272 |
|
| 273 |
def log_memory_usage(device, step, phase="unknown"):
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
def train_loop(config: _config.TrainConfig):
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
|
| 592 |
def main():
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
|
| 597 |
|
| 598 |
if __name__ == "__main__":
|
| 599 |
-
|
|
|
|
| 23 |
|
| 24 |
"""
|
| 25 |
|
|
|
|
| 26 |
import dataclasses
|
| 27 |
import gc
|
| 28 |
import logging
|
|
|
|
| 30 |
import platform
|
| 31 |
import shutil
|
| 32 |
import time
|
|
|
|
| 33 |
|
| 34 |
import jax
|
| 35 |
import numpy as np
|
| 36 |
+
import safetensors.torch
|
| 37 |
import torch
|
| 38 |
import torch.distributed as dist
|
| 39 |
import torch.nn.parallel
|
|
|
|
| 41 |
import torch.utils.data.distributed
|
| 42 |
import tqdm
|
| 43 |
import wandb
|
|
|
|
| 44 |
|
| 45 |
+
import openpi.models.pi0_config
|
| 46 |
+
import openpi.models_pytorch.pi0_pytorch
|
| 47 |
import openpi.training.config as _config
|
| 48 |
import openpi.training.data_loader as _data
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
def init_logging():
|
| 52 |
+
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
| 53 |
+
|
| 54 |
+
class CustomFormatter(logging.Formatter):
|
| 55 |
+
def format(self, record):
|
| 56 |
+
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
| 57 |
+
return super().format(record)
|
| 58 |
+
|
| 59 |
+
formatter = CustomFormatter(
|
| 60 |
+
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
| 61 |
+
datefmt="%H:%M:%S",
|
| 62 |
+
)
|
| 63 |
+
logger = logging.getLogger()
|
| 64 |
+
logger.setLevel(logging.INFO)
|
| 65 |
+
if not logger.handlers:
|
| 66 |
+
ch = logging.StreamHandler()
|
| 67 |
+
ch.setFormatter(formatter)
|
| 68 |
+
logger.addHandler(ch)
|
| 69 |
+
else:
|
| 70 |
+
logger.handlers[0].setFormatter(formatter)
|
| 71 |
|
| 72 |
|
| 73 |
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
| 74 |
+
"""Initialize wandb logging."""
|
| 75 |
+
if not enabled:
|
| 76 |
+
wandb.init(mode="disabled")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
ckpt_dir = config.checkpoint_dir
|
| 80 |
+
if not ckpt_dir.exists():
|
| 81 |
+
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
| 82 |
+
|
| 83 |
+
if resuming:
|
| 84 |
+
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
| 85 |
+
wandb.init(id=run_id, resume="must", project=config.project_name)
|
| 86 |
+
else:
|
| 87 |
+
wandb.init(
|
| 88 |
+
name=config.exp_name,
|
| 89 |
+
config=dataclasses.asdict(config),
|
| 90 |
+
project=config.project_name,
|
| 91 |
+
)
|
| 92 |
+
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
| 93 |
|
| 94 |
|
| 95 |
def setup_ddp():
|
| 96 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 97 |
+
use_ddp = world_size > 1
|
| 98 |
+
if use_ddp and not torch.distributed.is_initialized():
|
| 99 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 100 |
+
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
| 101 |
+
|
| 102 |
+
# Set up debugging environment variables for DDP issues
|
| 103 |
+
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
| 104 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
| 105 |
+
|
| 106 |
+
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
| 107 |
+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
torch.cuda.set_device(device)
|
| 110 |
+
return use_ddp, local_rank, device
|
| 111 |
|
| 112 |
|
| 113 |
def cleanup_ddp():
|
| 114 |
+
if torch.distributed.is_initialized():
|
| 115 |
+
torch.distributed.barrier()
|
| 116 |
+
torch.distributed.destroy_process_group()
|
| 117 |
|
| 118 |
|
| 119 |
def set_seed(seed: int, local_rank: int):
|
| 120 |
+
torch.manual_seed(seed + local_rank)
|
| 121 |
+
np.random.seed(seed + local_rank)
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
torch.cuda.manual_seed_all(seed + local_rank)
|
| 124 |
|
| 125 |
|
| 126 |
def build_datasets(config: _config.TrainConfig):
|
| 127 |
+
# Use the unified data loader with PyTorch framework
|
| 128 |
+
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
| 129 |
+
return data_loader, data_loader.data_config()
|
| 130 |
|
| 131 |
|
| 132 |
def get_model_state_dict(model):
|
| 133 |
+
"""Get state dict from model, handling DDP wrapper."""
|
| 134 |
+
return (
|
| 135 |
+
model.module.state_dict()
|
| 136 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 137 |
+
else model.state_dict()
|
| 138 |
+
)
|
| 139 |
|
| 140 |
|
| 141 |
def get_model_parameters(model):
|
| 142 |
+
"""Get parameters from model, handling DDP wrapper."""
|
| 143 |
+
return (
|
| 144 |
+
model.module.parameters()
|
| 145 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
| 146 |
+
else model.parameters()
|
| 147 |
+
)
|
| 148 |
|
| 149 |
|
| 150 |
def save_checkpoint(model, optimizer, global_step, config, is_main):
|
| 151 |
+
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
| 152 |
+
if not is_main:
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
# Only save if it's time to save or if it's the final step
|
| 156 |
+
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
| 157 |
+
# Create temporary directory for atomic checkpoint saving
|
| 158 |
+
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
| 159 |
+
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
| 160 |
+
|
| 161 |
+
# Remove any existing temp directory and create new one
|
| 162 |
+
if tmp_ckpt_dir.exists():
|
| 163 |
+
shutil.rmtree(tmp_ckpt_dir)
|
| 164 |
+
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
# Save model state using safetensors (handle shared tensors)
|
| 167 |
+
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 168 |
+
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "pytorch_model.safetensors")
|
| 169 |
+
|
| 170 |
+
# Save optimizer state using PyTorch format
|
| 171 |
+
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
| 172 |
+
|
| 173 |
+
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
| 174 |
+
metadata = {
|
| 175 |
+
"global_step": global_step,
|
| 176 |
+
"config": dataclasses.asdict(config),
|
| 177 |
+
"timestamp": time.time(),
|
| 178 |
+
}
|
| 179 |
+
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
| 180 |
+
|
| 181 |
+
# Atomically move temp directory to final location
|
| 182 |
+
if final_ckpt_dir.exists():
|
| 183 |
+
shutil.rmtree(final_ckpt_dir)
|
| 184 |
+
tmp_ckpt_dir.rename(final_ckpt_dir)
|
| 185 |
+
|
| 186 |
+
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
| 187 |
+
|
| 188 |
+
# Log checkpoint to wandb
|
| 189 |
+
if config.wandb_enabled:
|
| 190 |
+
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
| 191 |
|
| 192 |
|
| 193 |
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
| 194 |
"""Load the latest checkpoint and return the global step."""
|
| 195 |
+
checkpoint_steps = [
|
| 196 |
+
int(d.name)
|
| 197 |
+
for d in checkpoint_dir.iterdir()
|
| 198 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
if not checkpoint_steps:
|
| 202 |
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 203 |
+
|
| 204 |
latest_step = max(checkpoint_steps)
|
| 205 |
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
| 206 |
+
|
| 207 |
# Clear memory before loading checkpoints
|
| 208 |
if torch.cuda.is_available():
|
| 209 |
torch.cuda.empty_cache()
|
|
|
|
| 214 |
# Load model state with error handling
|
| 215 |
logging.info("Loading model state...")
|
| 216 |
safetensors_path = ckpt_dir / "pytorch_model.safetensors"
|
| 217 |
+
|
| 218 |
if safetensors_path.exists():
|
| 219 |
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
| 220 |
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
| 221 |
logging.info("Loaded model state from safetensors format")
|
| 222 |
else:
|
| 223 |
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
| 224 |
+
|
| 225 |
torch.cuda.empty_cache()
|
| 226 |
gc.collect()
|
| 227 |
log_memory_usage(device, latest_step, "after_loading_model")
|
| 228 |
+
|
| 229 |
# Load optimizer state with error handling
|
| 230 |
logging.info("Loading optimizer state...")
|
| 231 |
optimizer_path = ckpt_dir / "optimizer.pt"
|
| 232 |
+
|
| 233 |
if optimizer_path.exists():
|
| 234 |
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
| 235 |
logging.info("Loaded optimizer state from pt format")
|
| 236 |
else:
|
| 237 |
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
| 238 |
+
|
| 239 |
optimizer.load_state_dict(optimizer_state_dict)
|
| 240 |
del optimizer_state_dict
|
| 241 |
torch.cuda.empty_cache()
|
| 242 |
gc.collect()
|
| 243 |
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
| 244 |
+
|
|
|
|
| 245 |
# Load metadata
|
| 246 |
logging.info("Loading metadata...")
|
| 247 |
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
|
|
|
| 250 |
torch.cuda.empty_cache()
|
| 251 |
gc.collect()
|
| 252 |
log_memory_usage(device, latest_step, "after_loading_metadata")
|
| 253 |
+
|
| 254 |
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
| 255 |
return global_step
|
| 256 |
+
|
| 257 |
except RuntimeError as e:
|
| 258 |
if "out of memory" in str(e):
|
| 259 |
# Clear memory and provide detailed error message
|
| 260 |
torch.cuda.empty_cache()
|
| 261 |
gc.collect()
|
| 262 |
+
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
| 263 |
log_memory_usage(device, latest_step, "after_oom_error")
|
| 264 |
+
raise RuntimeError(
|
| 265 |
+
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
| 266 |
+
) from e
|
| 267 |
raise
|
| 268 |
|
| 269 |
|
| 270 |
def get_latest_checkpoint_step(checkpoint_dir):
|
| 271 |
+
"""Get the latest checkpoint step number from a checkpoint directory."""
|
| 272 |
+
checkpoint_steps = [
|
| 273 |
+
int(d.name)
|
| 274 |
+
for d in checkpoint_dir.iterdir()
|
| 275 |
+
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
| 276 |
+
]
|
| 277 |
+
return max(checkpoint_steps) if checkpoint_steps else None
|
| 278 |
|
| 279 |
|
| 280 |
def log_memory_usage(device, step, phase="unknown"):
|
| 281 |
+
"""Log detailed memory usage information."""
|
| 282 |
+
if not torch.cuda.is_available():
|
| 283 |
+
return
|
| 284 |
+
|
| 285 |
+
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
| 286 |
+
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
| 287 |
+
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
| 288 |
+
memory_free = memory_free / 1e9
|
| 289 |
+
|
| 290 |
+
# Get more detailed memory info
|
| 291 |
+
memory_stats = torch.cuda.memory_stats(device)
|
| 292 |
+
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
| 293 |
+
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
| 294 |
+
|
| 295 |
+
# Get DDP info if available
|
| 296 |
+
ddp_info = ""
|
| 297 |
+
if dist.is_initialized():
|
| 298 |
+
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
| 299 |
+
|
| 300 |
+
logging.info(
|
| 301 |
+
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
| 302 |
+
)
|
| 303 |
|
| 304 |
|
| 305 |
def train_loop(config: _config.TrainConfig):
|
| 306 |
+
use_ddp, local_rank, device = setup_ddp()
|
| 307 |
+
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
| 308 |
+
set_seed(config.seed, local_rank)
|
| 309 |
+
|
| 310 |
+
# Initialize checkpoint directory and wandb
|
| 311 |
+
resuming = False
|
| 312 |
+
if config.resume:
|
| 313 |
+
# Find checkpoint directory based on experiment name
|
| 314 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 315 |
+
if exp_checkpoint_dir.exists():
|
| 316 |
+
# Use validation to find the latest working checkpoint
|
| 317 |
+
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
| 318 |
+
if latest_step is not None:
|
| 319 |
+
resuming = True
|
| 320 |
+
logging.info(
|
| 321 |
+
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
| 325 |
+
else:
|
| 326 |
+
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
| 327 |
+
elif config.overwrite and config.checkpoint_dir.exists():
|
| 328 |
+
shutil.rmtree(config.checkpoint_dir)
|
| 329 |
+
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
| 330 |
+
|
| 331 |
+
# Create checkpoint directory with experiment name
|
| 332 |
+
if not resuming:
|
| 333 |
+
# For new runs, create experiment-specific checkpoint directory
|
| 334 |
+
exp_checkpoint_dir = config.checkpoint_dir
|
| 335 |
+
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 336 |
+
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
| 337 |
+
else:
|
| 338 |
+
# For resume, checkpoint_dir is already set to the experiment directory
|
| 339 |
+
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
| 340 |
+
|
| 341 |
+
# Initialize wandb (only on main process)
|
| 342 |
+
if is_main:
|
| 343 |
+
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
| 344 |
+
|
| 345 |
+
# Build data loader using the unified data loader
|
| 346 |
+
# Calculate effective batch size per GPU for DDP
|
| 347 |
+
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
| 348 |
+
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
| 349 |
+
effective_batch_size = config.batch_size // world_size
|
| 350 |
+
logging.info(
|
| 351 |
+
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
| 355 |
+
loader, _ = build_datasets(config)
|
| 356 |
+
|
| 357 |
+
# Log sample images to wandb on first batch
|
| 358 |
+
if is_main and config.wandb_enabled and not resuming:
|
| 359 |
+
# Create a separate data loader for sample batch to avoid consuming the main loader
|
| 360 |
+
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
| 361 |
+
sample_batch = next(iter(sample_data_loader))
|
| 362 |
+
# Convert observation and actions to torch tensors
|
| 363 |
+
observation, actions = sample_batch
|
| 364 |
+
sample_batch = observation.to_dict()
|
| 365 |
+
sample_batch["actions"] = actions
|
| 366 |
+
|
| 367 |
+
# Create sample images for wandb
|
| 368 |
+
images_to_log = []
|
| 369 |
+
# Get batch size from the first image tensor
|
| 370 |
+
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
| 371 |
+
for i in range(min(5, batch_size)):
|
| 372 |
+
# Concatenate all camera views horizontally for this batch item
|
| 373 |
+
# Convert from NCHW to NHWC format for wandb
|
| 374 |
+
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
| 375 |
+
img_concatenated = img_concatenated.cpu().numpy()
|
| 376 |
+
images_to_log.append(wandb.Image(img_concatenated))
|
| 377 |
+
|
| 378 |
+
wandb.log({"camera_views": images_to_log}, step=0)
|
| 379 |
+
|
| 380 |
+
# Clear sample batch from memory aggressively
|
| 381 |
+
del sample_batch, observation, actions, images_to_log, img_concatenated
|
| 382 |
+
del sample_data_loader # Also delete the sample data loader
|
| 383 |
+
gc.collect()
|
| 384 |
+
if torch.cuda.is_available():
|
| 385 |
+
torch.cuda.empty_cache()
|
| 386 |
+
logging.info("Cleared sample batch and data loader from memory")
|
| 387 |
+
|
| 388 |
+
# Build model
|
| 389 |
+
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
| 390 |
+
# Convert dataclass to Pi0Config if needed
|
| 391 |
+
model_cfg = openpi.models.pi0_config.Pi0Config(
|
| 392 |
+
dtype=config.pytorch_training_precision,
|
| 393 |
+
action_dim=config.model.action_dim,
|
| 394 |
+
action_horizon=config.model.action_horizon,
|
| 395 |
+
max_token_len=config.model.max_token_len,
|
| 396 |
+
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
| 397 |
+
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
| 398 |
+
pi05=getattr(config.model, "pi05", False),
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
model_cfg = config.model
|
| 402 |
+
# Update dtype to match pytorch_training_precision
|
| 403 |
+
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
| 404 |
+
|
| 405 |
+
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
|
| 406 |
+
|
| 407 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 408 |
+
enable_gradient_checkpointing = True
|
| 409 |
+
model.gradient_checkpointing_enable()
|
| 410 |
+
logging.info("Enabled gradient checkpointing for memory optimization")
|
| 411 |
+
else:
|
| 412 |
+
enable_gradient_checkpointing = False
|
| 413 |
+
logging.info("Gradient checkpointing is not supported for this model")
|
| 414 |
+
|
| 415 |
+
# Log initial memory usage after model creation
|
| 416 |
+
if is_main and torch.cuda.is_available():
|
| 417 |
+
log_memory_usage(device, 0, "after_model_creation")
|
| 418 |
+
|
| 419 |
+
# Enable memory optimizations for large-scale training
|
| 420 |
+
if world_size >= 8:
|
| 421 |
+
torch.backends.cudnn.benchmark = True
|
| 422 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 423 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 424 |
+
# Set memory allocation configuration
|
| 425 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
| 426 |
+
logging.info("Enabled memory optimizations for 8+ GPU training")
|
| 427 |
+
|
| 428 |
+
if use_ddp:
|
| 429 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 430 |
+
model,
|
| 431 |
+
device_ids=[device.index] if device.type == "cuda" else None,
|
| 432 |
+
find_unused_parameters=True, # Disable for memory efficiency
|
| 433 |
+
gradient_as_bucket_view=True, # Enable for memory efficiency
|
| 434 |
+
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Load weights from weight_loader if specified (for fine-tuning)
|
| 438 |
+
if config.pytorch_weight_path is not None:
|
| 439 |
+
logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
| 440 |
+
|
| 441 |
+
model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
| 442 |
+
safetensors.torch.load_model(
|
| 443 |
+
(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
|
| 444 |
+
)
|
| 445 |
+
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
| 446 |
+
|
| 447 |
+
# Optimizer + learning rate schedule from config
|
| 448 |
+
warmup_steps = config.lr_schedule.warmup_steps
|
| 449 |
+
peak_lr = config.lr_schedule.peak_lr
|
| 450 |
+
decay_steps = config.lr_schedule.decay_steps
|
| 451 |
+
end_lr = config.lr_schedule.decay_lr
|
| 452 |
+
|
| 453 |
+
# Create optimizer with config parameters
|
| 454 |
+
optim = torch.optim.AdamW(
|
| 455 |
+
model.parameters(),
|
| 456 |
+
lr=peak_lr,
|
| 457 |
+
betas=(config.optimizer.b1, config.optimizer.b2),
|
| 458 |
+
eps=config.optimizer.eps,
|
| 459 |
+
weight_decay=config.optimizer.weight_decay,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Load checkpoint if resuming
|
| 463 |
+
global_step = 0
|
| 464 |
+
if resuming:
|
| 465 |
+
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
| 466 |
+
logging.info(f"Resumed training from step {global_step}")
|
| 467 |
+
|
| 468 |
+
def lr_schedule(step: int):
|
| 469 |
+
if step < warmup_steps:
|
| 470 |
+
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
| 471 |
+
init_lr = peak_lr / (warmup_steps + 1)
|
| 472 |
+
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
| 473 |
+
# cosine decay
|
| 474 |
+
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
| 475 |
+
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
| 476 |
+
return end_lr + (peak_lr - end_lr) * cos
|
| 477 |
+
|
| 478 |
+
model.train()
|
| 479 |
+
start_time = time.time()
|
| 480 |
+
infos = [] # Collect stats over log interval
|
| 481 |
+
if is_main:
|
| 482 |
+
logging.info(
|
| 483 |
+
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
| 484 |
+
)
|
| 485 |
+
logging.info(
|
| 486 |
+
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
| 487 |
+
)
|
| 488 |
+
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
| 489 |
+
logging.info(
|
| 490 |
+
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
| 491 |
+
)
|
| 492 |
+
logging.info(
|
| 493 |
+
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
| 494 |
+
)
|
| 495 |
+
logging.info("EMA is not supported for PyTorch training")
|
| 496 |
+
logging.info(f"Training precision: {model_cfg.dtype}")
|
| 497 |
+
|
| 498 |
+
# Training loop - iterate until we reach num_train_steps
|
| 499 |
+
pbar = (
|
| 500 |
+
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
| 501 |
+
if is_main
|
| 502 |
+
else None
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
while global_step < config.num_train_steps:
|
| 506 |
+
# Set epoch for distributed training
|
| 507 |
+
if use_ddp and hasattr(loader, "set_epoch"):
|
| 508 |
+
loader.set_epoch(global_step // len(loader))
|
| 509 |
+
|
| 510 |
+
for observation, actions in loader:
|
| 511 |
+
# Check if we've reached the target number of steps
|
| 512 |
+
if global_step >= config.num_train_steps:
|
| 513 |
+
break
|
| 514 |
+
|
| 515 |
+
# The unified data loader returns (observation, actions) tuple
|
| 516 |
+
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
| 517 |
+
actions = actions.to(torch.float32) # noqa: PLW2901
|
| 518 |
+
actions = actions.to(device) # noqa: PLW2901
|
| 519 |
+
|
| 520 |
+
# Update LR
|
| 521 |
+
for pg in optim.param_groups:
|
| 522 |
+
pg["lr"] = lr_schedule(global_step)
|
| 523 |
+
|
| 524 |
+
# Forward pass
|
| 525 |
+
losses = model(observation, actions)
|
| 526 |
+
# Ensure losses is a tensor and handle different return types
|
| 527 |
+
if isinstance(losses, list | tuple):
|
| 528 |
+
losses = torch.stack(losses)
|
| 529 |
+
elif not isinstance(losses, torch.Tensor):
|
| 530 |
+
losses = torch.tensor(losses, device=device, dtype=torch.float32)
|
| 531 |
+
|
| 532 |
+
loss = losses.mean()
|
| 533 |
+
|
| 534 |
+
# Backward pass
|
| 535 |
+
loss.backward()
|
| 536 |
+
|
| 537 |
+
# Log memory usage after backward pass
|
| 538 |
+
if global_step < 5 and is_main and torch.cuda.is_available():
|
| 539 |
+
log_memory_usage(device, global_step, "after_backward")
|
| 540 |
+
|
| 541 |
+
# Gradient clipping
|
| 542 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
| 543 |
+
|
| 544 |
+
# Optimizer step
|
| 545 |
+
optim.step()
|
| 546 |
+
optim.zero_grad(set_to_none=True)
|
| 547 |
+
|
| 548 |
+
# Clear gradients more aggressively
|
| 549 |
+
for param in model.parameters():
|
| 550 |
+
if param.grad is not None:
|
| 551 |
+
param.grad.detach_()
|
| 552 |
+
param.grad = None
|
| 553 |
+
|
| 554 |
+
# Collect stats
|
| 555 |
+
if is_main:
|
| 556 |
+
infos.append(
|
| 557 |
+
{
|
| 558 |
+
"loss": loss.item(),
|
| 559 |
+
"learning_rate": optim.param_groups[0]["lr"],
|
| 560 |
+
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
| 561 |
+
}
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if is_main and (global_step % config.log_interval == 0):
|
| 565 |
+
elapsed = time.time() - start_time
|
| 566 |
+
|
| 567 |
+
# Average stats over log interval
|
| 568 |
+
avg_loss = sum(info["loss"] for info in infos) / len(infos)
|
| 569 |
+
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
| 570 |
+
|
| 571 |
+
avg_grad_norm = None
|
| 572 |
+
if any("grad_norm" in info for info in infos):
|
| 573 |
+
vals = [
|
| 574 |
+
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
| 575 |
+
]
|
| 576 |
+
if len(vals) > 0:
|
| 577 |
+
avg_grad_norm = sum(vals) / len(vals)
|
| 578 |
+
logging.info(
|
| 579 |
+
f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
| 580 |
+
if avg_grad_norm is not None
|
| 581 |
+
else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Log to wandb
|
| 585 |
+
if config.wandb_enabled and len(infos) > 0:
|
| 586 |
+
log_payload = {
|
| 587 |
+
"loss": avg_loss,
|
| 588 |
+
"learning_rate": avg_lr,
|
| 589 |
+
"step": global_step,
|
| 590 |
+
"time_per_step": elapsed / config.log_interval,
|
| 591 |
+
}
|
| 592 |
+
if avg_grad_norm is not None:
|
| 593 |
+
log_payload["grad_norm"] = avg_grad_norm
|
| 594 |
+
wandb.log(log_payload, step=global_step)
|
| 595 |
+
|
| 596 |
+
start_time = time.time()
|
| 597 |
+
infos = [] # Reset stats collection
|
| 598 |
+
|
| 599 |
+
global_step += 1
|
| 600 |
+
# Save checkpoint using the new mechanism
|
| 601 |
+
save_checkpoint(model, optim, global_step, config, is_main)
|
| 602 |
+
|
| 603 |
+
# Update progress bar
|
| 604 |
+
if pbar is not None:
|
| 605 |
+
pbar.update(1)
|
| 606 |
+
pbar.set_postfix(
|
| 607 |
+
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Close progress bar
|
| 611 |
+
if pbar is not None:
|
| 612 |
+
pbar.close()
|
| 613 |
+
|
| 614 |
+
# Finish wandb run
|
| 615 |
+
if is_main and config.wandb_enabled:
|
| 616 |
+
wandb.finish()
|
| 617 |
+
|
| 618 |
+
cleanup_ddp()
|
| 619 |
|
| 620 |
|
| 621 |
def main():
|
| 622 |
+
init_logging()
|
| 623 |
+
config = _config.cli()
|
| 624 |
+
train_loop(config)
|
| 625 |
|
| 626 |
|
| 627 |
if __name__ == "__main__":
|
| 628 |
+
main()
|
src/openpi/models/model.py
CHANGED
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
| 4 |
import enum
|
| 5 |
import logging
|
| 6 |
import pathlib
|
| 7 |
-
from typing import Generic, TypeVar
|
| 8 |
|
| 9 |
import augmax
|
| 10 |
from flax import nnx
|
|
@@ -12,7 +12,6 @@ from flax import struct
|
|
| 12 |
from flax import traverse_util
|
| 13 |
import jax
|
| 14 |
import jax.numpy as jnp
|
| 15 |
-
import logging
|
| 16 |
import numpy as np
|
| 17 |
import orbax.checkpoint as ocp
|
| 18 |
import safetensors
|
|
@@ -25,7 +24,7 @@ import openpi.shared.array_typing as at
|
|
| 25 |
logger = logging.getLogger("openpi")
|
| 26 |
|
| 27 |
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
|
| 28 |
-
ArrayT = TypeVar("ArrayT", bound=
|
| 29 |
|
| 30 |
|
| 31 |
class ModelType(enum.Enum):
|
|
@@ -117,7 +116,7 @@ class Observation(Generic[ArrayT]):
|
|
| 117 |
for key in data["image"]:
|
| 118 |
if data["image"][key].dtype == np.uint8:
|
| 119 |
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
| 120 |
-
elif hasattr(data["image"][key],
|
| 121 |
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
|
| 122 |
return cls(
|
| 123 |
images=data["image"],
|
|
|
|
| 4 |
import enum
|
| 5 |
import logging
|
| 6 |
import pathlib
|
| 7 |
+
from typing import Generic, TypeVar
|
| 8 |
|
| 9 |
import augmax
|
| 10 |
from flax import nnx
|
|
|
|
| 12 |
from flax import traverse_util
|
| 13 |
import jax
|
| 14 |
import jax.numpy as jnp
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import orbax.checkpoint as ocp
|
| 17 |
import safetensors
|
|
|
|
| 24 |
logger = logging.getLogger("openpi")
|
| 25 |
|
| 26 |
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
|
| 27 |
+
ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
|
| 28 |
|
| 29 |
|
| 30 |
class ModelType(enum.Enum):
|
|
|
|
| 116 |
for key in data["image"]:
|
| 117 |
if data["image"][key].dtype == np.uint8:
|
| 118 |
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
| 119 |
+
elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
|
| 120 |
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
|
| 121 |
return cls(
|
| 122 |
images=data["image"],
|
src/openpi/models/pi0_config.py
CHANGED
|
@@ -48,6 +48,7 @@ class Pi0Config(_model.BaseModelConfig):
|
|
| 48 |
@override
|
| 49 |
def create(self, rng: at.KeyArrayLike) -> "Pi0":
|
| 50 |
from openpi.models.pi0 import Pi0
|
|
|
|
| 51 |
return Pi0(self, rngs=nnx.Rngs(rng))
|
| 52 |
|
| 53 |
@override
|
|
@@ -104,4 +105,4 @@ class Pi0Config(_model.BaseModelConfig):
|
|
| 104 |
)
|
| 105 |
if not filters:
|
| 106 |
return nnx.Nothing
|
| 107 |
-
return nnx.All(*filters)
|
|
|
|
| 48 |
@override
|
| 49 |
def create(self, rng: at.KeyArrayLike) -> "Pi0":
|
| 50 |
from openpi.models.pi0 import Pi0
|
| 51 |
+
|
| 52 |
return Pi0(self, rngs=nnx.Rngs(rng))
|
| 53 |
|
| 54 |
@override
|
|
|
|
| 105 |
)
|
| 106 |
if not filters:
|
| 107 |
return nnx.Nothing
|
| 108 |
+
return nnx.All(*filters)
|
src/openpi/models/tokenizer.py
CHANGED
|
@@ -254,7 +254,7 @@ class FSQTokenizer:
|
|
| 254 |
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
|
| 255 |
# Download tokenizer
|
| 256 |
path = download.maybe_download(fsq_tokenizer_path)
|
| 257 |
-
tok_path = os.path.join(path, os.listdir(path)[0])
|
| 258 |
|
| 259 |
# Split step from path
|
| 260 |
step = int(tok_path.split("/")[-1])
|
|
|
|
| 254 |
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
|
| 255 |
# Download tokenizer
|
| 256 |
path = download.maybe_download(fsq_tokenizer_path)
|
| 257 |
+
tok_path = os.path.join(path, os.listdir(path)[0])
|
| 258 |
|
| 259 |
# Split step from path
|
| 260 |
step = int(tok_path.split("/")[-1])
|
src/openpi/models_pytorch/gemma_pytorch.py
CHANGED
|
@@ -1,19 +1,28 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
from torch import nn
|
| 4 |
-
from transformers import GemmaForCausalLM
|
| 5 |
-
from transformers
|
| 6 |
-
|
| 7 |
from transformers.models.auto import CONFIG_MAPPING
|
| 8 |
-
from
|
| 9 |
|
| 10 |
|
| 11 |
class PaliGemmaWithExpertModel(nn.Module):
|
| 12 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
super().__init__()
|
| 14 |
|
| 15 |
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
| 16 |
-
vlm_config_hf._vocab_size = 257152
|
| 17 |
vlm_config_hf.image_token_index = 257152
|
| 18 |
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
| 19 |
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
|
@@ -53,9 +62,9 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 53 |
|
| 54 |
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
| 55 |
if precision == "bfloat16":
|
| 56 |
-
self
|
| 57 |
elif precision == "float32":
|
| 58 |
-
self
|
| 59 |
return
|
| 60 |
else:
|
| 61 |
raise ValueError(f"Invalid precision: {precision}")
|
|
@@ -83,11 +92,13 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 83 |
self,
|
| 84 |
attention_mask: torch.Tensor | None = None,
|
| 85 |
position_ids: torch.LongTensor | None = None,
|
| 86 |
-
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
| 87 |
-
inputs_embeds: list[torch.FloatTensor] = None,
|
| 88 |
use_cache: bool | None = None,
|
| 89 |
-
adarms_cond: list[torch.Tensor]
|
| 90 |
):
|
|
|
|
|
|
|
| 91 |
if inputs_embeds[1] is None:
|
| 92 |
prefix_output = self.paligemma.language_model.forward(
|
| 93 |
inputs_embeds=inputs_embeds[0],
|
|
@@ -115,45 +126,45 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 115 |
else:
|
| 116 |
models = [self.paligemma.language_model, self.gemma_expert.model]
|
| 117 |
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
| 118 |
-
|
| 119 |
# Check if gradient checkpointing is enabled for any of the models
|
| 120 |
use_gradient_checkpointing = (
|
| 121 |
-
hasattr(self.gemma_expert.model,
|
| 122 |
-
self.gemma_expert.model.gradient_checkpointing
|
| 123 |
-
self.training
|
| 124 |
-
) or (
|
| 125 |
-
|
| 126 |
-
self.gradient_checkpointing and
|
| 127 |
-
self.training
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
| 131 |
-
if self.training and hasattr(self.gemma_expert.model,
|
| 132 |
if not self.gemma_expert.model.gradient_checkpointing:
|
| 133 |
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
| 134 |
self.gemma_expert.model.gradient_checkpointing = True
|
| 135 |
use_gradient_checkpointing = True
|
| 136 |
-
|
| 137 |
# Debug gradient checkpointing status
|
| 138 |
-
if hasattr(self,
|
| 139 |
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
| 140 |
print(f"Model training mode: {self.training}")
|
| 141 |
-
print(
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
self._debug_gc_printed = True
|
| 145 |
-
|
| 146 |
# Define the complete layer computation function for gradient checkpointing
|
| 147 |
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
| 148 |
models = [self.paligemma.language_model, self.gemma_expert.model]
|
| 149 |
-
|
| 150 |
query_states = []
|
| 151 |
key_states = []
|
| 152 |
value_states = []
|
| 153 |
gates = []
|
| 154 |
for i, hidden_states in enumerate(inputs_embeds):
|
| 155 |
layer = models[i].layers[layer_idx]
|
| 156 |
-
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
|
| 157 |
gates.append(gate)
|
| 158 |
|
| 159 |
input_shape = hidden_states.shape[:-1]
|
|
@@ -171,16 +182,29 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 171 |
key_states = torch.cat(key_states, dim=2)
|
| 172 |
value_states = torch.cat(value_states, dim=2)
|
| 173 |
|
| 174 |
-
dummy_tensor = torch.zeros(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
| 176 |
-
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
|
|
|
|
|
|
| 177 |
|
| 178 |
batch_size = query_states.shape[0]
|
| 179 |
scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
|
| 180 |
-
|
| 181 |
# Attention computation
|
| 182 |
att_output, _ = modeling_gemma.eager_attention_forward(
|
| 183 |
-
self.paligemma.language_model.layers[layer_idx].self_attn,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
)
|
| 185 |
# Get head_dim from the current layer, not from the model
|
| 186 |
head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
|
@@ -195,10 +219,10 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 195 |
|
| 196 |
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
| 197 |
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
| 198 |
-
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
| 199 |
|
| 200 |
# first residual
|
| 201 |
-
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i])
|
| 202 |
after_first_residual = out_emb.clone()
|
| 203 |
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
| 204 |
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
|
@@ -207,10 +231,10 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 207 |
|
| 208 |
out_emb = layer.mlp(out_emb)
|
| 209 |
# second residual
|
| 210 |
-
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate)
|
| 211 |
outputs_embeds.append(out_emb)
|
| 212 |
start_pos = end_pos
|
| 213 |
-
|
| 214 |
return outputs_embeds
|
| 215 |
|
| 216 |
# Process all layers with gradient checkpointing if enabled
|
|
@@ -218,12 +242,18 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 218 |
if use_gradient_checkpointing:
|
| 219 |
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
| 220 |
compute_layer_complete,
|
| 221 |
-
layer_idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
use_reentrant=False,
|
| 223 |
-
preserve_rng_state=False
|
| 224 |
)
|
| 225 |
else:
|
| 226 |
-
inputs_embeds = compute_layer_complete(
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# Old code removed - now using compute_layer_complete function above
|
| 229 |
|
|
@@ -235,14 +265,11 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 235 |
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
| 236 |
outputs_embeds.append(out_emb)
|
| 237 |
return outputs_embeds
|
| 238 |
-
|
| 239 |
# Apply gradient checkpointing to final norm if enabled
|
| 240 |
if use_gradient_checkpointing:
|
| 241 |
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
| 242 |
-
compute_final_norms,
|
| 243 |
-
inputs_embeds, adarms_cond,
|
| 244 |
-
use_reentrant=False,
|
| 245 |
-
preserve_rng_state=False
|
| 246 |
)
|
| 247 |
else:
|
| 248 |
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
|
@@ -251,4 +278,4 @@ class PaliGemmaWithExpertModel(nn.Module):
|
|
| 251 |
suffix_output = outputs_embeds[1]
|
| 252 |
prefix_past_key_values = None
|
| 253 |
|
| 254 |
-
return [prefix_output, suffix_output], prefix_past_key_values
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
import torch
|
| 5 |
from torch import nn
|
| 6 |
+
from transformers import GemmaForCausalLM
|
| 7 |
+
from transformers import PaliGemmaForConditionalGeneration
|
|
|
|
| 8 |
from transformers.models.auto import CONFIG_MAPPING
|
| 9 |
+
from transformers.models.gemma import modeling_gemma
|
| 10 |
|
| 11 |
|
| 12 |
class PaliGemmaWithExpertModel(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
vlm_config,
|
| 16 |
+
action_expert_config,
|
| 17 |
+
use_adarms=None,
|
| 18 |
+
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
| 19 |
+
):
|
| 20 |
+
if use_adarms is None:
|
| 21 |
+
use_adarms = [False, False]
|
| 22 |
super().__init__()
|
| 23 |
|
| 24 |
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
| 25 |
+
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
| 26 |
vlm_config_hf.image_token_index = 257152
|
| 27 |
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
| 28 |
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
|
|
|
| 62 |
|
| 63 |
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
| 64 |
if precision == "bfloat16":
|
| 65 |
+
self.to(dtype=torch.bfloat16)
|
| 66 |
elif precision == "float32":
|
| 67 |
+
self.to(dtype=torch.float32)
|
| 68 |
return
|
| 69 |
else:
|
| 70 |
raise ValueError(f"Invalid precision: {precision}")
|
|
|
|
| 92 |
self,
|
| 93 |
attention_mask: torch.Tensor | None = None,
|
| 94 |
position_ids: torch.LongTensor | None = None,
|
| 95 |
+
past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
|
| 96 |
+
inputs_embeds: list[torch.FloatTensor] | None = None,
|
| 97 |
use_cache: bool | None = None,
|
| 98 |
+
adarms_cond: list[torch.Tensor] | None = None,
|
| 99 |
):
|
| 100 |
+
if adarms_cond is None:
|
| 101 |
+
adarms_cond = [None, None]
|
| 102 |
if inputs_embeds[1] is None:
|
| 103 |
prefix_output = self.paligemma.language_model.forward(
|
| 104 |
inputs_embeds=inputs_embeds[0],
|
|
|
|
| 126 |
else:
|
| 127 |
models = [self.paligemma.language_model, self.gemma_expert.model]
|
| 128 |
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
| 129 |
+
|
| 130 |
# Check if gradient checkpointing is enabled for any of the models
|
| 131 |
use_gradient_checkpointing = (
|
| 132 |
+
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
| 133 |
+
and self.gemma_expert.model.gradient_checkpointing
|
| 134 |
+
and self.training
|
| 135 |
+
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
| 136 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
| 138 |
+
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
| 139 |
if not self.gemma_expert.model.gradient_checkpointing:
|
| 140 |
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
| 141 |
self.gemma_expert.model.gradient_checkpointing = True
|
| 142 |
use_gradient_checkpointing = True
|
| 143 |
+
|
| 144 |
# Debug gradient checkpointing status
|
| 145 |
+
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
| 146 |
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
| 147 |
print(f"Model training mode: {self.training}")
|
| 148 |
+
print(
|
| 149 |
+
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
| 150 |
+
)
|
| 151 |
+
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
| 152 |
+
print(
|
| 153 |
+
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
| 154 |
+
)
|
| 155 |
self._debug_gc_printed = True
|
| 156 |
+
|
| 157 |
# Define the complete layer computation function for gradient checkpointing
|
| 158 |
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
| 159 |
models = [self.paligemma.language_model, self.gemma_expert.model]
|
| 160 |
+
|
| 161 |
query_states = []
|
| 162 |
key_states = []
|
| 163 |
value_states = []
|
| 164 |
gates = []
|
| 165 |
for i, hidden_states in enumerate(inputs_embeds):
|
| 166 |
layer = models[i].layers[layer_idx]
|
| 167 |
+
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
| 168 |
gates.append(gate)
|
| 169 |
|
| 170 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
| 182 |
key_states = torch.cat(key_states, dim=2)
|
| 183 |
value_states = torch.cat(value_states, dim=2)
|
| 184 |
|
| 185 |
+
dummy_tensor = torch.zeros(
|
| 186 |
+
query_states.shape[0],
|
| 187 |
+
query_states.shape[2],
|
| 188 |
+
query_states.shape[-1],
|
| 189 |
+
device=query_states.device,
|
| 190 |
+
dtype=query_states.dtype,
|
| 191 |
+
)
|
| 192 |
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
| 193 |
+
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
| 194 |
+
query_states, key_states, cos, sin, unsqueeze_dim=1
|
| 195 |
+
)
|
| 196 |
|
| 197 |
batch_size = query_states.shape[0]
|
| 198 |
scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
|
| 199 |
+
|
| 200 |
# Attention computation
|
| 201 |
att_output, _ = modeling_gemma.eager_attention_forward(
|
| 202 |
+
self.paligemma.language_model.layers[layer_idx].self_attn,
|
| 203 |
+
query_states,
|
| 204 |
+
key_states,
|
| 205 |
+
value_states,
|
| 206 |
+
attention_mask,
|
| 207 |
+
scaling,
|
| 208 |
)
|
| 209 |
# Get head_dim from the current layer, not from the model
|
| 210 |
head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
|
|
|
| 219 |
|
| 220 |
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
| 221 |
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
| 222 |
+
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
| 223 |
|
| 224 |
# first residual
|
| 225 |
+
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
| 226 |
after_first_residual = out_emb.clone()
|
| 227 |
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
| 228 |
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
|
|
|
| 231 |
|
| 232 |
out_emb = layer.mlp(out_emb)
|
| 233 |
# second residual
|
| 234 |
+
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
| 235 |
outputs_embeds.append(out_emb)
|
| 236 |
start_pos = end_pos
|
| 237 |
+
|
| 238 |
return outputs_embeds
|
| 239 |
|
| 240 |
# Process all layers with gradient checkpointing if enabled
|
|
|
|
| 242 |
if use_gradient_checkpointing:
|
| 243 |
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
| 244 |
compute_layer_complete,
|
| 245 |
+
layer_idx,
|
| 246 |
+
inputs_embeds,
|
| 247 |
+
attention_mask,
|
| 248 |
+
position_ids,
|
| 249 |
+
adarms_cond,
|
| 250 |
use_reentrant=False,
|
| 251 |
+
preserve_rng_state=False,
|
| 252 |
)
|
| 253 |
else:
|
| 254 |
+
inputs_embeds = compute_layer_complete(
|
| 255 |
+
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
| 256 |
+
)
|
| 257 |
|
| 258 |
# Old code removed - now using compute_layer_complete function above
|
| 259 |
|
|
|
|
| 265 |
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
| 266 |
outputs_embeds.append(out_emb)
|
| 267 |
return outputs_embeds
|
| 268 |
+
|
| 269 |
# Apply gradient checkpointing to final norm if enabled
|
| 270 |
if use_gradient_checkpointing:
|
| 271 |
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
| 272 |
+
compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False
|
|
|
|
|
|
|
|
|
|
| 273 |
)
|
| 274 |
else:
|
| 275 |
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
|
|
|
| 278 |
suffix_output = outputs_embeds[1]
|
| 279 |
prefix_past_key_values = None
|
| 280 |
|
| 281 |
+
return [prefix_output, suffix_output], prefix_past_key_values
|
src/openpi/models_pytorch/pi0_pytorch.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
import math
|
| 2 |
import logging
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from torch import Tensor
|
| 6 |
from torch import nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
|
| 9 |
import openpi.models.gemma as _gemma
|
| 10 |
from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
|
@@ -17,7 +17,7 @@ def get_safe_dtype(target_dtype, device_type):
|
|
| 17 |
# CPU doesn't support bfloat16, use float32 instead
|
| 18 |
if target_dtype == torch.bfloat16:
|
| 19 |
return torch.float32
|
| 20 |
-
|
| 21 |
return torch.float64
|
| 22 |
return target_dtype
|
| 23 |
|
|
@@ -39,16 +39,14 @@ def create_sinusoidal_pos_embedding(
|
|
| 39 |
# Compute the outer product
|
| 40 |
scaling_factor = 1.0 / period * 2 * math.pi
|
| 41 |
sin_input = scaling_factor[None, :] * time[:, None]
|
| 42 |
-
|
| 43 |
-
return pos_emb
|
| 44 |
|
| 45 |
|
| 46 |
def sample_beta(alpha, beta, bsize, device):
|
| 47 |
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
| 48 |
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
| 49 |
dist = torch.distributions.Beta(alpha_t, beta_t)
|
| 50 |
-
|
| 51 |
-
return samples
|
| 52 |
|
| 53 |
|
| 54 |
def make_att_2d_masks(pad_masks, att_masks):
|
|
@@ -80,8 +78,7 @@ def make_att_2d_masks(pad_masks, att_masks):
|
|
| 80 |
cumsum = torch.cumsum(att_masks, dim=1)
|
| 81 |
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
| 82 |
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
| 83 |
-
|
| 84 |
-
return att_2d_masks
|
| 85 |
|
| 86 |
|
| 87 |
class PI0Pytorch(nn.Module):
|
|
@@ -93,7 +90,12 @@ class PI0Pytorch(nn.Module):
|
|
| 93 |
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
| 94 |
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
| 95 |
|
| 96 |
-
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
self.action_in_proj = nn.Linear(32, action_expert_config.width)
|
| 99 |
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
|
|
@@ -106,17 +108,20 @@ class PI0Pytorch(nn.Module):
|
|
| 106 |
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
| 107 |
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
| 108 |
|
| 109 |
-
torch.set_float32_matmul_precision(
|
| 110 |
self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
|
| 111 |
-
|
| 112 |
# Initialize gradient checkpointing flag
|
| 113 |
self.gradient_checkpointing_enabled = False
|
|
|
|
|
|
|
| 114 |
try:
|
| 115 |
from transformers.models.siglip import check
|
|
|
|
| 116 |
if not check.check_whether_transformers_replace_is_installed_correctly():
|
| 117 |
-
raise ValueError(
|
| 118 |
except ImportError:
|
| 119 |
-
raise ValueError(
|
| 120 |
|
| 121 |
def gradient_checkpointing_enable(self):
|
| 122 |
"""Enable gradient checkpointing for memory optimization."""
|
|
@@ -124,7 +129,7 @@ class PI0Pytorch(nn.Module):
|
|
| 124 |
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
| 125 |
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
| 126 |
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
| 127 |
-
|
| 128 |
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
| 129 |
|
| 130 |
def gradient_checkpointing_disable(self):
|
|
@@ -133,7 +138,7 @@ class PI0Pytorch(nn.Module):
|
|
| 133 |
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
| 134 |
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
| 135 |
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
| 136 |
-
|
| 137 |
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
| 138 |
|
| 139 |
def is_gradient_checkpointing_enabled(self):
|
|
@@ -146,15 +151,14 @@ class PI0Pytorch(nn.Module):
|
|
| 146 |
return torch.utils.checkpoint.checkpoint(
|
| 147 |
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
| 148 |
)
|
| 149 |
-
|
| 150 |
-
return func(*args, **kwargs)
|
| 151 |
|
| 152 |
def _prepare_attention_masks_4d(self, att_2d_masks):
|
| 153 |
"""Helper method to prepare 4D attention masks for transformer."""
|
| 154 |
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
| 155 |
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
| 156 |
|
| 157 |
-
def _preprocess_observation(self, observation, train=True):
|
| 158 |
"""Helper method to preprocess observation."""
|
| 159 |
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
| 160 |
return (
|
|
@@ -162,18 +166,17 @@ class PI0Pytorch(nn.Module):
|
|
| 162 |
list(observation.image_masks.values()),
|
| 163 |
observation.tokenized_prompt,
|
| 164 |
observation.tokenized_prompt_mask,
|
| 165 |
-
observation.state
|
| 166 |
)
|
| 167 |
|
| 168 |
def sample_noise(self, shape, device):
|
| 169 |
-
|
| 170 |
mean=0.0,
|
| 171 |
std=1.0,
|
| 172 |
size=shape,
|
| 173 |
dtype=torch.float32,
|
| 174 |
device=device,
|
| 175 |
)
|
| 176 |
-
return noise
|
| 177 |
|
| 178 |
def sample_time(self, bsize, device):
|
| 179 |
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
|
@@ -189,19 +192,19 @@ class PI0Pytorch(nn.Module):
|
|
| 189 |
embs = []
|
| 190 |
pad_masks = []
|
| 191 |
att_masks = []
|
| 192 |
-
|
| 193 |
# Process images
|
| 194 |
-
for img, img_mask in zip(images, img_masks):
|
|
|
|
| 195 |
def image_embed_func(img):
|
| 196 |
return self.paligemma_with_expert.embed_image(img)
|
| 197 |
-
|
| 198 |
img_emb = self._apply_checkpoint(image_embed_func, img)
|
| 199 |
|
| 200 |
bsize, num_img_embs = img_emb.shape[:2]
|
| 201 |
-
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
| 202 |
|
| 203 |
embs.append(img_emb)
|
| 204 |
-
pad_masks.append(img_mask)
|
| 205 |
|
| 206 |
# Create attention masks so that image tokens attend to each other
|
| 207 |
att_masks += [0] * num_img_embs
|
|
@@ -211,7 +214,7 @@ class PI0Pytorch(nn.Module):
|
|
| 211 |
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
| 212 |
lang_emb_dim = lang_emb.shape[-1]
|
| 213 |
return lang_emb * math.sqrt(lang_emb_dim)
|
| 214 |
-
|
| 215 |
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
| 216 |
|
| 217 |
embs.append(lang_emb)
|
|
@@ -239,16 +242,16 @@ class PI0Pytorch(nn.Module):
|
|
| 239 |
|
| 240 |
if not self.pi05:
|
| 241 |
if self.state_proj.weight.dtype == torch.float32:
|
| 242 |
-
|
|
|
|
| 243 |
# Embed state
|
| 244 |
def state_proj_func(state):
|
| 245 |
return self.state_proj(state)
|
| 246 |
-
|
| 247 |
state_emb = self._apply_checkpoint(state_proj_func, state)
|
| 248 |
-
|
| 249 |
embs.append(state_emb[:, None, :])
|
| 250 |
bsize = state_emb.shape[0]
|
| 251 |
-
dtype = state_emb.dtype
|
| 252 |
device = state_emb.device
|
| 253 |
|
| 254 |
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
|
@@ -266,20 +269,19 @@ class PI0Pytorch(nn.Module):
|
|
| 266 |
# Fuse timestep + action information using an MLP
|
| 267 |
def action_proj_func(noisy_actions):
|
| 268 |
return self.action_in_proj(noisy_actions)
|
| 269 |
-
|
| 270 |
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
| 271 |
|
| 272 |
if not self.pi05:
|
| 273 |
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
| 274 |
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
| 275 |
-
|
| 276 |
# Apply MLP layers
|
| 277 |
def mlp_func(action_time_emb):
|
| 278 |
x = self.action_time_mlp_in(action_time_emb)
|
| 279 |
x = F.silu(x) # swish == silu
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
| 284 |
adarms_cond = None
|
| 285 |
else:
|
|
@@ -288,9 +290,8 @@ class PI0Pytorch(nn.Module):
|
|
| 288 |
x = self.time_mlp_in(time_emb)
|
| 289 |
x = F.silu(x) # swish == silu
|
| 290 |
x = self.time_mlp_out(x)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
| 295 |
action_time_emb = action_emb
|
| 296 |
adarms_cond = time_emb
|
|
@@ -328,7 +329,10 @@ class PI0Pytorch(nn.Module):
|
|
| 328 |
|
| 329 |
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
|
| 330 |
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
| 331 |
-
if
|
|
|
|
|
|
|
|
|
|
| 332 |
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
| 333 |
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
| 334 |
|
|
@@ -349,25 +353,24 @@ class PI0Pytorch(nn.Module):
|
|
| 349 |
past_key_values=None,
|
| 350 |
inputs_embeds=[prefix_embs, suffix_embs],
|
| 351 |
use_cache=False,
|
| 352 |
-
adarms_cond=[None, adarms_cond]
|
| 353 |
)
|
| 354 |
return suffix_out
|
| 355 |
-
|
| 356 |
suffix_out = self._apply_checkpoint(
|
| 357 |
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
| 358 |
)
|
| 359 |
-
|
| 360 |
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
| 361 |
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 362 |
|
| 363 |
# Apply gradient checkpointing to final action projection if enabled
|
| 364 |
def action_out_proj_func(suffix_out):
|
| 365 |
return self.action_out_proj(suffix_out)
|
| 366 |
-
|
| 367 |
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
| 368 |
|
| 369 |
-
|
| 370 |
-
return losses
|
| 371 |
|
| 372 |
@torch.no_grad()
|
| 373 |
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
|
@@ -376,7 +379,7 @@ class PI0Pytorch(nn.Module):
|
|
| 376 |
if noise is None:
|
| 377 |
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
| 378 |
noise = self.sample_noise(actions_shape, device)
|
| 379 |
-
|
| 380 |
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
|
| 381 |
|
| 382 |
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
|
|
@@ -385,7 +388,7 @@ class PI0Pytorch(nn.Module):
|
|
| 385 |
|
| 386 |
# Compute image and language key value cache
|
| 387 |
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
| 388 |
-
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager"
|
| 389 |
|
| 390 |
_, past_key_values = self.paligemma_with_expert.forward(
|
| 391 |
attention_mask=prefix_att_2d_masks_4d,
|
|
@@ -441,7 +444,7 @@ class PI0Pytorch(nn.Module):
|
|
| 441 |
|
| 442 |
# Prepare attention masks
|
| 443 |
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
| 444 |
-
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager"
|
| 445 |
|
| 446 |
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
| 447 |
attention_mask=full_att_2d_masks_4d,
|
|
@@ -449,12 +452,10 @@ class PI0Pytorch(nn.Module):
|
|
| 449 |
past_key_values=past_key_values,
|
| 450 |
inputs_embeds=[None, suffix_embs],
|
| 451 |
use_cache=False,
|
| 452 |
-
adarms_cond=[None, adarms_cond]
|
| 453 |
)
|
| 454 |
|
| 455 |
suffix_out = outputs_embeds[1]
|
| 456 |
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
| 457 |
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
return v_t
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import math
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from torch import Tensor
|
| 6 |
from torch import nn
|
| 7 |
+
import torch.nn.functional as F # noqa: N812
|
| 8 |
|
| 9 |
import openpi.models.gemma as _gemma
|
| 10 |
from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
|
|
|
| 17 |
# CPU doesn't support bfloat16, use float32 instead
|
| 18 |
if target_dtype == torch.bfloat16:
|
| 19 |
return torch.float32
|
| 20 |
+
if target_dtype == torch.float64:
|
| 21 |
return torch.float64
|
| 22 |
return target_dtype
|
| 23 |
|
|
|
|
| 39 |
# Compute the outer product
|
| 40 |
scaling_factor = 1.0 / period * 2 * math.pi
|
| 41 |
sin_input = scaling_factor[None, :] * time[:, None]
|
| 42 |
+
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def sample_beta(alpha, beta, bsize, device):
|
| 46 |
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
| 47 |
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
| 48 |
dist = torch.distributions.Beta(alpha_t, beta_t)
|
| 49 |
+
return dist.sample((bsize,))
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def make_att_2d_masks(pad_masks, att_masks):
|
|
|
|
| 78 |
cumsum = torch.cumsum(att_masks, dim=1)
|
| 79 |
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
| 80 |
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
| 81 |
+
return att_2d_masks & pad_2d_masks
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
class PI0Pytorch(nn.Module):
|
|
|
|
| 90 |
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
| 91 |
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
| 92 |
|
| 93 |
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
| 94 |
+
paligemma_config,
|
| 95 |
+
action_expert_config,
|
| 96 |
+
use_adarms=[False, True] if self.pi05 else [False, False],
|
| 97 |
+
precision=config.dtype,
|
| 98 |
+
)
|
| 99 |
|
| 100 |
self.action_in_proj = nn.Linear(32, action_expert_config.width)
|
| 101 |
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
|
|
|
|
| 108 |
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
| 109 |
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
| 110 |
|
| 111 |
+
torch.set_float32_matmul_precision("high")
|
| 112 |
self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
|
| 113 |
+
|
| 114 |
# Initialize gradient checkpointing flag
|
| 115 |
self.gradient_checkpointing_enabled = False
|
| 116 |
+
|
| 117 |
+
msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`."
|
| 118 |
try:
|
| 119 |
from transformers.models.siglip import check
|
| 120 |
+
|
| 121 |
if not check.check_whether_transformers_replace_is_installed_correctly():
|
| 122 |
+
raise ValueError(msg)
|
| 123 |
except ImportError:
|
| 124 |
+
raise ValueError(msg) from None
|
| 125 |
|
| 126 |
def gradient_checkpointing_enable(self):
|
| 127 |
"""Enable gradient checkpointing for memory optimization."""
|
|
|
|
| 129 |
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
| 130 |
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
| 131 |
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
| 132 |
+
|
| 133 |
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
| 134 |
|
| 135 |
def gradient_checkpointing_disable(self):
|
|
|
|
| 138 |
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
| 139 |
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
| 140 |
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
| 141 |
+
|
| 142 |
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
| 143 |
|
| 144 |
def is_gradient_checkpointing_enabled(self):
|
|
|
|
| 151 |
return torch.utils.checkpoint.checkpoint(
|
| 152 |
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
| 153 |
)
|
| 154 |
+
return func(*args, **kwargs)
|
|
|
|
| 155 |
|
| 156 |
def _prepare_attention_masks_4d(self, att_2d_masks):
|
| 157 |
"""Helper method to prepare 4D attention masks for transformer."""
|
| 158 |
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
| 159 |
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
| 160 |
|
| 161 |
+
def _preprocess_observation(self, observation, *, train=True):
|
| 162 |
"""Helper method to preprocess observation."""
|
| 163 |
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
| 164 |
return (
|
|
|
|
| 166 |
list(observation.image_masks.values()),
|
| 167 |
observation.tokenized_prompt,
|
| 168 |
observation.tokenized_prompt_mask,
|
| 169 |
+
observation.state,
|
| 170 |
)
|
| 171 |
|
| 172 |
def sample_noise(self, shape, device):
|
| 173 |
+
return torch.normal(
|
| 174 |
mean=0.0,
|
| 175 |
std=1.0,
|
| 176 |
size=shape,
|
| 177 |
dtype=torch.float32,
|
| 178 |
device=device,
|
| 179 |
)
|
|
|
|
| 180 |
|
| 181 |
def sample_time(self, bsize, device):
|
| 182 |
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
|
|
|
| 192 |
embs = []
|
| 193 |
pad_masks = []
|
| 194 |
att_masks = []
|
| 195 |
+
|
| 196 |
# Process images
|
| 197 |
+
for img, img_mask in zip(images, img_masks, strict=True):
|
| 198 |
+
|
| 199 |
def image_embed_func(img):
|
| 200 |
return self.paligemma_with_expert.embed_image(img)
|
| 201 |
+
|
| 202 |
img_emb = self._apply_checkpoint(image_embed_func, img)
|
| 203 |
|
| 204 |
bsize, num_img_embs = img_emb.shape[:2]
|
|
|
|
| 205 |
|
| 206 |
embs.append(img_emb)
|
| 207 |
+
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
| 208 |
|
| 209 |
# Create attention masks so that image tokens attend to each other
|
| 210 |
att_masks += [0] * num_img_embs
|
|
|
|
| 214 |
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
| 215 |
lang_emb_dim = lang_emb.shape[-1]
|
| 216 |
return lang_emb * math.sqrt(lang_emb_dim)
|
| 217 |
+
|
| 218 |
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
| 219 |
|
| 220 |
embs.append(lang_emb)
|
|
|
|
| 242 |
|
| 243 |
if not self.pi05:
|
| 244 |
if self.state_proj.weight.dtype == torch.float32:
|
| 245 |
+
state = state.to(torch.float32)
|
| 246 |
+
|
| 247 |
# Embed state
|
| 248 |
def state_proj_func(state):
|
| 249 |
return self.state_proj(state)
|
| 250 |
+
|
| 251 |
state_emb = self._apply_checkpoint(state_proj_func, state)
|
| 252 |
+
|
| 253 |
embs.append(state_emb[:, None, :])
|
| 254 |
bsize = state_emb.shape[0]
|
|
|
|
| 255 |
device = state_emb.device
|
| 256 |
|
| 257 |
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
|
|
|
| 269 |
# Fuse timestep + action information using an MLP
|
| 270 |
def action_proj_func(noisy_actions):
|
| 271 |
return self.action_in_proj(noisy_actions)
|
| 272 |
+
|
| 273 |
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
| 274 |
|
| 275 |
if not self.pi05:
|
| 276 |
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
| 277 |
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
| 278 |
+
|
| 279 |
# Apply MLP layers
|
| 280 |
def mlp_func(action_time_emb):
|
| 281 |
x = self.action_time_mlp_in(action_time_emb)
|
| 282 |
x = F.silu(x) # swish == silu
|
| 283 |
+
return self.action_time_mlp_out(x)
|
| 284 |
+
|
|
|
|
| 285 |
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
| 286 |
adarms_cond = None
|
| 287 |
else:
|
|
|
|
| 290 |
x = self.time_mlp_in(time_emb)
|
| 291 |
x = F.silu(x) # swish == silu
|
| 292 |
x = self.time_mlp_out(x)
|
| 293 |
+
return F.silu(x)
|
| 294 |
+
|
|
|
|
| 295 |
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
| 296 |
action_time_emb = action_emb
|
| 297 |
adarms_cond = time_emb
|
|
|
|
| 329 |
|
| 330 |
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
|
| 331 |
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
| 332 |
+
if (
|
| 333 |
+
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
| 334 |
+
== torch.bfloat16
|
| 335 |
+
):
|
| 336 |
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
| 337 |
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
| 338 |
|
|
|
|
| 353 |
past_key_values=None,
|
| 354 |
inputs_embeds=[prefix_embs, suffix_embs],
|
| 355 |
use_cache=False,
|
| 356 |
+
adarms_cond=[None, adarms_cond],
|
| 357 |
)
|
| 358 |
return suffix_out
|
| 359 |
+
|
| 360 |
suffix_out = self._apply_checkpoint(
|
| 361 |
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
| 362 |
)
|
| 363 |
+
|
| 364 |
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
| 365 |
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 366 |
|
| 367 |
# Apply gradient checkpointing to final action projection if enabled
|
| 368 |
def action_out_proj_func(suffix_out):
|
| 369 |
return self.action_out_proj(suffix_out)
|
| 370 |
+
|
| 371 |
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
| 372 |
|
| 373 |
+
return F.mse_loss(u_t, v_t, reduction="none")
|
|
|
|
| 374 |
|
| 375 |
@torch.no_grad()
|
| 376 |
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
|
|
|
| 379 |
if noise is None:
|
| 380 |
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
| 381 |
noise = self.sample_noise(actions_shape, device)
|
| 382 |
+
|
| 383 |
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
|
| 384 |
|
| 385 |
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
|
|
|
|
| 388 |
|
| 389 |
# Compute image and language key value cache
|
| 390 |
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
| 391 |
+
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
| 392 |
|
| 393 |
_, past_key_values = self.paligemma_with_expert.forward(
|
| 394 |
attention_mask=prefix_att_2d_masks_4d,
|
|
|
|
| 444 |
|
| 445 |
# Prepare attention masks
|
| 446 |
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
| 447 |
+
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
| 448 |
|
| 449 |
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
| 450 |
attention_mask=full_att_2d_masks_4d,
|
|
|
|
| 452 |
past_key_values=past_key_values,
|
| 453 |
inputs_embeds=[None, suffix_embs],
|
| 454 |
use_cache=False,
|
| 455 |
+
adarms_cond=[None, adarms_cond],
|
| 456 |
)
|
| 457 |
|
| 458 |
suffix_out = outputs_embeds[1]
|
| 459 |
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
| 460 |
suffix_out = suffix_out.to(dtype=torch.float32)
|
| 461 |
+
return self.action_out_proj(suffix_out)
|
|
|
|
|
|
src/openpi/models_pytorch/preprocessing_pytorch.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
-
import logging
|
| 2 |
from collections.abc import Sequence
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
from openpi.shared import image_tools
|
|
@@ -15,6 +16,7 @@ IMAGE_KEYS = (
|
|
| 15 |
|
| 16 |
IMAGE_RESOLUTION = (224, 224)
|
| 17 |
|
|
|
|
| 18 |
def preprocess_observation_pytorch(
|
| 19 |
observation,
|
| 20 |
*,
|
|
@@ -23,7 +25,7 @@ def preprocess_observation_pytorch(
|
|
| 23 |
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
| 24 |
):
|
| 25 |
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
| 26 |
-
|
| 27 |
This function avoids complex type annotations that can cause torch.compile issues.
|
| 28 |
"""
|
| 29 |
if not set(image_keys).issubset(observation.images):
|
|
@@ -67,14 +69,14 @@ def preprocess_observation_pytorch(
|
|
| 67 |
# Use tensor operations instead of .item() for torch.compile compatibility
|
| 68 |
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
| 69 |
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
| 70 |
-
image = image[:, start_h:start_h + crop_height, start_w:start_w + crop_width, :]
|
| 71 |
|
| 72 |
# Resize back to original size
|
| 73 |
image = torch.nn.functional.interpolate(
|
| 74 |
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
| 75 |
size=(height, width),
|
| 76 |
-
mode=
|
| 77 |
-
align_corners=False
|
| 78 |
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
| 79 |
|
| 80 |
# Random rotation (small angles)
|
|
@@ -93,7 +95,7 @@ def preprocess_observation_pytorch(
|
|
| 93 |
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
| 94 |
|
| 95 |
# Create meshgrid
|
| 96 |
-
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing=
|
| 97 |
|
| 98 |
# Expand to batch dimension
|
| 99 |
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
|
@@ -109,9 +111,9 @@ def preprocess_observation_pytorch(
|
|
| 109 |
image = torch.nn.functional.grid_sample(
|
| 110 |
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
| 111 |
grid,
|
| 112 |
-
mode=
|
| 113 |
-
padding_mode=
|
| 114 |
-
align_corners=False
|
| 115 |
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
| 116 |
|
| 117 |
# Color augmentations for all cameras
|
|
@@ -159,7 +161,7 @@ def preprocess_observation_pytorch(
|
|
| 159 |
def __init__(self, **kwargs):
|
| 160 |
for key, value in kwargs.items():
|
| 161 |
setattr(self, key, value)
|
| 162 |
-
|
| 163 |
return SimpleProcessedObservation(
|
| 164 |
images=out_images,
|
| 165 |
image_masks=out_masks,
|
|
|
|
|
|
|
| 1 |
from collections.abc import Sequence
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
import torch
|
| 5 |
|
| 6 |
from openpi.shared import image_tools
|
|
|
|
| 16 |
|
| 17 |
IMAGE_RESOLUTION = (224, 224)
|
| 18 |
|
| 19 |
+
|
| 20 |
def preprocess_observation_pytorch(
|
| 21 |
observation,
|
| 22 |
*,
|
|
|
|
| 25 |
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
| 26 |
):
|
| 27 |
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
| 28 |
+
|
| 29 |
This function avoids complex type annotations that can cause torch.compile issues.
|
| 30 |
"""
|
| 31 |
if not set(image_keys).issubset(observation.images):
|
|
|
|
| 69 |
# Use tensor operations instead of .item() for torch.compile compatibility
|
| 70 |
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
| 71 |
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
| 72 |
+
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
| 73 |
|
| 74 |
# Resize back to original size
|
| 75 |
image = torch.nn.functional.interpolate(
|
| 76 |
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
| 77 |
size=(height, width),
|
| 78 |
+
mode="bilinear",
|
| 79 |
+
align_corners=False,
|
| 80 |
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
| 81 |
|
| 82 |
# Random rotation (small angles)
|
|
|
|
| 95 |
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
| 96 |
|
| 97 |
# Create meshgrid
|
| 98 |
+
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
| 99 |
|
| 100 |
# Expand to batch dimension
|
| 101 |
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
|
|
|
| 111 |
image = torch.nn.functional.grid_sample(
|
| 112 |
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
| 113 |
grid,
|
| 114 |
+
mode="bilinear",
|
| 115 |
+
padding_mode="zeros",
|
| 116 |
+
align_corners=False,
|
| 117 |
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
| 118 |
|
| 119 |
# Color augmentations for all cameras
|
|
|
|
| 161 |
def __init__(self, **kwargs):
|
| 162 |
for key, value in kwargs.items():
|
| 163 |
setattr(self, key, value)
|
| 164 |
+
|
| 165 |
return SimpleProcessedObservation(
|
| 166 |
images=out_images,
|
| 167 |
image_masks=out_masks,
|
src/openpi/policies/policy.py
CHANGED
|
@@ -35,7 +35,7 @@ class Policy(BasePolicy):
|
|
| 35 |
is_pytorch: bool = False,
|
| 36 |
):
|
| 37 |
"""Initialize the Policy.
|
| 38 |
-
|
| 39 |
Args:
|
| 40 |
model: The model to use for action sampling.
|
| 41 |
rng: Random number generator key for JAX models. Ignored for PyTorch models.
|
|
@@ -43,7 +43,7 @@ class Policy(BasePolicy):
|
|
| 43 |
output_transforms: Output data transformations to apply after inference.
|
| 44 |
sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
|
| 45 |
metadata: Additional metadata to store with the policy.
|
| 46 |
-
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
|
| 47 |
Only relevant when is_pytorch=True.
|
| 48 |
is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
|
| 49 |
"""
|
|
@@ -81,10 +81,7 @@ class Policy(BasePolicy):
|
|
| 81 |
# Prepare kwargs for sample_actions
|
| 82 |
sample_kwargs = dict(self._sample_kwargs)
|
| 83 |
if noise is not None:
|
| 84 |
-
if self._is_pytorch_model
|
| 85 |
-
noise = torch.from_numpy(noise).to(self._pytorch_device)
|
| 86 |
-
else:
|
| 87 |
-
noise = jnp.asarray(noise)
|
| 88 |
|
| 89 |
if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
|
| 90 |
noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
|
|
|
|
| 35 |
is_pytorch: bool = False,
|
| 36 |
):
|
| 37 |
"""Initialize the Policy.
|
| 38 |
+
|
| 39 |
Args:
|
| 40 |
model: The model to use for action sampling.
|
| 41 |
rng: Random number generator key for JAX models. Ignored for PyTorch models.
|
|
|
|
| 43 |
output_transforms: Output data transformations to apply after inference.
|
| 44 |
sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
|
| 45 |
metadata: Additional metadata to store with the policy.
|
| 46 |
+
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
|
| 47 |
Only relevant when is_pytorch=True.
|
| 48 |
is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
|
| 49 |
"""
|
|
|
|
| 81 |
# Prepare kwargs for sample_actions
|
| 82 |
sample_kwargs = dict(self._sample_kwargs)
|
| 83 |
if noise is not None:
|
| 84 |
+
noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
|
| 87 |
noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
|
src/openpi/policies/policy_config.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
-
import pathlib
|
| 3 |
import os
|
|
|
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
import jax.numpy as jnp
|
|
@@ -35,9 +35,9 @@ def create_trained_policy(
|
|
| 35 |
data if it doesn't already exist.
|
| 36 |
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
|
| 37 |
from the checkpoint directory.
|
| 38 |
-
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
|
| 39 |
If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
|
| 40 |
-
|
| 41 |
Note:
|
| 42 |
The function automatically detects whether the model is PyTorch-based by checking for the
|
| 43 |
presence of "model.safensors" in the checkpoint directory.
|
|
@@ -52,7 +52,7 @@ def create_trained_policy(
|
|
| 52 |
logging.info("Loading model...")
|
| 53 |
if is_pytorch:
|
| 54 |
model = train_config.model.load_pytorch(train_config, weight_path)
|
| 55 |
-
model.paligemma_with_expert.to_bfloat16_for_selected_params(
|
| 56 |
else:
|
| 57 |
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
|
| 58 |
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
|
|
@@ -67,13 +67,11 @@ def create_trained_policy(
|
|
| 67 |
if is_pytorch and pytorch_device is None:
|
| 68 |
try:
|
| 69 |
import torch
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
else:
|
| 73 |
-
pytorch_device = "cpu"
|
| 74 |
except ImportError:
|
| 75 |
pytorch_device = "cpu"
|
| 76 |
-
|
| 77 |
return _policy.Policy(
|
| 78 |
model,
|
| 79 |
transforms=[
|
|
|
|
| 1 |
import logging
|
|
|
|
| 2 |
import os
|
| 3 |
+
import pathlib
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
import jax.numpy as jnp
|
|
|
|
| 35 |
data if it doesn't already exist.
|
| 36 |
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
|
| 37 |
from the checkpoint directory.
|
| 38 |
+
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
|
| 39 |
If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
|
| 40 |
+
|
| 41 |
Note:
|
| 42 |
The function automatically detects whether the model is PyTorch-based by checking for the
|
| 43 |
presence of "model.safensors" in the checkpoint directory.
|
|
|
|
| 52 |
logging.info("Loading model...")
|
| 53 |
if is_pytorch:
|
| 54 |
model = train_config.model.load_pytorch(train_config, weight_path)
|
| 55 |
+
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
|
| 56 |
else:
|
| 57 |
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
|
| 58 |
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
|
|
|
|
| 67 |
if is_pytorch and pytorch_device is None:
|
| 68 |
try:
|
| 69 |
import torch
|
| 70 |
+
|
| 71 |
+
pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 72 |
except ImportError:
|
| 73 |
pytorch_device = "cpu"
|
| 74 |
+
|
| 75 |
return _policy.Policy(
|
| 76 |
model,
|
| 77 |
transforms=[
|
src/openpi/shared/array_typing.py
CHANGED
|
@@ -7,7 +7,6 @@ import beartype
|
|
| 7 |
import jax
|
| 8 |
import jax._src.tree_util as private_tree_util
|
| 9 |
import jax.core
|
| 10 |
-
from jaxtyping import Array # noqa: F401
|
| 11 |
from jaxtyping import ArrayLike
|
| 12 |
from jaxtyping import Bool # noqa: F401
|
| 13 |
from jaxtyping import DTypeLike # noqa: F401
|
|
@@ -31,6 +30,7 @@ _original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_an
|
|
| 31 |
# Redefine Array to include both JAX arrays and PyTorch tensors
|
| 32 |
Array = jax.Array | torch.Tensor
|
| 33 |
|
|
|
|
| 34 |
def _check_dataclass_annotations(self, typechecker):
|
| 35 |
if not any(
|
| 36 |
frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}
|
|
|
|
| 7 |
import jax
|
| 8 |
import jax._src.tree_util as private_tree_util
|
| 9 |
import jax.core
|
|
|
|
| 10 |
from jaxtyping import ArrayLike
|
| 11 |
from jaxtyping import Bool # noqa: F401
|
| 12 |
from jaxtyping import DTypeLike # noqa: F401
|
|
|
|
| 30 |
# Redefine Array to include both JAX arrays and PyTorch tensors
|
| 31 |
Array = jax.Array | torch.Tensor
|
| 32 |
|
| 33 |
+
|
| 34 |
def _check_dataclass_annotations(self, typechecker):
|
| 35 |
if not any(
|
| 36 |
frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}
|
src/openpi/shared/image_tools.py
CHANGED
|
@@ -3,7 +3,7 @@ import functools
|
|
| 3 |
import jax
|
| 4 |
import jax.numpy as jnp
|
| 5 |
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
|
| 8 |
import openpi.shared.array_typing as at
|
| 9 |
|
|
@@ -60,13 +60,13 @@ def resize_with_pad_torch(
|
|
| 60 |
) -> torch.Tensor:
|
| 61 |
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
| 62 |
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
| 63 |
-
|
| 64 |
Args:
|
| 65 |
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
| 66 |
height: Target height
|
| 67 |
width: Target width
|
| 68 |
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
| 69 |
-
|
| 70 |
Returns:
|
| 71 |
Resized and padded tensor with same shape format as input
|
| 72 |
"""
|
|
@@ -91,10 +91,7 @@ def resize_with_pad_torch(
|
|
| 91 |
|
| 92 |
# Resize
|
| 93 |
resized_images = F.interpolate(
|
| 94 |
-
images,
|
| 95 |
-
size=(resized_height, resized_width),
|
| 96 |
-
mode=mode,
|
| 97 |
-
align_corners=False if mode == "bilinear" else None
|
| 98 |
)
|
| 99 |
|
| 100 |
# Handle dtype-specific clipping
|
|
@@ -116,8 +113,8 @@ def resize_with_pad_torch(
|
|
| 116 |
padded_images = F.pad(
|
| 117 |
resized_images,
|
| 118 |
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
| 119 |
-
mode=
|
| 120 |
-
value=constant_value
|
| 121 |
)
|
| 122 |
|
| 123 |
# Convert back to original format if needed
|
|
@@ -126,4 +123,4 @@ def resize_with_pad_torch(
|
|
| 126 |
if batch_size == 1 and images.shape[0] == 1:
|
| 127 |
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
| 128 |
|
| 129 |
-
return padded_images
|
|
|
|
| 3 |
import jax
|
| 4 |
import jax.numpy as jnp
|
| 5 |
import torch
|
| 6 |
+
import torch.nn.functional as F # noqa: N812
|
| 7 |
|
| 8 |
import openpi.shared.array_typing as at
|
| 9 |
|
|
|
|
| 60 |
) -> torch.Tensor:
|
| 61 |
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
| 62 |
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
| 63 |
+
|
| 64 |
Args:
|
| 65 |
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
| 66 |
height: Target height
|
| 67 |
width: Target width
|
| 68 |
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
| 69 |
+
|
| 70 |
Returns:
|
| 71 |
Resized and padded tensor with same shape format as input
|
| 72 |
"""
|
|
|
|
| 91 |
|
| 92 |
# Resize
|
| 93 |
resized_images = F.interpolate(
|
| 94 |
+
images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
|
|
|
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
|
| 97 |
# Handle dtype-specific clipping
|
|
|
|
| 113 |
padded_images = F.pad(
|
| 114 |
resized_images,
|
| 115 |
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
| 116 |
+
mode="constant",
|
| 117 |
+
value=constant_value,
|
| 118 |
)
|
| 119 |
|
| 120 |
# Convert back to original format if needed
|
|
|
|
| 123 |
if batch_size == 1 and images.shape[0] == 1:
|
| 124 |
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
| 125 |
|
| 126 |
+
return padded_images
|
src/openpi/training/config.py
CHANGED
|
@@ -6,7 +6,7 @@ import dataclasses
|
|
| 6 |
import difflib
|
| 7 |
import logging
|
| 8 |
import pathlib
|
| 9 |
-
from typing import Any,
|
| 10 |
|
| 11 |
import etils.epath as epath
|
| 12 |
import flax.nnx as nnx
|
|
@@ -623,7 +623,7 @@ _CONFIGS = [
|
|
| 623 |
data=SimpleDataConfig(
|
| 624 |
assets=AssetsConfig(asset_id="droid"),
|
| 625 |
data_transforms=lambda model: _transforms.Group(
|
| 626 |
-
inputs=[droid_policy.DroidInputs(
|
| 627 |
outputs=[droid_policy.DroidOutputs()],
|
| 628 |
),
|
| 629 |
base_config=DataConfig(
|
|
|
|
| 6 |
import difflib
|
| 7 |
import logging
|
| 8 |
import pathlib
|
| 9 |
+
from typing import Any, Literal, Protocol, TypeAlias
|
| 10 |
|
| 11 |
import etils.epath as epath
|
| 12 |
import flax.nnx as nnx
|
|
|
|
| 623 |
data=SimpleDataConfig(
|
| 624 |
assets=AssetsConfig(asset_id="droid"),
|
| 625 |
data_transforms=lambda model: _transforms.Group(
|
| 626 |
+
inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)],
|
| 627 |
outputs=[droid_policy.DroidOutputs()],
|
| 628 |
),
|
| 629 |
base_config=DataConfig(
|
src/openpi/training/data_loader.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
from collections.abc import Iterator, Sequence
|
| 2 |
-
|
| 3 |
import multiprocessing
|
| 4 |
import os
|
| 5 |
import typing
|
| 6 |
-
from typing import Protocol, SupportsIndex, TypeVar
|
| 7 |
|
| 8 |
import jax
|
| 9 |
import jax.numpy as jnp
|
| 10 |
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
|
| 11 |
-
import logging
|
| 12 |
import numpy as np
|
| 13 |
import torch
|
| 14 |
|
|
@@ -231,7 +230,7 @@ def create_data_loader(
|
|
| 231 |
framework: Literal["jax", "pytorch"],
|
| 232 |
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
| 233 |
"""Create a data loader for training.
|
| 234 |
-
|
| 235 |
Args:
|
| 236 |
config: The training configuration.
|
| 237 |
sharding: The sharding to use for the data loader (JAX only).
|
|
@@ -367,22 +366,21 @@ def create_rlds_data_loader(
|
|
| 367 |
"""
|
| 368 |
if framework == "pytorch":
|
| 369 |
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
|
| 380 |
return DataLoaderImpl(data_config, data_loader)
|
| 381 |
|
| 382 |
|
| 383 |
class TorchDataLoader:
|
| 384 |
"""Torch data loader implementation."""
|
| 385 |
-
|
| 386 |
def __init__(
|
| 387 |
self,
|
| 388 |
dataset,
|
|
|
|
| 1 |
from collections.abc import Iterator, Sequence
|
| 2 |
+
import logging
|
| 3 |
import multiprocessing
|
| 4 |
import os
|
| 5 |
import typing
|
| 6 |
+
from typing import Literal, Protocol, SupportsIndex, TypeVar
|
| 7 |
|
| 8 |
import jax
|
| 9 |
import jax.numpy as jnp
|
| 10 |
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
|
|
|
|
| 230 |
framework: Literal["jax", "pytorch"],
|
| 231 |
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
| 232 |
"""Create a data loader for training.
|
| 233 |
+
|
| 234 |
Args:
|
| 235 |
config: The training configuration.
|
| 236 |
sharding: The sharding to use for the data loader (JAX only).
|
|
|
|
| 366 |
"""
|
| 367 |
if framework == "pytorch":
|
| 368 |
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
|
| 369 |
+
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
|
| 370 |
+
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
|
|
|
|
| 371 |
|
| 372 |
+
data_loader = RLDSDataLoader(
|
| 373 |
+
dataset,
|
| 374 |
+
sharding=sharding,
|
| 375 |
+
num_batches=num_batches,
|
| 376 |
+
)
|
| 377 |
|
| 378 |
return DataLoaderImpl(data_config, data_loader)
|
| 379 |
|
| 380 |
|
| 381 |
class TorchDataLoader:
|
| 382 |
"""Torch data loader implementation."""
|
| 383 |
+
|
| 384 |
def __init__(
|
| 385 |
self,
|
| 386 |
dataset,
|