| # UR5 Example | |
| Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets. | |
| First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line. | |
| ```python | |
| @dataclasses.dataclass(frozen=True) | |
| class UR5Inputs(transforms.DataTransformFn): | |
| model_type: _model.ModelType = _model.ModelType.PI0 | |
| def __call__(self, data: dict) -> dict: | |
| # First, concatenate the joints and gripper into the state vector. | |
| state = np.concatenate([data["joints"], data["gripper"]]) | |
| # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically | |
| # stores as float32 (C,H,W), gets skipped for policy inference. | |
| base_image = _parse_image(data["base_rgb"]) | |
| wrist_image = _parse_image(data["wrist_rgb"]) | |
| # Create inputs dict. | |
| inputs = { | |
| "state": state, | |
| "image": { | |
| "base_0_rgb": base_image, | |
| "left_wrist_0_rgb": wrist_image, | |
| # Since there is no right wrist, replace with zeros | |
| "right_wrist_0_rgb": np.zeros_like(base_image), | |
| }, | |
| "image_mask": { | |
| "base_0_rgb": np.True_, | |
| "left_wrist_0_rgb": np.True_, | |
| # Since the "slot" for the right wrist is not used, this mask is set | |
| # to False | |
| "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, | |
| }, | |
| } | |
| if "actions" in data: | |
| inputs["actions"] = data["actions"] | |
| # Pass the prompt (aka language instruction) to the model. | |
| if "prompt" in data: | |
| inputs["prompt"] = data["prompt"] | |
| return inputs | |
| @dataclasses.dataclass(frozen=True) | |
| class UR5Outputs(transforms.DataTransformFn): | |
| def __call__(self, data: dict) -> dict: | |
| # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims | |
| return {"actions": np.asarray(data["actions"][:, :7])} | |
| ``` | |
| Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). | |
| ```python | |
| @dataclasses.dataclass(frozen=True) | |
| class LeRobotUR5DataConfig(DataConfigFactory): | |
| @override | |
| def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: | |
| # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here. | |
| repack_transform = _transforms.Group( | |
| inputs=[ | |
| _transforms.RepackTransform( | |
| { | |
| "base_rgb": "image", | |
| "wrist_rgb": "wrist_image", | |
| "joints": "joints", | |
| "gripper": "gripper", | |
| "prompt": "prompt", | |
| } | |
| ) | |
| ] | |
| ) | |
| # These transforms are the ones we wrote earlier. | |
| data_transforms = _transforms.Group( | |
| inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], | |
| outputs=[UR5Outputs()], | |
| ) | |
| # Convert absolute actions to delta actions. | |
| # By convention, we do not convert the gripper action (7th dimension). | |
| delta_action_mask = _transforms.make_bool_mask(6, -1) | |
| data_transforms = data_transforms.push( | |
| inputs=[_transforms.DeltaActions(delta_action_mask)], | |
| outputs=[_transforms.AbsoluteActions(delta_action_mask)], | |
| ) | |
| # Model transforms include things like tokenizing the prompt and action targets | |
| # You do not need to change anything here for your own dataset. | |
| model_transforms = ModelTransformFactory()(model_config) | |
| # We return all data transforms for training and inference. No need to change anything here. | |
| return dataclasses.replace( | |
| self.create_base_config(assets_dirs), | |
| repack_transforms=repack_transform, | |
| data_transforms=data_transforms, | |
| model_transforms=model_transforms, | |
| ) | |
| ``` | |
| Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning. | |
| ```python | |
| TrainConfig( | |
| name="pi0_ur5", | |
| model=pi0.Pi0Config(), | |
| data=LeRobotUR5DataConfig( | |
| repo_id="your_username/ur5_dataset", | |
| # This config lets us reload the UR5 normalization stats from the base model checkpoint. | |
| # Reloading normalization stats can help transfer pre-trained models to new environments. | |
| # See the [norm_stats.md](../docs/norm_stats.md) file for more details. | |
| assets=AssetsConfig( | |
| assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", | |
| asset_id="ur5e", | |
| ), | |
| base_config=DataConfig( | |
| # This flag determines whether we load the prompt (i.e. the task instruction) from the | |
| # ``task`` field in the LeRobot dataset. The recommended setting is True. | |
| prompt_from_task=True, | |
| ), | |
| ), | |
| # Load the pi0 base model checkpoint. | |
| weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), | |
| num_train_steps=30_000, | |
| ) | |
| ``` | |