| import dataclasses | |
| import jax | |
| from openpi.models import model as _model | |
| from openpi.policies import droid_policy | |
| from openpi.policies import policy_config as _policy_config | |
| from openpi.shared import download | |
| from openpi.training import config as _config | |
| from openpi.training import data_loader as _data_loader | |
| config = _config.get_config("pi0_fast_droid") | |
| checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid") | |
| # Create a trained policy. | |
| policy = _policy_config.create_trained_policy(config, checkpoint_dir) | |
| # Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime. | |
| example = droid_policy.make_droid_example() | |
| result = policy.infer(example) | |
| # Delete the policy to free up memory. | |
| del policy | |
| print("Actions shape:", result["actions"].shape) |