| | |
| |
|
| | import dataclasses |
| |
|
| | import safetensors.torch |
| | import torch |
| | import tyro |
| |
|
| | import openpi.models.pi0_config |
| | import openpi.training.config as _config |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Args: |
| | config_name: str |
| | checkpoint_dir: str |
| | tolerance: float = 1e-6 |
| | batch_size: int = 2 |
| | prefix_len: int = 12 |
| | seed: int = 123 |
| |
|
| |
|
| | def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: |
| | if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): |
| | return openpi.models.pi0_config.Pi0Config( |
| | dtype="float32", |
| | action_dim=config.model.action_dim, |
| | action_horizon=config.model.action_horizon, |
| | max_token_len=config.model.max_token_len, |
| | paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), |
| | action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), |
| | pi05=getattr(config.model, "pi05", False), |
| | arm_action_dims=getattr(config.model, "arm_action_dims", None), |
| | action_expert_mode=getattr(config.model, "action_expert_mode", None), |
| | ) |
| |
|
| | model_cfg = dataclasses.replace(config.model) |
| | object.__setattr__(model_cfg, "dtype", "float32") |
| | return model_cfg |
| |
|
| |
|
| | def _random_prefix_context(model, batch_size: int, prefix_len: int, seed: int): |
| | generator = torch.Generator(device="cpu") |
| | generator.manual_seed(seed) |
| | prefix_width = model.paligemma_with_expert.paligemma.config.text_config.hidden_size |
| | prefix_embs = torch.randn(batch_size, prefix_len, prefix_width, generator=generator, dtype=torch.float32) |
| | prefix_pad_masks = torch.ones(batch_size, prefix_len, dtype=torch.bool) |
| | prefix_att_masks = torch.zeros(batch_size, prefix_len, dtype=torch.bool) |
| | return prefix_embs, prefix_pad_masks, prefix_att_masks |
| |
|
| |
|
| | def _run_model(model, prefix_context, x_t, timestep): |
| | prefix_embs, prefix_pad_masks, prefix_att_masks = prefix_context |
| | state = torch.zeros(x_t.shape[0], model.config.action_dim, dtype=torch.float32) |
| | suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = model.embed_suffix(state, x_t, timestep) |
| | suffix_outputs = model._run_joint_action_expert( |
| | prefix_embs, |
| | prefix_pad_masks, |
| | prefix_att_masks, |
| | suffix_embs, |
| | suffix_pad_masks, |
| | suffix_att_masks, |
| | adarms_cond, |
| | ) |
| | suffix_outputs = [output[:, -model.config.action_horizon :].to(dtype=torch.float32) for output in suffix_outputs] |
| | projected_actions = model._project_action_outputs(suffix_outputs) |
| | return suffix_outputs, projected_actions |
| |
|
| |
|
| | def _run_identical_branch_inputs(model, prefix_context, timestep, seed: int): |
| | generator = torch.Generator(device="cpu") |
| | generator.manual_seed(seed) |
| | width = model.action_expert_width |
| | horizon = model.config.action_horizon |
| | batch_size = prefix_context[0].shape[0] |
| |
|
| | shared_suffix = torch.randn(batch_size, horizon, width, generator=generator, dtype=torch.float32) |
| | shared_cond = torch.randn(batch_size, width, generator=generator, dtype=torch.float32) |
| | suffix_pad_masks = [torch.ones(batch_size, horizon, dtype=torch.bool) for _ in range(2)] |
| | suffix_att_masks = [model._action_att_mask(batch_size, torch.device("cpu"), torch.float32) for _ in range(2)] |
| |
|
| | suffix_outputs = model._run_joint_action_expert( |
| | prefix_context[0], |
| | prefix_context[1], |
| | prefix_context[2], |
| | [shared_suffix.clone(), shared_suffix.clone()], |
| | suffix_pad_masks, |
| | suffix_att_masks, |
| | [shared_cond.clone(), shared_cond.clone()], |
| | ) |
| | return suffix_outputs |
| |
|
| |
|
| | def main() -> None: |
| | args = tyro.cli(Args) |
| | config = _config.get_config(args.config_name) |
| | model_cfg = _build_model_config(config) |
| | if not model_cfg.use_split_action_expert: |
| | raise ValueError(f"Config {args.config_name} is not a split-expert config.") |
| |
|
| | import openpi.models_pytorch.pi0_pytorch as pi0_pytorch |
| |
|
| | torch.manual_seed(args.seed) |
| | model = pi0_pytorch.PI0Pytorch(model_cfg) |
| | missing, unexpected = safetensors.torch.load_model(model, f"{args.checkpoint_dir}/model.safetensors", strict=False) |
| | model.eval() |
| |
|
| | prefix_context = _random_prefix_context(model, args.batch_size, args.prefix_len, args.seed + 1) |
| | x_t = torch.randn(args.batch_size, model.config.action_horizon, model.config.action_dim, dtype=torch.float32) |
| | timestep = torch.full((args.batch_size,), 0.5, dtype=torch.float32) |
| |
|
| | identical_suffix_outputs = _run_identical_branch_inputs(model, prefix_context, timestep, args.seed + 2) |
| | identical_branch_suffix_max_abs_diff = float( |
| | (identical_suffix_outputs[0] - identical_suffix_outputs[1]).abs().max().item() |
| | ) |
| |
|
| | left_suffix_outputs, left_projected_actions = _run_model(model, prefix_context, x_t, timestep) |
| | x_t_right_perturbed = x_t.clone() |
| | x_t_right_perturbed[:, :, 16:32] += 0.5 * torch.randn_like(x_t_right_perturbed[:, :, 16:32]) |
| | _, right_perturbed_actions = _run_model(model, prefix_context, x_t_right_perturbed, timestep) |
| | left_branch_invariance_max_abs_diff = float( |
| | (left_projected_actions[:, :, 0:16] - right_perturbed_actions[:, :, 0:16]).abs().max().item() |
| | ) |
| |
|
| | x_t_left_perturbed = x_t.clone() |
| | x_t_left_perturbed[:, :, 0:16] += 0.5 * torch.randn_like(x_t_left_perturbed[:, :, 0:16]) |
| | _, left_perturbed_actions = _run_model(model, prefix_context, x_t_left_perturbed, timestep) |
| | right_branch_invariance_max_abs_diff = float( |
| | (left_projected_actions[:, :, 16:32] - left_perturbed_actions[:, :, 16:32]).abs().max().item() |
| | ) |
| |
|
| | print(f"config_name: {args.config_name}") |
| | print(f"checkpoint_dir: {args.checkpoint_dir}") |
| | print(f"action_expert_mode: {model_cfg.action_expert_mode}") |
| | print(f"weight_loading_missing_keys: {list(missing)}") |
| | print(f"weight_loading_unexpected_keys: {list(unexpected)}") |
| | print(f"identical_branch_suffix_max_abs_diff: {identical_branch_suffix_max_abs_diff:.8f}") |
| | print( |
| | f"identical_branch_suffix_match: " |
| | f"{identical_branch_suffix_max_abs_diff <= args.tolerance}" |
| | ) |
| |
|
| | if model_cfg.action_expert_mode == "split_independent": |
| | print(f"left_branch_invariance_max_abs_diff: {left_branch_invariance_max_abs_diff:.8f}") |
| | print(f"right_branch_invariance_max_abs_diff: {right_branch_invariance_max_abs_diff:.8f}") |
| | print(f"left_branch_invariant: {left_branch_invariance_max_abs_diff <= args.tolerance}") |
| | print(f"right_branch_invariant: {right_branch_invariance_max_abs_diff <= args.tolerance}") |
| | else: |
| | print("left_branch_invariance_max_abs_diff: skipped_for_split_communicating") |
| | print("right_branch_invariance_max_abs_diff: skipped_for_split_communicating") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|