Upload config.py with huggingface_hub
Browse files
config.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""See _CONFIGS for the list of available configs."""
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
from collections.abc import Sequence
|
| 5 |
+
import dataclasses
|
| 6 |
+
import difflib
|
| 7 |
+
import logging
|
| 8 |
+
import pathlib
|
| 9 |
+
from typing import Any, Literal, Protocol, TypeAlias
|
| 10 |
+
|
| 11 |
+
import etils.epath as epath
|
| 12 |
+
import flax.nnx as nnx
|
| 13 |
+
from typing_extensions import override
|
| 14 |
+
import tyro
|
| 15 |
+
|
| 16 |
+
import openpi.models.model as _model
|
| 17 |
+
import openpi.models.pi0_config as pi0_config
|
| 18 |
+
import openpi.models.pi0moh_config as pi0gate_config
|
| 19 |
+
import openpi.models.tokenizer as _tokenizer
|
| 20 |
+
import openpi.policies.aloha_policy as aloha_policy
|
| 21 |
+
import openpi.policies.droid_policy as droid_policy
|
| 22 |
+
import openpi.policies.libero_policy as libero_policy
|
| 23 |
+
import openpi.shared.download as _download
|
| 24 |
+
import openpi.shared.normalize as _normalize
|
| 25 |
+
import openpi.training.droid_rlds_dataset as droid_rlds_dataset
|
| 26 |
+
import openpi.training.optimizer as _optimizer
|
| 27 |
+
import openpi.training.weight_loaders as weight_loaders
|
| 28 |
+
import openpi.transforms as _transforms
|
| 29 |
+
|
| 30 |
+
ModelType: TypeAlias = _model.ModelType
|
| 31 |
+
# Work around a tyro issue with using nnx.filterlib.Filter directly.
|
| 32 |
+
Filter: TypeAlias = nnx.filterlib.Filter
|
| 33 |
+
import numpy as np
|
| 34 |
+
from openpi.transforms import DataTransformFn
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclasses.dataclass(frozen=True)
|
| 38 |
+
class AssetsConfig:
|
| 39 |
+
"""Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline.
|
| 40 |
+
|
| 41 |
+
These assets will be replicated inside the checkpoint under the `assets/asset_id` directory.
|
| 42 |
+
|
| 43 |
+
This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other
|
| 44 |
+
centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint
|
| 45 |
+
during fine-tuning, use:
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
AssetsConfig(
|
| 49 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 50 |
+
asset_id="trossen",
|
| 51 |
+
)
|
| 52 |
+
```
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from
|
| 56 |
+
# a different checkpoint (e.g., base model checkpoint) or some other centralized location.
|
| 57 |
+
assets_dir: str | None = None
|
| 58 |
+
|
| 59 |
+
# Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe
|
| 60 |
+
# different robot platforms.
|
| 61 |
+
asset_id: str | None = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclasses.dataclass(frozen=True)
|
| 65 |
+
class DataConfig:
|
| 66 |
+
# LeRobot repo id. If None, fake data will be created.
|
| 67 |
+
repo_id: str | None = None
|
| 68 |
+
# Directory within the assets directory containing the data assets.
|
| 69 |
+
asset_id: str | None = None
|
| 70 |
+
# Contains precomputed normalization stats. If None, normalization will not be performed.
|
| 71 |
+
norm_stats: dict[str, _transforms.NormStats] | None = None
|
| 72 |
+
|
| 73 |
+
# Used to adopt the inputs from a dataset specific format to a common format
|
| 74 |
+
# which is expected by the data transforms.
|
| 75 |
+
repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
|
| 76 |
+
# Data transforms, typically include robot specific transformations. Will be applied
|
| 77 |
+
# before the data is normalized. See `model.Observation` and `model.Actions` to learn about the
|
| 78 |
+
# normalized data.
|
| 79 |
+
data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
|
| 80 |
+
# Model specific transforms. Will be applied after the data is normalized.
|
| 81 |
+
model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
|
| 82 |
+
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
|
| 83 |
+
use_quantile_norm: bool = False
|
| 84 |
+
|
| 85 |
+
# Names of keys that will be used by the data loader to generate the action sequence. The length of the
|
| 86 |
+
# sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your
|
| 87 |
+
# LeRobot dataset is using different keys to represent the action.
|
| 88 |
+
action_sequence_keys: Sequence[str] = ("actions",)
|
| 89 |
+
|
| 90 |
+
# If true, will use the LeRobot dataset task to define the prompt.
|
| 91 |
+
prompt_from_task: bool = False
|
| 92 |
+
|
| 93 |
+
# Only used for RLDS data loader (ie currently only used for DROID).
|
| 94 |
+
rlds_data_dir: str | None = None
|
| 95 |
+
# Action space for DROID dataset.
|
| 96 |
+
action_space: droid_rlds_dataset.DroidActionSpace | None = None
|
| 97 |
+
# Path to the data filter file for DROID dataset
|
| 98 |
+
filter_dict_path: str | None = None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class GroupFactory(Protocol):
|
| 102 |
+
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
| 103 |
+
"""Create a group."""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclasses.dataclass(frozen=True)
|
| 107 |
+
class ModelTransformFactory(GroupFactory):
|
| 108 |
+
"""Creates model transforms for standard pi0 models."""
|
| 109 |
+
|
| 110 |
+
# If provided, will determine the default prompt that be used by the model.
|
| 111 |
+
default_prompt: str | None = None
|
| 112 |
+
|
| 113 |
+
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
| 114 |
+
match model_config.model_type:
|
| 115 |
+
case _model.ModelType.PI0:
|
| 116 |
+
return _transforms.Group(
|
| 117 |
+
inputs=[
|
| 118 |
+
_transforms.InjectDefaultPrompt(self.default_prompt),
|
| 119 |
+
_transforms.ResizeImages(224, 224),
|
| 120 |
+
_transforms.TokenizePrompt(
|
| 121 |
+
_tokenizer.PaligemmaTokenizer(model_config.max_token_len),
|
| 122 |
+
),
|
| 123 |
+
_transforms.PadStatesAndActions(model_config.action_dim),
|
| 124 |
+
],
|
| 125 |
+
)
|
| 126 |
+
case _model.ModelType.PI05:
|
| 127 |
+
assert isinstance(model_config, pi0_config.Pi0Config) or isinstance(model_config, pi0gate_config.Pi0GatedConfig)
|
| 128 |
+
return _transforms.Group(
|
| 129 |
+
inputs=[
|
| 130 |
+
_transforms.InjectDefaultPrompt(self.default_prompt),
|
| 131 |
+
_transforms.ResizeImages(224, 224),
|
| 132 |
+
_transforms.TokenizePrompt(
|
| 133 |
+
_tokenizer.PaligemmaTokenizer(model_config.max_token_len),
|
| 134 |
+
discrete_state_input=model_config.discrete_state_input,
|
| 135 |
+
),
|
| 136 |
+
_transforms.PadStatesAndActions(model_config.action_dim),
|
| 137 |
+
],
|
| 138 |
+
)
|
| 139 |
+
case _model.ModelType.PI0_FAST:
|
| 140 |
+
tokenizer_cls = (
|
| 141 |
+
_tokenizer.FASTTokenizer
|
| 142 |
+
if model_config.fast_model_tokenizer is None
|
| 143 |
+
else model_config.fast_model_tokenizer
|
| 144 |
+
)
|
| 145 |
+
tokenizer_kwargs = (
|
| 146 |
+
{} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs
|
| 147 |
+
)
|
| 148 |
+
return _transforms.Group(
|
| 149 |
+
inputs=[
|
| 150 |
+
_transforms.InjectDefaultPrompt(self.default_prompt),
|
| 151 |
+
_transforms.ResizeImages(224, 224),
|
| 152 |
+
_transforms.TokenizeFASTInputs(
|
| 153 |
+
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
| 154 |
+
),
|
| 155 |
+
],
|
| 156 |
+
outputs=[
|
| 157 |
+
_transforms.ExtractFASTActions(
|
| 158 |
+
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
| 159 |
+
action_horizon=model_config.action_horizon,
|
| 160 |
+
action_dim=model_config.action_dim,
|
| 161 |
+
)
|
| 162 |
+
],
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclasses.dataclass(frozen=True)
|
| 167 |
+
class DataConfigFactory(abc.ABC):
|
| 168 |
+
# The LeRobot repo id.
|
| 169 |
+
repo_id: str = tyro.MISSING
|
| 170 |
+
# Determines how the assets will be loaded.
|
| 171 |
+
assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig)
|
| 172 |
+
# Base config that will be updated by the factory.
|
| 173 |
+
base_config: tyro.conf.Suppress[DataConfig | None] = None
|
| 174 |
+
|
| 175 |
+
@abc.abstractmethod
|
| 176 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 177 |
+
"""Create a data config."""
|
| 178 |
+
|
| 179 |
+
def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 180 |
+
repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None
|
| 181 |
+
asset_id = self.assets.asset_id or repo_id
|
| 182 |
+
return dataclasses.replace(
|
| 183 |
+
self.base_config or DataConfig(),
|
| 184 |
+
repo_id=repo_id,
|
| 185 |
+
asset_id=asset_id,
|
| 186 |
+
norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
|
| 187 |
+
use_quantile_norm=model_config.model_type != ModelType.PI0,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
|
| 191 |
+
if asset_id is None:
|
| 192 |
+
return None
|
| 193 |
+
try:
|
| 194 |
+
data_assets_dir = str(assets_dir / asset_id)
|
| 195 |
+
norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))
|
| 196 |
+
logging.info(f"Loaded norm stats from {data_assets_dir}")
|
| 197 |
+
return norm_stats
|
| 198 |
+
except FileNotFoundError:
|
| 199 |
+
logging.info(f"Norm stats not found in {data_assets_dir}, skipping.")
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@dataclasses.dataclass(frozen=True)
|
| 204 |
+
class FakeDataConfig(DataConfigFactory):
|
| 205 |
+
repo_id: str = "fake"
|
| 206 |
+
|
| 207 |
+
@override
|
| 208 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 209 |
+
return DataConfig(repo_id=self.repo_id)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclasses.dataclass(frozen=True)
|
| 213 |
+
class SimpleDataConfig(DataConfigFactory):
|
| 214 |
+
# Factory for the data transforms.
|
| 215 |
+
data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory)
|
| 216 |
+
# Factory for the model transforms.
|
| 217 |
+
model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory)
|
| 218 |
+
|
| 219 |
+
@override
|
| 220 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 221 |
+
return dataclasses.replace(
|
| 222 |
+
self.create_base_config(assets_dirs, model_config),
|
| 223 |
+
data_transforms=self.data_transforms(model_config),
|
| 224 |
+
model_transforms=self.model_transforms(model_config),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@dataclasses.dataclass(frozen=True)
|
| 229 |
+
class LeRobotAlohaDataConfig(DataConfigFactory):
|
| 230 |
+
# If true, will convert joint dimensions to deltas with respect to the current state before passing to the model.
|
| 231 |
+
# Gripper dimensions will remain in absolute values.
|
| 232 |
+
use_delta_joint_actions: bool = True
|
| 233 |
+
# If provided, will be injected into the input data if the "prompt" key is not present.
|
| 234 |
+
default_prompt: str | None = None
|
| 235 |
+
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
| 236 |
+
# the space used by the pi internal runtime which was used to train the base model. People who
|
| 237 |
+
# use standard Aloha data should set this to true.
|
| 238 |
+
adapt_to_pi: bool = True
|
| 239 |
+
|
| 240 |
+
# Repack transforms.
|
| 241 |
+
repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field(
|
| 242 |
+
default=_transforms.Group(
|
| 243 |
+
inputs=[
|
| 244 |
+
_transforms.RepackTransform(
|
| 245 |
+
{
|
| 246 |
+
"images": {"cam_high": "observation.images.top"},
|
| 247 |
+
"state": "observation.state",
|
| 248 |
+
"actions": "action",
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
]
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
# Action keys that will be used to read the action sequence from the dataset.
|
| 255 |
+
action_sequence_keys: Sequence[str] = ("action",)
|
| 256 |
+
|
| 257 |
+
@override
|
| 258 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 259 |
+
data_transforms = _transforms.Group(
|
| 260 |
+
inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)],
|
| 261 |
+
outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)],
|
| 262 |
+
)
|
| 263 |
+
if self.use_delta_joint_actions:
|
| 264 |
+
delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1)
|
| 265 |
+
data_transforms = data_transforms.push(
|
| 266 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 267 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)
|
| 271 |
+
|
| 272 |
+
return dataclasses.replace(
|
| 273 |
+
self.create_base_config(assets_dirs, model_config),
|
| 274 |
+
repack_transforms=self.repack_transforms,
|
| 275 |
+
data_transforms=data_transforms,
|
| 276 |
+
model_transforms=model_transforms,
|
| 277 |
+
action_sequence_keys=self.action_sequence_keys,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@dataclasses.dataclass(frozen=True)
|
| 282 |
+
class LeRobotLiberoDataConfig(DataConfigFactory):
|
| 283 |
+
"""
|
| 284 |
+
This config is used to configure transforms that are applied at various parts of the data pipeline.
|
| 285 |
+
For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
|
| 286 |
+
comments below.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
extra_delta_transform: bool = False
|
| 290 |
+
|
| 291 |
+
@override
|
| 292 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 293 |
+
# The repack transform is *only* applied to the data coming from the dataset,
|
| 294 |
+
# and *not* during inference. We can use it to make inputs from the dataset look
|
| 295 |
+
# as close as possible to those coming from the inference environment (e.g. match the keys).
|
| 296 |
+
# Below, we match the keys in the dataset (which we defined in the data conversion script) to
|
| 297 |
+
# the keys we use in our inference pipeline (defined in the inference script for libero_scripts).
|
| 298 |
+
# For your own dataset, first figure out what keys your environment passes to the policy server
|
| 299 |
+
# and then modify the mappings below so your dataset's keys get matched to those target keys.
|
| 300 |
+
# The repack transform simply remaps key names here.
|
| 301 |
+
repack_transform = _transforms.Group(
|
| 302 |
+
inputs=[
|
| 303 |
+
_transforms.RepackTransform(
|
| 304 |
+
{
|
| 305 |
+
"observation/image": "image",
|
| 306 |
+
"observation/wrist_image": "wrist_image",
|
| 307 |
+
"observation/state": "state",
|
| 308 |
+
"actions": "actions",
|
| 309 |
+
"prompt": "prompt",
|
| 310 |
+
}
|
| 311 |
+
)
|
| 312 |
+
]
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# The data transforms are applied to the data coming from the dataset *and* during inference.
|
| 316 |
+
# Below, we define the transforms for data going into the model (``inputs``) and the transforms
|
| 317 |
+
# for data coming out of the model (``outputs``) (the latter is only used during inference).
|
| 318 |
+
# We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
|
| 319 |
+
# how to modify the transforms to match your dataset. Once you created your own transforms, you can
|
| 320 |
+
# replace the transforms below with your own.
|
| 321 |
+
data_transforms = _transforms.Group(
|
| 322 |
+
inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)],
|
| 323 |
+
outputs=[libero_policy.LiberoOutputs()],
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# One additional data transform: pi0 models are trained on delta actions (relative to the first
|
| 327 |
+
# state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
|
| 328 |
+
# you can uncomment the following line to convert the actions to delta actions. The only exception
|
| 329 |
+
# is for the gripper actions which are always absolute.
|
| 330 |
+
# In the example below, we would apply the delta conversion to the first 6 actions (joints) and
|
| 331 |
+
# leave the 7th action (gripper) unchanged, i.e. absolute.
|
| 332 |
+
# In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
|
| 333 |
+
# apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
|
| 334 |
+
# transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
|
| 335 |
+
|
| 336 |
+
# LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this
|
| 337 |
+
# extra delta transform.
|
| 338 |
+
if self.extra_delta_transform:
|
| 339 |
+
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
| 340 |
+
data_transforms = data_transforms.push(
|
| 341 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 342 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Model transforms include things like tokenizing the prompt and action targets
|
| 346 |
+
# You do not need to change anything here for your own dataset.
|
| 347 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 348 |
+
|
| 349 |
+
# We return all data transforms for training and inference. No need to change anything here.
|
| 350 |
+
return dataclasses.replace(
|
| 351 |
+
self.create_base_config(assets_dirs, model_config),
|
| 352 |
+
repack_transforms=repack_transform,
|
| 353 |
+
data_transforms=data_transforms,
|
| 354 |
+
model_transforms=model_transforms,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@dataclasses.dataclass(frozen=True)
|
| 359 |
+
class RLDSDroidDataConfig(DataConfigFactory):
|
| 360 |
+
"""
|
| 361 |
+
Config for training on DROID, using RLDS data format (for efficient training on larger datasets).
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
rlds_data_dir: str | None = None
|
| 365 |
+
action_space: droid_rlds_dataset.DroidActionSpace | None = None
|
| 366 |
+
|
| 367 |
+
# Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges
|
| 368 |
+
# to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with
|
| 369 |
+
# f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata.
|
| 370 |
+
# Path to the filter dictionary file.
|
| 371 |
+
filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json"
|
| 372 |
+
|
| 373 |
+
@override
|
| 374 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 375 |
+
repack_transform = _transforms.Group(
|
| 376 |
+
inputs=[
|
| 377 |
+
_transforms.RepackTransform(
|
| 378 |
+
{
|
| 379 |
+
"observation/exterior_image_1_left": "observation/image",
|
| 380 |
+
"observation/wrist_image_left": "observation/wrist_image",
|
| 381 |
+
"observation/joint_position": "observation/joint_position",
|
| 382 |
+
"observation/gripper_position": "observation/gripper_position",
|
| 383 |
+
"actions": "actions",
|
| 384 |
+
"prompt": "prompt",
|
| 385 |
+
}
|
| 386 |
+
)
|
| 387 |
+
]
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
data_transforms = _transforms.Group(
|
| 391 |
+
inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
|
| 392 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION:
|
| 396 |
+
# Data loader returns absolute joint position actions -- convert to delta actions for training.
|
| 397 |
+
delta_action_mask = _transforms.make_bool_mask(7, -1)
|
| 398 |
+
data_transforms = data_transforms.push(
|
| 399 |
+
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
| 400 |
+
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 404 |
+
|
| 405 |
+
assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader."
|
| 406 |
+
|
| 407 |
+
return dataclasses.replace(
|
| 408 |
+
self.create_base_config(assets_dirs, model_config),
|
| 409 |
+
repack_transforms=repack_transform,
|
| 410 |
+
data_transforms=data_transforms,
|
| 411 |
+
model_transforms=model_transforms,
|
| 412 |
+
rlds_data_dir=self.rlds_data_dir,
|
| 413 |
+
action_space=self.action_space,
|
| 414 |
+
filter_dict_path=self.filter_dict_path,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@dataclasses.dataclass(frozen=True)
|
| 419 |
+
class LeRobotDROIDDataConfig(DataConfigFactory):
|
| 420 |
+
"""
|
| 421 |
+
Example data config for custom DROID dataset in LeRobot format.
|
| 422 |
+
To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
@override
|
| 426 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 427 |
+
repack_transform = _transforms.Group(
|
| 428 |
+
inputs=[
|
| 429 |
+
_transforms.RepackTransform(
|
| 430 |
+
{
|
| 431 |
+
"observation/exterior_image_1_left": "exterior_image_1_left",
|
| 432 |
+
# "observation/exterior_image_2_left": "exterior_image_2_left",
|
| 433 |
+
"observation/wrist_image_left": "wrist_image_left",
|
| 434 |
+
"observation/joint_position": "joint_position",
|
| 435 |
+
"observation/gripper_position": "gripper_position",
|
| 436 |
+
"actions": "actions",
|
| 437 |
+
"prompt": "prompt",
|
| 438 |
+
}
|
| 439 |
+
)
|
| 440 |
+
]
|
| 441 |
+
)
|
| 442 |
+
# We assume joint *velocity* actions, so we should *not* apply an additional delta transform.
|
| 443 |
+
data_transforms = _transforms.Group(
|
| 444 |
+
inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
|
| 445 |
+
outputs=[droid_policy.DroidOutputs()],
|
| 446 |
+
)
|
| 447 |
+
model_transforms = ModelTransformFactory()(model_config)
|
| 448 |
+
|
| 449 |
+
return dataclasses.replace(
|
| 450 |
+
self.create_base_config(assets_dirs, model_config),
|
| 451 |
+
repack_transforms=repack_transform,
|
| 452 |
+
data_transforms=data_transforms,
|
| 453 |
+
model_transforms=model_transforms,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
@dataclasses.dataclass(frozen=False)
|
| 458 |
+
class TrainConfig:
|
| 459 |
+
# Name of the config. Must be unique. Will be used to reference this config.
|
| 460 |
+
name: tyro.conf.Suppress[str]
|
| 461 |
+
# Project name.
|
| 462 |
+
project_name: str = "openpi"
|
| 463 |
+
# Experiment name. Will be used to name the metadata and checkpoint directories.
|
| 464 |
+
exp_name: str = tyro.MISSING
|
| 465 |
+
|
| 466 |
+
# Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models
|
| 467 |
+
# -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may
|
| 468 |
+
# define additional attributes.
|
| 469 |
+
model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config)
|
| 470 |
+
|
| 471 |
+
# A weight loader can optionally load (possibly partial) weights from disk after the model is initialized.
|
| 472 |
+
weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader)
|
| 473 |
+
|
| 474 |
+
# Optional path to a PyTorch checkpoint to load weights from.
|
| 475 |
+
pytorch_weight_path: str | None = None
|
| 476 |
+
|
| 477 |
+
# Precision for PyTorch training.
|
| 478 |
+
pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16"
|
| 479 |
+
|
| 480 |
+
lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule)
|
| 481 |
+
optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW)
|
| 482 |
+
ema_decay: float | None = 0.99
|
| 483 |
+
|
| 484 |
+
# Specifies which weights should be frozen.
|
| 485 |
+
freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing)
|
| 486 |
+
|
| 487 |
+
# Determines the data to be trained on.
|
| 488 |
+
data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig)
|
| 489 |
+
|
| 490 |
+
# Base directory for config assets (e.g., norm stats).
|
| 491 |
+
assets_base_dir: str = "./assets"
|
| 492 |
+
# Base directory for checkpoints.
|
| 493 |
+
checkpoint_base_dir: str = "./checkpoints"
|
| 494 |
+
|
| 495 |
+
# Random seed that will be used by random generators during training.
|
| 496 |
+
seed: int = 42
|
| 497 |
+
# Global batch size.
|
| 498 |
+
batch_size: int = 32
|
| 499 |
+
# Number of workers to use for the data loader. Increasing this number will speed up data loading but
|
| 500 |
+
# will increase memory and CPU usage.
|
| 501 |
+
num_workers: int = 16
|
| 502 |
+
# Number of train steps (batches) to run.
|
| 503 |
+
num_train_steps: int = 30_000
|
| 504 |
+
learning_rate: float = 5e-5
|
| 505 |
+
|
| 506 |
+
# How often (in steps) to log training metrics.
|
| 507 |
+
log_interval: int = 100
|
| 508 |
+
# How often (in steps) to save checkpoints.
|
| 509 |
+
save_interval: int = 5000
|
| 510 |
+
# If set, any existing checkpoints matching step % keep_period == 0 will not be deleted.
|
| 511 |
+
keep_period: int | None = 5000
|
| 512 |
+
|
| 513 |
+
# If true, will overwrite the checkpoint directory if it already exists.
|
| 514 |
+
overwrite: bool = True
|
| 515 |
+
# If true, will resume training from the last checkpoint.
|
| 516 |
+
resume: bool = False
|
| 517 |
+
|
| 518 |
+
# If true, will enable wandb logging.
|
| 519 |
+
wandb_enabled: bool = True
|
| 520 |
+
|
| 521 |
+
# Used to pass metadata to the policy server.
|
| 522 |
+
policy_metadata: dict[str, Any] | None = None
|
| 523 |
+
|
| 524 |
+
# If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall
|
| 525 |
+
# device memory will be reduced but training could potentially be slower.
|
| 526 |
+
# eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run
|
| 527 |
+
# data parallel between 2 groups of devices.
|
| 528 |
+
fsdp_devices: int = 1
|
| 529 |
+
|
| 530 |
+
training_mode: str = "warmup" # warmup: train ca&proj; finetune: freeze vlm; full_finetune
|
| 531 |
+
horizons: list[int] = dataclasses.field(default_factory=lambda: [10, 20, 30])
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@property
|
| 535 |
+
def assets_dirs(self) -> pathlib.Path:
|
| 536 |
+
"""Get the assets directory for this config."""
|
| 537 |
+
return (pathlib.Path(self.assets_base_dir) / self.name).resolve()
|
| 538 |
+
|
| 539 |
+
@property
|
| 540 |
+
def checkpoint_dir(self) -> pathlib.Path:
|
| 541 |
+
"""Get the checkpoint directory for this config."""
|
| 542 |
+
if not self.exp_name:
|
| 543 |
+
raise ValueError("--exp_name must be set")
|
| 544 |
+
return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve()
|
| 545 |
+
|
| 546 |
+
@property
|
| 547 |
+
def trainable_filter(self) -> nnx.filterlib.Filter:
|
| 548 |
+
"""Get the filter for the trainable parameters."""
|
| 549 |
+
return nnx.All(nnx.Param, nnx.Not(self.freeze_filter))
|
| 550 |
+
|
| 551 |
+
def __post_init__(self) -> None:
|
| 552 |
+
if self.resume and self.overwrite:
|
| 553 |
+
raise ValueError("Cannot resume and overwrite at the same time.")
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# Use `get_config` if you need to get a config by name in your code.
|
| 557 |
+
_CONFIGS = [
|
| 558 |
+
#
|
| 559 |
+
# Fine-tuning Libero configs.
|
| 560 |
+
#
|
| 561 |
+
TrainConfig(
|
| 562 |
+
# Change the name to reflect your model and dataset.
|
| 563 |
+
name="pi0_libero",
|
| 564 |
+
model=pi0_config.Pi0Config(action_horizon=30),
|
| 565 |
+
data=LeRobotLiberoDataConfig(
|
| 566 |
+
repo_id="/mnt/data/fangyu/dataset/physical-intelligence/libero", # Download from hf physical-intelligence/libero
|
| 567 |
+
base_config=DataConfig(
|
| 568 |
+
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
| 569 |
+
# ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
|
| 570 |
+
# a field called ``prompt`` in the input dict. The recommended setting is True.
|
| 571 |
+
prompt_from_task=True,
|
| 572 |
+
),
|
| 573 |
+
extra_delta_transform=True,
|
| 574 |
+
),
|
| 575 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 576 |
+
warmup_steps=1_000,
|
| 577 |
+
peak_lr=5e-5,
|
| 578 |
+
decay_steps=30_000,
|
| 579 |
+
decay_lr=1e-6,
|
| 580 |
+
),
|
| 581 |
+
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), # New Add
|
| 582 |
+
num_train_steps=30_000,
|
| 583 |
+
pytorch_weight_path="/mnt/data/fangyu/model/Timsty/pi_base_models_torch/pi0_base_torch/model.pt",
|
| 584 |
+
training_mode="finetune",
|
| 585 |
+
save_interval=30_000,
|
| 586 |
+
),
|
| 587 |
+
TrainConfig(
|
| 588 |
+
name="pi05_libero",
|
| 589 |
+
model=pi0_config.Pi0Config(pi05=True, action_horizon=20, discrete_state_input=False),
|
| 590 |
+
data=LeRobotLiberoDataConfig(
|
| 591 |
+
repo_id="/mnt/data/fangyu/dataset/physical-intelligence/libero", # Download from hf physical-intelligence/libero
|
| 592 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 593 |
+
extra_delta_transform=False,
|
| 594 |
+
),
|
| 595 |
+
batch_size=32,
|
| 596 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 597 |
+
warmup_steps=1_000,
|
| 598 |
+
peak_lr=5e-5,
|
| 599 |
+
decay_steps=30_000,
|
| 600 |
+
decay_lr=1e-6,
|
| 601 |
+
),
|
| 602 |
+
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
|
| 603 |
+
ema_decay=0.999,
|
| 604 |
+
# weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 605 |
+
pytorch_weight_path="/mnt/data/fangyu/model/Timsty/pi_base_models_torch/pi05_base_torch/model.pt",
|
| 606 |
+
num_train_steps=30_000,
|
| 607 |
+
save_interval=30000,
|
| 608 |
+
),
|
| 609 |
+
TrainConfig(
|
| 610 |
+
name="pi05_droid_fold_towel",
|
| 611 |
+
model=pi0_config.Pi0Config(
|
| 612 |
+
pi05=True,
|
| 613 |
+
action_dim=32, # pi05 is trained with 32-dim actions
|
| 614 |
+
action_horizon=30,
|
| 615 |
+
),
|
| 616 |
+
data=LeRobotDROIDDataConfig(
|
| 617 |
+
# Replace with your custom DROID LeRobot dataset repo id.
|
| 618 |
+
repo_id="/mnt/data/fangyu/dataset/real_world/fold_towel",
|
| 619 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 620 |
+
assets=AssetsConfig(
|
| 621 |
+
# Important: reuse the original DROID norm stats during fine-tuning!
|
| 622 |
+
assets_dir="/mnt/data/fangyu/model/pi05_droid/assets",
|
| 623 |
+
asset_id="droid",
|
| 624 |
+
),
|
| 625 |
+
),
|
| 626 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 627 |
+
warmup_steps=1_000,
|
| 628 |
+
peak_lr=5e-5,
|
| 629 |
+
decay_steps=10_000,
|
| 630 |
+
decay_lr=1e-6,
|
| 631 |
+
),
|
| 632 |
+
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
|
| 633 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("/mnt/data/fangyu/model/pi05_droid/params"),
|
| 634 |
+
num_train_steps=10000,
|
| 635 |
+
batch_size=32,
|
| 636 |
+
),
|
| 637 |
+
# Pi0.5 Mixture-of-Horizons (JAX `Pi0Gated` in pi0_moh.py): same data / init as pi05_droid_fold_towel,
|
| 638 |
+
# with multi-horizon heads; ema_decay like pi05_libero.
|
| 639 |
+
TrainConfig(
|
| 640 |
+
name="pi05_moh_droid_fold_towel",
|
| 641 |
+
model=pi0gate_config.Pi0GatedConfig(
|
| 642 |
+
pi05=True,
|
| 643 |
+
action_dim=32,
|
| 644 |
+
action_horizon=30,
|
| 645 |
+
horizons=[3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
|
| 646 |
+
),
|
| 647 |
+
data=LeRobotDROIDDataConfig(
|
| 648 |
+
repo_id="/mnt/data/fangyu/dataset/real_world/fold_towel",
|
| 649 |
+
base_config=DataConfig(prompt_from_task=True),
|
| 650 |
+
assets=AssetsConfig(
|
| 651 |
+
assets_dir="/mnt/data/fangyu/model/pi05_droid/assets",
|
| 652 |
+
asset_id="droid",
|
| 653 |
+
),
|
| 654 |
+
),
|
| 655 |
+
lr_schedule=_optimizer.CosineDecaySchedule(
|
| 656 |
+
warmup_steps=1_000,
|
| 657 |
+
peak_lr=5e-5,
|
| 658 |
+
decay_steps=10_000,
|
| 659 |
+
decay_lr=1e-6,
|
| 660 |
+
),
|
| 661 |
+
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
|
| 662 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("/mnt/data/fangyu/model/pi05_droid/params"),
|
| 663 |
+
num_train_steps=10_000,
|
| 664 |
+
batch_size=32,
|
| 665 |
+
ema_decay=0.999,
|
| 666 |
+
save_interval=10_000,
|
| 667 |
+
horizons=[3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
|
| 668 |
+
),
|
| 669 |
+
]
|
| 670 |
+
|
| 671 |
+
if len({config.name for config in _CONFIGS}) != len(_CONFIGS):
|
| 672 |
+
raise ValueError("Config names must be unique.")
|
| 673 |
+
_CONFIGS_DICT = {config.name: config for config in _CONFIGS}
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def cli() -> TrainConfig:
|
| 677 |
+
return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()})
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def get_config(config_name: str) -> TrainConfig:
|
| 681 |
+
"""Get a config by name."""
|
| 682 |
+
if config_name not in _CONFIGS_DICT:
|
| 683 |
+
closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0)
|
| 684 |
+
closest_str = f" Did you mean '{closest[0]}'? " if closest else ""
|
| 685 |
+
raise ValueError(f"Config '{config_name}' not found.{closest_str}")
|
| 686 |
+
|
| 687 |
+
return _CONFIGS_DICT[config_name]
|