File size: 5,846 Bytes
93a69f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# UR5 Example
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
```python
@dataclasses.dataclass(frozen=True)
class UR5Inputs(transforms.DataTransformFn):
model_type: _model.ModelType = _model.ModelType.PI0
def __call__(self, data: dict) -> dict:
# First, concatenate the joints and gripper into the state vector.
state = np.concatenate([data["joints"], data["gripper"]])
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference.
base_image = _parse_image(data["base_rgb"])
wrist_image = _parse_image(data["wrist_rgb"])
# Create inputs dict.
inputs = {
"state": state,
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Since there is no right wrist, replace with zeros
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# Since the "slot" for the right wrist is not used, this mask is set
# to False
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
},
}
if "actions" in data:
inputs["actions"] = data["actions"]
# Pass the prompt (aka language instruction) to the model.
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class UR5Outputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
return {"actions": np.asarray(data["actions"][:, :7])}
```
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
```python
@dataclasses.dataclass(frozen=True)
class LeRobotUR5DataConfig(DataConfigFactory):
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"base_rgb": "image",
"wrist_rgb": "wrist_image",
"joints": "joints",
"gripper": "gripper",
"prompt": "prompt",
}
)
]
)
# These transforms are the ones we wrote earlier.
data_transforms = _transforms.Group(
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
outputs=[UR5Outputs()],
)
# Convert absolute actions to delta actions.
# By convention, we do not convert the gripper action (7th dimension).
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
# Model transforms include things like tokenizing the prompt and action targets
# You do not need to change anything here for your own dataset.
model_transforms = ModelTransformFactory()(model_config)
# We return all data transforms for training and inference. No need to change anything here.
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)
```
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
```python
TrainConfig(
name="pi0_ur5",
model=pi0.Pi0Config(),
data=LeRobotUR5DataConfig(
repo_id="your_username/ur5_dataset",
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
# Reloading normalization stats can help transfer pre-trained models to new environments.
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="ur5e",
),
base_config=DataConfig(
# This flag determines whether we load the prompt (i.e. the task instruction) from the
# ``task`` field in the LeRobot dataset. The recommended setting is True.
prompt_from_task=True,
),
),
# Load the pi0 base model checkpoint.
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
)
```
|