op-test / examples /ur5 /README.md
s3y's picture
Upload folder using huggingface_hub
40571aa verified
# UR5 Example
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
```python
@dataclasses.dataclass(frozen=True)
class UR5Inputs(transforms.DataTransformFn):
model_type: _model.ModelType = _model.ModelType.PI0
def __call__(self, data: dict) -> dict:
# First, concatenate the joints and gripper into the state vector.
state = np.concatenate([data["joints"], data["gripper"]])
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference.
base_image = _parse_image(data["base_rgb"])
wrist_image = _parse_image(data["wrist_rgb"])
# Create inputs dict.
inputs = {
"state": state,
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Since there is no right wrist, replace with zeros
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# Since the "slot" for the right wrist is not used, this mask is set
# to False
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
},
}
if "actions" in data:
inputs["actions"] = data["actions"]
# Pass the prompt (aka language instruction) to the model.
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class UR5Outputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
return {"actions": np.asarray(data["actions"][:, :7])}
```
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
```python
@dataclasses.dataclass(frozen=True)
class LeRobotUR5DataConfig(DataConfigFactory):
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"base_rgb": "image",
"wrist_rgb": "wrist_image",
"joints": "joints",
"gripper": "gripper",
"prompt": "prompt",
}
)
]
)
# These transforms are the ones we wrote earlier.
data_transforms = _transforms.Group(
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
outputs=[UR5Outputs()],
)
# Convert absolute actions to delta actions.
# By convention, we do not convert the gripper action (7th dimension).
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
# Model transforms include things like tokenizing the prompt and action targets
# You do not need to change anything here for your own dataset.
model_transforms = ModelTransformFactory()(model_config)
# We return all data transforms for training and inference. No need to change anything here.
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)
```
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
```python
TrainConfig(
name="pi0_ur5",
model=pi0.Pi0Config(),
data=LeRobotUR5DataConfig(
repo_id="your_username/ur5_dataset",
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
# Reloading normalization stats can help transfer pre-trained models to new environments.
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="ur5e",
),
base_config=DataConfig(
# This flag determines whether we load the prompt (i.e. the task instruction) from the
# ``task`` field in the LeRobot dataset. The recommended setting is True.
prompt_from_task=True,
),
),
# Load the pi0 base model checkpoint.
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
)
```