#!/usr/bin/env python3 import dataclasses import json import os from pathlib import Path import safetensors.torch import torch import torch.nn.functional as F # noqa: N812 import tyro import openpi.models.pi0_config import openpi.models_pytorch.pi0_pytorch import openpi.training.config as _config @dataclasses.dataclass class Args: single_ckpt: str config_name: str output_path: str def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): return openpi.models.pi0_config.Pi0Config( dtype=config.pytorch_training_precision, action_dim=config.model.action_dim, action_horizon=config.model.action_horizon, max_token_len=config.model.max_token_len, paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), pi05=getattr(config.model, "pi05", False), arm_action_dims=getattr(config.model, "arm_action_dims", None), action_expert_mode=getattr(config.model, "action_expert_mode", None), ) model_cfg = config.model object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) return model_cfg def _copy_factorized_heads(model, weight_in, bias_in, weight_out, bias_out) -> None: hidden_width = weight_in.shape[0] with torch.no_grad(): model.action_in_proj_arms[0].weight.copy_(weight_in[:, 0:16]) model.action_in_proj_arms[0].bias.zero_() model.action_in_proj_arms[1].weight.copy_(weight_in[:, 16:32]) model.action_in_proj_arms[1].bias.zero_() if hasattr(model, "arm_token_fuse"): fuse_weight = torch.zeros_like(model.arm_token_fuse.weight) identity = torch.eye(hidden_width, dtype=fuse_weight.dtype) fuse_weight[:, 0:hidden_width] = identity fuse_weight[:, hidden_width : 2 * hidden_width] = identity model.arm_token_fuse.weight.copy_(fuse_weight) model.arm_token_fuse.bias.copy_(bias_in) model.action_out_proj_arms[0].weight.copy_(weight_out[0:16, :]) model.action_out_proj_arms[0].bias.copy_(bias_out[0:16]) model.action_out_proj_arms[1].weight.copy_(weight_out[16:32, :]) model.action_out_proj_arms[1].bias.copy_(bias_out[16:32]) def _copy_split_expert_weights(model, single_state) -> None: model_state = model.state_dict() with torch.no_grad(): for key, value in single_state.items(): if not key.startswith("paligemma_with_expert.gemma_expert."): continue suffix = key.removeprefix("paligemma_with_expert.gemma_expert.") left_key = f"paligemma_with_expert.left_gemma_expert.{suffix}" right_key = f"paligemma_with_expert.right_gemma_expert.{suffix}" model_state[left_key].copy_(value.to(dtype=model_state[left_key].dtype)) model_state[right_key].copy_(value.to(dtype=model_state[right_key].dtype)) def _expert_copy_max_abs_diff(model, single_state, target_prefix: str) -> float: model_state = model.state_dict() max_abs_diff = 0.0 for key, value in single_state.items(): if not key.startswith("paligemma_with_expert.gemma_expert."): continue suffix = key.removeprefix("paligemma_with_expert.gemma_expert.") target_key = f"{target_prefix}{suffix}" diff = (model_state[target_key].to(torch.float32) - value.to(torch.float32)).abs().max().item() max_abs_diff = max(max_abs_diff, float(diff)) return max_abs_diff def main() -> None: args = tyro.cli(Args) config = _config.get_config(args.config_name) model_cfg = _build_model_config(config) if not model_cfg.use_parallel_action_heads: raise ValueError(f"Config {args.config_name} does not use factorized or split action heads.") if tuple(model_cfg.arm_action_dims) != (16, 16): raise ValueError(f"Expected arm_action_dims=(16, 16), got {model_cfg.arm_action_dims}.") parallel_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg) single_state = safetensors.torch.load_file(os.path.join(args.single_ckpt, "model.safetensors"), device="cpu") missing, unexpected = parallel_model.load_state_dict(single_state, strict=False) weight_in = single_state["action_in_proj.weight"] bias_in = single_state["action_in_proj.bias"] weight_out = single_state["action_out_proj.weight"] bias_out = single_state["action_out_proj.bias"] hidden_width = weight_in.shape[0] if weight_in.shape[1] != 32 or weight_out.shape[0] != 32: raise ValueError( f"Expected single-head checkpoint with packed 32-dim actions, got in={tuple(weight_in.shape)} out={tuple(weight_out.shape)}." ) _copy_factorized_heads(parallel_model, weight_in, bias_in, weight_out, bias_out) if model_cfg.use_split_action_expert: _copy_split_expert_weights(parallel_model, single_state) proj_in_dtype = parallel_model.action_in_proj_arms[0].weight.dtype proj_out_dtype = parallel_model.action_out_proj_arms[0].weight.dtype x = torch.randn(2, model_cfg.action_horizon, model_cfg.action_dim, dtype=proj_in_dtype) x_left = x[:, :, 0:16] x_right = x[:, :, 16:32] suffix = torch.randn(2, model_cfg.action_horizon, hidden_width, dtype=proj_out_dtype) metadata = { "config_name": args.config_name, "action_expert_mode": model_cfg.action_expert_mode, "single_ckpt": args.single_ckpt, "output_path": args.output_path, "load_state_missing_keys": list(missing), "load_state_unexpected_keys": list(unexpected), } with torch.no_grad(): left_input_projection_max_abs_diff = float( ( F.linear(x_left, weight_in[:, 0:16].to(proj_in_dtype), None) - parallel_model.action_in_proj_arms[0](x_left) ) .abs() .max() .item() ) right_input_projection_max_abs_diff = float( ( F.linear(x_right, weight_in[:, 16:32].to(proj_in_dtype), None) - parallel_model.action_in_proj_arms[1](x_right) ) .abs() .max() .item() ) left_output_projection_max_abs_diff = float( ( F.linear(suffix, weight_out[0:16, :].to(proj_out_dtype), bias_out[0:16].to(proj_out_dtype)) - parallel_model.action_out_proj_arms[0](suffix) ) .abs() .max() .item() ) right_output_projection_max_abs_diff = float( ( F.linear(suffix, weight_out[16:32, :].to(proj_out_dtype), bias_out[16:32].to(proj_out_dtype)) - parallel_model.action_out_proj_arms[1](suffix) ) .abs() .max() .item() ) metadata.update( { "left_input_projection_max_abs_diff": left_input_projection_max_abs_diff, "right_input_projection_max_abs_diff": right_input_projection_max_abs_diff, "left_output_projection_max_abs_diff": left_output_projection_max_abs_diff, "right_output_projection_max_abs_diff": right_output_projection_max_abs_diff, } ) if model_cfg.action_expert_mode == "head_only_parallel": input_max_abs_diff = float( ( F.linear(x, weight_in.to(proj_in_dtype), bias_in.to(proj_in_dtype)) - parallel_model._project_action_inputs(x) ) .abs() .max() .item() ) output_max_abs_diff = float( ( F.linear(suffix, weight_out.to(proj_out_dtype), bias_out.to(proj_out_dtype)) - parallel_model._project_action_outputs(suffix) ) .abs() .max() .item() ) metadata["input_projection_max_abs_diff"] = input_max_abs_diff metadata["output_projection_max_abs_diff"] = output_max_abs_diff metadata["warm_start_exact"] = input_max_abs_diff == 0.0 and output_max_abs_diff == 0.0 else: left_expert_max_abs_diff = _expert_copy_max_abs_diff( parallel_model, single_state, "paligemma_with_expert.left_gemma_expert.", ) right_expert_max_abs_diff = _expert_copy_max_abs_diff( parallel_model, single_state, "paligemma_with_expert.right_gemma_expert.", ) metadata["left_expert_max_abs_diff"] = left_expert_max_abs_diff metadata["right_expert_max_abs_diff"] = right_expert_max_abs_diff if parallel_model.paligemma_with_expert.cross_arm_comm is not None: metadata["cross_arm_comm_init"] = [ float(value) for value in parallel_model.paligemma_with_expert.cross_arm_comm.detach().cpu().tolist() ] metadata["warm_start_exact"] = ( left_input_projection_max_abs_diff == 0.0 and right_input_projection_max_abs_diff == 0.0 and left_output_projection_max_abs_diff == 0.0 and right_output_projection_max_abs_diff == 0.0 and left_expert_max_abs_diff == 0.0 and right_expert_max_abs_diff == 0.0 ) output_dir = Path(args.output_path) output_dir.mkdir(parents=True, exist_ok=True) safetensors.torch.save_model(parallel_model, output_dir / "model.safetensors") (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(model_cfg), indent=2, sort_keys=True)) (output_dir / "init_parallel_metadata.json").write_text(json.dumps(metadata, indent=2, sort_keys=True)) print(f"config_name: {args.config_name}") print(f"action_expert_mode: {model_cfg.action_expert_mode}") print(f"single_ckpt: {args.single_ckpt}") print(f"output_path: {args.output_path}") print(f"load_state_missing_keys_count: {len(missing)}") print(f"load_state_missing_keys: {list(missing)}") print(f"load_state_unexpected_keys_count: {len(unexpected)}") print(f"load_state_unexpected_keys: {list(unexpected)}") for key in sorted(metadata): if key in {"config_name", "action_expert_mode", "single_ckpt", "output_path", "load_state_missing_keys", "load_state_unexpected_keys"}: continue print(f"{key}: {metadata[key]}") if __name__ == "__main__": main()