| #!/usr/bin/env python3 | |
| import dataclasses | |
| import jax | |
| import numpy as np | |
| import safetensors.torch | |
| import torch | |
| import tyro | |
| import openpi.models.pi0_config | |
| import openpi.training.config as _config | |
| import openpi.training.data_loader as _data | |
| class Args: | |
| baseline_config_name: str = "pi05_twin_handover_256_packed_baseline_pytorch_10k" | |
| parallel_config_name: str = "pi05_twin_handover_256_packed_parallel_pytorch_10k" | |
| baseline_ckpt: str = "/workspace/checkpoints/pi05_base_single_pytorch" | |
| parallel_ckpt: str = "/workspace/checkpoints/pi05_base_parallel_packed_from_single" | |
| repo_id: str = "lsnu/twin_handover_256_train" | |
| batch_size: int = 4 | |
| num_workers: int = 0 | |
| eval_seed: int = 777 | |
| tolerance: float = 1e-6 | |
| def build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: | |
| if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): | |
| model_cfg = openpi.models.pi0_config.Pi0Config( | |
| dtype="float32", | |
| action_dim=config.model.action_dim, | |
| action_horizon=config.model.action_horizon, | |
| max_token_len=config.model.max_token_len, | |
| paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), | |
| action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), | |
| pi05=getattr(config.model, "pi05", False), | |
| arm_action_dims=getattr(config.model, "arm_action_dims", None), | |
| ) | |
| else: | |
| model_cfg = dataclasses.replace(config.model) | |
| object.__setattr__(model_cfg, "dtype", "float32") | |
| return model_cfg | |
| def compute_masked_action_loss( | |
| losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None | |
| ) -> torch.Tensor: | |
| if action_loss_mask is None: | |
| return losses.mean() | |
| mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype) | |
| denom = mask.sum() * losses.shape[0] * losses.shape[1] | |
| return (losses * mask.view(1, 1, -1)).sum() / denom | |
| def make_eval_noise(actions: torch.Tensor, seed: int) -> torch.Tensor: | |
| rng = np.random.default_rng(seed) | |
| noise = rng.standard_normal(size=tuple(actions.shape)).astype(np.float32) | |
| return torch.as_tensor(noise, device=actions.device, dtype=torch.float32) | |
| def make_eval_time(batch_size: int, device: torch.device, seed: int) -> torch.Tensor: | |
| rng = np.random.default_rng(seed) | |
| time_beta = rng.beta(1.5, 1.0, size=batch_size).astype(np.float32) | |
| time_values = time_beta * 0.999 + 0.001 | |
| return torch.as_tensor(time_values, device=device, dtype=torch.float32) | |
| def main() -> None: | |
| args = tyro.cli(Args) | |
| print( | |
| f"starting_warmstart_equivalence baseline_config={args.baseline_config_name} " | |
| f"parallel_config={args.parallel_config_name} repo_id={args.repo_id}", | |
| flush=True, | |
| ) | |
| baseline_config = _config.get_config(args.baseline_config_name) | |
| parallel_config = _config.get_config(args.parallel_config_name) | |
| parallel_model_cfg = build_model_config(parallel_config) | |
| if parallel_model_cfg.use_split_action_expert: | |
| raise ValueError( | |
| "Exact end-to-end warm-start equivalence is not expected for split action experts. " | |
| "Use init_parallel_pi05_from_single_pytorch.py for branch copy checks and " | |
| "check_split_expert_invariants.py for branch-local invariants." | |
| ) | |
| data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model) | |
| data_config = dataclasses.replace(data_config, repo_id=args.repo_id) | |
| loader = _data.create_torch_data_loader( | |
| data_config, | |
| model_config=baseline_config.model, | |
| action_horizon=baseline_config.model.action_horizon, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_batches=1, | |
| num_workers=args.num_workers, | |
| seed=baseline_config.seed, | |
| framework="pytorch", | |
| ) | |
| print("loaded_eval_dataloader", flush=True) | |
| import openpi.models_pytorch.pi0_pytorch as pi0_pytorch | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| observation, actions = next(iter(loader)) | |
| print("loaded_reference_batch", flush=True) | |
| actions = actions.to(torch.float32) | |
| actions_device = actions.to(device) | |
| noise = make_eval_noise(actions_device, args.eval_seed) | |
| time_values = make_eval_time(actions_device.shape[0], device, args.eval_seed + 10_000) | |
| time_expanded = time_values[:, None, None] | |
| x_t = time_expanded * noise + (1 - time_expanded) * actions_device | |
| def run_model(config: _config.TrainConfig, checkpoint_dir: str): | |
| print(f"loading_model config={config.name} checkpoint={checkpoint_dir}", flush=True) | |
| model = pi0_pytorch.PI0Pytorch(build_model_config(config)).to(device) | |
| missing, unexpected = safetensors.torch.load_model(model, f"{checkpoint_dir}/model.safetensors", strict=False) | |
| model.eval() | |
| print(f"running_forward config={config.name}", flush=True) | |
| observation_device = jax.tree.map(lambda x: x.to(device), observation) | |
| actions_local = actions.to(device) | |
| with torch.inference_mode(): | |
| projected_inputs = model._project_action_inputs(x_t) | |
| losses = model(observation_device, actions_local, noise=noise, time=time_values).to(torch.float32) | |
| print(f"finished_forward config={config.name}", flush=True) | |
| del model, observation_device, actions_local | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return projected_inputs.cpu(), losses.cpu(), list(missing), list(unexpected) | |
| baseline_in, baseline_losses, baseline_missing, baseline_unexpected = run_model( | |
| baseline_config, args.baseline_ckpt | |
| ) | |
| parallel_in, parallel_losses, parallel_missing, parallel_unexpected = run_model( | |
| parallel_config, args.parallel_ckpt | |
| ) | |
| projection_diff = (baseline_in - parallel_in).abs() | |
| loss_diff = (baseline_losses - parallel_losses).abs() | |
| baseline_masked_loss = compute_masked_action_loss(baseline_losses, baseline_config.action_loss_mask) | |
| parallel_masked_loss = compute_masked_action_loss(parallel_losses, parallel_config.action_loss_mask) | |
| masked_loss_abs_diff = abs(float(baseline_masked_loss.item()) - float(parallel_masked_loss.item())) | |
| print(f"baseline_config_name: {args.baseline_config_name}") | |
| print(f"parallel_config_name: {args.parallel_config_name}") | |
| print(f"repo_id_used: {args.repo_id}") | |
| print(f"baseline_ckpt: {args.baseline_ckpt}") | |
| print(f"parallel_ckpt: {args.parallel_ckpt}") | |
| print(f"batch_size: {args.batch_size}") | |
| print(f"eval_seed: {args.eval_seed}") | |
| print(f"tolerance: {args.tolerance}") | |
| print(f"baseline_missing_keys: {list(baseline_missing)}") | |
| print(f"baseline_unexpected_keys: {list(baseline_unexpected)}") | |
| print(f"parallel_missing_keys: {list(parallel_missing)}") | |
| print(f"parallel_unexpected_keys: {list(parallel_unexpected)}") | |
| print(f"input_projection_max_abs_diff: {float(projection_diff.max().item()):.8f}") | |
| print(f"input_projection_mean_abs_diff: {float(projection_diff.mean().item()):.8f}") | |
| print(f"loss_max_abs_diff: {float(loss_diff.max().item()):.8f}") | |
| print(f"loss_mean_abs_diff: {float(loss_diff.mean().item()):.8f}") | |
| print(f"baseline_masked_loss: {float(baseline_masked_loss.item()):.8f}") | |
| print(f"parallel_masked_loss: {float(parallel_masked_loss.item()):.8f}") | |
| print(f"masked_loss_abs_diff: {masked_loss_abs_diff:.8f}") | |
| print( | |
| "warmstart_equivalent: " | |
| f"{float(projection_diff.max().item()) <= args.tolerance and float(loss_diff.max().item()) <= args.tolerance}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |