File size: 4,197 Bytes
0b18f1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import dataclasses
import enum
import logging
import os
import socket
import tyro
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
LIBERO_BASE = "libero_base"
@dataclasses.dataclass
class Checkpoint:
"""Load a policy from a trained checkpoint."""
# Training config name (e.g., "pi0_aloha_sim").
config: str
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
dir: str
@dataclasses.dataclass
class Default:
"""Use the default policy for the given environment."""
@dataclasses.dataclass
class Args:
"""Arguments for the serve_policy_openpi script."""
# Environment to serve the policy for. This is only used when serving default policies.
env: EnvMode = EnvMode.ALOHA_SIM
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
# prompt.
default_prompt: str | None = None
# Port to serve the policy on.
port: int = 8000
# Record the policy's behavior for debugging.
record: bool = False
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
# CUDA device ID to use (e.g., 0, 1, 2, ...). Set to -1 to use CPU.
cuda_device: int = 7
# Default checkpoints that should be used for each environment.
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
EnvMode.ALOHA: Checkpoint(
config="pi05_aloha",
dir="gs://openpi-assets/checkpoints/pi05_base",
),
EnvMode.ALOHA_SIM: Checkpoint(
config="pi0_aloha_sim",
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
),
EnvMode.DROID: Checkpoint(
config="pi05_droid",
dir="gs://openpi-assets/checkpoints/pi05_droid",
),
EnvMode.LIBERO: Checkpoint(
config="pi05_libero",
dir="gs://openpi-assets/checkpoints/pi05_libero",
),
EnvMode.LIBERO_BASE: Checkpoint(
config="pi05_libero",
dir="gs://openpi-assets/checkpoints/pi05_base",
),
}
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
"""Create a default policy for the given environment."""
if checkpoint := DEFAULT_CHECKPOINT.get(env):
return _policy_config.create_trained_policy(
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
)
raise ValueError(f"Unsupported environment mode: {env}")
def create_policy(args: Args) -> _policy.Policy:
"""Create a policy from the given arguments."""
match args.policy:
case Checkpoint():
return _policy_config.create_trained_policy(
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
)
case Default():
return create_default_policy(args.env, default_prompt=args.default_prompt)
def main(args: Args) -> None:
if args.cuda_device >= 0:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)
logging.info("Using CUDA device %d", args.cuda_device)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
logging.info("Using CPU (CUDA disabled)")
policy = create_policy(args)
policy_metadata = policy.metadata
if args.record:
policy = _policy.PolicyRecorder(policy, "policy_records")
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
server = websocket_policy_server.WebsocketPolicyServer(
policy=policy,
host="0.0.0.0",
port=args.port,
metadata=policy_metadata,
)
server.serve_forever()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
main(tyro.cli(Args))
|