File size: 6,856 Bytes
ccf25b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | #!/usr/bin/env python3
import dataclasses
import safetensors.torch
import torch
import tyro
import openpi.models.pi0_config
import openpi.training.config as _config
@dataclasses.dataclass
class Args:
config_name: str
checkpoint_dir: str
tolerance: float = 1e-6
batch_size: int = 2
prefix_len: int = 12
seed: int = 123
def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config:
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
return openpi.models.pi0_config.Pi0Config(
dtype="float32",
action_dim=config.model.action_dim,
action_horizon=config.model.action_horizon,
max_token_len=config.model.max_token_len,
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
arm_action_dims=getattr(config.model, "arm_action_dims", None),
action_expert_mode=getattr(config.model, "action_expert_mode", None),
)
model_cfg = dataclasses.replace(config.model)
object.__setattr__(model_cfg, "dtype", "float32")
return model_cfg
def _random_prefix_context(model, batch_size: int, prefix_len: int, seed: int):
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
prefix_width = model.paligemma_with_expert.paligemma.config.text_config.hidden_size
prefix_embs = torch.randn(batch_size, prefix_len, prefix_width, generator=generator, dtype=torch.float32)
prefix_pad_masks = torch.ones(batch_size, prefix_len, dtype=torch.bool)
prefix_att_masks = torch.zeros(batch_size, prefix_len, dtype=torch.bool)
return prefix_embs, prefix_pad_masks, prefix_att_masks
def _run_model(model, prefix_context, x_t, timestep):
prefix_embs, prefix_pad_masks, prefix_att_masks = prefix_context
state = torch.zeros(x_t.shape[0], model.config.action_dim, dtype=torch.float32)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = model.embed_suffix(state, x_t, timestep)
suffix_outputs = model._run_joint_action_expert( # noqa: SLF001
prefix_embs,
prefix_pad_masks,
prefix_att_masks,
suffix_embs,
suffix_pad_masks,
suffix_att_masks,
adarms_cond,
)
suffix_outputs = [output[:, -model.config.action_horizon :].to(dtype=torch.float32) for output in suffix_outputs]
projected_actions = model._project_action_outputs(suffix_outputs) # noqa: SLF001
return suffix_outputs, projected_actions
def _run_identical_branch_inputs(model, prefix_context, timestep, seed: int):
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
width = model.action_expert_width
horizon = model.config.action_horizon
batch_size = prefix_context[0].shape[0]
shared_suffix = torch.randn(batch_size, horizon, width, generator=generator, dtype=torch.float32)
shared_cond = torch.randn(batch_size, width, generator=generator, dtype=torch.float32)
suffix_pad_masks = [torch.ones(batch_size, horizon, dtype=torch.bool) for _ in range(2)]
suffix_att_masks = [model._action_att_mask(batch_size, torch.device("cpu"), torch.float32) for _ in range(2)] # noqa: SLF001
suffix_outputs = model._run_joint_action_expert( # noqa: SLF001
prefix_context[0],
prefix_context[1],
prefix_context[2],
[shared_suffix.clone(), shared_suffix.clone()],
suffix_pad_masks,
suffix_att_masks,
[shared_cond.clone(), shared_cond.clone()],
)
return suffix_outputs
def main() -> None:
args = tyro.cli(Args)
config = _config.get_config(args.config_name)
model_cfg = _build_model_config(config)
if not model_cfg.use_split_action_expert:
raise ValueError(f"Config {args.config_name} is not a split-expert config.")
import openpi.models_pytorch.pi0_pytorch as pi0_pytorch
torch.manual_seed(args.seed)
model = pi0_pytorch.PI0Pytorch(model_cfg)
missing, unexpected = safetensors.torch.load_model(model, f"{args.checkpoint_dir}/model.safetensors", strict=False)
model.eval()
prefix_context = _random_prefix_context(model, args.batch_size, args.prefix_len, args.seed + 1)
x_t = torch.randn(args.batch_size, model.config.action_horizon, model.config.action_dim, dtype=torch.float32)
timestep = torch.full((args.batch_size,), 0.5, dtype=torch.float32)
identical_suffix_outputs = _run_identical_branch_inputs(model, prefix_context, timestep, args.seed + 2)
identical_branch_suffix_max_abs_diff = float(
(identical_suffix_outputs[0] - identical_suffix_outputs[1]).abs().max().item()
)
left_suffix_outputs, left_projected_actions = _run_model(model, prefix_context, x_t, timestep)
x_t_right_perturbed = x_t.clone()
x_t_right_perturbed[:, :, 16:32] += 0.5 * torch.randn_like(x_t_right_perturbed[:, :, 16:32])
_, right_perturbed_actions = _run_model(model, prefix_context, x_t_right_perturbed, timestep)
left_branch_invariance_max_abs_diff = float(
(left_projected_actions[:, :, 0:16] - right_perturbed_actions[:, :, 0:16]).abs().max().item()
)
x_t_left_perturbed = x_t.clone()
x_t_left_perturbed[:, :, 0:16] += 0.5 * torch.randn_like(x_t_left_perturbed[:, :, 0:16])
_, left_perturbed_actions = _run_model(model, prefix_context, x_t_left_perturbed, timestep)
right_branch_invariance_max_abs_diff = float(
(left_projected_actions[:, :, 16:32] - left_perturbed_actions[:, :, 16:32]).abs().max().item()
)
print(f"config_name: {args.config_name}")
print(f"checkpoint_dir: {args.checkpoint_dir}")
print(f"action_expert_mode: {model_cfg.action_expert_mode}")
print(f"weight_loading_missing_keys: {list(missing)}")
print(f"weight_loading_unexpected_keys: {list(unexpected)}")
print(f"identical_branch_suffix_max_abs_diff: {identical_branch_suffix_max_abs_diff:.8f}")
print(
f"identical_branch_suffix_match: "
f"{identical_branch_suffix_max_abs_diff <= args.tolerance}"
)
if model_cfg.action_expert_mode == "split_independent":
print(f"left_branch_invariance_max_abs_diff: {left_branch_invariance_max_abs_diff:.8f}")
print(f"right_branch_invariance_max_abs_diff: {right_branch_invariance_max_abs_diff:.8f}")
print(f"left_branch_invariant: {left_branch_invariance_max_abs_diff <= args.tolerance}")
print(f"right_branch_invariant: {right_branch_invariance_max_abs_diff <= args.tolerance}")
else:
print("left_branch_invariance_max_abs_diff: skipped_for_split_communicating")
print("right_branch_invariance_max_abs_diff: skipped_for_split_communicating")
if __name__ == "__main__":
main()
|