Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from starVLA.dataloader.gr00t_lerobot.datasets import ModalityConfig
from starVLA.dataloader.gr00t_lerobot.transform.base import ComposedModalityTransform, ModalityTransform
from starVLA.dataloader.gr00t_lerobot.transform.state_action import (
StateActionSinCosTransform,
StateActionToTensor,
StateActionTransform,
)
class BaseDataConfig(ABC):
@abstractmethod
def modality_config(self) -> dict[str, ModalityConfig]:
pass
@abstractmethod
def transform(self) -> ModalityTransform:
pass
###########################################################################################
class Libero4in1DataConfig:
video_keys = [
"video.primary_image",
"video.wrist_image",
]
state_keys = [
"state.x",
"state.y",
"state.z",
"state.roll",
"state.pitch",
"state.yaw",
"state.pad",
"state.gripper",
]
action_keys = [
"action.x",
"action.y",
"action.z",
"action.roll",
"action.pitch",
"action.yaw",
"action.gripper",
]
language_keys = ["annotation.human.action.task_description"]
observation_indices = [0]
action_indices = list(range(16))
def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
self.chunk_size = chunk_size
self.action_indices = list(range(chunk_size))
self.state_use_action_chunk = state_use_action_chunk
self.num_history_steps = int(num_history_steps or 0)
self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
def modality_config(self):
video_modality = ModalityConfig(
delta_indices=self.video_observation_indices,
modality_keys=self.video_keys,
)
state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
state_modality = ModalityConfig(
delta_indices=state_delta,
modality_keys=self.state_keys,
)
action_modality = ModalityConfig(
delta_indices=self.action_indices,
modality_keys=self.action_keys,
)
language_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.language_keys,
)
modality_configs = {
"video": video_modality,
"state": state_modality,
"action": action_modality,
"language": language_modality,
}
return modality_configs
def transform(self):
transforms = [
# state transforms
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(
apply_to=self.state_keys,
normalization_modes={
"state.x": "min_max",
"state.y": "min_max",
"state.z": "min_max",
"state.roll": "min_max",
"state.pitch": "min_max",
"state.yaw": "min_max",
"state.pad": "min_max",
# "state.gripper": "binary",
},
),
# action transforms
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(
apply_to=self.action_keys,
normalization_modes={
"action.x": "min_max",
"action.y": "min_max",
"action.z": "min_max",
"action.roll": "min_max",
"action.pitch": "min_max",
"action.yaw": "min_max",
# "action.gripper": "binary",
},
),
]
return ComposedModalityTransform(transforms=transforms)
###########################################################################################
class RealWorldFrankaDataConfig:
"""Real-world Panda robot: 7 joints + 1 gripper (8D), single-arm -> right slot [7:15]."""
video_keys = [
"video.exterior_image_1_left",
"video.wrist_image_left",
]
state_keys = [
"state.joints",
"state.gripper",
]
action_keys = [
"action.joints",
"action.gripper",
]
language_keys = ["annotation.human.action.task_description"]
observation_indices = [0]
action_indices = list(range(16))
def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
self.chunk_size = chunk_size
self.action_indices = list(range(chunk_size))
self.state_use_action_chunk = state_use_action_chunk
self.num_history_steps = int(num_history_steps or 0)
self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
def modality_config(self):
video_modality = ModalityConfig(
delta_indices=self.video_observation_indices,
modality_keys=self.video_keys,
)
state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
state_modality = ModalityConfig(
delta_indices=state_delta,
modality_keys=self.state_keys,
)
action_modality = ModalityConfig(
delta_indices=self.action_indices,
modality_keys=self.action_keys,
)
language_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.language_keys,
)
modality_configs = {
"video": video_modality,
"state": state_modality,
"action": action_modality,
"language": language_modality,
}
return modality_configs
def transform(self):
transforms = [
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(
apply_to=self.state_keys,
normalization_modes={
"state.joints": "min_max",
# "state.gripper": "binary",
},
),
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(
apply_to=self.action_keys,
normalization_modes={
"action.joints": "min_max",
# "action.gripper": "binary",
},
),
]
return ComposedModalityTransform(transforms=transforms)
class AgilexDataConfig:
video_keys = [
"video.cam_high",
"video.cam_left_wrist",
"video.cam_right_wrist",
]
state_keys = [
"state.left_joints",
"state.left_gripper",
"state.right_joints",
"state.right_gripper",
]
action_keys = [
"action.left_joints",
"action.left_gripper",
"action.right_joints",
"action.right_gripper",
]
language_keys = ["annotation.human.action.task_description"]
observation_indices = [0]
def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
self.chunk_size = chunk_size
self.action_indices = list(range(chunk_size))
self.state_use_action_chunk = state_use_action_chunk
self.num_history_steps = int(num_history_steps or 0)
self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
def modality_config(self):
video_modality = ModalityConfig(
delta_indices=self.video_observation_indices,
modality_keys=self.video_keys,
)
state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
state_modality = ModalityConfig(
delta_indices=state_delta,
modality_keys=self.state_keys,
)
action_modality = ModalityConfig(
delta_indices=self.action_indices,
modality_keys=self.action_keys,
)
language_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.language_keys,
)
modality_configs = {
"video": video_modality,
"state": state_modality,
"action": action_modality,
"language": language_modality,
}
return modality_configs
def transform(self):
transforms = [
# state transforms
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(
apply_to=self.state_keys,
normalization_modes={
"state.left_joints": "min_max",
"state.left_gripper": "binary",
"state.right_joints": "min_max",
"state.right_gripper": "binary",
},
),
# action transforms
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(
apply_to=self.action_keys,
normalization_modes={
"action.left_joints": "min_max",
"action.left_gripper": "binary",
"action.right_joints": "min_max",
"action.right_gripper": "binary",
},
),
]
return ComposedModalityTransform(transforms=transforms)
class FourierGr1ArmsWaistDataConfig:
video_keys = ["video.ego_view"]
state_keys = [
"state.left_arm",
"state.right_arm",
"state.left_hand",
"state.right_hand",
"state.waist",
]
action_keys = [
"action.left_arm",
"action.right_arm",
"action.left_hand",
"action.right_hand",
"action.waist",
]
language_keys = ["annotation.human.coarse_action"]
observation_indices = [0]
def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
self.chunk_size = chunk_size
self.action_indices = list(range(chunk_size))
self.state_use_action_chunk = state_use_action_chunk
self.num_history_steps = int(num_history_steps or 0)
self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
def modality_config(self):
video_modality = ModalityConfig(
delta_indices=self.video_observation_indices,
modality_keys=self.video_keys,
)
state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
state_modality = ModalityConfig(
delta_indices=state_delta,
modality_keys=self.state_keys,
)
action_modality = ModalityConfig(
delta_indices=self.action_indices,
modality_keys=self.action_keys,
)
language_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.language_keys,
)
modality_configs = {
"video": video_modality,
"state": state_modality,
"action": action_modality,
"language": language_modality,
}
return modality_configs
def transform(self) -> ModalityTransform:
transforms = [
# state transforms
StateActionToTensor(apply_to=self.state_keys),
StateActionSinCosTransform(apply_to=self.state_keys),
# action transforms
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(
apply_to=self.action_keys,
normalization_modes={key: "min_max" for key in self.action_keys},
),
]
return ComposedModalityTransform(transforms=transforms)
###########################################################################################
def get_robot_type_config_map(
chunk_size: int = 15,
state_use_action_chunk: bool = True,
num_history_steps: int = 0,
) -> dict[str, BaseDataConfig]:
"""state_use_action_chunk: when True, state uses action_indices so state has shape (L, state_dim) aligned with action chunk."""
return {
"libero_franka": Libero4in1DataConfig(
chunk_size=chunk_size,
state_use_action_chunk=state_use_action_chunk,
num_history_steps=num_history_steps,
),
"robotwin": AgilexDataConfig(
chunk_size=chunk_size,
state_use_action_chunk=state_use_action_chunk,
num_history_steps=num_history_steps,
),
"fourier_gr1_arms_waist": FourierGr1ArmsWaistDataConfig(
chunk_size=chunk_size,
state_use_action_chunk=state_use_action_chunk,
num_history_steps=num_history_steps,
),
"real_world_franka": RealWorldFrankaDataConfig(
chunk_size=chunk_size,
state_use_action_chunk=state_use_action_chunk,
num_history_steps=num_history_steps,
),
}