| | import dataclasses |
| | import enum |
| | import logging |
| | import pathlib |
| | import time |
| |
|
| | import numpy as np |
| | from openpi_client import websocket_client_policy as _websocket_client_policy |
| | import polars as pl |
| | import rich |
| | import tqdm |
| | import tyro |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EnvMode(enum.Enum): |
| | """Supported environments.""" |
| |
|
| | ALOHA = "aloha" |
| | ALOHA_SIM = "aloha_sim" |
| | DROID = "droid" |
| | LIBERO = "libero" |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Args: |
| | """Command line arguments.""" |
| |
|
| | |
| | host: str = "0.0.0.0" |
| | |
| | port: int | None = 8000 |
| | |
| | api_key: str | None = None |
| | |
| | num_steps: int = 20 |
| | |
| | timing_file: pathlib.Path | None = None |
| | |
| | env: EnvMode = EnvMode.ALOHA_SIM |
| |
|
| |
|
| | class TimingRecorder: |
| | """Records timing measurements for different keys.""" |
| |
|
| | def __init__(self) -> None: |
| | self._timings: dict[str, list[float]] = {} |
| |
|
| | def record(self, key: str, time_ms: float) -> None: |
| | """Record a timing measurement for the given key.""" |
| | if key not in self._timings: |
| | self._timings[key] = [] |
| | self._timings[key].append(time_ms) |
| |
|
| | def get_stats(self, key: str) -> dict[str, float]: |
| | """Get statistics for the given key.""" |
| | times = self._timings[key] |
| | return { |
| | "mean": float(np.mean(times)), |
| | "std": float(np.std(times)), |
| | "p25": float(np.quantile(times, 0.25)), |
| | "p50": float(np.quantile(times, 0.50)), |
| | "p75": float(np.quantile(times, 0.75)), |
| | "p90": float(np.quantile(times, 0.90)), |
| | "p95": float(np.quantile(times, 0.95)), |
| | "p99": float(np.quantile(times, 0.99)), |
| | } |
| |
|
| | def print_all_stats(self) -> None: |
| | """Print statistics for all keys in a concise format.""" |
| |
|
| | table = rich.table.Table( |
| | title="[bold blue]Timing Statistics[/bold blue]", |
| | show_header=True, |
| | header_style="bold white", |
| | border_style="blue", |
| | title_justify="center", |
| | ) |
| |
|
| | |
| | table.add_column("Metric", style="cyan", justify="left", no_wrap=True) |
| |
|
| | |
| | stat_columns = [ |
| | ("Mean", "yellow", "mean"), |
| | ("Std", "yellow", "std"), |
| | ("P25", "magenta", "p25"), |
| | ("P50", "magenta", "p50"), |
| | ("P75", "magenta", "p75"), |
| | ("P90", "magenta", "p90"), |
| | ("P95", "magenta", "p95"), |
| | ("P99", "magenta", "p99"), |
| | ] |
| |
|
| | for name, style, _ in stat_columns: |
| | table.add_column(name, justify="right", style=style, no_wrap=True) |
| |
|
| | |
| | for key in sorted(self._timings.keys()): |
| | stats = self.get_stats(key) |
| | values = [f"{stats[key]:.1f}" for _, _, key in stat_columns] |
| | table.add_row(key, *values) |
| |
|
| | |
| | console = rich.console.Console(width=None, highlight=True) |
| | console.print(table) |
| |
|
| | def write_parquet(self, path: pathlib.Path) -> None: |
| | """Save the timings to a parquet file.""" |
| | logger.info(f"Writing timings to {path}") |
| | frame = pl.DataFrame(self._timings) |
| | path.parent.mkdir(parents=True, exist_ok=True) |
| | frame.write_parquet(path) |
| |
|
| |
|
| | def main(args: Args) -> None: |
| | obs_fn = { |
| | EnvMode.ALOHA: _random_observation_aloha, |
| | EnvMode.ALOHA_SIM: _random_observation_aloha, |
| | EnvMode.DROID: _random_observation_droid, |
| | EnvMode.LIBERO: _random_observation_libero, |
| | }[args.env] |
| |
|
| | policy = _websocket_client_policy.WebsocketClientPolicy( |
| | host=args.host, |
| | port=args.port, |
| | api_key=args.api_key, |
| | ) |
| | logger.info(f"Server metadata: {policy.get_server_metadata()}") |
| |
|
| | |
| | for _ in range(2): |
| | policy.infer(obs_fn()) |
| |
|
| | timing_recorder = TimingRecorder() |
| |
|
| | for _ in tqdm.trange(args.num_steps, desc="Running policy"): |
| | inference_start = time.time() |
| | action = policy.infer(obs_fn()) |
| | timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start)) |
| | for key, value in action.get("server_timing", {}).items(): |
| | timing_recorder.record(f"server_{key}", value) |
| | for key, value in action.get("policy_timing", {}).items(): |
| | timing_recorder.record(f"policy_{key}", value) |
| |
|
| | timing_recorder.print_all_stats() |
| |
|
| | if args.timing_file is not None: |
| | timing_recorder.write_parquet(args.timing_file) |
| |
|
| |
|
| | def _random_observation_aloha() -> dict: |
| | return { |
| | "state": np.ones((14,)), |
| | "images": { |
| | "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
| | "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
| | "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
| | "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
| | }, |
| | "prompt": "do something", |
| | } |
| |
|
| |
|
| | def _random_observation_droid() -> dict: |
| | return { |
| | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| | "observation/joint_position": np.random.rand(7), |
| | "observation/gripper_position": np.random.rand(1), |
| | "prompt": "do something", |
| | } |
| |
|
| |
|
| | def _random_observation_libero() -> dict: |
| | return { |
| | "observation/state": np.random.rand(8), |
| | "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| | "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| | "prompt": "do something", |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | logging.basicConfig(level=logging.INFO) |
| | main(tyro.cli(Args)) |
| |
|