# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from starVLA.dataloader.gr00t_lerobot.datasets import ModalityConfig from starVLA.dataloader.gr00t_lerobot.transform.base import ComposedModalityTransform, ModalityTransform from starVLA.dataloader.gr00t_lerobot.transform.state_action import ( StateActionSinCosTransform, StateActionToTensor, StateActionTransform, ) class BaseDataConfig(ABC): @abstractmethod def modality_config(self) -> dict[str, ModalityConfig]: pass @abstractmethod def transform(self) -> ModalityTransform: pass ########################################################################################### class Libero4in1DataConfig: video_keys = [ "video.primary_image", "video.wrist_image", ] state_keys = [ "state.x", "state.y", "state.z", "state.roll", "state.pitch", "state.yaw", "state.pad", "state.gripper", ] action_keys = [ "action.x", "action.y", "action.z", "action.roll", "action.pitch", "action.yaw", "action.gripper", ] language_keys = ["annotation.human.action.task_description"] observation_indices = [0] action_indices = list(range(16)) def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): self.chunk_size = chunk_size self.action_indices = list(range(chunk_size)) self.state_use_action_chunk = state_use_action_chunk self.num_history_steps = int(num_history_steps or 0) self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] def modality_config(self): video_modality = ModalityConfig( delta_indices=self.video_observation_indices, modality_keys=self.video_keys, ) state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices state_modality = ModalityConfig( delta_indices=state_delta, modality_keys=self.state_keys, ) action_modality = ModalityConfig( delta_indices=self.action_indices, modality_keys=self.action_keys, ) language_modality = ModalityConfig( delta_indices=self.observation_indices, modality_keys=self.language_keys, ) modality_configs = { "video": video_modality, "state": state_modality, "action": action_modality, "language": language_modality, } return modality_configs def transform(self): transforms = [ # state transforms StateActionToTensor(apply_to=self.state_keys), StateActionTransform( apply_to=self.state_keys, normalization_modes={ "state.x": "min_max", "state.y": "min_max", "state.z": "min_max", "state.roll": "min_max", "state.pitch": "min_max", "state.yaw": "min_max", "state.pad": "min_max", # "state.gripper": "binary", }, ), # action transforms StateActionToTensor(apply_to=self.action_keys), StateActionTransform( apply_to=self.action_keys, normalization_modes={ "action.x": "min_max", "action.y": "min_max", "action.z": "min_max", "action.roll": "min_max", "action.pitch": "min_max", "action.yaw": "min_max", # "action.gripper": "binary", }, ), ] return ComposedModalityTransform(transforms=transforms) ########################################################################################### class RealWorldFrankaDataConfig: """Real-world Panda robot: 7 joints + 1 gripper (8D), single-arm -> right slot [7:15].""" video_keys = [ "video.exterior_image_1_left", "video.wrist_image_left", ] state_keys = [ "state.joints", "state.gripper", ] action_keys = [ "action.joints", "action.gripper", ] language_keys = ["annotation.human.action.task_description"] observation_indices = [0] action_indices = list(range(16)) def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): self.chunk_size = chunk_size self.action_indices = list(range(chunk_size)) self.state_use_action_chunk = state_use_action_chunk self.num_history_steps = int(num_history_steps or 0) self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] def modality_config(self): video_modality = ModalityConfig( delta_indices=self.video_observation_indices, modality_keys=self.video_keys, ) state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices state_modality = ModalityConfig( delta_indices=state_delta, modality_keys=self.state_keys, ) action_modality = ModalityConfig( delta_indices=self.action_indices, modality_keys=self.action_keys, ) language_modality = ModalityConfig( delta_indices=self.observation_indices, modality_keys=self.language_keys, ) modality_configs = { "video": video_modality, "state": state_modality, "action": action_modality, "language": language_modality, } return modality_configs def transform(self): transforms = [ StateActionToTensor(apply_to=self.state_keys), StateActionTransform( apply_to=self.state_keys, normalization_modes={ "state.joints": "min_max", # "state.gripper": "binary", }, ), StateActionToTensor(apply_to=self.action_keys), StateActionTransform( apply_to=self.action_keys, normalization_modes={ "action.joints": "min_max", # "action.gripper": "binary", }, ), ] return ComposedModalityTransform(transforms=transforms) class AgilexDataConfig: video_keys = [ "video.cam_high", "video.cam_left_wrist", "video.cam_right_wrist", ] state_keys = [ "state.left_joints", "state.left_gripper", "state.right_joints", "state.right_gripper", ] action_keys = [ "action.left_joints", "action.left_gripper", "action.right_joints", "action.right_gripper", ] language_keys = ["annotation.human.action.task_description"] observation_indices = [0] def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): self.chunk_size = chunk_size self.action_indices = list(range(chunk_size)) self.state_use_action_chunk = state_use_action_chunk self.num_history_steps = int(num_history_steps or 0) self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] def modality_config(self): video_modality = ModalityConfig( delta_indices=self.video_observation_indices, modality_keys=self.video_keys, ) state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices state_modality = ModalityConfig( delta_indices=state_delta, modality_keys=self.state_keys, ) action_modality = ModalityConfig( delta_indices=self.action_indices, modality_keys=self.action_keys, ) language_modality = ModalityConfig( delta_indices=self.observation_indices, modality_keys=self.language_keys, ) modality_configs = { "video": video_modality, "state": state_modality, "action": action_modality, "language": language_modality, } return modality_configs def transform(self): transforms = [ # state transforms StateActionToTensor(apply_to=self.state_keys), StateActionTransform( apply_to=self.state_keys, normalization_modes={ "state.left_joints": "min_max", "state.left_gripper": "binary", "state.right_joints": "min_max", "state.right_gripper": "binary", }, ), # action transforms StateActionToTensor(apply_to=self.action_keys), StateActionTransform( apply_to=self.action_keys, normalization_modes={ "action.left_joints": "min_max", "action.left_gripper": "binary", "action.right_joints": "min_max", "action.right_gripper": "binary", }, ), ] return ComposedModalityTransform(transforms=transforms) class FourierGr1ArmsWaistDataConfig: video_keys = ["video.ego_view"] state_keys = [ "state.left_arm", "state.right_arm", "state.left_hand", "state.right_hand", "state.waist", ] action_keys = [ "action.left_arm", "action.right_arm", "action.left_hand", "action.right_hand", "action.waist", ] language_keys = ["annotation.human.coarse_action"] observation_indices = [0] def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): self.chunk_size = chunk_size self.action_indices = list(range(chunk_size)) self.state_use_action_chunk = state_use_action_chunk self.num_history_steps = int(num_history_steps or 0) self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] def modality_config(self): video_modality = ModalityConfig( delta_indices=self.video_observation_indices, modality_keys=self.video_keys, ) state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices state_modality = ModalityConfig( delta_indices=state_delta, modality_keys=self.state_keys, ) action_modality = ModalityConfig( delta_indices=self.action_indices, modality_keys=self.action_keys, ) language_modality = ModalityConfig( delta_indices=self.observation_indices, modality_keys=self.language_keys, ) modality_configs = { "video": video_modality, "state": state_modality, "action": action_modality, "language": language_modality, } return modality_configs def transform(self) -> ModalityTransform: transforms = [ # state transforms StateActionToTensor(apply_to=self.state_keys), StateActionSinCosTransform(apply_to=self.state_keys), # action transforms StateActionToTensor(apply_to=self.action_keys), StateActionTransform( apply_to=self.action_keys, normalization_modes={key: "min_max" for key in self.action_keys}, ), ] return ComposedModalityTransform(transforms=transforms) ########################################################################################### def get_robot_type_config_map( chunk_size: int = 15, state_use_action_chunk: bool = True, num_history_steps: int = 0, ) -> dict[str, BaseDataConfig]: """state_use_action_chunk: when True, state uses action_indices so state has shape (L, state_dim) aligned with action chunk.""" return { "libero_franka": Libero4in1DataConfig( chunk_size=chunk_size, state_use_action_chunk=state_use_action_chunk, num_history_steps=num_history_steps, ), "robotwin": AgilexDataConfig( chunk_size=chunk_size, state_use_action_chunk=state_use_action_chunk, num_history_steps=num_history_steps, ), "fourier_gr1_arms_waist": FourierGr1ArmsWaistDataConfig( chunk_size=chunk_size, state_use_action_chunk=state_use_action_chunk, num_history_steps=num_history_steps, ), "real_world_franka": RealWorldFrankaDataConfig( chunk_size=chunk_size, state_use_action_chunk=state_use_action_chunk, num_history_steps=num_history_steps, ), }