#!/usr/bin/env python3 import dataclasses import jax import numpy as np import safetensors.torch import torch import tyro import openpi.models.pi0_config import openpi.training.config as _config import openpi.training.data_loader as _data @dataclasses.dataclass class Args: baseline_config_name: str = "pi05_twin_handover_256_packed_baseline_pytorch_10k" parallel_config_name: str = "pi05_twin_handover_256_packed_parallel_pytorch_10k" baseline_ckpt: str = "/workspace/checkpoints/pi05_base_single_pytorch" parallel_ckpt: str = "/workspace/checkpoints/pi05_base_parallel_packed_from_single" repo_id: str = "lsnu/twin_handover_256_train" batch_size: int = 4 num_workers: int = 0 eval_seed: int = 777 tolerance: float = 1e-6 def build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): model_cfg = openpi.models.pi0_config.Pi0Config( dtype="float32", action_dim=config.model.action_dim, action_horizon=config.model.action_horizon, max_token_len=config.model.max_token_len, paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), pi05=getattr(config.model, "pi05", False), arm_action_dims=getattr(config.model, "arm_action_dims", None), ) else: model_cfg = dataclasses.replace(config.model) object.__setattr__(model_cfg, "dtype", "float32") return model_cfg def compute_masked_action_loss( losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None ) -> torch.Tensor: if action_loss_mask is None: return losses.mean() mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype) denom = mask.sum() * losses.shape[0] * losses.shape[1] return (losses * mask.view(1, 1, -1)).sum() / denom def make_eval_noise(actions: torch.Tensor, seed: int) -> torch.Tensor: rng = np.random.default_rng(seed) noise = rng.standard_normal(size=tuple(actions.shape)).astype(np.float32) return torch.as_tensor(noise, device=actions.device, dtype=torch.float32) def make_eval_time(batch_size: int, device: torch.device, seed: int) -> torch.Tensor: rng = np.random.default_rng(seed) time_beta = rng.beta(1.5, 1.0, size=batch_size).astype(np.float32) time_values = time_beta * 0.999 + 0.001 return torch.as_tensor(time_values, device=device, dtype=torch.float32) def main() -> None: args = tyro.cli(Args) print( f"starting_warmstart_equivalence baseline_config={args.baseline_config_name} " f"parallel_config={args.parallel_config_name} repo_id={args.repo_id}", flush=True, ) baseline_config = _config.get_config(args.baseline_config_name) parallel_config = _config.get_config(args.parallel_config_name) parallel_model_cfg = build_model_config(parallel_config) if parallel_model_cfg.use_split_action_expert: raise ValueError( "Exact end-to-end warm-start equivalence is not expected for split action experts. " "Use init_parallel_pi05_from_single_pytorch.py for branch copy checks and " "check_split_expert_invariants.py for branch-local invariants." ) data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model) data_config = dataclasses.replace(data_config, repo_id=args.repo_id) loader = _data.create_torch_data_loader( data_config, model_config=baseline_config.model, action_horizon=baseline_config.model.action_horizon, batch_size=args.batch_size, shuffle=False, num_batches=1, num_workers=args.num_workers, seed=baseline_config.seed, framework="pytorch", ) print("loaded_eval_dataloader", flush=True) import openpi.models_pytorch.pi0_pytorch as pi0_pytorch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") observation, actions = next(iter(loader)) print("loaded_reference_batch", flush=True) actions = actions.to(torch.float32) actions_device = actions.to(device) noise = make_eval_noise(actions_device, args.eval_seed) time_values = make_eval_time(actions_device.shape[0], device, args.eval_seed + 10_000) time_expanded = time_values[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions_device def run_model(config: _config.TrainConfig, checkpoint_dir: str): print(f"loading_model config={config.name} checkpoint={checkpoint_dir}", flush=True) model = pi0_pytorch.PI0Pytorch(build_model_config(config)).to(device) missing, unexpected = safetensors.torch.load_model(model, f"{checkpoint_dir}/model.safetensors", strict=False) model.eval() print(f"running_forward config={config.name}", flush=True) observation_device = jax.tree.map(lambda x: x.to(device), observation) actions_local = actions.to(device) with torch.inference_mode(): projected_inputs = model._project_action_inputs(x_t) losses = model(observation_device, actions_local, noise=noise, time=time_values).to(torch.float32) print(f"finished_forward config={config.name}", flush=True) del model, observation_device, actions_local if torch.cuda.is_available(): torch.cuda.empty_cache() return projected_inputs.cpu(), losses.cpu(), list(missing), list(unexpected) baseline_in, baseline_losses, baseline_missing, baseline_unexpected = run_model( baseline_config, args.baseline_ckpt ) parallel_in, parallel_losses, parallel_missing, parallel_unexpected = run_model( parallel_config, args.parallel_ckpt ) projection_diff = (baseline_in - parallel_in).abs() loss_diff = (baseline_losses - parallel_losses).abs() baseline_masked_loss = compute_masked_action_loss(baseline_losses, baseline_config.action_loss_mask) parallel_masked_loss = compute_masked_action_loss(parallel_losses, parallel_config.action_loss_mask) masked_loss_abs_diff = abs(float(baseline_masked_loss.item()) - float(parallel_masked_loss.item())) print(f"baseline_config_name: {args.baseline_config_name}") print(f"parallel_config_name: {args.parallel_config_name}") print(f"repo_id_used: {args.repo_id}") print(f"baseline_ckpt: {args.baseline_ckpt}") print(f"parallel_ckpt: {args.parallel_ckpt}") print(f"batch_size: {args.batch_size}") print(f"eval_seed: {args.eval_seed}") print(f"tolerance: {args.tolerance}") print(f"baseline_missing_keys: {list(baseline_missing)}") print(f"baseline_unexpected_keys: {list(baseline_unexpected)}") print(f"parallel_missing_keys: {list(parallel_missing)}") print(f"parallel_unexpected_keys: {list(parallel_unexpected)}") print(f"input_projection_max_abs_diff: {float(projection_diff.max().item()):.8f}") print(f"input_projection_mean_abs_diff: {float(projection_diff.mean().item()):.8f}") print(f"loss_max_abs_diff: {float(loss_diff.max().item()):.8f}") print(f"loss_mean_abs_diff: {float(loss_diff.mean().item()):.8f}") print(f"baseline_masked_loss: {float(baseline_masked_loss.item()):.8f}") print(f"parallel_masked_loss: {float(parallel_masked_loss.item()):.8f}") print(f"masked_loss_abs_diff: {masked_loss_abs_diff:.8f}") print( "warmstart_equivalent: " f"{float(projection_diff.max().item()) <= args.tolerance and float(loss_diff.max().item()) <= args.tolerance}" ) if __name__ == "__main__": main()