File size: 5,154 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from ..schema import DatasetMetadata
class ModalityTransform(BaseModel, ABC):
"""
Abstract class for transforming data modalities, e.g. video frame augmentation or action normalization.
"""
apply_to: list[str] = Field(..., description="The keys to apply the transform to.")
training: bool = Field(
default=True, description="Whether to apply the transform in training mode."
)
_dataset_metadata: DatasetMetadata | None = PrivateAttr(default=None)
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def dataset_metadata(self) -> DatasetMetadata:
assert (
self._dataset_metadata is not None
), "Dataset metadata is not set. Please call set_metadata() before calling apply()."
return self._dataset_metadata
@dataset_metadata.setter
def dataset_metadata(self, value: DatasetMetadata):
self._dataset_metadata = value
def set_metadata(self, dataset_metadata: DatasetMetadata):
"""
Set the dataset metadata. This is useful for transforms that need to know the dataset metadata, e.g. to normalize actions.
Subclasses can override this method if they need to do something more complex.
"""
self.dataset_metadata = dataset_metadata
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
"""Apply the transformation to the data corresponding to target_keys and return the processed data.
Args:
data (dict[str, Any]): The data to transform.
example: data = {
"video.image_side_0": np.ndarray,
"action.eef_position": np.ndarray,
...
}
Returns:
dict[str, Any]: The transformed data.
example: transformed_data = {
"video.image_side_0": np.ndarray,
"action.eef_position": torch.Tensor, # Normalized and converted to tensor
...
}
"""
return self.apply(data)
@abstractmethod
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
def train(self):
self.training = True
def eval(self):
self.training = False
class InvertibleModalityTransform(ModalityTransform):
@abstractmethod
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
class ComposedModalityTransform(ModalityTransform):
"""Compose multiple modality transforms."""
transforms: list[ModalityTransform] = Field(..., description="The transforms to compose.")
apply_to: list[str] = Field(
default_factory=list, description="Will be ignored for composed transforms."
)
training: bool = Field(
default=True, description="Whether to apply the transform in training mode."
)
model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True)
def set_metadata(self, dataset_metadata: DatasetMetadata):
for transform in self.transforms:
transform.set_metadata(dataset_metadata)
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
for i, transform in enumerate(self.transforms):
try:
data = transform(data)
except Exception as e:
raise ValueError(f"Error applying transform {i} to data: {e}") from e
return data
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
for i, transform in enumerate(reversed(self.transforms)):
if isinstance(transform, InvertibleModalityTransform):
try:
data = transform.unapply(data)
except Exception as e:
step = len(self.transforms) - i - 1
raise ValueError(f"Error unapplying transform {step} to data: {e}") from e
return data
def train(self):
for transform in self.transforms:
transform.train()
def eval(self):
for transform in self.transforms:
transform.eval()
|