| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import functools |
| import random |
| from typing import Any, ClassVar |
|
|
| import numpy as np |
| import pytorch3d.transforms as pt |
| import torch |
| from pydantic import Field, PrivateAttr, field_validator, model_validator |
|
|
| from ..schema import DatasetMetadata, RotationType, StateActionMetadata |
| from .base import InvertibleModalityTransform, ModalityTransform |
|
|
|
|
| class RotationTransform: |
| """Adapted from https://github.com/real-stanford/diffusion_policy/blob/548a52bbb105518058e27bf34dcf90bf6f73681a/diffusion_policy/model/common/rotation_transformer.py""" |
|
|
| valid_reps = ["axis_angle", "euler_angles", "quaternion", "rotation_6d", "matrix"] |
|
|
| def __init__(self, from_rep="axis_angle", to_rep="rotation_6d"): |
| """ |
| Valid representations |
| |
| Always use matrix as intermediate representation. |
| """ |
| if from_rep.startswith("euler_angles"): |
| from_convention = from_rep.split("_")[-1] |
| from_rep = "euler_angles" |
| from_convention = from_convention.replace("r", "X").replace("p", "Y").replace("y", "Z") |
| else: |
| from_convention = None |
| if to_rep.startswith("euler_angles"): |
| to_convention = to_rep.split("_")[-1] |
| to_rep = "euler_angles" |
| to_convention = to_convention.replace("r", "X").replace("p", "Y").replace("y", "Z") |
| else: |
| to_convention = None |
| assert from_rep != to_rep, f"from_rep and to_rep cannot be the same: {from_rep}" |
| assert from_rep in self.valid_reps, f"Invalid from_rep: {from_rep}" |
| assert to_rep in self.valid_reps, f"Invalid to_rep: {to_rep}" |
|
|
| forward_funcs = list() |
| inverse_funcs = list() |
|
|
| if from_rep != "matrix": |
| funcs = [getattr(pt, f"{from_rep}_to_matrix"), getattr(pt, f"matrix_to_{from_rep}")] |
| if from_convention is not None: |
| funcs = [functools.partial(func, convention=from_convention) for func in funcs] |
| forward_funcs.append(funcs[0]) |
| inverse_funcs.append(funcs[1]) |
|
|
| if to_rep != "matrix": |
| funcs = [getattr(pt, f"matrix_to_{to_rep}"), getattr(pt, f"{to_rep}_to_matrix")] |
| if to_convention is not None: |
| funcs = [functools.partial(func, convention=to_convention) for func in funcs] |
| forward_funcs.append(funcs[0]) |
| inverse_funcs.append(funcs[1]) |
|
|
| inverse_funcs = inverse_funcs[::-1] |
|
|
| self.forward_funcs = forward_funcs |
| self.inverse_funcs = inverse_funcs |
|
|
| @staticmethod |
| def _apply_funcs(x: torch.Tensor, funcs: list) -> torch.Tensor: |
| assert isinstance(x, torch.Tensor) |
| for func in funcs: |
| x = func(x) |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert isinstance( |
| x, torch.Tensor |
| ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" |
| return self._apply_funcs(x, self.forward_funcs) |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| assert isinstance( |
| x, torch.Tensor |
| ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" |
| return self._apply_funcs(x, self.inverse_funcs) |
|
|
|
|
| class Normalizer: |
| valid_modes = ["q99", "mean_std", "min_max", "binary"] |
|
|
| def __init__(self, mode: str, statistics: dict): |
| self.mode = mode |
| self.statistics = statistics |
| for key, value in self.statistics.items(): |
| self.statistics[key] = torch.tensor(value) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert isinstance( |
| x, torch.Tensor |
| ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" |
|
|
| |
| if self.mode == "q99": |
| |
| q01 = self.statistics["q01"].to(x.dtype) |
| q99 = self.statistics["q99"].to(x.dtype) |
|
|
| |
| |
| mask = q01 != q99 |
| normalized = torch.zeros_like(x) |
|
|
| |
| |
| normalized[..., mask] = (x[..., mask] - q01[..., mask]) / ( |
| q99[..., mask] - q01[..., mask] |
| ) |
| normalized[..., mask] = 2 * normalized[..., mask] - 1 |
|
|
| |
| normalized[..., ~mask] = x[..., ~mask].to(x.dtype) |
|
|
| |
| normalized = torch.clamp(normalized, -1, 1) |
|
|
| elif self.mode == "mean_std": |
| |
| mean = self.statistics["mean"].to(x.dtype) |
| std = self.statistics["std"].to(x.dtype) |
|
|
| |
| |
| mask = std != 0 |
| normalized = torch.zeros_like(x) |
|
|
| |
| |
| normalized[..., mask] = (x[..., mask] - mean[..., mask]) / std[..., mask] |
|
|
| |
| normalized[..., ~mask] = x[..., ~mask].to(x.dtype) |
|
|
| elif self.mode == "min_max": |
| |
| min = self.statistics["min"].to(x.dtype) |
| max = self.statistics["max"].to(x.dtype) |
|
|
| |
| |
| mask = min != max |
| normalized = torch.zeros_like(x) |
|
|
| |
| |
| normalized[..., mask] = (x[..., mask] - min[..., mask]) / ( |
| max[..., mask] - min[..., mask] |
| ) |
| normalized[..., mask] = 2 * normalized[..., mask] - 1 |
|
|
| |
| |
| |
| normalized[..., ~mask] = 0 |
|
|
| elif self.mode == "scale": |
| |
| min = self.statistics["min"].to(x.dtype) |
| max = self.statistics["max"].to(x.dtype) |
| abs_max = torch.max(torch.abs(min), torch.abs(max)) |
| mask = abs_max != 0 |
| normalized = torch.zeros_like(x) |
| normalized[..., mask] = x[..., mask] / abs_max[..., mask] |
| normalized[..., ~mask] = 0 |
|
|
| elif self.mode == "binary": |
| |
| normalized = (x > 0.5).to(x.dtype) |
| else: |
| raise ValueError(f"Invalid normalization mode: {self.mode}") |
|
|
| return normalized |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| assert isinstance( |
| x, torch.Tensor |
| ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" |
| if self.mode == "q99": |
| q01 = self.statistics["q01"].to(x.dtype) |
| q99 = self.statistics["q99"].to(x.dtype) |
| return (x + 1) / 2 * (q99 - q01) + q01 |
| elif self.mode == "mean_std": |
| mean = self.statistics["mean"].to(x.dtype) |
| std = self.statistics["std"].to(x.dtype) |
| return x * std + mean |
| elif self.mode == "min_max": |
| min = self.statistics["min"].to(x.dtype) |
| max = self.statistics["max"].to(x.dtype) |
| return (x + 1) / 2 * (max - min) + min |
| elif self.mode == "binary": |
| return (x > 0.5).to(x.dtype) |
| else: |
| raise ValueError(f"Invalid normalization mode: {self.mode}") |
|
|
|
|
| class StateActionToTensor(InvertibleModalityTransform): |
| """ |
| Transforms states and actions to tensors. |
| """ |
|
|
| input_dtypes: dict[str, np.dtype] = Field( |
| default_factory=dict, description="The input dtypes for each state key." |
| ) |
| output_dtypes: dict[str, torch.dtype] = Field( |
| default_factory=dict, description="The output dtypes for each state key." |
| ) |
|
|
| def model_dump(self, *args, **kwargs): |
| if kwargs.get("mode", "python") == "json": |
| include = {"apply_to"} |
| else: |
| include = kwargs.pop("include", None) |
|
|
| return super().model_dump(*args, include=include, **kwargs) |
|
|
| @field_validator("input_dtypes", "output_dtypes", mode="before") |
| def validate_dtypes(cls, v): |
| for key, dtype in v.items(): |
| if isinstance(dtype, str): |
| if dtype.startswith("torch."): |
| dtype_split = dtype.split(".")[-1] |
| v[key] = getattr(torch, dtype_split) |
| elif dtype.startswith("np.") or dtype.startswith("numpy."): |
| dtype_split = dtype.split(".")[-1] |
| v[key] = np.dtype(dtype_split) |
| else: |
| raise ValueError(f"Invalid dtype: {dtype}") |
| return v |
|
|
| def apply(self, data: dict[str, Any]) -> dict[str, Any]: |
| for key in self.apply_to: |
| if key not in data: |
| continue |
| value = data[key] |
| assert isinstance( |
| value, np.ndarray |
| ), f"Unexpected input type: {type(value)}. Expected type: {np.ndarray}" |
| data[key] = torch.from_numpy(value) |
| if key in self.output_dtypes: |
| data[key] = data[key].to(self.output_dtypes[key]) |
| return data |
|
|
| def unapply(self, data: dict[str, Any]) -> dict[str, Any]: |
| for key in self.apply_to: |
| if key not in data: |
| continue |
| value = data[key] |
| assert isinstance( |
| value, torch.Tensor |
| ), f"Unexpected input type: {type(value)}. Expected type: {torch.Tensor}" |
| data[key] = value.numpy() |
| if key in self.input_dtypes: |
| data[key] = data[key].astype(self.input_dtypes[key]) |
| return data |
|
|
|
|
| class StateActionTransform(InvertibleModalityTransform): |
| """ |
| Class for state or action transform. |
| |
| Args: |
| apply_to (list[str]): The keys in the modality to load and transform. |
| normalization_modes (dict[str, str]): The normalization modes for each state key. |
| If a state key in apply_to is not present in the dictionary, it will not be normalized. |
| target_rotations (dict[str, str]): The target representations for each state key. |
| If a state key in apply_to is not present in the dictionary, it will not be rotated. |
| """ |
|
|
| |
| apply_to: list[str] = Field(..., description="The keys in the modality to load and transform.") |
| normalization_modes: dict[str, str] = Field( |
| default_factory=dict, description="The normalization modes for each state key." |
| ) |
| target_rotations: dict[str, str] = Field( |
| default_factory=dict, description="The target representations for each state key." |
| ) |
| normalization_statistics: dict[str, dict] = Field( |
| default_factory=dict, description="The statistics for each state key." |
| ) |
| modality_metadata: dict[str, StateActionMetadata] = Field( |
| default_factory=dict, description="The modality metadata for each state key." |
| ) |
|
|
| |
| _rotation_transformers: dict[str, RotationTransform] = PrivateAttr(default_factory=dict) |
| _normalizers: dict[str, Normalizer] = PrivateAttr(default_factory=dict) |
| _input_dtypes: dict[str, np.dtype | torch.dtype] = PrivateAttr(default_factory=dict) |
|
|
| |
| _DEFAULT_MIN_MAX_STATISTICS: ClassVar[dict] = { |
| "rotation_6d": { |
| "min": [-1, -1, -1, -1, -1, -1], |
| "max": [1, 1, 1, 1, 1, 1], |
| }, |
| "euler_angles": { |
| "min": [-np.pi, -np.pi, -np.pi], |
| "max": [np.pi, np.pi, np.pi], |
| }, |
| "quaternion": { |
| "min": [-1, -1, -1, -1], |
| "max": [1, 1, 1, 1], |
| }, |
| "axis_angle": { |
| "min": [-np.pi, -np.pi, -np.pi], |
| "max": [np.pi, np.pi, np.pi], |
| }, |
| } |
|
|
| def model_dump(self, *args, **kwargs): |
| if kwargs.get("mode", "python") == "json": |
| include = {"apply_to", "normalization_modes", "target_rotations"} |
| else: |
| include = kwargs.pop("include", None) |
|
|
| return super().model_dump(*args, include=include, **kwargs) |
|
|
| @field_validator("modality_metadata", mode="before") |
| def validate_modality_metadata(cls, v): |
| for modality_key, config in v.items(): |
| if isinstance(config, dict): |
| config = StateActionMetadata.model_validate(config) |
| else: |
| assert isinstance( |
| config, StateActionMetadata |
| ), f"Invalid source rotation config: {config}" |
| v[modality_key] = config |
| return v |
|
|
| @model_validator(mode="after") |
| def validate_normalization_statistics(self): |
| for modality_key, normalization_statistics in self.normalization_statistics.items(): |
| if modality_key in self.normalization_modes: |
| normalization_mode = self.normalization_modes[modality_key] |
| if normalization_mode == "min_max": |
| assert ( |
| "min" in normalization_statistics and "max" in normalization_statistics |
| ), f"Min and max statistics are required for min_max normalization, but got {normalization_statistics}" |
| assert len(normalization_statistics["min"]) == len( |
| normalization_statistics["max"] |
| ), f"Min and max statistics must have the same length, but got {normalization_statistics['min']} and {normalization_statistics['max']}" |
| elif normalization_mode == "mean_std": |
| assert ( |
| "mean" in normalization_statistics and "std" in normalization_statistics |
| ), f"Mean and std statistics are required for mean_std normalization, but got {normalization_statistics}" |
| assert len(normalization_statistics["mean"]) == len( |
| normalization_statistics["std"] |
| ), f"Mean and std statistics must have the same length, but got {normalization_statistics['mean']} and {normalization_statistics['std']}" |
| elif normalization_mode == "q99": |
| assert ( |
| "q01" in normalization_statistics and "q99" in normalization_statistics |
| ), f"q01 and q99 statistics are required for q99 normalization, but got {normalization_statistics}" |
| assert len(normalization_statistics["q01"]) == len( |
| normalization_statistics["q99"] |
| ), f"q01 and q99 statistics must have the same length, but got {normalization_statistics['q01']} and {normalization_statistics['q99']}" |
| elif normalization_mode == "binary": |
| assert ( |
| len(normalization_statistics) == 1 |
| ), f"Binary normalization should only have one value, but got {normalization_statistics}" |
| assert normalization_statistics[0] in [ |
| 0, |
| 1, |
| ], f"Binary normalization should only have 0 or 1, but got {normalization_statistics[0]}" |
| else: |
| raise ValueError(f"Invalid normalization mode: {normalization_mode}") |
| return self |
|
|
| def set_metadata(self, dataset_metadata: DatasetMetadata): |
| dataset_statistics = dataset_metadata.statistics |
| modality_metadata = dataset_metadata.modalities |
|
|
| |
| for key in self.apply_to: |
| split_key = key.split(".", 1) |
| assert len(split_key) == 2, "State keys should have two parts: 'modality.key'" |
| if key not in self.modality_metadata: |
| modality, state_key = split_key |
| assert hasattr(modality_metadata, modality), f"{modality} config not found" |
| assert state_key in getattr( |
| modality_metadata, modality |
| ), f"{state_key} config not found" |
| self.modality_metadata[key] = getattr(modality_metadata, modality)[state_key] |
|
|
| |
| for key in self.normalization_modes: |
| split_key = key.split(".", 1) |
| assert len(split_key) == 2, "State keys should have two parts: 'modality.key'" |
| modality, state_key = split_key |
| assert hasattr(dataset_statistics, modality), f"{modality} statistics not found" |
| assert state_key in getattr( |
| dataset_statistics, modality |
| ), f"{state_key} statistics not found" |
| assert ( |
| len(getattr(modality_metadata, modality)[state_key].shape) == 1 |
| ), f"{getattr(modality_metadata, modality)[state_key].shape=}" |
| self.normalization_statistics[key] = getattr(dataset_statistics, modality)[ |
| state_key |
| ].model_dump() |
|
|
| |
| for key in self.target_rotations: |
| |
| from_rep = self.modality_metadata[key].rotation_type |
| assert from_rep is not None, f"Source rotation type not found for {key}" |
|
|
| |
| to_rep = RotationType(self.target_rotations[key]) |
|
|
| |
| if from_rep != to_rep: |
| self._rotation_transformers[key] = RotationTransform( |
| from_rep=from_rep.value, to_rep=to_rep.value |
| ) |
|
|
| |
| for key in self.normalization_modes: |
| modality, state_key = key.split(".", 1) |
| |
| |
| |
| |
| if key in self._rotation_transformers: |
| |
| if self.modality_metadata[key].absolute: |
| |
| assert ( |
| self.normalization_modes[key] == "min_max" |
| ), "Absolute rotations that are converted to other formats must be normalized using `min_max` mode" |
| rotation_type = RotationType(self.target_rotations[key]).value |
| |
| if rotation_type.startswith("euler_angles"): |
| rotation_type = "euler_angles" |
| |
| statistics = self._DEFAULT_MIN_MAX_STATISTICS[rotation_type] |
| |
| else: |
| raise ValueError( |
| f"Cannot normalize relative rotations: {key} that's converted to {self.target_rotations[key]}" |
| ) |
| |
| elif ( |
| not self.modality_metadata[key].continuous |
| and self.normalization_modes[key] != "binary" |
| ): |
| raise ValueError( |
| f"{key} is not continuous, so it should be normalized using `binary` mode" |
| ) |
| |
| else: |
| statistics = self.normalization_statistics[key] |
| self._normalizers[key] = Normalizer( |
| mode=self.normalization_modes[key], statistics=statistics |
| ) |
|
|
| def apply(self, data: dict[str, Any]) -> dict[str, Any]: |
| for key in self.apply_to: |
| if key not in data: |
| |
| continue |
| if key not in self._input_dtypes: |
| input_dtype = data[key].dtype |
| assert isinstance( |
| input_dtype, torch.dtype |
| ), f"Unexpected input dtype: {input_dtype}. Expected type: {torch.dtype}" |
| self._input_dtypes[key] = input_dtype |
| else: |
| assert ( |
| data[key].dtype == self._input_dtypes[key] |
| ), f"All states corresponding to the same key must be of the same dtype, input dtype: {data[key].dtype}, expected dtype: {self._input_dtypes[key]}" |
| |
| state = data[key] |
| if key in self._rotation_transformers: |
| state = self._rotation_transformers[key].forward(state) |
| |
| if key in self._normalizers: |
| state = self._normalizers[key].forward(state) |
| data[key] = state |
| return data |
|
|
| def unapply(self, data: dict[str, Any]) -> dict[str, Any]: |
| for key in self.apply_to: |
| if key not in data: |
| continue |
| state = data[key] |
| assert isinstance( |
| state, torch.Tensor |
| ), f"Unexpected state type: {type(state)}. Expected type: {torch.Tensor}" |
| |
| if key in self._normalizers: |
| state = self._normalizers[key].inverse(state) |
| |
| if key in self._rotation_transformers: |
| state = self._rotation_transformers[key].inverse(state) |
| assert isinstance( |
| state, torch.Tensor |
| ), f"State should be tensor after unapplying transformations, but got {type(state)}" |
| |
| |
| if key in self._input_dtypes: |
| original_dtype = self._input_dtypes[key] |
| if isinstance(original_dtype, np.dtype): |
| state = state.numpy().astype(original_dtype) |
| elif isinstance(original_dtype, torch.dtype): |
| state = state.to(original_dtype) |
| else: |
| raise ValueError(f"Invalid input dtype: {original_dtype}") |
| data[key] = state |
| return data |
|
|
|
|
| class StateActionPerturbation(ModalityTransform): |
| """ |
| Class for state or action perturbation. |
| |
| Args: |
| apply_to (list[str]): The keys in the modality to load and transform. |
| std (float): Standard deviation of the noise to be added to the state or action. |
| """ |
|
|
| |
| std: float = Field( |
| ..., description="Standard deviation of the noise to be added to the state or action." |
| ) |
|
|
| def apply(self, data: dict[str, Any]) -> dict[str, Any]: |
| if not self.training: |
| |
| return data |
| if self.std < 0: |
| |
| return data |
| for key in self.apply_to: |
| state = data[key] |
| assert isinstance(state, torch.Tensor) |
| transformed_data_min = torch.min(state) |
| transformed_data_max = torch.max(state) |
| noise = torch.randn_like(state) * self.std |
| state += noise |
| |
| state = torch.clamp(state, transformed_data_min, transformed_data_max) |
| data[key] = state |
| return data |
|
|
|
|
| class StateActionDropout(ModalityTransform): |
| """ |
| Class for state or action dropout. |
| |
| Args: |
| apply_to (list[str]): The keys in the modality to load and transform. |
| dropout_prob (float): Probability of dropping out a state or action. |
| """ |
|
|
| |
| dropout_prob: float = Field(..., description="Probability of dropping out a state or action.") |
|
|
| def apply(self, data: dict[str, Any]) -> dict[str, Any]: |
| if not self.training: |
| |
| return data |
| if self.dropout_prob < 0: |
| |
| return data |
| if self.dropout_prob > 1e-9 and random.random() < self.dropout_prob: |
| for key in self.apply_to: |
| state = data[key] |
| assert isinstance(state, torch.Tensor) |
| state = torch.zeros_like(state) |
| data[key] = state |
| return data |
|
|
|
|
| class StateActionSinCosTransform(ModalityTransform): |
| """ |
| Class for state or action sin-cos transform. |
| |
| Args: |
| apply_to (list[str]): The keys in the modality to load and transform. |
| """ |
|
|
| def apply(self, data: dict[str, Any]) -> dict[str, Any]: |
| for key in self.apply_to: |
| state = data[key] |
| assert isinstance(state, torch.Tensor) |
| sin_state = torch.sin(state) |
| cos_state = torch.cos(state) |
| data[key] = torch.cat([sin_state, cos_state], dim=-1) |
| return data |
|
|