#!/usr/bin/env python3 import dataclasses import safetensors.torch import torch import tyro import openpi.models.pi0_config import openpi.training.config as _config @dataclasses.dataclass class Args: config_name: str checkpoint_dir: str tolerance: float = 1e-6 batch_size: int = 2 prefix_len: int = 12 seed: int = 123 def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): return openpi.models.pi0_config.Pi0Config( dtype="float32", action_dim=config.model.action_dim, action_horizon=config.model.action_horizon, max_token_len=config.model.max_token_len, paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), pi05=getattr(config.model, "pi05", False), arm_action_dims=getattr(config.model, "arm_action_dims", None), action_expert_mode=getattr(config.model, "action_expert_mode", None), ) model_cfg = dataclasses.replace(config.model) object.__setattr__(model_cfg, "dtype", "float32") return model_cfg def _random_prefix_context(model, batch_size: int, prefix_len: int, seed: int): generator = torch.Generator(device="cpu") generator.manual_seed(seed) prefix_width = model.paligemma_with_expert.paligemma.config.text_config.hidden_size prefix_embs = torch.randn(batch_size, prefix_len, prefix_width, generator=generator, dtype=torch.float32) prefix_pad_masks = torch.ones(batch_size, prefix_len, dtype=torch.bool) prefix_att_masks = torch.zeros(batch_size, prefix_len, dtype=torch.bool) return prefix_embs, prefix_pad_masks, prefix_att_masks def _run_model(model, prefix_context, x_t, timestep): prefix_embs, prefix_pad_masks, prefix_att_masks = prefix_context state = torch.zeros(x_t.shape[0], model.config.action_dim, dtype=torch.float32) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = model.embed_suffix(state, x_t, timestep) suffix_outputs = model._run_joint_action_expert( # noqa: SLF001 prefix_embs, prefix_pad_masks, prefix_att_masks, suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond, ) suffix_outputs = [output[:, -model.config.action_horizon :].to(dtype=torch.float32) for output in suffix_outputs] projected_actions = model._project_action_outputs(suffix_outputs) # noqa: SLF001 return suffix_outputs, projected_actions def _run_identical_branch_inputs(model, prefix_context, timestep, seed: int): generator = torch.Generator(device="cpu") generator.manual_seed(seed) width = model.action_expert_width horizon = model.config.action_horizon batch_size = prefix_context[0].shape[0] shared_suffix = torch.randn(batch_size, horizon, width, generator=generator, dtype=torch.float32) shared_cond = torch.randn(batch_size, width, generator=generator, dtype=torch.float32) suffix_pad_masks = [torch.ones(batch_size, horizon, dtype=torch.bool) for _ in range(2)] suffix_att_masks = [model._action_att_mask(batch_size, torch.device("cpu"), torch.float32) for _ in range(2)] # noqa: SLF001 suffix_outputs = model._run_joint_action_expert( # noqa: SLF001 prefix_context[0], prefix_context[1], prefix_context[2], [shared_suffix.clone(), shared_suffix.clone()], suffix_pad_masks, suffix_att_masks, [shared_cond.clone(), shared_cond.clone()], ) return suffix_outputs def main() -> None: args = tyro.cli(Args) config = _config.get_config(args.config_name) model_cfg = _build_model_config(config) if not model_cfg.use_split_action_expert: raise ValueError(f"Config {args.config_name} is not a split-expert config.") import openpi.models_pytorch.pi0_pytorch as pi0_pytorch torch.manual_seed(args.seed) model = pi0_pytorch.PI0Pytorch(model_cfg) missing, unexpected = safetensors.torch.load_model(model, f"{args.checkpoint_dir}/model.safetensors", strict=False) model.eval() prefix_context = _random_prefix_context(model, args.batch_size, args.prefix_len, args.seed + 1) x_t = torch.randn(args.batch_size, model.config.action_horizon, model.config.action_dim, dtype=torch.float32) timestep = torch.full((args.batch_size,), 0.5, dtype=torch.float32) identical_suffix_outputs = _run_identical_branch_inputs(model, prefix_context, timestep, args.seed + 2) identical_branch_suffix_max_abs_diff = float( (identical_suffix_outputs[0] - identical_suffix_outputs[1]).abs().max().item() ) left_suffix_outputs, left_projected_actions = _run_model(model, prefix_context, x_t, timestep) x_t_right_perturbed = x_t.clone() x_t_right_perturbed[:, :, 16:32] += 0.5 * torch.randn_like(x_t_right_perturbed[:, :, 16:32]) _, right_perturbed_actions = _run_model(model, prefix_context, x_t_right_perturbed, timestep) left_branch_invariance_max_abs_diff = float( (left_projected_actions[:, :, 0:16] - right_perturbed_actions[:, :, 0:16]).abs().max().item() ) x_t_left_perturbed = x_t.clone() x_t_left_perturbed[:, :, 0:16] += 0.5 * torch.randn_like(x_t_left_perturbed[:, :, 0:16]) _, left_perturbed_actions = _run_model(model, prefix_context, x_t_left_perturbed, timestep) right_branch_invariance_max_abs_diff = float( (left_projected_actions[:, :, 16:32] - left_perturbed_actions[:, :, 16:32]).abs().max().item() ) print(f"config_name: {args.config_name}") print(f"checkpoint_dir: {args.checkpoint_dir}") print(f"action_expert_mode: {model_cfg.action_expert_mode}") print(f"weight_loading_missing_keys: {list(missing)}") print(f"weight_loading_unexpected_keys: {list(unexpected)}") print(f"identical_branch_suffix_max_abs_diff: {identical_branch_suffix_max_abs_diff:.8f}") print( f"identical_branch_suffix_match: " f"{identical_branch_suffix_max_abs_diff <= args.tolerance}" ) if model_cfg.action_expert_mode == "split_independent": print(f"left_branch_invariance_max_abs_diff: {left_branch_invariance_max_abs_diff:.8f}") print(f"right_branch_invariance_max_abs_diff: {right_branch_invariance_max_abs_diff:.8f}") print(f"left_branch_invariant: {left_branch_invariance_max_abs_diff <= args.tolerance}") print(f"right_branch_invariant: {right_branch_invariance_max_abs_diff <= args.tolerance}") else: print("left_branch_invariance_max_abs_diff: skipped_for_split_communicating") print("right_branch_invariance_max_abs_diff: skipped_for_split_communicating") if __name__ == "__main__": main()