| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from enum import Enum |
| from typing import Optional |
|
|
| from numpydantic import NDArray |
| from pydantic import BaseModel, Field, field_serializer |
|
|
| from .embodiment_tags import EmbodimentTag |
|
|
| |
|
|
|
|
| class RotationType(Enum): |
| """Type of rotation representation""" |
|
|
| AXIS_ANGLE = "axis_angle" |
| QUATERNION = "quaternion" |
| ROTATION_6D = "rotation_6d" |
| MATRIX = "matrix" |
| EULER_ANGLES_RPY = "euler_angles_rpy" |
| EULER_ANGLES_RYP = "euler_angles_ryp" |
| EULER_ANGLES_PRY = "euler_angles_pry" |
| EULER_ANGLES_PYR = "euler_angles_pyr" |
| EULER_ANGLES_YRP = "euler_angles_yrp" |
| EULER_ANGLES_YPR = "euler_angles_ypr" |
|
|
|
|
| |
|
|
|
|
| class LeRobotModalityField(BaseModel): |
| """Metadata for a LeRobot modality field.""" |
|
|
| original_key: Optional[str] = Field( |
| default=None, |
| description="The original key of the modality in the LeRobot dataset", |
| ) |
|
|
|
|
| class LeRobotStateActionMetadata(LeRobotModalityField): |
| """Metadata for a LeRobot modality.""" |
|
|
| start: int = Field( |
| ..., |
| description="The start index of the modality in the concatenated state/action vector", |
| ) |
| end: int = Field( |
| ..., |
| description="The end index of the modality in the concatenated state/action vector", |
| ) |
| rotation_type: Optional[RotationType] = Field( |
| default=None, description="The type of rotation for the modality" |
| ) |
| absolute: bool = Field(default=True, description="Whether the modality is absolute") |
| dtype: str = Field( |
| default="float64", |
| description="The data type of the modality. Defaults to float64.", |
| ) |
| range: Optional[tuple[float, float]] = Field( |
| default=None, |
| description="The range of the modality, if applicable. Defaults to None.", |
| ) |
| original_key: Optional[str] = Field( |
| default=None, |
| description="The original key of the modality in the LeRobot dataset.", |
| ) |
|
|
|
|
| class LeRobotStateMetadata(LeRobotStateActionMetadata): |
| """Metadata for a LeRobot state modality.""" |
|
|
| original_key: Optional[str] = Field( |
| default="observation.state", |
| description="The original key of the state modality in the LeRobot dataset", |
| ) |
|
|
|
|
| class LeRobotActionMetadata(LeRobotStateActionMetadata): |
| """Metadata for a LeRobot action modality.""" |
|
|
| original_key: Optional[str] = Field( |
| default="action", |
| description="The original key of the action modality in the LeRobot dataset", |
| ) |
|
|
|
|
| class LeRobotModalityMetadata(BaseModel): |
| """Metadata for a LeRobot modality.""" |
|
|
| state: dict[str, LeRobotStateMetadata] = Field( |
| ..., |
| description="The metadata for the state modality. The keys are the names of each split of the state vector.", |
| ) |
| action: dict[str, LeRobotActionMetadata] = Field( |
| ..., |
| description="The metadata for the action modality. The keys are the names of each split of the action vector.", |
| ) |
| video: dict[str, LeRobotModalityField] = Field( |
| ..., |
| description="The metadata for the video modality. The keys are the new names of each video modality.", |
| ) |
| annotation: Optional[dict[str, LeRobotModalityField]] = Field( |
| default=None, |
| description="The metadata for the annotation modality. The keys are the new names of each annotation modality.", |
| ) |
|
|
| def get_key_meta(self, key: str) -> LeRobotModalityField: |
| """Get the metadata for a key in the LeRobot modality metadata. |
| |
| Args: |
| key (str): The key to get the metadata for. |
| |
| Returns: |
| LeRobotModalityField: The metadata for the key. |
| |
| Example: |
| lerobot_modality_meta = LeRobotModalityMetadata.model_validate(U.load_json(modality_meta_path)) |
| lerobot_modality_meta.get_key_meta("state.joint_shoulder_y") |
| lerobot_modality_meta.get_key_meta("video.main_camera") |
| lerobot_modality_meta.get_key_meta("annotation.human.action.task_description") |
| """ |
| split_key = key.split(".") |
| modality = split_key[0] |
| subkey = ".".join(split_key[1:]) |
| if modality == "state": |
| if subkey not in self.state: |
| raise ValueError( |
| f"Key: {key}, state key {subkey} not found in metadata, available state keys: {self.state.keys()}" |
| ) |
| return self.state[subkey] |
| elif modality == "action": |
| if subkey not in self.action: |
| raise ValueError( |
| f"Key: {key}, action key {subkey} not found in metadata, available action keys: {self.action.keys()}" |
| ) |
| return self.action[subkey] |
| elif modality == "video": |
| if subkey not in self.video: |
| raise ValueError( |
| f"Key: {key}, video key {subkey} not found in metadata, available video keys: {self.video.keys()}" |
| ) |
| return self.video[subkey] |
| elif modality == "annotation": |
| assert ( |
| self.annotation is not None |
| ), "Trying to get annotation metadata for a dataset with no annotations" |
| if subkey not in self.annotation: |
| raise ValueError( |
| f"Key: {key}, annotation key {subkey} not found in metadata, available annotation keys: {self.annotation.keys()}" |
| ) |
| return self.annotation[subkey] |
| else: |
| raise ValueError(f"Key: {key}, unexpected modality: {modality}") |
|
|
|
|
| |
|
|
|
|
| class DatasetStatisticalValues(BaseModel): |
| max: NDArray = Field(..., description="Maximum values") |
| min: NDArray = Field(..., description="Minimum values") |
| mean: NDArray = Field(..., description="Mean values") |
| std: NDArray = Field(..., description="Standard deviation") |
| q01: NDArray = Field(..., description="1st percentile values") |
| q99: NDArray = Field(..., description="99th percentile values") |
|
|
| @field_serializer("*", when_used="json") |
| def serialize_ndarray(self, v: NDArray) -> list[float]: |
| return v.tolist() |
|
|
|
|
| class DatasetStatistics(BaseModel): |
| state: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the state") |
| action: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the action") |
|
|
|
|
| class VideoMetadata(BaseModel): |
| """Metadata of the video modality""" |
|
|
| resolution: tuple[int, int] = Field(..., description="Resolution of the video") |
| channels: int = Field(..., description="Number of channels in the video", gt=0) |
| fps: float = Field(..., description="Frames per second", gt=0) |
|
|
|
|
| class StateActionMetadata(BaseModel): |
| absolute: bool = Field(..., description="Whether the state or action is absolute") |
| rotation_type: Optional[RotationType] = Field(None, description="Type of rotation, if any") |
| shape: tuple[int, ...] = Field(..., description="Shape of the state or action") |
| continuous: bool = Field(..., description="Whether the state or action is continuous") |
|
|
|
|
| class DatasetModalities(BaseModel): |
| video: dict[str, VideoMetadata] = Field(..., description="Metadata of the video") |
| state: dict[str, StateActionMetadata] = Field(..., description="Metadata of the state") |
| action: dict[str, StateActionMetadata] = Field(..., description="Metadata of the action") |
|
|
|
|
| class DatasetMetadata(BaseModel): |
| """Metadata of the trainable dataset |
| |
| Changes: |
| - Update to use the new RawCommitHashMetadataMetadata_V1_2 |
| """ |
|
|
| statistics: DatasetStatistics = Field(..., description="Statistics of the dataset") |
| modalities: DatasetModalities = Field(..., description="Metadata of the modalities") |
| embodiment_tag: EmbodimentTag = Field(..., description="Embodiment tag of the dataset") |