| | |
| |
|
| | import dataclasses |
| | import json |
| | import os |
| | from pathlib import Path |
| |
|
| | import safetensors.torch |
| | import torch |
| | import torch.nn.functional as F |
| | import tyro |
| |
|
| | import openpi.models.pi0_config |
| | import openpi.models_pytorch.pi0_pytorch |
| | import openpi.training.config as _config |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Args: |
| | single_ckpt: str |
| | config_name: str |
| | output_path: str |
| |
|
| |
|
| | def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: |
| | if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): |
| | return openpi.models.pi0_config.Pi0Config( |
| | dtype=config.pytorch_training_precision, |
| | action_dim=config.model.action_dim, |
| | action_horizon=config.model.action_horizon, |
| | max_token_len=config.model.max_token_len, |
| | paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), |
| | action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), |
| | pi05=getattr(config.model, "pi05", False), |
| | arm_action_dims=getattr(config.model, "arm_action_dims", None), |
| | action_expert_mode=getattr(config.model, "action_expert_mode", None), |
| | ) |
| |
|
| | model_cfg = config.model |
| | object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) |
| | return model_cfg |
| |
|
| |
|
| | def _copy_factorized_heads(model, weight_in, bias_in, weight_out, bias_out) -> None: |
| | hidden_width = weight_in.shape[0] |
| | with torch.no_grad(): |
| | model.action_in_proj_arms[0].weight.copy_(weight_in[:, 0:16]) |
| | model.action_in_proj_arms[0].bias.zero_() |
| | model.action_in_proj_arms[1].weight.copy_(weight_in[:, 16:32]) |
| | model.action_in_proj_arms[1].bias.zero_() |
| |
|
| | if hasattr(model, "arm_token_fuse"): |
| | fuse_weight = torch.zeros_like(model.arm_token_fuse.weight) |
| | identity = torch.eye(hidden_width, dtype=fuse_weight.dtype) |
| | fuse_weight[:, 0:hidden_width] = identity |
| | fuse_weight[:, hidden_width : 2 * hidden_width] = identity |
| | model.arm_token_fuse.weight.copy_(fuse_weight) |
| | model.arm_token_fuse.bias.copy_(bias_in) |
| |
|
| | model.action_out_proj_arms[0].weight.copy_(weight_out[0:16, :]) |
| | model.action_out_proj_arms[0].bias.copy_(bias_out[0:16]) |
| | model.action_out_proj_arms[1].weight.copy_(weight_out[16:32, :]) |
| | model.action_out_proj_arms[1].bias.copy_(bias_out[16:32]) |
| |
|
| |
|
| | def _copy_split_expert_weights(model, single_state) -> None: |
| | model_state = model.state_dict() |
| | with torch.no_grad(): |
| | for key, value in single_state.items(): |
| | if not key.startswith("paligemma_with_expert.gemma_expert."): |
| | continue |
| | suffix = key.removeprefix("paligemma_with_expert.gemma_expert.") |
| | left_key = f"paligemma_with_expert.left_gemma_expert.{suffix}" |
| | right_key = f"paligemma_with_expert.right_gemma_expert.{suffix}" |
| | model_state[left_key].copy_(value.to(dtype=model_state[left_key].dtype)) |
| | model_state[right_key].copy_(value.to(dtype=model_state[right_key].dtype)) |
| |
|
| |
|
| | def _expert_copy_max_abs_diff(model, single_state, target_prefix: str) -> float: |
| | model_state = model.state_dict() |
| | max_abs_diff = 0.0 |
| | for key, value in single_state.items(): |
| | if not key.startswith("paligemma_with_expert.gemma_expert."): |
| | continue |
| | suffix = key.removeprefix("paligemma_with_expert.gemma_expert.") |
| | target_key = f"{target_prefix}{suffix}" |
| | diff = (model_state[target_key].to(torch.float32) - value.to(torch.float32)).abs().max().item() |
| | max_abs_diff = max(max_abs_diff, float(diff)) |
| | return max_abs_diff |
| |
|
| |
|
| | def main() -> None: |
| | args = tyro.cli(Args) |
| | config = _config.get_config(args.config_name) |
| | model_cfg = _build_model_config(config) |
| | if not model_cfg.use_parallel_action_heads: |
| | raise ValueError(f"Config {args.config_name} does not use factorized or split action heads.") |
| | if tuple(model_cfg.arm_action_dims) != (16, 16): |
| | raise ValueError(f"Expected arm_action_dims=(16, 16), got {model_cfg.arm_action_dims}.") |
| |
|
| | parallel_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg) |
| | single_state = safetensors.torch.load_file(os.path.join(args.single_ckpt, "model.safetensors"), device="cpu") |
| |
|
| | missing, unexpected = parallel_model.load_state_dict(single_state, strict=False) |
| |
|
| | weight_in = single_state["action_in_proj.weight"] |
| | bias_in = single_state["action_in_proj.bias"] |
| | weight_out = single_state["action_out_proj.weight"] |
| | bias_out = single_state["action_out_proj.bias"] |
| |
|
| | hidden_width = weight_in.shape[0] |
| | if weight_in.shape[1] != 32 or weight_out.shape[0] != 32: |
| | raise ValueError( |
| | f"Expected single-head checkpoint with packed 32-dim actions, got in={tuple(weight_in.shape)} out={tuple(weight_out.shape)}." |
| | ) |
| |
|
| | _copy_factorized_heads(parallel_model, weight_in, bias_in, weight_out, bias_out) |
| | if model_cfg.use_split_action_expert: |
| | _copy_split_expert_weights(parallel_model, single_state) |
| |
|
| | proj_in_dtype = parallel_model.action_in_proj_arms[0].weight.dtype |
| | proj_out_dtype = parallel_model.action_out_proj_arms[0].weight.dtype |
| | x = torch.randn(2, model_cfg.action_horizon, model_cfg.action_dim, dtype=proj_in_dtype) |
| | x_left = x[:, :, 0:16] |
| | x_right = x[:, :, 16:32] |
| | suffix = torch.randn(2, model_cfg.action_horizon, hidden_width, dtype=proj_out_dtype) |
| |
|
| | metadata = { |
| | "config_name": args.config_name, |
| | "action_expert_mode": model_cfg.action_expert_mode, |
| | "single_ckpt": args.single_ckpt, |
| | "output_path": args.output_path, |
| | "load_state_missing_keys": list(missing), |
| | "load_state_unexpected_keys": list(unexpected), |
| | } |
| |
|
| | with torch.no_grad(): |
| | left_input_projection_max_abs_diff = float( |
| | ( |
| | F.linear(x_left, weight_in[:, 0:16].to(proj_in_dtype), None) |
| | - parallel_model.action_in_proj_arms[0](x_left) |
| | ) |
| | .abs() |
| | .max() |
| | .item() |
| | ) |
| | right_input_projection_max_abs_diff = float( |
| | ( |
| | F.linear(x_right, weight_in[:, 16:32].to(proj_in_dtype), None) |
| | - parallel_model.action_in_proj_arms[1](x_right) |
| | ) |
| | .abs() |
| | .max() |
| | .item() |
| | ) |
| | left_output_projection_max_abs_diff = float( |
| | ( |
| | F.linear(suffix, weight_out[0:16, :].to(proj_out_dtype), bias_out[0:16].to(proj_out_dtype)) |
| | - parallel_model.action_out_proj_arms[0](suffix) |
| | ) |
| | .abs() |
| | .max() |
| | .item() |
| | ) |
| | right_output_projection_max_abs_diff = float( |
| | ( |
| | F.linear(suffix, weight_out[16:32, :].to(proj_out_dtype), bias_out[16:32].to(proj_out_dtype)) |
| | - parallel_model.action_out_proj_arms[1](suffix) |
| | ) |
| | .abs() |
| | .max() |
| | .item() |
| | ) |
| |
|
| | metadata.update( |
| | { |
| | "left_input_projection_max_abs_diff": left_input_projection_max_abs_diff, |
| | "right_input_projection_max_abs_diff": right_input_projection_max_abs_diff, |
| | "left_output_projection_max_abs_diff": left_output_projection_max_abs_diff, |
| | "right_output_projection_max_abs_diff": right_output_projection_max_abs_diff, |
| | } |
| | ) |
| |
|
| | if model_cfg.action_expert_mode == "head_only_parallel": |
| | input_max_abs_diff = float( |
| | ( |
| | F.linear(x, weight_in.to(proj_in_dtype), bias_in.to(proj_in_dtype)) |
| | - parallel_model._project_action_inputs(x) |
| | ) |
| | .abs() |
| | .max() |
| | .item() |
| | ) |
| | output_max_abs_diff = float( |
| | ( |
| | F.linear(suffix, weight_out.to(proj_out_dtype), bias_out.to(proj_out_dtype)) |
| | - parallel_model._project_action_outputs(suffix) |
| | ) |
| | .abs() |
| | .max() |
| | .item() |
| | ) |
| | metadata["input_projection_max_abs_diff"] = input_max_abs_diff |
| | metadata["output_projection_max_abs_diff"] = output_max_abs_diff |
| | metadata["warm_start_exact"] = input_max_abs_diff == 0.0 and output_max_abs_diff == 0.0 |
| | else: |
| | left_expert_max_abs_diff = _expert_copy_max_abs_diff( |
| | parallel_model, |
| | single_state, |
| | "paligemma_with_expert.left_gemma_expert.", |
| | ) |
| | right_expert_max_abs_diff = _expert_copy_max_abs_diff( |
| | parallel_model, |
| | single_state, |
| | "paligemma_with_expert.right_gemma_expert.", |
| | ) |
| | metadata["left_expert_max_abs_diff"] = left_expert_max_abs_diff |
| | metadata["right_expert_max_abs_diff"] = right_expert_max_abs_diff |
| | if parallel_model.paligemma_with_expert.cross_arm_comm is not None: |
| | metadata["cross_arm_comm_init"] = [ |
| | float(value) for value in parallel_model.paligemma_with_expert.cross_arm_comm.detach().cpu().tolist() |
| | ] |
| | metadata["warm_start_exact"] = ( |
| | left_input_projection_max_abs_diff == 0.0 |
| | and right_input_projection_max_abs_diff == 0.0 |
| | and left_output_projection_max_abs_diff == 0.0 |
| | and right_output_projection_max_abs_diff == 0.0 |
| | and left_expert_max_abs_diff == 0.0 |
| | and right_expert_max_abs_diff == 0.0 |
| | ) |
| |
|
| | output_dir = Path(args.output_path) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | safetensors.torch.save_model(parallel_model, output_dir / "model.safetensors") |
| | (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(model_cfg), indent=2, sort_keys=True)) |
| | (output_dir / "init_parallel_metadata.json").write_text(json.dumps(metadata, indent=2, sort_keys=True)) |
| |
|
| | print(f"config_name: {args.config_name}") |
| | print(f"action_expert_mode: {model_cfg.action_expert_mode}") |
| | print(f"single_ckpt: {args.single_ckpt}") |
| | print(f"output_path: {args.output_path}") |
| | print(f"load_state_missing_keys_count: {len(missing)}") |
| | print(f"load_state_missing_keys: {list(missing)}") |
| | print(f"load_state_unexpected_keys_count: {len(unexpected)}") |
| | print(f"load_state_unexpected_keys: {list(unexpected)}") |
| | for key in sorted(metadata): |
| | if key in {"config_name", "action_expert_mode", "single_ckpt", "output_path", "load_state_missing_keys", "load_state_unexpected_keys"}: |
| | continue |
| | print(f"{key}: {metadata[key]}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|