| | import dataclasses |
| |
|
| | import einops |
| | import numpy as np |
| |
|
| | from openpi import transforms |
| | from openpi.models import model as _model |
| |
|
| |
|
| | def make_droid_example() -> dict: |
| | """Creates a random input example for the Droid policy.""" |
| | return { |
| | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| | "observation/joint_position": np.random.rand(7), |
| | "observation/gripper_position": np.random.rand(1), |
| | "prompt": "do something", |
| | } |
| |
|
| |
|
| | def _parse_image(image) -> np.ndarray: |
| | image = np.asarray(image) |
| | if np.issubdtype(image.dtype, np.floating): |
| | image = (255 * image).astype(np.uint8) |
| | if image.shape[0] == 3: |
| | image = einops.rearrange(image, "c h w -> h w c") |
| | return image |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class DroidInputs(transforms.DataTransformFn): |
| | |
| | action_dim: int |
| |
|
| | |
| | model_type: _model.ModelType = _model.ModelType.PI0 |
| |
|
| | def __call__(self, data: dict) -> dict: |
| | state = np.concatenate([data["observation/joint_position"], data["observation/gripper_position"]]) |
| | state = transforms.pad_to_dim(state, self.action_dim) |
| |
|
| | |
| | |
| | base_image = _parse_image(data["observation/exterior_image_1_left"]) |
| | wrist_image = _parse_image(data["observation/wrist_image_left"]) |
| |
|
| | match self.model_type: |
| | case _model.ModelType.PI0: |
| | names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") |
| | images = (base_image, wrist_image, np.zeros_like(base_image)) |
| | image_masks = (np.True_, np.True_, np.False_) |
| | case _model.ModelType.PI0_FAST: |
| | names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") |
| | |
| | images = (base_image, np.zeros_like(base_image), wrist_image) |
| | image_masks = (np.True_, np.True_, np.True_) |
| | case _: |
| | raise ValueError(f"Unsupported model type: {self.model_type}") |
| |
|
| | inputs = { |
| | "state": state, |
| | "image": dict(zip(names, images, strict=True)), |
| | "image_mask": dict(zip(names, image_masks, strict=True)), |
| | } |
| |
|
| | if "actions" in data: |
| | inputs["actions"] = np.array(data["actions"]) |
| |
|
| | if "prompt" in data: |
| | if isinstance(data["prompt"], bytes): |
| | data["prompt"] = data["prompt"].decode("utf-8") |
| | inputs["prompt"] = data["prompt"] |
| |
|
| | return inputs |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class DroidOutputs(transforms.DataTransformFn): |
| | def __call__(self, data: dict) -> dict: |
| | |
| | return {"actions": np.asarray(data["actions"][:, :8])} |
| |
|