File size: 7,733 Bytes
78e6df8 ccf25b1 78e6df8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | #!/usr/bin/env python3
import dataclasses
import jax
import numpy as np
import safetensors.torch
import torch
import tyro
import openpi.models.pi0_config
import openpi.training.config as _config
import openpi.training.data_loader as _data
@dataclasses.dataclass
class Args:
baseline_config_name: str = "pi05_twin_handover_256_packed_baseline_pytorch_10k"
parallel_config_name: str = "pi05_twin_handover_256_packed_parallel_pytorch_10k"
baseline_ckpt: str = "/workspace/checkpoints/pi05_base_single_pytorch"
parallel_ckpt: str = "/workspace/checkpoints/pi05_base_parallel_packed_from_single"
repo_id: str = "lsnu/twin_handover_256_train"
batch_size: int = 4
num_workers: int = 0
eval_seed: int = 777
tolerance: float = 1e-6
def build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config:
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
model_cfg = openpi.models.pi0_config.Pi0Config(
dtype="float32",
action_dim=config.model.action_dim,
action_horizon=config.model.action_horizon,
max_token_len=config.model.max_token_len,
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
arm_action_dims=getattr(config.model, "arm_action_dims", None),
)
else:
model_cfg = dataclasses.replace(config.model)
object.__setattr__(model_cfg, "dtype", "float32")
return model_cfg
def compute_masked_action_loss(
losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None
) -> torch.Tensor:
if action_loss_mask is None:
return losses.mean()
mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype)
denom = mask.sum() * losses.shape[0] * losses.shape[1]
return (losses * mask.view(1, 1, -1)).sum() / denom
def make_eval_noise(actions: torch.Tensor, seed: int) -> torch.Tensor:
rng = np.random.default_rng(seed)
noise = rng.standard_normal(size=tuple(actions.shape)).astype(np.float32)
return torch.as_tensor(noise, device=actions.device, dtype=torch.float32)
def make_eval_time(batch_size: int, device: torch.device, seed: int) -> torch.Tensor:
rng = np.random.default_rng(seed)
time_beta = rng.beta(1.5, 1.0, size=batch_size).astype(np.float32)
time_values = time_beta * 0.999 + 0.001
return torch.as_tensor(time_values, device=device, dtype=torch.float32)
def main() -> None:
args = tyro.cli(Args)
print(
f"starting_warmstart_equivalence baseline_config={args.baseline_config_name} "
f"parallel_config={args.parallel_config_name} repo_id={args.repo_id}",
flush=True,
)
baseline_config = _config.get_config(args.baseline_config_name)
parallel_config = _config.get_config(args.parallel_config_name)
parallel_model_cfg = build_model_config(parallel_config)
if parallel_model_cfg.use_split_action_expert:
raise ValueError(
"Exact end-to-end warm-start equivalence is not expected for split action experts. "
"Use init_parallel_pi05_from_single_pytorch.py for branch copy checks and "
"check_split_expert_invariants.py for branch-local invariants."
)
data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model)
data_config = dataclasses.replace(data_config, repo_id=args.repo_id)
loader = _data.create_torch_data_loader(
data_config,
model_config=baseline_config.model,
action_horizon=baseline_config.model.action_horizon,
batch_size=args.batch_size,
shuffle=False,
num_batches=1,
num_workers=args.num_workers,
seed=baseline_config.seed,
framework="pytorch",
)
print("loaded_eval_dataloader", flush=True)
import openpi.models_pytorch.pi0_pytorch as pi0_pytorch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
observation, actions = next(iter(loader))
print("loaded_reference_batch", flush=True)
actions = actions.to(torch.float32)
actions_device = actions.to(device)
noise = make_eval_noise(actions_device, args.eval_seed)
time_values = make_eval_time(actions_device.shape[0], device, args.eval_seed + 10_000)
time_expanded = time_values[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions_device
def run_model(config: _config.TrainConfig, checkpoint_dir: str):
print(f"loading_model config={config.name} checkpoint={checkpoint_dir}", flush=True)
model = pi0_pytorch.PI0Pytorch(build_model_config(config)).to(device)
missing, unexpected = safetensors.torch.load_model(model, f"{checkpoint_dir}/model.safetensors", strict=False)
model.eval()
print(f"running_forward config={config.name}", flush=True)
observation_device = jax.tree.map(lambda x: x.to(device), observation)
actions_local = actions.to(device)
with torch.inference_mode():
projected_inputs = model._project_action_inputs(x_t)
losses = model(observation_device, actions_local, noise=noise, time=time_values).to(torch.float32)
print(f"finished_forward config={config.name}", flush=True)
del model, observation_device, actions_local
if torch.cuda.is_available():
torch.cuda.empty_cache()
return projected_inputs.cpu(), losses.cpu(), list(missing), list(unexpected)
baseline_in, baseline_losses, baseline_missing, baseline_unexpected = run_model(
baseline_config, args.baseline_ckpt
)
parallel_in, parallel_losses, parallel_missing, parallel_unexpected = run_model(
parallel_config, args.parallel_ckpt
)
projection_diff = (baseline_in - parallel_in).abs()
loss_diff = (baseline_losses - parallel_losses).abs()
baseline_masked_loss = compute_masked_action_loss(baseline_losses, baseline_config.action_loss_mask)
parallel_masked_loss = compute_masked_action_loss(parallel_losses, parallel_config.action_loss_mask)
masked_loss_abs_diff = abs(float(baseline_masked_loss.item()) - float(parallel_masked_loss.item()))
print(f"baseline_config_name: {args.baseline_config_name}")
print(f"parallel_config_name: {args.parallel_config_name}")
print(f"repo_id_used: {args.repo_id}")
print(f"baseline_ckpt: {args.baseline_ckpt}")
print(f"parallel_ckpt: {args.parallel_ckpt}")
print(f"batch_size: {args.batch_size}")
print(f"eval_seed: {args.eval_seed}")
print(f"tolerance: {args.tolerance}")
print(f"baseline_missing_keys: {list(baseline_missing)}")
print(f"baseline_unexpected_keys: {list(baseline_unexpected)}")
print(f"parallel_missing_keys: {list(parallel_missing)}")
print(f"parallel_unexpected_keys: {list(parallel_unexpected)}")
print(f"input_projection_max_abs_diff: {float(projection_diff.max().item()):.8f}")
print(f"input_projection_mean_abs_diff: {float(projection_diff.mean().item()):.8f}")
print(f"loss_max_abs_diff: {float(loss_diff.max().item()):.8f}")
print(f"loss_mean_abs_diff: {float(loss_diff.mean().item()):.8f}")
print(f"baseline_masked_loss: {float(baseline_masked_loss.item()):.8f}")
print(f"parallel_masked_loss: {float(parallel_masked_loss.item()):.8f}")
print(f"masked_loss_abs_diff: {masked_loss_abs_diff:.8f}")
print(
"warmstart_equivalent: "
f"{float(projection_diff.max().item()) <= args.tolerance and float(loss_diff.max().item()) <= args.tolerance}"
)
if __name__ == "__main__":
main()
|