# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import numpy as np import torch from pydantic import Field from ..schema import DatasetMetadata, StateActionMetadata from .base import InvertibleModalityTransform class ConcatTransform(InvertibleModalityTransform): """ Concatenate the keys according to specified order. """ # -- We inherit from ModalityTransform, so we keep apply_to as well -- apply_to: list[str] = Field( default_factory=list, description="Not used in this transform, kept for compatibility." ) video_concat_order: list[str] = Field( ..., description="Concatenation order for each video modality. " "Format: ['video.ego_view_pad_res224_freq20', ...]", ) state_concat_order: Optional[list[str]] = Field( default=None, description="Concatenation order for each state modality. " "Format: ['state.position', 'state.velocity', ...].", ) action_concat_order: Optional[list[str]] = Field( default=None, description="Concatenation order for each action modality. " "Format: ['action.position', 'action.velocity', ...].", ) action_dims: dict[str, int] = Field( default_factory=dict, description="The dimensions of the action keys.", ) state_dims: dict[str, int] = Field( default_factory=dict, description="The dimensions of the state keys.", ) def model_dump(self, *args, **kwargs): if kwargs.get("mode", "python") == "json": include = { "apply_to", "video_concat_order", "state_concat_order", "action_concat_order", } else: include = kwargs.pop("include", None) return super().model_dump(*args, include=include, **kwargs) def apply(self, data: dict) -> dict: grouped_keys = {} for key in data.keys(): try: modality, _ = key.split(".") except: # noqa: E722 ### Handle language annotation special case if "annotation" in key: modality = "language" else: modality = "others" if modality not in grouped_keys: grouped_keys[modality] = [] grouped_keys[modality].append(key) if "video" in grouped_keys: # Check if keys in video_concat_order, state_concat_order, action_concat_order are # ineed contained in the data. If not, then the keys are misspecified video_keys = grouped_keys["video"] assert self.video_concat_order is not None, f"{self.video_concat_order=}, {video_keys=}" assert all( item in video_keys for item in self.video_concat_order ), f"keys in video_concat_order are misspecified, \n{video_keys=}, \n{self.video_concat_order=}" # Process each video view unsqueezed_videos = [] for video_key in self.video_concat_order: video_data = data.pop(video_key) unsqueezed_video = np.expand_dims( video_data, axis=-4 ) # [..., H, W, C] -> [..., 1, H, W, C] unsqueezed_videos.append(unsqueezed_video) # Concatenate along the new axis unsqueezed_video = np.concatenate(unsqueezed_videos, axis=-4) # [..., V, H, W, C] # Video data["video"] = unsqueezed_video # "state" if "state" in grouped_keys: state_keys = grouped_keys["state"] assert self.state_concat_order is not None, f"{self.state_concat_order=}" assert all( item in state_keys for item in self.state_concat_order ), f"keys in state_concat_order are misspecified, \n{state_keys=}, \n{self.state_concat_order=}" # Check the state dims for key in self.state_concat_order: target_shapes = [self.state_dims[key]] if self.is_rotation_key(key): target_shapes.append(6) # Allow for rotation_6d # if key in ["state.right_arm", "state.right_hand"]: target_shapes.append(self.state_dims[key] * 2) # Allow for sin-cos transform assert ( data[key].shape[-1] in target_shapes ), f"State dim mismatch for {key=}, {data[key].shape[-1]=}, {target_shapes=}" # Concatenate the state keys # We'll have StateActionToTensor before this transform, so here we use torch.cat data["state"] = torch.cat( [data.pop(key) for key in self.state_concat_order], dim=-1 ) # [T, D_state] if "action" in grouped_keys: action_keys = grouped_keys["action"] assert self.action_concat_order is not None, f"{self.action_concat_order=}" # Check if all keys in concat_order are present assert set(self.action_concat_order) == set( action_keys ), f"{set(self.action_concat_order)=}, {set(action_keys)=}" # Record the action dims for key in self.action_concat_order: target_shapes = [self.action_dims[key]] if self.is_rotation_key(key): target_shapes.append(3) # Allow for axis angle assert ( self.action_dims[key] == data[key].shape[-1] ), f"Action dim mismatch for {key=}, {self.action_dims[key]=}, {data[key].shape[-1]=}" # Concatenate the action keys # We'll have StateActionToTensor before this transform, so here we use torch.cat data["action"] = torch.cat( [data.pop(key) for key in self.action_concat_order], dim=-1 ) # [T, D_action] return data def unapply(self, data: dict) -> dict: start_dim = 0 assert "action" in data, f"{data.keys()=}" # For those dataset without actions (LAPA), we'll never run unapply assert self.action_concat_order is not None, f"{self.action_concat_order=}" action_tensor = data.pop("action") for key in self.action_concat_order: if key not in self.action_dims: raise ValueError(f"Action dim {key} not found in action_dims.") end_dim = start_dim + self.action_dims[key] data[key] = action_tensor[..., start_dim:end_dim] start_dim = end_dim if "state" in data: assert self.state_concat_order is not None, f"{self.state_concat_order=}" start_dim = 0 state_tensor = data.pop("state") for key in self.state_concat_order: end_dim = start_dim + self.state_dims[key] data[key] = state_tensor[..., start_dim:end_dim] start_dim = end_dim return data def __call__(self, data: dict) -> dict: return self.apply(data) def get_modality_metadata(self, key: str) -> StateActionMetadata: modality, subkey = key.split(".") assert self.dataset_metadata is not None, "Metadata not set" modality_config = getattr(self.dataset_metadata.modalities, modality) assert subkey in modality_config, f"{subkey=} not found in {modality_config=}" assert isinstance( modality_config[subkey], StateActionMetadata ), f"Expected {StateActionMetadata} for {subkey=}, got {type(modality_config[subkey])=}" return modality_config[subkey] def get_state_action_dims(self, key: str) -> int: """Get the dimension of a state or action key from the dataset metadata.""" modality_config = self.get_modality_metadata(key) shape = modality_config.shape assert len(shape) == 1, f"{shape=}" return shape[0] def is_rotation_key(self, key: str) -> bool: modality_config = self.get_modality_metadata(key) return modality_config.rotation_type is not None def set_metadata(self, dataset_metadata: DatasetMetadata): """Set the metadata and compute the dimensions of the state and action keys.""" super().set_metadata(dataset_metadata) # Pre-compute the dimensions of the state and action keys if self.action_concat_order is not None: for key in self.action_concat_order: self.action_dims[key] = self.get_state_action_dims(key) if self.state_concat_order is not None: for key in self.state_concat_order: self.state_dims[key] = self.get_state_action_dims(key)