# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import random from typing import Any, ClassVar import numpy as np import pytorch3d.transforms as pt import torch from pydantic import Field, PrivateAttr, field_validator, model_validator from ..schema import DatasetMetadata, RotationType, StateActionMetadata from .base import InvertibleModalityTransform, ModalityTransform class RotationTransform: """Adapted from https://github.com/real-stanford/diffusion_policy/blob/548a52bbb105518058e27bf34dcf90bf6f73681a/diffusion_policy/model/common/rotation_transformer.py""" valid_reps = ["axis_angle", "euler_angles", "quaternion", "rotation_6d", "matrix"] def __init__(self, from_rep="axis_angle", to_rep="rotation_6d"): """ Valid representations Always use matrix as intermediate representation. """ if from_rep.startswith("euler_angles"): from_convention = from_rep.split("_")[-1] from_rep = "euler_angles" from_convention = from_convention.replace("r", "X").replace("p", "Y").replace("y", "Z") else: from_convention = None if to_rep.startswith("euler_angles"): to_convention = to_rep.split("_")[-1] to_rep = "euler_angles" to_convention = to_convention.replace("r", "X").replace("p", "Y").replace("y", "Z") else: to_convention = None assert from_rep != to_rep, f"from_rep and to_rep cannot be the same: {from_rep}" assert from_rep in self.valid_reps, f"Invalid from_rep: {from_rep}" assert to_rep in self.valid_reps, f"Invalid to_rep: {to_rep}" forward_funcs = list() inverse_funcs = list() if from_rep != "matrix": funcs = [getattr(pt, f"{from_rep}_to_matrix"), getattr(pt, f"matrix_to_{from_rep}")] if from_convention is not None: funcs = [functools.partial(func, convention=from_convention) for func in funcs] forward_funcs.append(funcs[0]) inverse_funcs.append(funcs[1]) if to_rep != "matrix": funcs = [getattr(pt, f"matrix_to_{to_rep}"), getattr(pt, f"{to_rep}_to_matrix")] if to_convention is not None: funcs = [functools.partial(func, convention=to_convention) for func in funcs] forward_funcs.append(funcs[0]) inverse_funcs.append(funcs[1]) inverse_funcs = inverse_funcs[::-1] self.forward_funcs = forward_funcs self.inverse_funcs = inverse_funcs @staticmethod def _apply_funcs(x: torch.Tensor, funcs: list) -> torch.Tensor: assert isinstance(x, torch.Tensor) for func in funcs: x = func(x) return x def forward(self, x: torch.Tensor) -> torch.Tensor: assert isinstance( x, torch.Tensor ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" return self._apply_funcs(x, self.forward_funcs) def inverse(self, x: torch.Tensor) -> torch.Tensor: assert isinstance( x, torch.Tensor ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" return self._apply_funcs(x, self.inverse_funcs) class Normalizer: valid_modes = ["q99", "mean_std", "min_max", "binary"] def __init__(self, mode: str, statistics: dict): self.mode = mode self.statistics = statistics for key, value in self.statistics.items(): self.statistics[key] = torch.tensor(value) def forward(self, x: torch.Tensor) -> torch.Tensor: assert isinstance( x, torch.Tensor ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" # Normalize the tensor if self.mode == "q99": # Range of q99 is [-1, 1] q01 = self.statistics["q01"].to(x.dtype) q99 = self.statistics["q99"].to(x.dtype) # In the case of q01 == q99, the normalization will be undefined # So we set the normalized values to the original values mask = q01 != q99 normalized = torch.zeros_like(x) # Normalize the values where q01 != q99 # Formula: 2 * (x - q01) / (q99 - q01) - 1 normalized[..., mask] = (x[..., mask] - q01[..., mask]) / ( q99[..., mask] - q01[..., mask] ) normalized[..., mask] = 2 * normalized[..., mask] - 1 # Set the normalized values to the original values where q01 == q99 normalized[..., ~mask] = x[..., ~mask].to(x.dtype) # Clip the normalized values to be between -1 and 1 normalized = torch.clamp(normalized, -1, 1) elif self.mode == "mean_std": # Range of mean_std is not fixed, but can be positive or negative mean = self.statistics["mean"].to(x.dtype) std = self.statistics["std"].to(x.dtype) # In the case of std == 0, the normalization will be undefined # So we set the normalized values to the original values mask = std != 0 normalized = torch.zeros_like(x) # Normalize the values where std != 0 # Formula: (x - mean) / std normalized[..., mask] = (x[..., mask] - mean[..., mask]) / std[..., mask] # Set the normalized values to the original values where std == 0 normalized[..., ~mask] = x[..., ~mask].to(x.dtype) elif self.mode == "min_max": # Range of min_max is [-1, 1] min = self.statistics["min"].to(x.dtype) max = self.statistics["max"].to(x.dtype) # In the case of min == max, the normalization will be undefined # So we set the normalized values to the original values mask = min != max normalized = torch.zeros_like(x) # Normalize the values where min != max # Formula: 2 * (x - min) / (max - min) - 1 normalized[..., mask] = (x[..., mask] - min[..., mask]) / ( max[..., mask] - min[..., mask] ) normalized[..., mask] = 2 * normalized[..., mask] - 1 # Set the normalized values to the original values where min == max # normalized[..., ~mask] = x[..., ~mask].to(x.dtype) # Set the normalized values to 0 where min == max normalized[..., ~mask] = 0 elif self.mode == "scale": # Range of scale is [0, 1] min = self.statistics["min"].to(x.dtype) max = self.statistics["max"].to(x.dtype) abs_max = torch.max(torch.abs(min), torch.abs(max)) mask = abs_max != 0 normalized = torch.zeros_like(x) normalized[..., mask] = x[..., mask] / abs_max[..., mask] normalized[..., ~mask] = 0 elif self.mode == "binary": # Range of binary is [0, 1] normalized = (x > 0.5).to(x.dtype) else: raise ValueError(f"Invalid normalization mode: {self.mode}") return normalized def inverse(self, x: torch.Tensor) -> torch.Tensor: assert isinstance( x, torch.Tensor ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" if self.mode == "q99": q01 = self.statistics["q01"].to(x.dtype) q99 = self.statistics["q99"].to(x.dtype) return (x + 1) / 2 * (q99 - q01) + q01 elif self.mode == "mean_std": mean = self.statistics["mean"].to(x.dtype) std = self.statistics["std"].to(x.dtype) return x * std + mean elif self.mode == "min_max": min = self.statistics["min"].to(x.dtype) max = self.statistics["max"].to(x.dtype) return (x + 1) / 2 * (max - min) + min elif self.mode == "binary": return (x > 0.5).to(x.dtype) else: raise ValueError(f"Invalid normalization mode: {self.mode}") class StateActionToTensor(InvertibleModalityTransform): """ Transforms states and actions to tensors. """ input_dtypes: dict[str, np.dtype] = Field( default_factory=dict, description="The input dtypes for each state key." ) output_dtypes: dict[str, torch.dtype] = Field( default_factory=dict, description="The output dtypes for each state key." ) def model_dump(self, *args, **kwargs): if kwargs.get("mode", "python") == "json": include = {"apply_to"} else: include = kwargs.pop("include", None) return super().model_dump(*args, include=include, **kwargs) @field_validator("input_dtypes", "output_dtypes", mode="before") def validate_dtypes(cls, v): for key, dtype in v.items(): if isinstance(dtype, str): if dtype.startswith("torch."): dtype_split = dtype.split(".")[-1] v[key] = getattr(torch, dtype_split) elif dtype.startswith("np.") or dtype.startswith("numpy."): dtype_split = dtype.split(".")[-1] v[key] = np.dtype(dtype_split) else: raise ValueError(f"Invalid dtype: {dtype}") return v def apply(self, data: dict[str, Any]) -> dict[str, Any]: for key in self.apply_to: if key not in data: continue value = data[key] assert isinstance( value, np.ndarray ), f"Unexpected input type: {type(value)}. Expected type: {np.ndarray}" data[key] = torch.from_numpy(value) if key in self.output_dtypes: data[key] = data[key].to(self.output_dtypes[key]) return data def unapply(self, data: dict[str, Any]) -> dict[str, Any]: for key in self.apply_to: if key not in data: continue value = data[key] assert isinstance( value, torch.Tensor ), f"Unexpected input type: {type(value)}. Expected type: {torch.Tensor}" data[key] = value.numpy() if key in self.input_dtypes: data[key] = data[key].astype(self.input_dtypes[key]) return data class StateActionTransform(InvertibleModalityTransform): """ Class for state or action transform. Args: apply_to (list[str]): The keys in the modality to load and transform. normalization_modes (dict[str, str]): The normalization modes for each state key. If a state key in apply_to is not present in the dictionary, it will not be normalized. target_rotations (dict[str, str]): The target representations for each state key. If a state key in apply_to is not present in the dictionary, it will not be rotated. """ # Configurable attributes apply_to: list[str] = Field(..., description="The keys in the modality to load and transform.") normalization_modes: dict[str, str] = Field( default_factory=dict, description="The normalization modes for each state key." ) target_rotations: dict[str, str] = Field( default_factory=dict, description="The target representations for each state key." ) normalization_statistics: dict[str, dict] = Field( default_factory=dict, description="The statistics for each state key." ) modality_metadata: dict[str, StateActionMetadata] = Field( default_factory=dict, description="The modality metadata for each state key." ) # Model variables _rotation_transformers: dict[str, RotationTransform] = PrivateAttr(default_factory=dict) _normalizers: dict[str, Normalizer] = PrivateAttr(default_factory=dict) _input_dtypes: dict[str, np.dtype | torch.dtype] = PrivateAttr(default_factory=dict) # Model constants _DEFAULT_MIN_MAX_STATISTICS: ClassVar[dict] = { "rotation_6d": { "min": [-1, -1, -1, -1, -1, -1], "max": [1, 1, 1, 1, 1, 1], }, "euler_angles": { "min": [-np.pi, -np.pi, -np.pi], "max": [np.pi, np.pi, np.pi], }, "quaternion": { "min": [-1, -1, -1, -1], "max": [1, 1, 1, 1], }, "axis_angle": { "min": [-np.pi, -np.pi, -np.pi], "max": [np.pi, np.pi, np.pi], }, } def model_dump(self, *args, **kwargs): if kwargs.get("mode", "python") == "json": include = {"apply_to", "normalization_modes", "target_rotations"} else: include = kwargs.pop("include", None) return super().model_dump(*args, include=include, **kwargs) @field_validator("modality_metadata", mode="before") def validate_modality_metadata(cls, v): for modality_key, config in v.items(): if isinstance(config, dict): config = StateActionMetadata.model_validate(config) else: assert isinstance( config, StateActionMetadata ), f"Invalid source rotation config: {config}" v[modality_key] = config return v @model_validator(mode="after") def validate_normalization_statistics(self): for modality_key, normalization_statistics in self.normalization_statistics.items(): if modality_key in self.normalization_modes: normalization_mode = self.normalization_modes[modality_key] if normalization_mode == "min_max": assert ( "min" in normalization_statistics and "max" in normalization_statistics ), f"Min and max statistics are required for min_max normalization, but got {normalization_statistics}" assert len(normalization_statistics["min"]) == len( normalization_statistics["max"] ), f"Min and max statistics must have the same length, but got {normalization_statistics['min']} and {normalization_statistics['max']}" elif normalization_mode == "mean_std": assert ( "mean" in normalization_statistics and "std" in normalization_statistics ), f"Mean and std statistics are required for mean_std normalization, but got {normalization_statistics}" assert len(normalization_statistics["mean"]) == len( normalization_statistics["std"] ), f"Mean and std statistics must have the same length, but got {normalization_statistics['mean']} and {normalization_statistics['std']}" elif normalization_mode == "q99": assert ( "q01" in normalization_statistics and "q99" in normalization_statistics ), f"q01 and q99 statistics are required for q99 normalization, but got {normalization_statistics}" assert len(normalization_statistics["q01"]) == len( normalization_statistics["q99"] ), f"q01 and q99 statistics must have the same length, but got {normalization_statistics['q01']} and {normalization_statistics['q99']}" elif normalization_mode == "binary": assert ( len(normalization_statistics) == 1 ), f"Binary normalization should only have one value, but got {normalization_statistics}" assert normalization_statistics[0] in [ 0, 1, ], f"Binary normalization should only have 0 or 1, but got {normalization_statistics[0]}" else: raise ValueError(f"Invalid normalization mode: {normalization_mode}") return self def set_metadata(self, dataset_metadata: DatasetMetadata): dataset_statistics = dataset_metadata.statistics modality_metadata = dataset_metadata.modalities # Check that all state keys specified in apply_to have their modality_metadata for key in self.apply_to: split_key = key.split(".", 1) assert len(split_key) == 2, "State keys should have two parts: 'modality.key'" if key not in self.modality_metadata: modality, state_key = split_key assert hasattr(modality_metadata, modality), f"{modality} config not found" assert state_key in getattr( modality_metadata, modality ), f"{state_key} config not found" self.modality_metadata[key] = getattr(modality_metadata, modality)[state_key] # Check that all state keys specified in normalization_modes have their statistics in state_statistics for key in self.normalization_modes: split_key = key.split(".", 1) assert len(split_key) == 2, "State keys should have two parts: 'modality.key'" modality, state_key = split_key assert hasattr(dataset_statistics, modality), f"{modality} statistics not found" assert state_key in getattr( dataset_statistics, modality ), f"{state_key} statistics not found" assert ( len(getattr(modality_metadata, modality)[state_key].shape) == 1 ), f"{getattr(modality_metadata, modality)[state_key].shape=}" self.normalization_statistics[key] = getattr(dataset_statistics, modality)[ state_key ].model_dump() # Initialize the rotation transformers for key in self.target_rotations: # Get the original representation of the state from_rep = self.modality_metadata[key].rotation_type assert from_rep is not None, f"Source rotation type not found for {key}" # Get the target representation of the state, will raise an error if the target representation is not valid to_rep = RotationType(self.target_rotations[key]) # If the original representation is not the same as the target representation, initialize the rotation transformer if from_rep != to_rep: self._rotation_transformers[key] = RotationTransform( from_rep=from_rep.value, to_rep=to_rep.value ) # Initialize the normalizers for key in self.normalization_modes: modality, state_key = key.split(".", 1) # If the state has a nontrivial rotation, we need to handle it more carefully # For absolute rotations, we need to convert them to the target representation and normalize them using min_max mode, # since we can infer the bounds by the representation # For relative rotations, we cannot normalize them as we don't know the bounds if key in self._rotation_transformers: # Case 1: Absolute rotation if self.modality_metadata[key].absolute: # Check that the normalization mode is valid assert ( self.normalization_modes[key] == "min_max" ), "Absolute rotations that are converted to other formats must be normalized using `min_max` mode" rotation_type = RotationType(self.target_rotations[key]).value # If the target representation is euler angles, we need to parse the convention if rotation_type.startswith("euler_angles"): rotation_type = "euler_angles" # Get the statistics for the target representation statistics = self._DEFAULT_MIN_MAX_STATISTICS[rotation_type] # Case 2: Relative rotation else: raise ValueError( f"Cannot normalize relative rotations: {key} that's converted to {self.target_rotations[key]}" ) # If the state is not continuous, we should not use normalization modes other than binary elif ( not self.modality_metadata[key].continuous and self.normalization_modes[key] != "binary" ): raise ValueError( f"{key} is not continuous, so it should be normalized using `binary` mode" ) # Initialize the normalizer else: statistics = self.normalization_statistics[key] self._normalizers[key] = Normalizer( mode=self.normalization_modes[key], statistics=statistics ) def apply(self, data: dict[str, Any]) -> dict[str, Any]: for key in self.apply_to: if key not in data: # We allow some keys to be missing in the data, and only process the keys that are present continue if key not in self._input_dtypes: input_dtype = data[key].dtype assert isinstance( input_dtype, torch.dtype ), f"Unexpected input dtype: {input_dtype}. Expected type: {torch.dtype}" self._input_dtypes[key] = input_dtype else: assert ( data[key].dtype == self._input_dtypes[key] ), f"All states corresponding to the same key must be of the same dtype, input dtype: {data[key].dtype}, expected dtype: {self._input_dtypes[key]}" # Rotate the state state = data[key] if key in self._rotation_transformers: state = self._rotation_transformers[key].forward(state) # Normalize the state if key in self._normalizers: state = self._normalizers[key].forward(state) data[key] = state return data def unapply(self, data: dict[str, Any]) -> dict[str, Any]: for key in self.apply_to: if key not in data: continue state = data[key] assert isinstance( state, torch.Tensor ), f"Unexpected state type: {type(state)}. Expected type: {torch.Tensor}" # Unnormalize the state if key in self._normalizers: state = self._normalizers[key].inverse(state) # Change the state back to its original representation if key in self._rotation_transformers: state = self._rotation_transformers[key].inverse(state) assert isinstance( state, torch.Tensor ), f"State should be tensor after unapplying transformations, but got {type(state)}" # Only convert back to the original dtype if it's known, i.e. `apply` was called before # If not, we don't know the original dtype, so we don't convert if key in self._input_dtypes: original_dtype = self._input_dtypes[key] if isinstance(original_dtype, np.dtype): state = state.numpy().astype(original_dtype) elif isinstance(original_dtype, torch.dtype): state = state.to(original_dtype) else: raise ValueError(f"Invalid input dtype: {original_dtype}") data[key] = state return data class StateActionPerturbation(ModalityTransform): """ Class for state or action perturbation. Args: apply_to (list[str]): The keys in the modality to load and transform. std (float): Standard deviation of the noise to be added to the state or action. """ # Configurable attributes std: float = Field( ..., description="Standard deviation of the noise to be added to the state or action." ) def apply(self, data: dict[str, Any]) -> dict[str, Any]: if not self.training: # Don't perturb the data in eval mode return data if self.std < 0: # If the std is negative, we don't add any noise return data for key in self.apply_to: state = data[key] assert isinstance(state, torch.Tensor) transformed_data_min = torch.min(state) transformed_data_max = torch.max(state) noise = torch.randn_like(state) * self.std state += noise # Clip to the original range state = torch.clamp(state, transformed_data_min, transformed_data_max) data[key] = state return data class StateActionDropout(ModalityTransform): """ Class for state or action dropout. Args: apply_to (list[str]): The keys in the modality to load and transform. dropout_prob (float): Probability of dropping out a state or action. """ # Configurable attributes dropout_prob: float = Field(..., description="Probability of dropping out a state or action.") def apply(self, data: dict[str, Any]) -> dict[str, Any]: if not self.training: # Don't drop out the data in eval mode return data if self.dropout_prob < 0: # If the dropout probability is negative, we don't drop out any states return data if self.dropout_prob > 1e-9 and random.random() < self.dropout_prob: for key in self.apply_to: state = data[key] assert isinstance(state, torch.Tensor) state = torch.zeros_like(state) data[key] = state return data class StateActionSinCosTransform(ModalityTransform): """ Class for state or action sin-cos transform. Args: apply_to (list[str]): The keys in the modality to load and transform. """ def apply(self, data: dict[str, Any]) -> dict[str, Any]: for key in self.apply_to: state = data[key] assert isinstance(state, torch.Tensor) sin_state = torch.sin(state) cos_state = torch.cos(state) data[key] = torch.cat([sin_state, cos_state], dim=-1) return data