pi05tests-openpi-multiarm / openpi /scripts /init_parallel_pi05_from_single_pytorch.py
lsnu's picture
Add files using upload-large-folder tool
ccf25b1 verified
#!/usr/bin/env python3
import dataclasses
import json
import os
from pathlib import Path
import safetensors.torch
import torch
import torch.nn.functional as F # noqa: N812
import tyro
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
import openpi.training.config as _config
@dataclasses.dataclass
class Args:
single_ckpt: str
config_name: str
output_path: str
def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config:
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
return openpi.models.pi0_config.Pi0Config(
dtype=config.pytorch_training_precision,
action_dim=config.model.action_dim,
action_horizon=config.model.action_horizon,
max_token_len=config.model.max_token_len,
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
arm_action_dims=getattr(config.model, "arm_action_dims", None),
action_expert_mode=getattr(config.model, "action_expert_mode", None),
)
model_cfg = config.model
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
return model_cfg
def _copy_factorized_heads(model, weight_in, bias_in, weight_out, bias_out) -> None:
hidden_width = weight_in.shape[0]
with torch.no_grad():
model.action_in_proj_arms[0].weight.copy_(weight_in[:, 0:16])
model.action_in_proj_arms[0].bias.zero_()
model.action_in_proj_arms[1].weight.copy_(weight_in[:, 16:32])
model.action_in_proj_arms[1].bias.zero_()
if hasattr(model, "arm_token_fuse"):
fuse_weight = torch.zeros_like(model.arm_token_fuse.weight)
identity = torch.eye(hidden_width, dtype=fuse_weight.dtype)
fuse_weight[:, 0:hidden_width] = identity
fuse_weight[:, hidden_width : 2 * hidden_width] = identity
model.arm_token_fuse.weight.copy_(fuse_weight)
model.arm_token_fuse.bias.copy_(bias_in)
model.action_out_proj_arms[0].weight.copy_(weight_out[0:16, :])
model.action_out_proj_arms[0].bias.copy_(bias_out[0:16])
model.action_out_proj_arms[1].weight.copy_(weight_out[16:32, :])
model.action_out_proj_arms[1].bias.copy_(bias_out[16:32])
def _copy_split_expert_weights(model, single_state) -> None:
model_state = model.state_dict()
with torch.no_grad():
for key, value in single_state.items():
if not key.startswith("paligemma_with_expert.gemma_expert."):
continue
suffix = key.removeprefix("paligemma_with_expert.gemma_expert.")
left_key = f"paligemma_with_expert.left_gemma_expert.{suffix}"
right_key = f"paligemma_with_expert.right_gemma_expert.{suffix}"
model_state[left_key].copy_(value.to(dtype=model_state[left_key].dtype))
model_state[right_key].copy_(value.to(dtype=model_state[right_key].dtype))
def _expert_copy_max_abs_diff(model, single_state, target_prefix: str) -> float:
model_state = model.state_dict()
max_abs_diff = 0.0
for key, value in single_state.items():
if not key.startswith("paligemma_with_expert.gemma_expert."):
continue
suffix = key.removeprefix("paligemma_with_expert.gemma_expert.")
target_key = f"{target_prefix}{suffix}"
diff = (model_state[target_key].to(torch.float32) - value.to(torch.float32)).abs().max().item()
max_abs_diff = max(max_abs_diff, float(diff))
return max_abs_diff
def main() -> None:
args = tyro.cli(Args)
config = _config.get_config(args.config_name)
model_cfg = _build_model_config(config)
if not model_cfg.use_parallel_action_heads:
raise ValueError(f"Config {args.config_name} does not use factorized or split action heads.")
if tuple(model_cfg.arm_action_dims) != (16, 16):
raise ValueError(f"Expected arm_action_dims=(16, 16), got {model_cfg.arm_action_dims}.")
parallel_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg)
single_state = safetensors.torch.load_file(os.path.join(args.single_ckpt, "model.safetensors"), device="cpu")
missing, unexpected = parallel_model.load_state_dict(single_state, strict=False)
weight_in = single_state["action_in_proj.weight"]
bias_in = single_state["action_in_proj.bias"]
weight_out = single_state["action_out_proj.weight"]
bias_out = single_state["action_out_proj.bias"]
hidden_width = weight_in.shape[0]
if weight_in.shape[1] != 32 or weight_out.shape[0] != 32:
raise ValueError(
f"Expected single-head checkpoint with packed 32-dim actions, got in={tuple(weight_in.shape)} out={tuple(weight_out.shape)}."
)
_copy_factorized_heads(parallel_model, weight_in, bias_in, weight_out, bias_out)
if model_cfg.use_split_action_expert:
_copy_split_expert_weights(parallel_model, single_state)
proj_in_dtype = parallel_model.action_in_proj_arms[0].weight.dtype
proj_out_dtype = parallel_model.action_out_proj_arms[0].weight.dtype
x = torch.randn(2, model_cfg.action_horizon, model_cfg.action_dim, dtype=proj_in_dtype)
x_left = x[:, :, 0:16]
x_right = x[:, :, 16:32]
suffix = torch.randn(2, model_cfg.action_horizon, hidden_width, dtype=proj_out_dtype)
metadata = {
"config_name": args.config_name,
"action_expert_mode": model_cfg.action_expert_mode,
"single_ckpt": args.single_ckpt,
"output_path": args.output_path,
"load_state_missing_keys": list(missing),
"load_state_unexpected_keys": list(unexpected),
}
with torch.no_grad():
left_input_projection_max_abs_diff = float(
(
F.linear(x_left, weight_in[:, 0:16].to(proj_in_dtype), None)
- parallel_model.action_in_proj_arms[0](x_left)
)
.abs()
.max()
.item()
)
right_input_projection_max_abs_diff = float(
(
F.linear(x_right, weight_in[:, 16:32].to(proj_in_dtype), None)
- parallel_model.action_in_proj_arms[1](x_right)
)
.abs()
.max()
.item()
)
left_output_projection_max_abs_diff = float(
(
F.linear(suffix, weight_out[0:16, :].to(proj_out_dtype), bias_out[0:16].to(proj_out_dtype))
- parallel_model.action_out_proj_arms[0](suffix)
)
.abs()
.max()
.item()
)
right_output_projection_max_abs_diff = float(
(
F.linear(suffix, weight_out[16:32, :].to(proj_out_dtype), bias_out[16:32].to(proj_out_dtype))
- parallel_model.action_out_proj_arms[1](suffix)
)
.abs()
.max()
.item()
)
metadata.update(
{
"left_input_projection_max_abs_diff": left_input_projection_max_abs_diff,
"right_input_projection_max_abs_diff": right_input_projection_max_abs_diff,
"left_output_projection_max_abs_diff": left_output_projection_max_abs_diff,
"right_output_projection_max_abs_diff": right_output_projection_max_abs_diff,
}
)
if model_cfg.action_expert_mode == "head_only_parallel":
input_max_abs_diff = float(
(
F.linear(x, weight_in.to(proj_in_dtype), bias_in.to(proj_in_dtype))
- parallel_model._project_action_inputs(x)
)
.abs()
.max()
.item()
)
output_max_abs_diff = float(
(
F.linear(suffix, weight_out.to(proj_out_dtype), bias_out.to(proj_out_dtype))
- parallel_model._project_action_outputs(suffix)
)
.abs()
.max()
.item()
)
metadata["input_projection_max_abs_diff"] = input_max_abs_diff
metadata["output_projection_max_abs_diff"] = output_max_abs_diff
metadata["warm_start_exact"] = input_max_abs_diff == 0.0 and output_max_abs_diff == 0.0
else:
left_expert_max_abs_diff = _expert_copy_max_abs_diff(
parallel_model,
single_state,
"paligemma_with_expert.left_gemma_expert.",
)
right_expert_max_abs_diff = _expert_copy_max_abs_diff(
parallel_model,
single_state,
"paligemma_with_expert.right_gemma_expert.",
)
metadata["left_expert_max_abs_diff"] = left_expert_max_abs_diff
metadata["right_expert_max_abs_diff"] = right_expert_max_abs_diff
if parallel_model.paligemma_with_expert.cross_arm_comm is not None:
metadata["cross_arm_comm_init"] = [
float(value) for value in parallel_model.paligemma_with_expert.cross_arm_comm.detach().cpu().tolist()
]
metadata["warm_start_exact"] = (
left_input_projection_max_abs_diff == 0.0
and right_input_projection_max_abs_diff == 0.0
and left_output_projection_max_abs_diff == 0.0
and right_output_projection_max_abs_diff == 0.0
and left_expert_max_abs_diff == 0.0
and right_expert_max_abs_diff == 0.0
)
output_dir = Path(args.output_path)
output_dir.mkdir(parents=True, exist_ok=True)
safetensors.torch.save_model(parallel_model, output_dir / "model.safetensors")
(output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(model_cfg), indent=2, sort_keys=True))
(output_dir / "init_parallel_metadata.json").write_text(json.dumps(metadata, indent=2, sort_keys=True))
print(f"config_name: {args.config_name}")
print(f"action_expert_mode: {model_cfg.action_expert_mode}")
print(f"single_ckpt: {args.single_ckpt}")
print(f"output_path: {args.output_path}")
print(f"load_state_missing_keys_count: {len(missing)}")
print(f"load_state_missing_keys: {list(missing)}")
print(f"load_state_unexpected_keys_count: {len(unexpected)}")
print(f"load_state_unexpected_keys: {list(unexpected)}")
for key in sorted(metadata):
if key in {"config_name", "action_expert_mode", "single_ckpt", "output_path", "load_state_missing_keys", "load_state_unexpected_keys"}:
continue
print(f"{key}: {metadata[key]}")
if __name__ == "__main__":
main()