Timsty commited on
Commit
0b18f1f
·
verified ·
1 Parent(s): 48a771a

Create serve_policy_openpi.py

Browse files
Files changed (1) hide show
  1. serve_policy_openpi.py +137 -0
serve_policy_openpi.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import enum
3
+ import logging
4
+ import os
5
+ import socket
6
+
7
+ import tyro
8
+
9
+ from openpi.policies import policy as _policy
10
+ from openpi.policies import policy_config as _policy_config
11
+ from openpi.serving import websocket_policy_server
12
+ from openpi.training import config as _config
13
+
14
+
15
+ class EnvMode(enum.Enum):
16
+ """Supported environments."""
17
+
18
+ ALOHA = "aloha"
19
+ ALOHA_SIM = "aloha_sim"
20
+ DROID = "droid"
21
+ LIBERO = "libero"
22
+ LIBERO_BASE = "libero_base"
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class Checkpoint:
27
+ """Load a policy from a trained checkpoint."""
28
+
29
+ # Training config name (e.g., "pi0_aloha_sim").
30
+ config: str
31
+ # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
32
+ dir: str
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Default:
37
+ """Use the default policy for the given environment."""
38
+
39
+
40
+ @dataclasses.dataclass
41
+ class Args:
42
+ """Arguments for the serve_policy_openpi script."""
43
+
44
+ # Environment to serve the policy for. This is only used when serving default policies.
45
+ env: EnvMode = EnvMode.ALOHA_SIM
46
+
47
+ # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
48
+ # prompt.
49
+ default_prompt: str | None = None
50
+
51
+ # Port to serve the policy on.
52
+ port: int = 8000
53
+ # Record the policy's behavior for debugging.
54
+ record: bool = False
55
+
56
+ # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
57
+ policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
58
+
59
+ # CUDA device ID to use (e.g., 0, 1, 2, ...). Set to -1 to use CPU.
60
+ cuda_device: int = 7
61
+
62
+
63
+ # Default checkpoints that should be used for each environment.
64
+ DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
65
+ EnvMode.ALOHA: Checkpoint(
66
+ config="pi05_aloha",
67
+ dir="gs://openpi-assets/checkpoints/pi05_base",
68
+ ),
69
+ EnvMode.ALOHA_SIM: Checkpoint(
70
+ config="pi0_aloha_sim",
71
+ dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
72
+ ),
73
+ EnvMode.DROID: Checkpoint(
74
+ config="pi05_droid",
75
+ dir="gs://openpi-assets/checkpoints/pi05_droid",
76
+ ),
77
+ EnvMode.LIBERO: Checkpoint(
78
+ config="pi05_libero",
79
+ dir="gs://openpi-assets/checkpoints/pi05_libero",
80
+ ),
81
+ EnvMode.LIBERO_BASE: Checkpoint(
82
+ config="pi05_libero",
83
+ dir="gs://openpi-assets/checkpoints/pi05_base",
84
+ ),
85
+ }
86
+
87
+
88
+ def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
89
+ """Create a default policy for the given environment."""
90
+ if checkpoint := DEFAULT_CHECKPOINT.get(env):
91
+ return _policy_config.create_trained_policy(
92
+ _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
93
+ )
94
+ raise ValueError(f"Unsupported environment mode: {env}")
95
+
96
+
97
+ def create_policy(args: Args) -> _policy.Policy:
98
+ """Create a policy from the given arguments."""
99
+ match args.policy:
100
+ case Checkpoint():
101
+ return _policy_config.create_trained_policy(
102
+ _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
103
+ )
104
+ case Default():
105
+ return create_default_policy(args.env, default_prompt=args.default_prompt)
106
+
107
+
108
+ def main(args: Args) -> None:
109
+ if args.cuda_device >= 0:
110
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)
111
+ logging.info("Using CUDA device %d", args.cuda_device)
112
+ else:
113
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
114
+ logging.info("Using CPU (CUDA disabled)")
115
+
116
+ policy = create_policy(args)
117
+ policy_metadata = policy.metadata
118
+
119
+ if args.record:
120
+ policy = _policy.PolicyRecorder(policy, "policy_records")
121
+
122
+ hostname = socket.gethostname()
123
+ local_ip = socket.gethostbyname(hostname)
124
+ logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
125
+
126
+ server = websocket_policy_server.WebsocketPolicyServer(
127
+ policy=policy,
128
+ host="0.0.0.0",
129
+ port=args.port,
130
+ metadata=policy_metadata,
131
+ )
132
+ server.serve_forever()
133
+
134
+
135
+ if __name__ == "__main__":
136
+ logging.basicConfig(level=logging.INFO, force=True)
137
+ main(tyro.cli(Args))