# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from typing import Any from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from ..schema import DatasetMetadata class ModalityTransform(BaseModel, ABC): """ Abstract class for transforming data modalities, e.g. video frame augmentation or action normalization. """ apply_to: list[str] = Field(..., description="The keys to apply the transform to.") training: bool = Field( default=True, description="Whether to apply the transform in training mode." ) _dataset_metadata: DatasetMetadata | None = PrivateAttr(default=None) model_config = ConfigDict(arbitrary_types_allowed=True) @property def dataset_metadata(self) -> DatasetMetadata: assert ( self._dataset_metadata is not None ), "Dataset metadata is not set. Please call set_metadata() before calling apply()." return self._dataset_metadata @dataset_metadata.setter def dataset_metadata(self, value: DatasetMetadata): self._dataset_metadata = value def set_metadata(self, dataset_metadata: DatasetMetadata): """ Set the dataset metadata. This is useful for transforms that need to know the dataset metadata, e.g. to normalize actions. Subclasses can override this method if they need to do something more complex. """ self.dataset_metadata = dataset_metadata def __call__(self, data: dict[str, Any]) -> dict[str, Any]: """Apply the transformation to the data corresponding to target_keys and return the processed data. Args: data (dict[str, Any]): The data to transform. example: data = { "video.image_side_0": np.ndarray, "action.eef_position": np.ndarray, ... } Returns: dict[str, Any]: The transformed data. example: transformed_data = { "video.image_side_0": np.ndarray, "action.eef_position": torch.Tensor, # Normalized and converted to tensor ... } """ return self.apply(data) @abstractmethod def apply(self, data: dict[str, Any]) -> dict[str, Any]: """Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data.""" def train(self): self.training = True def eval(self): self.training = False class InvertibleModalityTransform(ModalityTransform): @abstractmethod def unapply(self, data: dict[str, Any]) -> dict[str, Any]: """Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data.""" class ComposedModalityTransform(ModalityTransform): """Compose multiple modality transforms.""" transforms: list[ModalityTransform] = Field(..., description="The transforms to compose.") apply_to: list[str] = Field( default_factory=list, description="Will be ignored for composed transforms." ) training: bool = Field( default=True, description="Whether to apply the transform in training mode." ) model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) def set_metadata(self, dataset_metadata: DatasetMetadata): for transform in self.transforms: transform.set_metadata(dataset_metadata) def apply(self, data: dict[str, Any]) -> dict[str, Any]: for i, transform in enumerate(self.transforms): try: data = transform(data) except Exception as e: raise ValueError(f"Error applying transform {i} to data: {e}") from e return data def unapply(self, data: dict[str, Any]) -> dict[str, Any]: for i, transform in enumerate(reversed(self.transforms)): if isinstance(transform, InvertibleModalityTransform): try: data = transform.unapply(data) except Exception as e: step = len(self.transforms) - i - 1 raise ValueError(f"Error unapplying transform {step} to data: {e}") from e return data def train(self): for transform in self.transforms: transform.train() def eval(self): for transform in self.transforms: transform.eval()