pi05tests-openpi-multiarm / openpi /scripts /check_parallel_warmstart_equivalence.py
lsnu's picture
Add files using upload-large-folder tool
ccf25b1 verified
#!/usr/bin/env python3
import dataclasses
import jax
import numpy as np
import safetensors.torch
import torch
import tyro
import openpi.models.pi0_config
import openpi.training.config as _config
import openpi.training.data_loader as _data
@dataclasses.dataclass
class Args:
baseline_config_name: str = "pi05_twin_handover_256_packed_baseline_pytorch_10k"
parallel_config_name: str = "pi05_twin_handover_256_packed_parallel_pytorch_10k"
baseline_ckpt: str = "/workspace/checkpoints/pi05_base_single_pytorch"
parallel_ckpt: str = "/workspace/checkpoints/pi05_base_parallel_packed_from_single"
repo_id: str = "lsnu/twin_handover_256_train"
batch_size: int = 4
num_workers: int = 0
eval_seed: int = 777
tolerance: float = 1e-6
def build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config:
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
model_cfg = openpi.models.pi0_config.Pi0Config(
dtype="float32",
action_dim=config.model.action_dim,
action_horizon=config.model.action_horizon,
max_token_len=config.model.max_token_len,
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
arm_action_dims=getattr(config.model, "arm_action_dims", None),
)
else:
model_cfg = dataclasses.replace(config.model)
object.__setattr__(model_cfg, "dtype", "float32")
return model_cfg
def compute_masked_action_loss(
losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None
) -> torch.Tensor:
if action_loss_mask is None:
return losses.mean()
mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype)
denom = mask.sum() * losses.shape[0] * losses.shape[1]
return (losses * mask.view(1, 1, -1)).sum() / denom
def make_eval_noise(actions: torch.Tensor, seed: int) -> torch.Tensor:
rng = np.random.default_rng(seed)
noise = rng.standard_normal(size=tuple(actions.shape)).astype(np.float32)
return torch.as_tensor(noise, device=actions.device, dtype=torch.float32)
def make_eval_time(batch_size: int, device: torch.device, seed: int) -> torch.Tensor:
rng = np.random.default_rng(seed)
time_beta = rng.beta(1.5, 1.0, size=batch_size).astype(np.float32)
time_values = time_beta * 0.999 + 0.001
return torch.as_tensor(time_values, device=device, dtype=torch.float32)
def main() -> None:
args = tyro.cli(Args)
print(
f"starting_warmstart_equivalence baseline_config={args.baseline_config_name} "
f"parallel_config={args.parallel_config_name} repo_id={args.repo_id}",
flush=True,
)
baseline_config = _config.get_config(args.baseline_config_name)
parallel_config = _config.get_config(args.parallel_config_name)
parallel_model_cfg = build_model_config(parallel_config)
if parallel_model_cfg.use_split_action_expert:
raise ValueError(
"Exact end-to-end warm-start equivalence is not expected for split action experts. "
"Use init_parallel_pi05_from_single_pytorch.py for branch copy checks and "
"check_split_expert_invariants.py for branch-local invariants."
)
data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model)
data_config = dataclasses.replace(data_config, repo_id=args.repo_id)
loader = _data.create_torch_data_loader(
data_config,
model_config=baseline_config.model,
action_horizon=baseline_config.model.action_horizon,
batch_size=args.batch_size,
shuffle=False,
num_batches=1,
num_workers=args.num_workers,
seed=baseline_config.seed,
framework="pytorch",
)
print("loaded_eval_dataloader", flush=True)
import openpi.models_pytorch.pi0_pytorch as pi0_pytorch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
observation, actions = next(iter(loader))
print("loaded_reference_batch", flush=True)
actions = actions.to(torch.float32)
actions_device = actions.to(device)
noise = make_eval_noise(actions_device, args.eval_seed)
time_values = make_eval_time(actions_device.shape[0], device, args.eval_seed + 10_000)
time_expanded = time_values[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions_device
def run_model(config: _config.TrainConfig, checkpoint_dir: str):
print(f"loading_model config={config.name} checkpoint={checkpoint_dir}", flush=True)
model = pi0_pytorch.PI0Pytorch(build_model_config(config)).to(device)
missing, unexpected = safetensors.torch.load_model(model, f"{checkpoint_dir}/model.safetensors", strict=False)
model.eval()
print(f"running_forward config={config.name}", flush=True)
observation_device = jax.tree.map(lambda x: x.to(device), observation)
actions_local = actions.to(device)
with torch.inference_mode():
projected_inputs = model._project_action_inputs(x_t)
losses = model(observation_device, actions_local, noise=noise, time=time_values).to(torch.float32)
print(f"finished_forward config={config.name}", flush=True)
del model, observation_device, actions_local
if torch.cuda.is_available():
torch.cuda.empty_cache()
return projected_inputs.cpu(), losses.cpu(), list(missing), list(unexpected)
baseline_in, baseline_losses, baseline_missing, baseline_unexpected = run_model(
baseline_config, args.baseline_ckpt
)
parallel_in, parallel_losses, parallel_missing, parallel_unexpected = run_model(
parallel_config, args.parallel_ckpt
)
projection_diff = (baseline_in - parallel_in).abs()
loss_diff = (baseline_losses - parallel_losses).abs()
baseline_masked_loss = compute_masked_action_loss(baseline_losses, baseline_config.action_loss_mask)
parallel_masked_loss = compute_masked_action_loss(parallel_losses, parallel_config.action_loss_mask)
masked_loss_abs_diff = abs(float(baseline_masked_loss.item()) - float(parallel_masked_loss.item()))
print(f"baseline_config_name: {args.baseline_config_name}")
print(f"parallel_config_name: {args.parallel_config_name}")
print(f"repo_id_used: {args.repo_id}")
print(f"baseline_ckpt: {args.baseline_ckpt}")
print(f"parallel_ckpt: {args.parallel_ckpt}")
print(f"batch_size: {args.batch_size}")
print(f"eval_seed: {args.eval_seed}")
print(f"tolerance: {args.tolerance}")
print(f"baseline_missing_keys: {list(baseline_missing)}")
print(f"baseline_unexpected_keys: {list(baseline_unexpected)}")
print(f"parallel_missing_keys: {list(parallel_missing)}")
print(f"parallel_unexpected_keys: {list(parallel_unexpected)}")
print(f"input_projection_max_abs_diff: {float(projection_diff.max().item()):.8f}")
print(f"input_projection_mean_abs_diff: {float(projection_diff.mean().item()):.8f}")
print(f"loss_max_abs_diff: {float(loss_diff.max().item()):.8f}")
print(f"loss_mean_abs_diff: {float(loss_diff.mean().item()):.8f}")
print(f"baseline_masked_loss: {float(baseline_masked_loss.item()):.8f}")
print(f"parallel_masked_loss: {float(parallel_masked_loss.item()):.8f}")
print(f"masked_loss_abs_diff: {masked_loss_abs_diff:.8f}")
print(
"warmstart_equivalent: "
f"{float(projection_diff.max().item()) <= args.tolerance and float(loss_diff.max().item()) <= args.tolerance}"
)
if __name__ == "__main__":
main()