| import dataclasses |
|
|
| import einops |
| import numpy as np |
|
|
| from openpi import transforms |
| from openpi.models import model as _model |
|
|
|
|
| def make_droid_example() -> dict: |
| """Creates a random input example for the Droid policy.""" |
| return { |
| "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
| "observation/joint_position": np.random.rand(7), |
| "observation/gripper_position": np.random.rand(1), |
| "prompt": "do something", |
| } |
|
|
|
|
| def _parse_image(image) -> np.ndarray: |
| image = np.asarray(image) |
| if np.issubdtype(image.dtype, np.floating): |
| image = (255 * image).astype(np.uint8) |
| if image.shape[0] == 3: |
| image = einops.rearrange(image, "c h w -> h w c") |
| return image |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class DroidInputs(transforms.DataTransformFn): |
| |
| model_type: _model.ModelType |
|
|
| def __call__(self, data: dict) -> dict: |
| gripper_pos = np.asarray(data["observation/gripper_position"]) |
| if gripper_pos.ndim == 0: |
| |
| gripper_pos = gripper_pos[np.newaxis] |
| state = np.concatenate([data["observation/joint_position"], gripper_pos]) |
|
|
| |
| |
| base_image = _parse_image(data["observation/exterior_image_1_left"]) |
| wrist_image = _parse_image(data["observation/wrist_image_left"]) |
|
|
| match self.model_type: |
| case _model.ModelType.PI0 | _model.ModelType.PI05: |
| names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") |
| images = (base_image, wrist_image, np.zeros_like(base_image)) |
| image_masks = (np.True_, np.True_, np.False_) |
| case _model.ModelType.PI0_FAST: |
| names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") |
| |
| images = (base_image, np.zeros_like(base_image), wrist_image) |
| image_masks = (np.True_, np.True_, np.True_) |
| case _: |
| raise ValueError(f"Unsupported model type: {self.model_type}") |
|
|
| inputs = { |
| "state": state, |
| "image": dict(zip(names, images, strict=True)), |
| "image_mask": dict(zip(names, image_masks, strict=True)), |
| } |
|
|
| if "actions" in data: |
| inputs["actions"] = np.asarray(data["actions"]) |
|
|
| if "prompt" in data: |
| if isinstance(data["prompt"], bytes): |
| data["prompt"] = data["prompt"].decode("utf-8") |
| inputs["prompt"] = data["prompt"] |
|
|
| return inputs |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class DroidOutputs(transforms.DataTransformFn): |
| def __call__(self, data: dict) -> dict: |
| |
| return {"actions": np.asarray(data["actions"][:, :8])} |
|
|