Create serve_policy_openpi.py
Browse files- serve_policy_openpi.py +137 -0
serve_policy_openpi.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import enum
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import socket
|
| 6 |
+
|
| 7 |
+
import tyro
|
| 8 |
+
|
| 9 |
+
from openpi.policies import policy as _policy
|
| 10 |
+
from openpi.policies import policy_config as _policy_config
|
| 11 |
+
from openpi.serving import websocket_policy_server
|
| 12 |
+
from openpi.training import config as _config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EnvMode(enum.Enum):
|
| 16 |
+
"""Supported environments."""
|
| 17 |
+
|
| 18 |
+
ALOHA = "aloha"
|
| 19 |
+
ALOHA_SIM = "aloha_sim"
|
| 20 |
+
DROID = "droid"
|
| 21 |
+
LIBERO = "libero"
|
| 22 |
+
LIBERO_BASE = "libero_base"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclasses.dataclass
|
| 26 |
+
class Checkpoint:
|
| 27 |
+
"""Load a policy from a trained checkpoint."""
|
| 28 |
+
|
| 29 |
+
# Training config name (e.g., "pi0_aloha_sim").
|
| 30 |
+
config: str
|
| 31 |
+
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
| 32 |
+
dir: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclasses.dataclass
|
| 36 |
+
class Default:
|
| 37 |
+
"""Use the default policy for the given environment."""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclasses.dataclass
|
| 41 |
+
class Args:
|
| 42 |
+
"""Arguments for the serve_policy_openpi script."""
|
| 43 |
+
|
| 44 |
+
# Environment to serve the policy for. This is only used when serving default policies.
|
| 45 |
+
env: EnvMode = EnvMode.ALOHA_SIM
|
| 46 |
+
|
| 47 |
+
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
| 48 |
+
# prompt.
|
| 49 |
+
default_prompt: str | None = None
|
| 50 |
+
|
| 51 |
+
# Port to serve the policy on.
|
| 52 |
+
port: int = 8000
|
| 53 |
+
# Record the policy's behavior for debugging.
|
| 54 |
+
record: bool = False
|
| 55 |
+
|
| 56 |
+
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
| 57 |
+
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
|
| 58 |
+
|
| 59 |
+
# CUDA device ID to use (e.g., 0, 1, 2, ...). Set to -1 to use CPU.
|
| 60 |
+
cuda_device: int = 7
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Default checkpoints that should be used for each environment.
|
| 64 |
+
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
| 65 |
+
EnvMode.ALOHA: Checkpoint(
|
| 66 |
+
config="pi05_aloha",
|
| 67 |
+
dir="gs://openpi-assets/checkpoints/pi05_base",
|
| 68 |
+
),
|
| 69 |
+
EnvMode.ALOHA_SIM: Checkpoint(
|
| 70 |
+
config="pi0_aloha_sim",
|
| 71 |
+
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
|
| 72 |
+
),
|
| 73 |
+
EnvMode.DROID: Checkpoint(
|
| 74 |
+
config="pi05_droid",
|
| 75 |
+
dir="gs://openpi-assets/checkpoints/pi05_droid",
|
| 76 |
+
),
|
| 77 |
+
EnvMode.LIBERO: Checkpoint(
|
| 78 |
+
config="pi05_libero",
|
| 79 |
+
dir="gs://openpi-assets/checkpoints/pi05_libero",
|
| 80 |
+
),
|
| 81 |
+
EnvMode.LIBERO_BASE: Checkpoint(
|
| 82 |
+
config="pi05_libero",
|
| 83 |
+
dir="gs://openpi-assets/checkpoints/pi05_base",
|
| 84 |
+
),
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
|
| 89 |
+
"""Create a default policy for the given environment."""
|
| 90 |
+
if checkpoint := DEFAULT_CHECKPOINT.get(env):
|
| 91 |
+
return _policy_config.create_trained_policy(
|
| 92 |
+
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
|
| 93 |
+
)
|
| 94 |
+
raise ValueError(f"Unsupported environment mode: {env}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def create_policy(args: Args) -> _policy.Policy:
|
| 98 |
+
"""Create a policy from the given arguments."""
|
| 99 |
+
match args.policy:
|
| 100 |
+
case Checkpoint():
|
| 101 |
+
return _policy_config.create_trained_policy(
|
| 102 |
+
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
|
| 103 |
+
)
|
| 104 |
+
case Default():
|
| 105 |
+
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main(args: Args) -> None:
|
| 109 |
+
if args.cuda_device >= 0:
|
| 110 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)
|
| 111 |
+
logging.info("Using CUDA device %d", args.cuda_device)
|
| 112 |
+
else:
|
| 113 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 114 |
+
logging.info("Using CPU (CUDA disabled)")
|
| 115 |
+
|
| 116 |
+
policy = create_policy(args)
|
| 117 |
+
policy_metadata = policy.metadata
|
| 118 |
+
|
| 119 |
+
if args.record:
|
| 120 |
+
policy = _policy.PolicyRecorder(policy, "policy_records")
|
| 121 |
+
|
| 122 |
+
hostname = socket.gethostname()
|
| 123 |
+
local_ip = socket.gethostbyname(hostname)
|
| 124 |
+
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
|
| 125 |
+
|
| 126 |
+
server = websocket_policy_server.WebsocketPolicyServer(
|
| 127 |
+
policy=policy,
|
| 128 |
+
host="0.0.0.0",
|
| 129 |
+
port=args.port,
|
| 130 |
+
metadata=policy_metadata,
|
| 131 |
+
)
|
| 132 |
+
server.serve_forever()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 137 |
+
main(tyro.cli(Args))
|