File size: 26,885 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, ClassVar
import numpy as np
import pytorch3d.transforms as pt
import torch
from pydantic import Field, PrivateAttr, field_validator, model_validator
from ..schema import DatasetMetadata, RotationType, StateActionMetadata
from .base import InvertibleModalityTransform, ModalityTransform
class RotationTransform:
"""Adapted from https://github.com/real-stanford/diffusion_policy/blob/548a52bbb105518058e27bf34dcf90bf6f73681a/diffusion_policy/model/common/rotation_transformer.py"""
valid_reps = ["axis_angle", "euler_angles", "quaternion", "rotation_6d", "matrix"]
def __init__(self, from_rep="axis_angle", to_rep="rotation_6d"):
"""
Valid representations
Always use matrix as intermediate representation.
"""
if from_rep.startswith("euler_angles"):
from_convention = from_rep.split("_")[-1]
from_rep = "euler_angles"
from_convention = from_convention.replace("r", "X").replace("p", "Y").replace("y", "Z")
else:
from_convention = None
if to_rep.startswith("euler_angles"):
to_convention = to_rep.split("_")[-1]
to_rep = "euler_angles"
to_convention = to_convention.replace("r", "X").replace("p", "Y").replace("y", "Z")
else:
to_convention = None
assert from_rep != to_rep, f"from_rep and to_rep cannot be the same: {from_rep}"
assert from_rep in self.valid_reps, f"Invalid from_rep: {from_rep}"
assert to_rep in self.valid_reps, f"Invalid to_rep: {to_rep}"
forward_funcs = list()
inverse_funcs = list()
if from_rep != "matrix":
funcs = [getattr(pt, f"{from_rep}_to_matrix"), getattr(pt, f"matrix_to_{from_rep}")]
if from_convention is not None:
funcs = [functools.partial(func, convention=from_convention) for func in funcs]
forward_funcs.append(funcs[0])
inverse_funcs.append(funcs[1])
if to_rep != "matrix":
funcs = [getattr(pt, f"matrix_to_{to_rep}"), getattr(pt, f"{to_rep}_to_matrix")]
if to_convention is not None:
funcs = [functools.partial(func, convention=to_convention) for func in funcs]
forward_funcs.append(funcs[0])
inverse_funcs.append(funcs[1])
inverse_funcs = inverse_funcs[::-1]
self.forward_funcs = forward_funcs
self.inverse_funcs = inverse_funcs
@staticmethod
def _apply_funcs(x: torch.Tensor, funcs: list) -> torch.Tensor:
assert isinstance(x, torch.Tensor)
for func in funcs:
x = func(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert isinstance(
x, torch.Tensor
), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}"
return self._apply_funcs(x, self.forward_funcs)
def inverse(self, x: torch.Tensor) -> torch.Tensor:
assert isinstance(
x, torch.Tensor
), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}"
return self._apply_funcs(x, self.inverse_funcs)
class Normalizer:
valid_modes = ["q99", "mean_std", "min_max", "binary"]
def __init__(self, mode: str, statistics: dict):
self.mode = mode
self.statistics = statistics
for key, value in self.statistics.items():
self.statistics[key] = torch.tensor(value)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert isinstance(
x, torch.Tensor
), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}"
# Normalize the tensor
if self.mode == "q99":
# Range of q99 is [-1, 1]
q01 = self.statistics["q01"].to(x.dtype)
q99 = self.statistics["q99"].to(x.dtype)
# In the case of q01 == q99, the normalization will be undefined
# So we set the normalized values to the original values
mask = q01 != q99
normalized = torch.zeros_like(x)
# Normalize the values where q01 != q99
# Formula: 2 * (x - q01) / (q99 - q01) - 1
normalized[..., mask] = (x[..., mask] - q01[..., mask]) / (
q99[..., mask] - q01[..., mask]
)
normalized[..., mask] = 2 * normalized[..., mask] - 1
# Set the normalized values to the original values where q01 == q99
normalized[..., ~mask] = x[..., ~mask].to(x.dtype)
# Clip the normalized values to be between -1 and 1
normalized = torch.clamp(normalized, -1, 1)
elif self.mode == "mean_std":
# Range of mean_std is not fixed, but can be positive or negative
mean = self.statistics["mean"].to(x.dtype)
std = self.statistics["std"].to(x.dtype)
# In the case of std == 0, the normalization will be undefined
# So we set the normalized values to the original values
mask = std != 0
normalized = torch.zeros_like(x)
# Normalize the values where std != 0
# Formula: (x - mean) / std
normalized[..., mask] = (x[..., mask] - mean[..., mask]) / std[..., mask]
# Set the normalized values to the original values where std == 0
normalized[..., ~mask] = x[..., ~mask].to(x.dtype)
elif self.mode == "min_max":
# Range of min_max is [-1, 1]
min = self.statistics["min"].to(x.dtype)
max = self.statistics["max"].to(x.dtype)
# In the case of min == max, the normalization will be undefined
# So we set the normalized values to the original values
mask = min != max
normalized = torch.zeros_like(x)
# Normalize the values where min != max
# Formula: 2 * (x - min) / (max - min) - 1
normalized[..., mask] = (x[..., mask] - min[..., mask]) / (
max[..., mask] - min[..., mask]
)
normalized[..., mask] = 2 * normalized[..., mask] - 1
# Set the normalized values to the original values where min == max
# normalized[..., ~mask] = x[..., ~mask].to(x.dtype)
# Set the normalized values to 0 where min == max
normalized[..., ~mask] = 0
elif self.mode == "scale":
# Range of scale is [0, 1]
min = self.statistics["min"].to(x.dtype)
max = self.statistics["max"].to(x.dtype)
abs_max = torch.max(torch.abs(min), torch.abs(max))
mask = abs_max != 0
normalized = torch.zeros_like(x)
normalized[..., mask] = x[..., mask] / abs_max[..., mask]
normalized[..., ~mask] = 0
elif self.mode == "binary":
# Range of binary is [0, 1]
normalized = (x > 0.5).to(x.dtype)
else:
raise ValueError(f"Invalid normalization mode: {self.mode}")
return normalized
def inverse(self, x: torch.Tensor) -> torch.Tensor:
assert isinstance(
x, torch.Tensor
), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}"
if self.mode == "q99":
q01 = self.statistics["q01"].to(x.dtype)
q99 = self.statistics["q99"].to(x.dtype)
return (x + 1) / 2 * (q99 - q01) + q01
elif self.mode == "mean_std":
mean = self.statistics["mean"].to(x.dtype)
std = self.statistics["std"].to(x.dtype)
return x * std + mean
elif self.mode == "min_max":
min = self.statistics["min"].to(x.dtype)
max = self.statistics["max"].to(x.dtype)
return (x + 1) / 2 * (max - min) + min
elif self.mode == "binary":
return (x > 0.5).to(x.dtype)
else:
raise ValueError(f"Invalid normalization mode: {self.mode}")
class StateActionToTensor(InvertibleModalityTransform):
"""
Transforms states and actions to tensors.
"""
input_dtypes: dict[str, np.dtype] = Field(
default_factory=dict, description="The input dtypes for each state key."
)
output_dtypes: dict[str, torch.dtype] = Field(
default_factory=dict, description="The output dtypes for each state key."
)
def model_dump(self, *args, **kwargs):
if kwargs.get("mode", "python") == "json":
include = {"apply_to"}
else:
include = kwargs.pop("include", None)
return super().model_dump(*args, include=include, **kwargs)
@field_validator("input_dtypes", "output_dtypes", mode="before")
def validate_dtypes(cls, v):
for key, dtype in v.items():
if isinstance(dtype, str):
if dtype.startswith("torch."):
dtype_split = dtype.split(".")[-1]
v[key] = getattr(torch, dtype_split)
elif dtype.startswith("np.") or dtype.startswith("numpy."):
dtype_split = dtype.split(".")[-1]
v[key] = np.dtype(dtype_split)
else:
raise ValueError(f"Invalid dtype: {dtype}")
return v
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
for key in self.apply_to:
if key not in data:
continue
value = data[key]
assert isinstance(
value, np.ndarray
), f"Unexpected input type: {type(value)}. Expected type: {np.ndarray}"
data[key] = torch.from_numpy(value)
if key in self.output_dtypes:
data[key] = data[key].to(self.output_dtypes[key])
return data
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
for key in self.apply_to:
if key not in data:
continue
value = data[key]
assert isinstance(
value, torch.Tensor
), f"Unexpected input type: {type(value)}. Expected type: {torch.Tensor}"
data[key] = value.numpy()
if key in self.input_dtypes:
data[key] = data[key].astype(self.input_dtypes[key])
return data
class StateActionTransform(InvertibleModalityTransform):
"""
Class for state or action transform.
Args:
apply_to (list[str]): The keys in the modality to load and transform.
normalization_modes (dict[str, str]): The normalization modes for each state key.
If a state key in apply_to is not present in the dictionary, it will not be normalized.
target_rotations (dict[str, str]): The target representations for each state key.
If a state key in apply_to is not present in the dictionary, it will not be rotated.
"""
# Configurable attributes
apply_to: list[str] = Field(..., description="The keys in the modality to load and transform.")
normalization_modes: dict[str, str] = Field(
default_factory=dict, description="The normalization modes for each state key."
)
target_rotations: dict[str, str] = Field(
default_factory=dict, description="The target representations for each state key."
)
normalization_statistics: dict[str, dict] = Field(
default_factory=dict, description="The statistics for each state key."
)
modality_metadata: dict[str, StateActionMetadata] = Field(
default_factory=dict, description="The modality metadata for each state key."
)
# Model variables
_rotation_transformers: dict[str, RotationTransform] = PrivateAttr(default_factory=dict)
_normalizers: dict[str, Normalizer] = PrivateAttr(default_factory=dict)
_input_dtypes: dict[str, np.dtype | torch.dtype] = PrivateAttr(default_factory=dict)
# Model constants
_DEFAULT_MIN_MAX_STATISTICS: ClassVar[dict] = {
"rotation_6d": {
"min": [-1, -1, -1, -1, -1, -1],
"max": [1, 1, 1, 1, 1, 1],
},
"euler_angles": {
"min": [-np.pi, -np.pi, -np.pi],
"max": [np.pi, np.pi, np.pi],
},
"quaternion": {
"min": [-1, -1, -1, -1],
"max": [1, 1, 1, 1],
},
"axis_angle": {
"min": [-np.pi, -np.pi, -np.pi],
"max": [np.pi, np.pi, np.pi],
},
}
def model_dump(self, *args, **kwargs):
if kwargs.get("mode", "python") == "json":
include = {"apply_to", "normalization_modes", "target_rotations"}
else:
include = kwargs.pop("include", None)
return super().model_dump(*args, include=include, **kwargs)
@field_validator("modality_metadata", mode="before")
def validate_modality_metadata(cls, v):
for modality_key, config in v.items():
if isinstance(config, dict):
config = StateActionMetadata.model_validate(config)
else:
assert isinstance(
config, StateActionMetadata
), f"Invalid source rotation config: {config}"
v[modality_key] = config
return v
@model_validator(mode="after")
def validate_normalization_statistics(self):
for modality_key, normalization_statistics in self.normalization_statistics.items():
if modality_key in self.normalization_modes:
normalization_mode = self.normalization_modes[modality_key]
if normalization_mode == "min_max":
assert (
"min" in normalization_statistics and "max" in normalization_statistics
), f"Min and max statistics are required for min_max normalization, but got {normalization_statistics}"
assert len(normalization_statistics["min"]) == len(
normalization_statistics["max"]
), f"Min and max statistics must have the same length, but got {normalization_statistics['min']} and {normalization_statistics['max']}"
elif normalization_mode == "mean_std":
assert (
"mean" in normalization_statistics and "std" in normalization_statistics
), f"Mean and std statistics are required for mean_std normalization, but got {normalization_statistics}"
assert len(normalization_statistics["mean"]) == len(
normalization_statistics["std"]
), f"Mean and std statistics must have the same length, but got {normalization_statistics['mean']} and {normalization_statistics['std']}"
elif normalization_mode == "q99":
assert (
"q01" in normalization_statistics and "q99" in normalization_statistics
), f"q01 and q99 statistics are required for q99 normalization, but got {normalization_statistics}"
assert len(normalization_statistics["q01"]) == len(
normalization_statistics["q99"]
), f"q01 and q99 statistics must have the same length, but got {normalization_statistics['q01']} and {normalization_statistics['q99']}"
elif normalization_mode == "binary":
assert (
len(normalization_statistics) == 1
), f"Binary normalization should only have one value, but got {normalization_statistics}"
assert normalization_statistics[0] in [
0,
1,
], f"Binary normalization should only have 0 or 1, but got {normalization_statistics[0]}"
else:
raise ValueError(f"Invalid normalization mode: {normalization_mode}")
return self
def set_metadata(self, dataset_metadata: DatasetMetadata):
dataset_statistics = dataset_metadata.statistics
modality_metadata = dataset_metadata.modalities
# Check that all state keys specified in apply_to have their modality_metadata
for key in self.apply_to:
split_key = key.split(".", 1)
assert len(split_key) == 2, "State keys should have two parts: 'modality.key'"
if key not in self.modality_metadata:
modality, state_key = split_key
assert hasattr(modality_metadata, modality), f"{modality} config not found"
assert state_key in getattr(
modality_metadata, modality
), f"{state_key} config not found"
self.modality_metadata[key] = getattr(modality_metadata, modality)[state_key]
# Check that all state keys specified in normalization_modes have their statistics in state_statistics
for key in self.normalization_modes:
split_key = key.split(".", 1)
assert len(split_key) == 2, "State keys should have two parts: 'modality.key'"
modality, state_key = split_key
assert hasattr(dataset_statistics, modality), f"{modality} statistics not found"
assert state_key in getattr(
dataset_statistics, modality
), f"{state_key} statistics not found"
assert (
len(getattr(modality_metadata, modality)[state_key].shape) == 1
), f"{getattr(modality_metadata, modality)[state_key].shape=}"
self.normalization_statistics[key] = getattr(dataset_statistics, modality)[
state_key
].model_dump()
# Initialize the rotation transformers
for key in self.target_rotations:
# Get the original representation of the state
from_rep = self.modality_metadata[key].rotation_type
assert from_rep is not None, f"Source rotation type not found for {key}"
# Get the target representation of the state, will raise an error if the target representation is not valid
to_rep = RotationType(self.target_rotations[key])
# If the original representation is not the same as the target representation, initialize the rotation transformer
if from_rep != to_rep:
self._rotation_transformers[key] = RotationTransform(
from_rep=from_rep.value, to_rep=to_rep.value
)
# Initialize the normalizers
for key in self.normalization_modes:
modality, state_key = key.split(".", 1)
# If the state has a nontrivial rotation, we need to handle it more carefully
# For absolute rotations, we need to convert them to the target representation and normalize them using min_max mode,
# since we can infer the bounds by the representation
# For relative rotations, we cannot normalize them as we don't know the bounds
if key in self._rotation_transformers:
# Case 1: Absolute rotation
if self.modality_metadata[key].absolute:
# Check that the normalization mode is valid
assert (
self.normalization_modes[key] == "min_max"
), "Absolute rotations that are converted to other formats must be normalized using `min_max` mode"
rotation_type = RotationType(self.target_rotations[key]).value
# If the target representation is euler angles, we need to parse the convention
if rotation_type.startswith("euler_angles"):
rotation_type = "euler_angles"
# Get the statistics for the target representation
statistics = self._DEFAULT_MIN_MAX_STATISTICS[rotation_type]
# Case 2: Relative rotation
else:
raise ValueError(
f"Cannot normalize relative rotations: {key} that's converted to {self.target_rotations[key]}"
)
# If the state is not continuous, we should not use normalization modes other than binary
elif (
not self.modality_metadata[key].continuous
and self.normalization_modes[key] != "binary"
):
raise ValueError(
f"{key} is not continuous, so it should be normalized using `binary` mode"
)
# Initialize the normalizer
else:
statistics = self.normalization_statistics[key]
self._normalizers[key] = Normalizer(
mode=self.normalization_modes[key], statistics=statistics
)
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
for key in self.apply_to:
if key not in data:
# We allow some keys to be missing in the data, and only process the keys that are present
continue
if key not in self._input_dtypes:
input_dtype = data[key].dtype
assert isinstance(
input_dtype, torch.dtype
), f"Unexpected input dtype: {input_dtype}. Expected type: {torch.dtype}"
self._input_dtypes[key] = input_dtype
else:
assert (
data[key].dtype == self._input_dtypes[key]
), f"All states corresponding to the same key must be of the same dtype, input dtype: {data[key].dtype}, expected dtype: {self._input_dtypes[key]}"
# Rotate the state
state = data[key]
if key in self._rotation_transformers:
state = self._rotation_transformers[key].forward(state)
# Normalize the state
if key in self._normalizers:
state = self._normalizers[key].forward(state)
data[key] = state
return data
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
for key in self.apply_to:
if key not in data:
continue
state = data[key]
assert isinstance(
state, torch.Tensor
), f"Unexpected state type: {type(state)}. Expected type: {torch.Tensor}"
# Unnormalize the state
if key in self._normalizers:
state = self._normalizers[key].inverse(state)
# Change the state back to its original representation
if key in self._rotation_transformers:
state = self._rotation_transformers[key].inverse(state)
assert isinstance(
state, torch.Tensor
), f"State should be tensor after unapplying transformations, but got {type(state)}"
# Only convert back to the original dtype if it's known, i.e. `apply` was called before
# If not, we don't know the original dtype, so we don't convert
if key in self._input_dtypes:
original_dtype = self._input_dtypes[key]
if isinstance(original_dtype, np.dtype):
state = state.numpy().astype(original_dtype)
elif isinstance(original_dtype, torch.dtype):
state = state.to(original_dtype)
else:
raise ValueError(f"Invalid input dtype: {original_dtype}")
data[key] = state
return data
class StateActionPerturbation(ModalityTransform):
"""
Class for state or action perturbation.
Args:
apply_to (list[str]): The keys in the modality to load and transform.
std (float): Standard deviation of the noise to be added to the state or action.
"""
# Configurable attributes
std: float = Field(
..., description="Standard deviation of the noise to be added to the state or action."
)
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
if not self.training:
# Don't perturb the data in eval mode
return data
if self.std < 0:
# If the std is negative, we don't add any noise
return data
for key in self.apply_to:
state = data[key]
assert isinstance(state, torch.Tensor)
transformed_data_min = torch.min(state)
transformed_data_max = torch.max(state)
noise = torch.randn_like(state) * self.std
state += noise
# Clip to the original range
state = torch.clamp(state, transformed_data_min, transformed_data_max)
data[key] = state
return data
class StateActionDropout(ModalityTransform):
"""
Class for state or action dropout.
Args:
apply_to (list[str]): The keys in the modality to load and transform.
dropout_prob (float): Probability of dropping out a state or action.
"""
# Configurable attributes
dropout_prob: float = Field(..., description="Probability of dropping out a state or action.")
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
if not self.training:
# Don't drop out the data in eval mode
return data
if self.dropout_prob < 0:
# If the dropout probability is negative, we don't drop out any states
return data
if self.dropout_prob > 1e-9 and random.random() < self.dropout_prob:
for key in self.apply_to:
state = data[key]
assert isinstance(state, torch.Tensor)
state = torch.zeros_like(state)
data[key] = state
return data
class StateActionSinCosTransform(ModalityTransform):
"""
Class for state or action sin-cos transform.
Args:
apply_to (list[str]): The keys in the modality to load and transform.
"""
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
for key in self.apply_to:
state = data[key]
assert isinstance(state, torch.Tensor)
sin_state = torch.sin(state)
cos_state = torch.cos(state)
data[key] = torch.cat([sin_state, cos_state], dim=-1)
return data
|