Timsty commited on
Commit
631e5cc
·
verified ·
1 Parent(s): c58cdb9

Upload config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. config.py +687 -0
config.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """See _CONFIGS for the list of available configs."""
2
+
3
+ import abc
4
+ from collections.abc import Sequence
5
+ import dataclasses
6
+ import difflib
7
+ import logging
8
+ import pathlib
9
+ from typing import Any, Literal, Protocol, TypeAlias
10
+
11
+ import etils.epath as epath
12
+ import flax.nnx as nnx
13
+ from typing_extensions import override
14
+ import tyro
15
+
16
+ import openpi.models.model as _model
17
+ import openpi.models.pi0_config as pi0_config
18
+ import openpi.models.pi0moh_config as pi0gate_config
19
+ import openpi.models.tokenizer as _tokenizer
20
+ import openpi.policies.aloha_policy as aloha_policy
21
+ import openpi.policies.droid_policy as droid_policy
22
+ import openpi.policies.libero_policy as libero_policy
23
+ import openpi.shared.download as _download
24
+ import openpi.shared.normalize as _normalize
25
+ import openpi.training.droid_rlds_dataset as droid_rlds_dataset
26
+ import openpi.training.optimizer as _optimizer
27
+ import openpi.training.weight_loaders as weight_loaders
28
+ import openpi.transforms as _transforms
29
+
30
+ ModelType: TypeAlias = _model.ModelType
31
+ # Work around a tyro issue with using nnx.filterlib.Filter directly.
32
+ Filter: TypeAlias = nnx.filterlib.Filter
33
+ import numpy as np
34
+ from openpi.transforms import DataTransformFn
35
+
36
+
37
+ @dataclasses.dataclass(frozen=True)
38
+ class AssetsConfig:
39
+ """Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline.
40
+
41
+ These assets will be replicated inside the checkpoint under the `assets/asset_id` directory.
42
+
43
+ This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other
44
+ centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint
45
+ during fine-tuning, use:
46
+
47
+ ```
48
+ AssetsConfig(
49
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
50
+ asset_id="trossen",
51
+ )
52
+ ```
53
+ """
54
+
55
+ # Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from
56
+ # a different checkpoint (e.g., base model checkpoint) or some other centralized location.
57
+ assets_dir: str | None = None
58
+
59
+ # Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe
60
+ # different robot platforms.
61
+ asset_id: str | None = None
62
+
63
+
64
+ @dataclasses.dataclass(frozen=True)
65
+ class DataConfig:
66
+ # LeRobot repo id. If None, fake data will be created.
67
+ repo_id: str | None = None
68
+ # Directory within the assets directory containing the data assets.
69
+ asset_id: str | None = None
70
+ # Contains precomputed normalization stats. If None, normalization will not be performed.
71
+ norm_stats: dict[str, _transforms.NormStats] | None = None
72
+
73
+ # Used to adopt the inputs from a dataset specific format to a common format
74
+ # which is expected by the data transforms.
75
+ repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
76
+ # Data transforms, typically include robot specific transformations. Will be applied
77
+ # before the data is normalized. See `model.Observation` and `model.Actions` to learn about the
78
+ # normalized data.
79
+ data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
80
+ # Model specific transforms. Will be applied after the data is normalized.
81
+ model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)
82
+ # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
83
+ use_quantile_norm: bool = False
84
+
85
+ # Names of keys that will be used by the data loader to generate the action sequence. The length of the
86
+ # sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your
87
+ # LeRobot dataset is using different keys to represent the action.
88
+ action_sequence_keys: Sequence[str] = ("actions",)
89
+
90
+ # If true, will use the LeRobot dataset task to define the prompt.
91
+ prompt_from_task: bool = False
92
+
93
+ # Only used for RLDS data loader (ie currently only used for DROID).
94
+ rlds_data_dir: str | None = None
95
+ # Action space for DROID dataset.
96
+ action_space: droid_rlds_dataset.DroidActionSpace | None = None
97
+ # Path to the data filter file for DROID dataset
98
+ filter_dict_path: str | None = None
99
+
100
+
101
+ class GroupFactory(Protocol):
102
+ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
103
+ """Create a group."""
104
+
105
+
106
+ @dataclasses.dataclass(frozen=True)
107
+ class ModelTransformFactory(GroupFactory):
108
+ """Creates model transforms for standard pi0 models."""
109
+
110
+ # If provided, will determine the default prompt that be used by the model.
111
+ default_prompt: str | None = None
112
+
113
+ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
114
+ match model_config.model_type:
115
+ case _model.ModelType.PI0:
116
+ return _transforms.Group(
117
+ inputs=[
118
+ _transforms.InjectDefaultPrompt(self.default_prompt),
119
+ _transforms.ResizeImages(224, 224),
120
+ _transforms.TokenizePrompt(
121
+ _tokenizer.PaligemmaTokenizer(model_config.max_token_len),
122
+ ),
123
+ _transforms.PadStatesAndActions(model_config.action_dim),
124
+ ],
125
+ )
126
+ case _model.ModelType.PI05:
127
+ assert isinstance(model_config, pi0_config.Pi0Config) or isinstance(model_config, pi0gate_config.Pi0GatedConfig)
128
+ return _transforms.Group(
129
+ inputs=[
130
+ _transforms.InjectDefaultPrompt(self.default_prompt),
131
+ _transforms.ResizeImages(224, 224),
132
+ _transforms.TokenizePrompt(
133
+ _tokenizer.PaligemmaTokenizer(model_config.max_token_len),
134
+ discrete_state_input=model_config.discrete_state_input,
135
+ ),
136
+ _transforms.PadStatesAndActions(model_config.action_dim),
137
+ ],
138
+ )
139
+ case _model.ModelType.PI0_FAST:
140
+ tokenizer_cls = (
141
+ _tokenizer.FASTTokenizer
142
+ if model_config.fast_model_tokenizer is None
143
+ else model_config.fast_model_tokenizer
144
+ )
145
+ tokenizer_kwargs = (
146
+ {} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs
147
+ )
148
+ return _transforms.Group(
149
+ inputs=[
150
+ _transforms.InjectDefaultPrompt(self.default_prompt),
151
+ _transforms.ResizeImages(224, 224),
152
+ _transforms.TokenizeFASTInputs(
153
+ tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
154
+ ),
155
+ ],
156
+ outputs=[
157
+ _transforms.ExtractFASTActions(
158
+ tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
159
+ action_horizon=model_config.action_horizon,
160
+ action_dim=model_config.action_dim,
161
+ )
162
+ ],
163
+ )
164
+
165
+
166
+ @dataclasses.dataclass(frozen=True)
167
+ class DataConfigFactory(abc.ABC):
168
+ # The LeRobot repo id.
169
+ repo_id: str = tyro.MISSING
170
+ # Determines how the assets will be loaded.
171
+ assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig)
172
+ # Base config that will be updated by the factory.
173
+ base_config: tyro.conf.Suppress[DataConfig | None] = None
174
+
175
+ @abc.abstractmethod
176
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
177
+ """Create a data config."""
178
+
179
+ def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
180
+ repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None
181
+ asset_id = self.assets.asset_id or repo_id
182
+ return dataclasses.replace(
183
+ self.base_config or DataConfig(),
184
+ repo_id=repo_id,
185
+ asset_id=asset_id,
186
+ norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
187
+ use_quantile_norm=model_config.model_type != ModelType.PI0,
188
+ )
189
+
190
+ def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
191
+ if asset_id is None:
192
+ return None
193
+ try:
194
+ data_assets_dir = str(assets_dir / asset_id)
195
+ norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))
196
+ logging.info(f"Loaded norm stats from {data_assets_dir}")
197
+ return norm_stats
198
+ except FileNotFoundError:
199
+ logging.info(f"Norm stats not found in {data_assets_dir}, skipping.")
200
+ return None
201
+
202
+
203
+ @dataclasses.dataclass(frozen=True)
204
+ class FakeDataConfig(DataConfigFactory):
205
+ repo_id: str = "fake"
206
+
207
+ @override
208
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
209
+ return DataConfig(repo_id=self.repo_id)
210
+
211
+
212
+ @dataclasses.dataclass(frozen=True)
213
+ class SimpleDataConfig(DataConfigFactory):
214
+ # Factory for the data transforms.
215
+ data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory)
216
+ # Factory for the model transforms.
217
+ model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory)
218
+
219
+ @override
220
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
221
+ return dataclasses.replace(
222
+ self.create_base_config(assets_dirs, model_config),
223
+ data_transforms=self.data_transforms(model_config),
224
+ model_transforms=self.model_transforms(model_config),
225
+ )
226
+
227
+
228
+ @dataclasses.dataclass(frozen=True)
229
+ class LeRobotAlohaDataConfig(DataConfigFactory):
230
+ # If true, will convert joint dimensions to deltas with respect to the current state before passing to the model.
231
+ # Gripper dimensions will remain in absolute values.
232
+ use_delta_joint_actions: bool = True
233
+ # If provided, will be injected into the input data if the "prompt" key is not present.
234
+ default_prompt: str | None = None
235
+ # If true, this will convert the joint and gripper values from the standard Aloha space to
236
+ # the space used by the pi internal runtime which was used to train the base model. People who
237
+ # use standard Aloha data should set this to true.
238
+ adapt_to_pi: bool = True
239
+
240
+ # Repack transforms.
241
+ repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field(
242
+ default=_transforms.Group(
243
+ inputs=[
244
+ _transforms.RepackTransform(
245
+ {
246
+ "images": {"cam_high": "observation.images.top"},
247
+ "state": "observation.state",
248
+ "actions": "action",
249
+ }
250
+ )
251
+ ]
252
+ )
253
+ )
254
+ # Action keys that will be used to read the action sequence from the dataset.
255
+ action_sequence_keys: Sequence[str] = ("action",)
256
+
257
+ @override
258
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
259
+ data_transforms = _transforms.Group(
260
+ inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)],
261
+ outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)],
262
+ )
263
+ if self.use_delta_joint_actions:
264
+ delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1)
265
+ data_transforms = data_transforms.push(
266
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
267
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
268
+ )
269
+
270
+ model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)
271
+
272
+ return dataclasses.replace(
273
+ self.create_base_config(assets_dirs, model_config),
274
+ repack_transforms=self.repack_transforms,
275
+ data_transforms=data_transforms,
276
+ model_transforms=model_transforms,
277
+ action_sequence_keys=self.action_sequence_keys,
278
+ )
279
+
280
+
281
+ @dataclasses.dataclass(frozen=True)
282
+ class LeRobotLiberoDataConfig(DataConfigFactory):
283
+ """
284
+ This config is used to configure transforms that are applied at various parts of the data pipeline.
285
+ For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
286
+ comments below.
287
+ """
288
+
289
+ extra_delta_transform: bool = False
290
+
291
+ @override
292
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
293
+ # The repack transform is *only* applied to the data coming from the dataset,
294
+ # and *not* during inference. We can use it to make inputs from the dataset look
295
+ # as close as possible to those coming from the inference environment (e.g. match the keys).
296
+ # Below, we match the keys in the dataset (which we defined in the data conversion script) to
297
+ # the keys we use in our inference pipeline (defined in the inference script for libero_scripts).
298
+ # For your own dataset, first figure out what keys your environment passes to the policy server
299
+ # and then modify the mappings below so your dataset's keys get matched to those target keys.
300
+ # The repack transform simply remaps key names here.
301
+ repack_transform = _transforms.Group(
302
+ inputs=[
303
+ _transforms.RepackTransform(
304
+ {
305
+ "observation/image": "image",
306
+ "observation/wrist_image": "wrist_image",
307
+ "observation/state": "state",
308
+ "actions": "actions",
309
+ "prompt": "prompt",
310
+ }
311
+ )
312
+ ]
313
+ )
314
+
315
+ # The data transforms are applied to the data coming from the dataset *and* during inference.
316
+ # Below, we define the transforms for data going into the model (``inputs``) and the transforms
317
+ # for data coming out of the model (``outputs``) (the latter is only used during inference).
318
+ # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
319
+ # how to modify the transforms to match your dataset. Once you created your own transforms, you can
320
+ # replace the transforms below with your own.
321
+ data_transforms = _transforms.Group(
322
+ inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)],
323
+ outputs=[libero_policy.LiberoOutputs()],
324
+ )
325
+
326
+ # One additional data transform: pi0 models are trained on delta actions (relative to the first
327
+ # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
328
+ # you can uncomment the following line to convert the actions to delta actions. The only exception
329
+ # is for the gripper actions which are always absolute.
330
+ # In the example below, we would apply the delta conversion to the first 6 actions (joints) and
331
+ # leave the 7th action (gripper) unchanged, i.e. absolute.
332
+ # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
333
+ # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
334
+ # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
335
+
336
+ # LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this
337
+ # extra delta transform.
338
+ if self.extra_delta_transform:
339
+ delta_action_mask = _transforms.make_bool_mask(6, -1)
340
+ data_transforms = data_transforms.push(
341
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
342
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
343
+ )
344
+
345
+ # Model transforms include things like tokenizing the prompt and action targets
346
+ # You do not need to change anything here for your own dataset.
347
+ model_transforms = ModelTransformFactory()(model_config)
348
+
349
+ # We return all data transforms for training and inference. No need to change anything here.
350
+ return dataclasses.replace(
351
+ self.create_base_config(assets_dirs, model_config),
352
+ repack_transforms=repack_transform,
353
+ data_transforms=data_transforms,
354
+ model_transforms=model_transforms,
355
+ )
356
+
357
+
358
+ @dataclasses.dataclass(frozen=True)
359
+ class RLDSDroidDataConfig(DataConfigFactory):
360
+ """
361
+ Config for training on DROID, using RLDS data format (for efficient training on larger datasets).
362
+ """
363
+
364
+ rlds_data_dir: str | None = None
365
+ action_space: droid_rlds_dataset.DroidActionSpace | None = None
366
+
367
+ # Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges
368
+ # to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with
369
+ # f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata.
370
+ # Path to the filter dictionary file.
371
+ filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json"
372
+
373
+ @override
374
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
375
+ repack_transform = _transforms.Group(
376
+ inputs=[
377
+ _transforms.RepackTransform(
378
+ {
379
+ "observation/exterior_image_1_left": "observation/image",
380
+ "observation/wrist_image_left": "observation/wrist_image",
381
+ "observation/joint_position": "observation/joint_position",
382
+ "observation/gripper_position": "observation/gripper_position",
383
+ "actions": "actions",
384
+ "prompt": "prompt",
385
+ }
386
+ )
387
+ ]
388
+ )
389
+
390
+ data_transforms = _transforms.Group(
391
+ inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
392
+ outputs=[droid_policy.DroidOutputs()],
393
+ )
394
+
395
+ if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION:
396
+ # Data loader returns absolute joint position actions -- convert to delta actions for training.
397
+ delta_action_mask = _transforms.make_bool_mask(7, -1)
398
+ data_transforms = data_transforms.push(
399
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
400
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
401
+ )
402
+
403
+ model_transforms = ModelTransformFactory()(model_config)
404
+
405
+ assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader."
406
+
407
+ return dataclasses.replace(
408
+ self.create_base_config(assets_dirs, model_config),
409
+ repack_transforms=repack_transform,
410
+ data_transforms=data_transforms,
411
+ model_transforms=model_transforms,
412
+ rlds_data_dir=self.rlds_data_dir,
413
+ action_space=self.action_space,
414
+ filter_dict_path=self.filter_dict_path,
415
+ )
416
+
417
+
418
+ @dataclasses.dataclass(frozen=True)
419
+ class LeRobotDROIDDataConfig(DataConfigFactory):
420
+ """
421
+ Example data config for custom DROID dataset in LeRobot format.
422
+ To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py
423
+ """
424
+
425
+ @override
426
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
427
+ repack_transform = _transforms.Group(
428
+ inputs=[
429
+ _transforms.RepackTransform(
430
+ {
431
+ "observation/exterior_image_1_left": "exterior_image_1_left",
432
+ # "observation/exterior_image_2_left": "exterior_image_2_left",
433
+ "observation/wrist_image_left": "wrist_image_left",
434
+ "observation/joint_position": "joint_position",
435
+ "observation/gripper_position": "gripper_position",
436
+ "actions": "actions",
437
+ "prompt": "prompt",
438
+ }
439
+ )
440
+ ]
441
+ )
442
+ # We assume joint *velocity* actions, so we should *not* apply an additional delta transform.
443
+ data_transforms = _transforms.Group(
444
+ inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
445
+ outputs=[droid_policy.DroidOutputs()],
446
+ )
447
+ model_transforms = ModelTransformFactory()(model_config)
448
+
449
+ return dataclasses.replace(
450
+ self.create_base_config(assets_dirs, model_config),
451
+ repack_transforms=repack_transform,
452
+ data_transforms=data_transforms,
453
+ model_transforms=model_transforms,
454
+ )
455
+
456
+
457
+ @dataclasses.dataclass(frozen=False)
458
+ class TrainConfig:
459
+ # Name of the config. Must be unique. Will be used to reference this config.
460
+ name: tyro.conf.Suppress[str]
461
+ # Project name.
462
+ project_name: str = "openpi"
463
+ # Experiment name. Will be used to name the metadata and checkpoint directories.
464
+ exp_name: str = tyro.MISSING
465
+
466
+ # Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models
467
+ # -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may
468
+ # define additional attributes.
469
+ model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config)
470
+
471
+ # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized.
472
+ weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader)
473
+
474
+ # Optional path to a PyTorch checkpoint to load weights from.
475
+ pytorch_weight_path: str | None = None
476
+
477
+ # Precision for PyTorch training.
478
+ pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16"
479
+
480
+ lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule)
481
+ optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW)
482
+ ema_decay: float | None = 0.99
483
+
484
+ # Specifies which weights should be frozen.
485
+ freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing)
486
+
487
+ # Determines the data to be trained on.
488
+ data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig)
489
+
490
+ # Base directory for config assets (e.g., norm stats).
491
+ assets_base_dir: str = "./assets"
492
+ # Base directory for checkpoints.
493
+ checkpoint_base_dir: str = "./checkpoints"
494
+
495
+ # Random seed that will be used by random generators during training.
496
+ seed: int = 42
497
+ # Global batch size.
498
+ batch_size: int = 32
499
+ # Number of workers to use for the data loader. Increasing this number will speed up data loading but
500
+ # will increase memory and CPU usage.
501
+ num_workers: int = 16
502
+ # Number of train steps (batches) to run.
503
+ num_train_steps: int = 30_000
504
+ learning_rate: float = 5e-5
505
+
506
+ # How often (in steps) to log training metrics.
507
+ log_interval: int = 100
508
+ # How often (in steps) to save checkpoints.
509
+ save_interval: int = 5000
510
+ # If set, any existing checkpoints matching step % keep_period == 0 will not be deleted.
511
+ keep_period: int | None = 5000
512
+
513
+ # If true, will overwrite the checkpoint directory if it already exists.
514
+ overwrite: bool = True
515
+ # If true, will resume training from the last checkpoint.
516
+ resume: bool = False
517
+
518
+ # If true, will enable wandb logging.
519
+ wandb_enabled: bool = True
520
+
521
+ # Used to pass metadata to the policy server.
522
+ policy_metadata: dict[str, Any] | None = None
523
+
524
+ # If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall
525
+ # device memory will be reduced but training could potentially be slower.
526
+ # eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run
527
+ # data parallel between 2 groups of devices.
528
+ fsdp_devices: int = 1
529
+
530
+ training_mode: str = "warmup" # warmup: train ca&proj; finetune: freeze vlm; full_finetune
531
+ horizons: list[int] = dataclasses.field(default_factory=lambda: [10, 20, 30])
532
+
533
+
534
+ @property
535
+ def assets_dirs(self) -> pathlib.Path:
536
+ """Get the assets directory for this config."""
537
+ return (pathlib.Path(self.assets_base_dir) / self.name).resolve()
538
+
539
+ @property
540
+ def checkpoint_dir(self) -> pathlib.Path:
541
+ """Get the checkpoint directory for this config."""
542
+ if not self.exp_name:
543
+ raise ValueError("--exp_name must be set")
544
+ return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve()
545
+
546
+ @property
547
+ def trainable_filter(self) -> nnx.filterlib.Filter:
548
+ """Get the filter for the trainable parameters."""
549
+ return nnx.All(nnx.Param, nnx.Not(self.freeze_filter))
550
+
551
+ def __post_init__(self) -> None:
552
+ if self.resume and self.overwrite:
553
+ raise ValueError("Cannot resume and overwrite at the same time.")
554
+
555
+
556
+ # Use `get_config` if you need to get a config by name in your code.
557
+ _CONFIGS = [
558
+ #
559
+ # Fine-tuning Libero configs.
560
+ #
561
+ TrainConfig(
562
+ # Change the name to reflect your model and dataset.
563
+ name="pi0_libero",
564
+ model=pi0_config.Pi0Config(action_horizon=30),
565
+ data=LeRobotLiberoDataConfig(
566
+ repo_id="/mnt/data/fangyu/dataset/physical-intelligence/libero", # Download from hf physical-intelligence/libero
567
+ base_config=DataConfig(
568
+ # This flag determines whether we load the prompt (i.e. the task instruction) from the
569
+ # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
570
+ # a field called ``prompt`` in the input dict. The recommended setting is True.
571
+ prompt_from_task=True,
572
+ ),
573
+ extra_delta_transform=True,
574
+ ),
575
+ lr_schedule=_optimizer.CosineDecaySchedule(
576
+ warmup_steps=1_000,
577
+ peak_lr=5e-5,
578
+ decay_steps=30_000,
579
+ decay_lr=1e-6,
580
+ ),
581
+ optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), # New Add
582
+ num_train_steps=30_000,
583
+ pytorch_weight_path="/mnt/data/fangyu/model/Timsty/pi_base_models_torch/pi0_base_torch/model.pt",
584
+ training_mode="finetune",
585
+ save_interval=30_000,
586
+ ),
587
+ TrainConfig(
588
+ name="pi05_libero",
589
+ model=pi0_config.Pi0Config(pi05=True, action_horizon=20, discrete_state_input=False),
590
+ data=LeRobotLiberoDataConfig(
591
+ repo_id="/mnt/data/fangyu/dataset/physical-intelligence/libero", # Download from hf physical-intelligence/libero
592
+ base_config=DataConfig(prompt_from_task=True),
593
+ extra_delta_transform=False,
594
+ ),
595
+ batch_size=32,
596
+ lr_schedule=_optimizer.CosineDecaySchedule(
597
+ warmup_steps=1_000,
598
+ peak_lr=5e-5,
599
+ decay_steps=30_000,
600
+ decay_lr=1e-6,
601
+ ),
602
+ optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
603
+ ema_decay=0.999,
604
+ # weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
605
+ pytorch_weight_path="/mnt/data/fangyu/model/Timsty/pi_base_models_torch/pi05_base_torch/model.pt",
606
+ num_train_steps=30_000,
607
+ save_interval=30000,
608
+ ),
609
+ TrainConfig(
610
+ name="pi05_droid_fold_towel",
611
+ model=pi0_config.Pi0Config(
612
+ pi05=True,
613
+ action_dim=32, # pi05 is trained with 32-dim actions
614
+ action_horizon=30,
615
+ ),
616
+ data=LeRobotDROIDDataConfig(
617
+ # Replace with your custom DROID LeRobot dataset repo id.
618
+ repo_id="/mnt/data/fangyu/dataset/real_world/fold_towel",
619
+ base_config=DataConfig(prompt_from_task=True),
620
+ assets=AssetsConfig(
621
+ # Important: reuse the original DROID norm stats during fine-tuning!
622
+ assets_dir="/mnt/data/fangyu/model/pi05_droid/assets",
623
+ asset_id="droid",
624
+ ),
625
+ ),
626
+ lr_schedule=_optimizer.CosineDecaySchedule(
627
+ warmup_steps=1_000,
628
+ peak_lr=5e-5,
629
+ decay_steps=10_000,
630
+ decay_lr=1e-6,
631
+ ),
632
+ optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
633
+ weight_loader=weight_loaders.CheckpointWeightLoader("/mnt/data/fangyu/model/pi05_droid/params"),
634
+ num_train_steps=10000,
635
+ batch_size=32,
636
+ ),
637
+ # Pi0.5 Mixture-of-Horizons (JAX `Pi0Gated` in pi0_moh.py): same data / init as pi05_droid_fold_towel,
638
+ # with multi-horizon heads; ema_decay like pi05_libero.
639
+ TrainConfig(
640
+ name="pi05_moh_droid_fold_towel",
641
+ model=pi0gate_config.Pi0GatedConfig(
642
+ pi05=True,
643
+ action_dim=32,
644
+ action_horizon=30,
645
+ horizons=[3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
646
+ ),
647
+ data=LeRobotDROIDDataConfig(
648
+ repo_id="/mnt/data/fangyu/dataset/real_world/fold_towel",
649
+ base_config=DataConfig(prompt_from_task=True),
650
+ assets=AssetsConfig(
651
+ assets_dir="/mnt/data/fangyu/model/pi05_droid/assets",
652
+ asset_id="droid",
653
+ ),
654
+ ),
655
+ lr_schedule=_optimizer.CosineDecaySchedule(
656
+ warmup_steps=1_000,
657
+ peak_lr=5e-5,
658
+ decay_steps=10_000,
659
+ decay_lr=1e-6,
660
+ ),
661
+ optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
662
+ weight_loader=weight_loaders.CheckpointWeightLoader("/mnt/data/fangyu/model/pi05_droid/params"),
663
+ num_train_steps=10_000,
664
+ batch_size=32,
665
+ ema_decay=0.999,
666
+ save_interval=10_000,
667
+ horizons=[3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
668
+ ),
669
+ ]
670
+
671
+ if len({config.name for config in _CONFIGS}) != len(_CONFIGS):
672
+ raise ValueError("Config names must be unique.")
673
+ _CONFIGS_DICT = {config.name: config for config in _CONFIGS}
674
+
675
+
676
+ def cli() -> TrainConfig:
677
+ return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()})
678
+
679
+
680
+ def get_config(config_name: str) -> TrainConfig:
681
+ """Get a config by name."""
682
+ if config_name not in _CONFIGS_DICT:
683
+ closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0)
684
+ closest_str = f" Did you mean '{closest[0]}'? " if closest else ""
685
+ raise ValueError(f"Config '{config_name}' not found.{closest_str}")
686
+
687
+ return _CONFIGS_DICT[config_name]