openpi-oft / inference_test.py
Sichang0621's picture
Upload folder using huggingface_hub
ce5618e verified
import dataclasses
import jax
from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")
# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)
# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)
# Delete the policy to free up memory.
del policy
print("Actions shape:", result["actions"].shape)