import collections import collections.abc import importlib import re import warnings from abc import abstractmethod from functools import cached_property from typing import Dict, List, Optional, Sequence, Tuple, TypeVar import numpy as np import PIL.Image import roma import torch import torchvision.transforms.v2 import transformers import yaml try: from .hf_compat import Configurable, Template except Exception: _hf_compat = importlib.import_module("src.hf_compat") Configurable = _hf_compat.Configurable Template = _hf_compat.Template from .common_vlarm import ( Normalization, ResizeMode, RoboticsControlPlan, RoboticsInput, RoboticsOutput, RoboticsTarget, RotationFormat, expand_dims, ) from .configuration_vlarm import ( ControlDataIOConfig, ControlTokenizerConfig, EmptyTokenizerConfig, ImageSizeConfig, PaliGemmaProcessorConfig, RegressionProcessorConfig, VLAMProcessorConfig, VLARMProcessorConfig, VLMProcessorConfig, ) ControlTokenizerConfigT = TypeVar('ControlTokenizerConfigT', bound=ControlTokenizerConfig) class ControlTokenizer(Configurable[ControlTokenizerConfigT], Template[ControlTokenizerConfigT]): @abstractmethod def __call__(self, *args, **kwargs) -> str: """Given GT actions and possibly other information, output text control. Gets appened to the prompt""" class EmptyTokenizer(ControlTokenizer[EmptyTokenizerConfig]): """ Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the desired result. Includes the hidden states for the image tokens. """ def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None: super().__init__(config) self.tokenizer = tokenizer def __call__(self, *_) -> str: return '' def np_unique(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Compute unique elements in data and corresponding indices. np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply run np.unique on unsorted data, the indices you will get will be invalid. """ (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True) (_, indices_of_first_occurence, inverse_indices, counts) = np.unique( indices[inverse], return_index=True, return_inverse=True, return_counts=True ) unique_ids = data[indices_of_first_occurence] return unique_ids, indices_of_first_occurence, inverse_indices, counts def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor: """ Args: angles: Euler angles in radians in the format 'xyz', shape [..., 3] Returns: torch.Tensor of shape [..., 3, 3] containing rotation matrices """ return roma.euler_to_rotmat(convention='xyz', angles=angles, degrees=False) def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor: """ Args: angles: Euler angles in radians in the format 'xyz', shape [..., 3] Returns: torch.Tensor of shape [..., 4] containing unit quaternions """ return roma.euler_to_unitquat(convention='xyz', angles=angles, degrees=False, normalize=True) def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor: """ Args: quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4] eps: Small constant to prevent division by zero Returns: torch.Tensor of shape [..., 4] of unit quaternions """ return quaternion / (quaternion.norm(dim=-1, keepdim=True) + eps) def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor: """ Args: quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized Returns: torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) """ unit_quat = normalize_quaternion(quaternion) rotmat = roma.unitquat_to_euler(convention='xyz', quat=unit_quat, as_tuple=False, degrees=False) return rotmat def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor: """ Args: quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized Returns: torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) """ unit_quat = normalize_quaternion(quaternion) rotmat = roma.unitquat_to_rotmat(unit_quat) return rotmat def is_quaternion(quaternion: torch.Tensor) -> bool: return quaternion.shape[-1] == 4 def is_unit_quaternion( quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = 'none' ) -> torch.Tensor | bool: """ Check if a quternion is normalized or not. Args: quaternion: torch.Tensor of shape [..., 4] tolerance: Tolerance for numerical comparisons reduction: 'none' - returns torch.Tensor of bools with the same batch shape 'all' - returns a bool, True if ALL quaternions in the batch are normalized Returns: torch.Tensor with the same batch shape or bool """ assert is_quaternion(quaternion) is_norm = torch.isclose( quaternion.norm(dim=-1, keepdim=True), torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device), atol=epsilon, ) if reduction == 'none': return is_norm if reduction == 'all': return bool(torch.all(is_norm).item()) raise ValueError(f'Unknown reduction mode {reduction}') def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor: """ Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion. If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically, this doesn't correspond to a single hemisphere of the unit sphere. Follows https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat """ assert is_quaternion(quaternion), quaternion.shape with torch.no_grad(): is_zero = quaternion == 0 flip_condition = ( (quaternion[..., -1:] < 0) | is_zero[..., -1:] & (quaternion[..., 0:1] < 0) | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0) | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0) ) quaternion = torch.where(flip_condition, -quaternion, quaternion) return quaternion def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor: """Convert any rotmat input to [..., 3, 3] shape""" if rotmat.shape[-1] == 9: return rotmat.reshape(*rotmat.shape[:-1], 3, 3) if rotmat.shape[-2:] == torch.Size([3, 3]): return rotmat raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor: """ Args: rotmat: Batch of rotation matrices, shape [..., 3, 3] Returns: Batch of Euler angles in radiant, shape [..., 3] """ rotmat = rotmat_as_3x3(rotmat) return roma.rotmat_to_euler(convention='xyz', rotmat=rotmat, as_tuple=False, degrees=False) def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor: """ Args: rotmat: Batch of rotation matrices, shape [..., 3, 3] Returns: Batch of unit quaternions, shape [..., 4] """ rotmat = rotmat_as_3x3(rotmat) return roma.rotmat_to_unitquat(rotmat) def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor: """ Maps 9D input vectors onto SO(3) via symmetric orthogonalization. - Let SVD(M) = U \Sigma V^T - Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T - det(UV^T) ensures that det(SVD+(M)) = 1 - The return value is a rotation matrix (ortonormal) with the least-squares distance to M Args: x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3] Returns: torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3) """ with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='In CPU autocast, but the target dtype is not supported. Disabling autocast.' ) with torch.autocast(device_type=x.device.type, dtype=torch.float32): matrices = x.view(-1, 3, 3) matrices = matrices.to(dtype=torch.float32) (u, s, v) = torch.svd(matrices) vt = torch.transpose(v, 1, 2) det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1) diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1) result = torch.matmul(u, diag_vt) result = result.view(*x.shape) result = result.to(dtype=x.dtype) return result def is_rotmat_3x3(rotmat: torch.Tensor) -> bool: return rotmat.shape[-2:] == torch.Size([3, 3]) def is_rotmat_9(rotmat: torch.Tensor) -> bool: return rotmat.shape[-1] == 9 def is_rotmat(rotmat: torch.Tensor) -> bool: """ Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a valid rotmat. `is_orthonormal_rotmat` performs this additional check. NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution """ return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat) def is_rotmat_orthonormal( rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = 'none' ) -> torch.Tensor | bool: """ Check if a rotation matrix is orthonormal or not. Args: rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9] epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally, anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not reduction: 'none' - returns torch.Tensor of bools with the same batch shape 'all' - returns a bool, True is ALL matrices in the batch are orthonormal Returns: torch.Tensor with the same batch shape or bool """ assert is_rotmat(rotmat) rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32)) is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon) if reduction == 'none': return is_orthonormal if reduction == 'all': return bool(torch.all(is_orthonormal).item()) raise ValueError(f'Unknown reduction mode {reduction}') def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool: """ Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3, also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles when accidentally `rotmat.shape[-2:] == [3, 3]` """ return ( is_rotmat_9(rotmat) or is_rotmat_3x3(rotmat) and is_rotmat_orthonormal(rotmat, epsilon=0.02, reduction='all') ) def is_euler(euler: torch.Tensor) -> bool: return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler) def rotation_format_from_tensor(rotation) -> RotationFormat: if is_quaternion(rotation): return RotationFormat.QUATERNION if is_orthonormal_rotmat(rotation): return RotationFormat.ROTMAT if is_euler(rotation): return RotationFormat.EULER raise ValueError(f'Tensor shape {rotation.shape} is not a valid rotation format') def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor: """Convert any rotmat input to [..., 9] shape""" if is_rotmat_9(rotmat): return rotmat if is_rotmat_3x3(rotmat): return rotmat.reshape(*rotmat.shape[:-2], 9) raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") def convert_rotation( rotation: torch.Tensor | np.ndarray, output_format: RotationFormat, autonorm: bool = True, half_cover: bool = True, ) -> torch.Tensor | np.ndarray: is_np = isinstance(rotation, np.ndarray) if is_np: rotation = torch.from_numpy(rotation) if is_quaternion(rotation): if autonorm and not is_unit_quaternion(rotation, reduction='all'): rotation = normalize_quaternion(rotation) if output_format == RotationFormat.QUATERNION: output = rotation elif output_format == RotationFormat.ROTMAT: output = rotmat_as_9(quaternion_to_rotmat(rotation)) elif output_format == RotationFormat.EULER: output = quaternion_to_euler(rotation) else: raise NotImplementedError(f'Unsupported rotation format: {output_format}') elif is_orthonormal_rotmat(rotation): if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction='all'): rotation = symmetric_orthogonalization(rotation) if output_format == RotationFormat.QUATERNION: output = rotmat_to_unit_quaternion(rotation) elif output_format == RotationFormat.ROTMAT: output = rotmat_as_9(rotation) elif output_format == RotationFormat.EULER: output = rotmat_to_euler(rotation) else: raise NotImplementedError(f'Unsupported rotation format: {output_format}') elif is_euler(rotation): if output_format == RotationFormat.QUATERNION: output = euler_to_unit_quaternion(rotation) elif output_format == RotationFormat.ROTMAT: output = rotmat_as_9(euler_to_rotmat(rotation)) elif output_format == RotationFormat.EULER: output = rotation else: raise NotImplementedError(f'Unsupported rotation format: {output_format}') else: raise ValueError(f'Unknown rotation encoding with shape {rotation.shape}') if output_format == RotationFormat.QUATERNION and half_cover: output = quaternion_half_cover(output) if is_np: output = output.numpy() return output def delta_to_world_rotations(rotation_sequence: torch.Tensor, reference_frame: torch.Tensor) -> torch.Tensor: """ Transform a sequence of rotation representations encoded w.r.t. `reference_frame` to encoding w.r.t. WORLD frame, where `reference_frame` is provided w.r.t. WORLD frame Ex: Sequence of points (rotations): R_1, R_2, R_3, R_4 `rotation_sequence` contains the rotations: R_01, R_02, R_03, R_04, where R_0 is the reference frame and R_01 is the pose of R1 frame in reference frame, i.e. R_10 converts from reference frame to R1 frame Output: R_W1, R_W2, R_W3, R_W4, where W is the world frame, i.e. the rotation poses of R_1, R_2, R_3, R_4 expressed in world frame Args: rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions reference_frame: torch.Tensor, shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] and the SAME number of BATCH dims as `rotation_sequence`. The reference frame, provided w.r.t. WORLD coordinate frame R_W0,1,2,3 Returns: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations (R_W1, R_W2, R_W3, R_W4, ...) """ assert rotation_sequence.ndim >= 3, rotation_sequence.shape rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) reference_frame = rotmat_as_3x3(convert_rotation(reference_frame, RotationFormat.ROTMAT)) rotation_sequence = rotmat_as_3x3(convert_rotation(rotation_sequence, RotationFormat.ROTMAT)) if reference_frame.ndim != rotation_sequence.ndim: raise ValueError( f'Cannot broadcast reference_frame of shape {reference_frame.shape} to rotation_sequence of shape {rotation_sequence.shape}. Provide tensors with the same number of batch dimensions' ) R_W0 = reference_frame world_rotations = torch.matmul(R_W0, rotation_sequence) world_rotations = world_rotations.view(*world_rotations.shape[:-2], 9) world_rotations = convert_rotation(world_rotations, rotation_format) return world_rotations def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor: """ Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the sequence to the 0-th element preceding the sequence Ex: `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame, implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame Output: R_01, R_02, R_03, R_04 Args: rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions Returns: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations (R_01, R_02, R_03, R_04, ...) TODO: Can you make it work without for loop """ assert rotation_sequence.ndim >= 3, rotation_sequence.shape rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION) batch_dims = np.arange(rotation_sequence.ndim - 2) delta_rotations = torch.cat( [rotation_sequence[..., :1, :]] + [ roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2)) for i in range(2, rotation_sequence.shape[-2] + 1) ], dim=-2, ) delta_rotations = convert_rotation(delta_rotations, rotation_format) return delta_rotations def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray: """Make sure image is of type np.ndarray and HWC format""" if isinstance(image, PIL.Image.Image): image = np.asarray(image) assert isinstance(image, np.ndarray), type(image) assert image.ndim in [2, 3], image.shape if image.ndim == 3: assert image.shape[-1] <= 4, image.shape return image def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]: if isinstance(image, np.ndarray): (height, width) = image.shape[:2] else: (width, height) = image.size return height, width def pad_image( image: PIL.Image.Image | np.ndarray, target_size: dict[str, int], pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: """Pad image adding a symmetric border around the height/width.""" assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) (height, width) = hw_from_image(image) (target_width, target_height) = (target_size['width'], target_size['height']) if width == target_width and height == target_height: return image assert target_width >= width, f"Can't pad image of width {width} to {target_width}" assert target_height >= height, f"Can't pad image of height {height} to {target_height}" (horizontal_pad, vertical_pad) = (int((target_width - width) / 2), int((target_height - height) / 2)) if isinstance(image, np.ndarray): padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * ( image.ndim - 2 ) image = np.pad(image, padding, mode='constant', constant_values=pad_value) else: padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) image = torchvision.transforms.v2.functional.pad( image, padding=padding, fill=pad_value, padding_mode='constant' ) return image def pad_image_to_ratio( image: PIL.Image.Image | np.ndarray, target_wh_ratio: float, pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: """Pad image to a target aspect ratio.""" (height, width) = hw_from_image(image) wh_ratio = width / height if target_wh_ratio >= wh_ratio: pad_size = {'width': round(height * target_wh_ratio), 'height': height} else: pad_size = {'width': width, 'height': round(width / target_wh_ratio)} image = pad_image(image, target_size=pad_size, pad_value=pad_value) return image def crop_image( image: np.ndarray | PIL.Image.Image, start_height: int, start_width: int, target_height: int, target_width: int, ) -> np.ndarray | PIL.Image.Image: np_image = assert_np_hwc_or_hw_image(image) (height, width) = hw_from_image(image) assert target_width <= width, f"Can't crop image of width {width} to {target_width}" assert target_height <= height, f"Can't crop image of width {height} to {target_height}" (start_height, start_width) = (round(start_height), round(start_width)) (target_height, target_width) = (round(target_height), round(target_width)) np_image = np_image[ start_height : start_height + target_height, start_width : start_width + target_width, ... ] image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image return image def crop_image_center( image: np.ndarray | PIL.Image.Image, target_size: dict[str, int] ) -> np.ndarray | PIL.Image.Image: np_image = assert_np_hwc_or_hw_image(image) (height, width) = np_image.shape[:2] (target_height, target_width) = (target_size['height'], target_size['width']) assert target_width <= width, f"Can't crop image of width {width} to {target_width}" assert target_height <= height, f"Can't crop image of width {height} to {target_height}" top = (height - target_height) // 2 left = (width - target_width) // 2 np_image = crop_image(np_image, top, left, target_height, target_width) image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image return image def crop_image_to_ratio( image: PIL.Image.Image | np.ndarray, target_wh_ratio: float ) -> PIL.Image.Image | np.ndarray: """Pad image to a target aspect ratio.""" (height, width) = hw_from_image(image) wh_ratio = width / height if target_wh_ratio >= wh_ratio: crop_size = {'width': width, 'height': round(width / target_wh_ratio)} else: crop_size = {'width': round(height * target_wh_ratio), 'height': height} image = crop_image_center(image, target_size=crop_size) return image def crop_and_pad_image_to_ratio( image: PIL.Image.Image | np.ndarray, target_wh_ratio: float, mode: ResizeMode | str, pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: """ Crop and pad an image to a target size depending on the mode. It's expected that the source image and target size have different aspect ratios. Args: image: The image to crop and pad. target_size: The target size to crop and pad the image to. mode: The mode to use for cropping and padding. """ (height, width) = hw_from_image(image) wh_ratio = width / height if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001): return image if mode == ResizeMode.SMART: aspect_ratio = max(width, height) / min(width, height) target_ratio = max(target_wh_ratio, 1 / target_wh_ratio) if aspect_ratio == 1: if target_ratio >= 4 / 3 - 0.01: crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 image = crop_image_to_ratio(image, crop_wh_ratio) else: pass elif aspect_ratio <= 4 / 3 + 0.01: if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): image = crop_image_to_ratio(image, 1.0) elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): image = crop_image_to_ratio(image, 1.0) elif target_ratio >= 4 / 3 + 0.01: pass else: crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 image = crop_image_to_ratio(image, crop_wh_ratio) image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) elif mode == ResizeMode.PAD: image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) elif mode == ResizeMode.CROP: image = crop_image_to_ratio(image, target_wh_ratio) else: raise ValueError(f'Mode {mode} not supported') return image def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool: if isinstance(image, PIL.Image.Image): return image.mode in ['1', 'L', 'LA', 'La', 'P', 'PA', 'F', 'I', 'I;16', 'I;16L', 'I;16B', 'I;16N'] if isinstance(image, np.ndarray): return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1 raise ValueError(f'Unsupported image type: {type(image)}') def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool: image = np.asarray(image) return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1 def resize_image( image: PIL.Image.Image | np.ndarray, target_size: dict[str, int], mode: ResizeMode | str, resample: PIL.Image.Resampling | str = 'auto', pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: (target_width, target_height) = (target_size['width'], target_size['height']) (height, width) = hw_from_image(image) if height == target_height and width == target_width: return image if resample == 'auto': if is_single_channel_image(image): resample = PIL.Image.Resampling.BILINEAR else: resample = PIL.Image.Resampling.LANCZOS else: assert isinstance(resample, PIL.Image.Resampling), resample if is_single_channel_image(image) and resample not in [ PIL.Image.Resampling.BILINEAR, PIL.Image.Resampling.BICUBIC, ]: raise ValueError( f'Single channel images must be resized with bilinear or bicubic, but got {resample}' ) if is_bin_mask := is_binary_mask(image): image = np.asarray(image).astype(np.uint8) * 255 if mode == ResizeMode.SMART: image = crop_and_pad_image_to_ratio( image, target_wh_ratio=target_width / target_height, mode=mode, pad_value=pad_value ) pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image if mode in [ResizeMode.NAIVE, ResizeMode.SMART]: pil_image = pil_image.resize((target_width, target_height), resample=resample) else: raise NotImplementedError(f'Mode {mode} not supported') image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image if is_bin_mask: image = image.astype(np.uint8) > 127 return image def is_global_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool: """Return true if norm is NONE or global for all datasets""" return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping) def is_mean_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool: """Return true if norm is based on mean and std""" return ( norm == Normalization.MEAN or isinstance(norm, collections.abc.Mapping) and set(norm.keys()) == {'mean', 'std'} ) def _broadcast_shapes( value: torch.Tensor, low: torch.Tensor, high: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Broadcast shapes for normalization: Args: value: torch.Tensor of shape [..., num_components]. The entire shape might be: - [num_components]: `value` has no batch dimension - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds contained in `low` and `high` - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds contained in `low` and `high` - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high` must be for a single dataset, i.e. `num_datasets = 1` low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low` contains normalization bounds for a single dataset high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high` contains normalization bounds for a single dataset Returns: Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value` """ assert low.ndim == high.ndim == 2, f'{low.shape} != {high.shape} or ndim != 2' assert value.shape[-1] == low.shape[-1] == high.shape[-1], f'{value.shape} != {low.shape} / {high.shape}' if value.ndim == low.ndim == high.ndim: return low, high if value.ndim < low.ndim: assert low.ndim == high.ndim == 2, f'{low.shape}, {high.shape}' assert low.shape[0] == high.shape[0] == 1, f'{low.shape}, {high.shape}' (low, high) = (low.view(-1), high.view(-1)) return low, high if low.shape[0] == high.shape[0] == 1: low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1]) high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1]) else: assert value.shape[0] == low.shape[0] == high.shape[0], f'{value.shape} != {low.shape} / {high.shape}' low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1]) high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1]) return low, high def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: (mean, std) = _broadcast_shapes(value, mean, std) (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) return value * (std + 1e-08) + mean def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: (low, high) = _broadcast_shapes(value, low, high) (low, high) = (low.to(device=value.device), high.to(device=value.device)) return 0.5 * (value + 1) * (high - low) + low def normalize_gripper_by_bounds( value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True ) -> torch.Tensor: """ If binary, normalize to [0, 1], otherwise normalize to [-1, 1] """ (low, high) = _broadcast_shapes(value, low, high) (low, high) = (low.to(device=value.device), high.to(device=value.device)) if binary: return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0) return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: (mean, std) = _broadcast_shapes(value, mean, std) (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) return (value - mean) / (std + 1e-08) def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: (low, high) = _broadcast_shapes(value, low, high) (low, high) = (low.to(device=value.device), high.to(device=value.device)) return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray: if low < 0.0: return np.clip(-gripper, low, high) return high - np.clip(gripper, low, high) GRIPPER_BOUNDS = { 'austin_buds_dataset': (0.0, 0.08), 'austin_sailor_dataset': (0.0, 0.08), 'austin_sirius_dataset': (0.0, 0.08), 'bc_z': (0.0, 1.0), 'berkeley_autolab_ur5': (0.0, 1.0), 'berkeley_cable_routing': (0.0, 1.0), 'berkeley_fanuc_manipulation': (0.0, 1.0), 'bridge': (0.0, 1.0), 'bridge_orig': (0.0, 1.0), 'bridge_action_lang': (0.0, 1.0), 'cmu_stretch': (-3.0, 3.0), 'dlr_edan_shared_control': (0.0, 1.0), 'droid': (0.0, 1.0), 'fmb': (0.0, 1.0), 'fractal20220817_data': (0.0, 1.0), 'furniture_bench_dataset': (0.0, 0.08), 'iamlab_cmu_pickup_insert': (0.0, 1.0), 'jaco_play': (0.0, 1.4), 'kuka': (0.0, 1.0), 'language_table': (0.0, 1.0), 'nyu_franka_play_dataset': (0.0, 1.0), 'roboset': (0.0, 1.0), 'roboturk': (0.0, 1.0), 'stanford_hydra_dataset': (0.0, 0.08), 'taco_play': (0.0, 0.08), 'toto': (0.0, 1.0), 'ucsd_kitchen_dataset': (0.0, 1.0), 'utaustin_mutex': (0.0, 0.08), 'viola': (0.0, 0.08), } def preprocess_gripper_observation( gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True ) -> np.ndarray: """ Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset or from the robot and output is normalized continuous value. - if `binary`, output is in [0, 1], with 0 = closed and 1 = open. - otherwise, output is in [-1, 1], with -1 = closed and 1 = open. Dataset-specific gripper observations: austin_buds_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) austin_sailor_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) austin_sirius_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) bc_z: continuous; [0=open; 1=closed] berkeley_autolab_ur5: binary; [0=open; 1=closed] berkeley_cable_routing: constant (closed) berkeley_fanuc_manipulation: binary; [0=open; 1=closed] bridge: continuous; ~[0=closed; 1=open] bridge_orig: continuous; ~[0=closed; 1=open] cmu_stretch: continuous; [-3=closed; 3=open] dlr_edan_shared_control: missing droid: continuous; [0=open, 1=closed] fmb: binary; [0=open; 1=closed] fractal20220817_data: continuous; [0=open; 1=closed] furniture_bench_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) iamlab_cmu_pickup_insert: binary; [0=closed; 1=open] jaco_play: continuous; [0=open; 1.4=closed] kuka: binary; [0=open; 1=closed] language_table: constant (no gripper) nyu_franka_play_dataset: missing roboset: continuous; [0=open, 1=closed] roboturk: continuous; [0=closed, 0.04=open] stanford_hydra_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) taco_play: continuous; ~[0=closed; 0.08=open] (franka gripper) toto: constant (closed) ucsd_kitchen_dataset: missing utaustin_mutex: continuous; ~[0=closed; 0.08=open] (franka gripper) viola: continuous; ~[0=closed; 0.08=open] (franka gripper) """ if isinstance(dataset_name, np.ndarray): assert np.unique(dataset_name).size == 1, dataset_name dataset_name = str(dataset_name[0]) if dataset_name in [ 'berkeley_cable_routing', 'dlr_edan_shared_control', 'language_table', 'nyu_franka_play_dataset', 'toto', 'ucsd_kitchen_dataset', ]: gripper = normalize_gripper_by_bounds( torch.from_numpy(gripper), low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), binary=binary, ).numpy() elif dataset_name in [ 'bc_z', 'berkeley_autolab_ur5', 'berkeley_fanuc_manipulation', 'droid', 'fmb', 'fractal20220817_data', 'jaco_play', 'kuka', 'roboset', ]: (low, high) = GRIPPER_BOUNDS[dataset_name] gripper = normalize_gripper_by_bounds( torch.from_numpy(invert_gripper(gripper, low=low, high=high)), low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), binary=binary, ).numpy() elif dataset_name in [ 'austin_buds_dataset', 'austin_sailor_dataset', 'austin_sirius_dataset', 'bridge', 'bridge_orig', 'bridge_action_lang', 'cmu_stretch', 'furniture_bench_dataset', 'iamlab_cmu_pickup_insert', 'roboturk', 'stanford_hydra_dataset', 'taco_play', 'utaustin_mutex', 'viola', ]: (low, high) = GRIPPER_BOUNDS[dataset_name] gripper = normalize_gripper_by_bounds( torch.from_numpy(gripper), low=torch.full(gripper.shape, low, dtype=torch.float32), high=torch.full(gripper.shape, high, dtype=torch.float32), binary=binary, ).numpy() else: raise NotImplementedError(f'Unknown dataset: {dataset_name}') return gripper def rotation_norm_bounds( rotation_norm: Normalization, rotation_format: RotationFormat, stats: Dict[str, Dict[str, Dict[str, List[float]]]], dataset_names: List[str], ) -> Dict[str, Dict[str, torch.Tensor]]: if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE: if rotation_norm == Normalization.BOUNDS: results = { dataset_name: { 'low': torch.tensor(dataset_stats['euler']['min']), 'high': torch.tensor(dataset_stats['euler']['max']), } for (dataset_name, dataset_stats) in stats.items() } elif rotation_norm == Normalization.BOUNDS_Q99: results = { dataset_name: { 'low': torch.tensor(dataset_stats['euler']['q01']), 'high': torch.tensor(dataset_stats['euler']['q99']), } for (dataset_name, dataset_stats) in stats.items() } else: raise NotImplementedError(f'Normalization type {rotation_norm} not yet implemented') else: assert rotation_norm == Normalization.NONE, rotation_norm if rotation_format == RotationFormat.EULER: rotation_size = 3 elif rotation_format == RotationFormat.QUATERNION: rotation_size = 4 else: rotation_size = 9 results = { dataset_name: { 'low': -1 * torch.ones(rotation_size, dtype=torch.float32), 'high': 1 * torch.ones(rotation_size, dtype=torch.float32), } for dataset_name in dataset_names } return results def translation_norm_bounds( translation_norm: Normalization | tuple, stats: Dict[str, Dict[str, Dict[str, List[float]]]], dataset_names: List[str], ) -> Dict[str, Dict[str, torch.Tensor]]: if isinstance(translation_norm, Normalization) and translation_norm != Normalization.NONE: if translation_norm == Normalization.BOUNDS: results = { dataset_name: { 'low': torch.tensor(dataset_stats['translation']['min']), 'high': torch.tensor(dataset_stats['translation']['max']), } for (dataset_name, dataset_stats) in stats.items() } elif translation_norm == Normalization.BOUNDS_Q99: results = { dataset_name: { 'low': torch.tensor(dataset_stats['translation']['q01']), 'high': torch.tensor(dataset_stats['translation']['q99']), } for (dataset_name, dataset_stats) in stats.items() } elif translation_norm == Normalization.MEAN: results = { dataset_name: { 'mean': torch.tensor(dataset_stats['translation']['mean']), 'std': torch.tensor(dataset_stats['translation']['std']), } for (dataset_name, dataset_stats) in stats.items() } else: raise NotImplementedError(f'Normalization type {translation_norm} not yet implemented') elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE: results = { dataset_name: { 'low': -1 * torch.ones(3, dtype=torch.float32), 'high': 1 * torch.ones(3, dtype=torch.float32), } for dataset_name in dataset_names } else: assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm) assert all((len(value) == 3 for value in translation_norm.values())), translation_norm assert set(translation_norm.keys()) in ({'low', 'high'}, {'mean', 'std'}), translation_norm results = { dataset_name: { key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items() } for dataset_name in dataset_names } return results VLMProcessorConfigT = TypeVar('VLMProcessorConfigT', bound=VLMProcessorConfig) class VLMProcessor(Configurable[VLMProcessorConfigT], Template[VLMProcessorConfigT]): @abstractmethod def preprocess_inputs( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ... @property @abstractmethod def tokenizer(self) -> transformers.PreTrainedTokenizerBase: pass @property @abstractmethod def image_sizes(self) -> Dict[str, ImageSizeConfig]: pass VLAMProcessorConfigT = TypeVar('VLAMProcessorConfigT', bound=VLAMProcessorConfig) class VLAMProcessor(Configurable[VLAMProcessorConfigT], Template[VLAMProcessorConfigT]): def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor): super().__init__(config) self.vlm_processor = vlm_processor self.control_tokenizer = EmptyTokenizer( config=self.config.control_tokenizer_config, tokenizer=self.tokenizer ) self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { 'obs_translation': self.obs_translation_norm_bounds, 'obs_rotation': self.obs_rotation_norm_bounds, 'translation': self.translation_norm_bounds, 'rotation': self.rotation_norm_bounds, 'joints': self.joints_norm_bounds, } @property def tokenizer(self) -> transformers.PreTrainedTokenizerBase: return self.vlm_processor.tokenizer @property def image_sizes(self) -> Dict[str, ImageSizeConfig]: return self.vlm_processor.image_sizes @property def camera_names(self) -> List[str]: return list(self.vlm_processor.image_sizes.keys()) @property def control_io_config(self) -> ControlDataIOConfig: return self.config.control_io_config @cached_property def rotation_components(self) -> int: if self.config.rotation_format == RotationFormat.EULER: return 3 if self.config.rotation_format == RotationFormat.QUATERNION: return 4 if self.config.rotation_format == RotationFormat.ROTMAT: return 9 raise NotImplementedError(self.config.rotation_format) @abstractmethod def policy_control_plan_from_model_target( self, target: RoboticsTarget, dataset_name: np.ndarray ) -> RoboticsControlPlan: pass @abstractmethod def policy_control_plan_from_model_output( self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor ) -> RoboticsControlPlan: pass def resize_image( self, camera_name: str, image: PIL.Image.Image | np.ndarray ) -> PIL.Image.Image | np.ndarray: return resize_image( image, target_size={ 'width': self.image_sizes[camera_name].width, 'height': self.image_sizes[camera_name].height, }, mode=self.config.image_resize, resample=PIL.Image.Resampling.LANCZOS, ) def preprocess_inputs( self, chat: List[str], images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]], ee_pose_translation: np.ndarray, ee_pose_rotation: np.ndarray, gripper: np.ndarray, joints: np.ndarray, dataset_name: np.ndarray, inference_mode: bool, control_target: Optional[RoboticsTarget] = None, ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: """ Preprocess the inputs for a single example Args: instruction: Language instruction images: History of input images with increasing timestamps ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3] ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9] joints: np.ndarray, shape [..., num_past_scalars, <= 7] dataset_name: 1D np.ndarray inference_mode: If True, prepare the input for inference (e.g. don't include target any tokens in the input if relevant). If control_target is available, it should still be preprocessed for test dataset comparison control_target: RoboticsTarget, each component of shape [..., num_control_steps, num_control_components]. Provided only when available, usually during training and dataset test Returns: Dict containing torch.Tensor with inputs """ del control_target del inference_mode inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) images: Dict[str, torch.Tensor] = inputs['images'] input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length] target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length] attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) gripper = preprocess_gripper_observation(gripper, dataset_name) gripper = torch.tensor(gripper, dtype=torch.float32) ee_pose_translation = self.normalize( ee_pose_translation, dataset_name=dataset_name, key='obs_translation' ) ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation') joints = torch.tensor(joints, dtype=torch.float32) if joints.shape[-1] < 7: missing_size = 7 - joints.shape[-1] joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) joints = self.normalize(joints, dataset_name=dataset_name, key='joints') outputs = { 'images': images, 'input_ids': input_ids, 'target_text_tokens_ids': target_text_tokens_ids, 'attn_mask': attn_mask, 'ee_pose_translation': ee_pose_translation, 'ee_pose_rotation': ee_pose_rotation, 'gripper': gripper, 'joints': joints, 'control_tokens_ids': None, 'target_control_tokens_ids': None, } return outputs def create_input( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]], ee_pose_translation: np.ndarray, ee_pose_rotation: np.ndarray, gripper: np.ndarray, joints: np.ndarray, dataset_name: np.ndarray, inference_mode: bool, control_target: Optional[RoboticsTarget] = None, ) -> RoboticsInput: inputs = self.preprocess_inputs( chat=chat, images=images, ee_pose_translation=ee_pose_translation, ee_pose_rotation=ee_pose_rotation, gripper=gripper, joints=joints, dataset_name=dataset_name, inference_mode=inference_mode, control_target=control_target, ) inputs.pop('target_text_tokens_ids') inputs.pop('target_control_tokens_ids') return RoboticsInput(**inputs) def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: if is_mean_norm(getattr(self.config, f'{key}_norm')): (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = normalize_by_moments(value, mean=mean, std=std) else: (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = normalize_by_bounds(value, low=low, high=high) return output def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: if is_mean_norm(getattr(self.config, f'{key}_norm')): (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = unnormalize_by_moments(value, mean=mean, std=std) else: (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = unnormalize_by_bounds(value, low=low, high=high) return output def _norm_bounds_from_dataset_name( self, dataset_name: np.ndarray, component_key: str ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create an array of normalization bounds corresponding to dataset names Args: dataset_name: Array of shape [B] of dataset names for which to fetch the low and high normalization bounds. Note the values can be repeating component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to compute the normalization bounds Returns: Tuple of low and high bounds or norm and std, each of shape [B, -1] """ norm = getattr(self.config, f'{component_key}_norm') if is_mean_norm(norm): (stats_key_1, stats_key_2) = ('mean', 'std') else: (stats_key_1, stats_key_2) = ('low', 'high') if component_key == 'joints': if not isinstance(norm, collections.abc.Mapping): raise NotImplementedError() stats = { key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1])) for (key, value) in self.joints_norm_bounds['ANY'].items() } return tuple(stats.values()) component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1] if self.dataset_names == ['ANY']: stats_1 = self.norm_bounds[component_key]['ANY'][stats_key_1] stats_2 = self.norm_bounds[component_key]['ANY'][stats_key_2] stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0) stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0) else: (unique_names, _, inverse_indices, _) = np_unique(dataset_name) stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32) stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32) for i, ds_name in enumerate(unique_names): stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1] stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2] stats_1 = stats_1[inverse_indices] stats_2 = stats_2[inverse_indices] return torch.from_numpy(stats_1), torch.from_numpy(stats_2) @cached_property def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return rotation_norm_bounds( rotation_norm=self.config.obs_rotation_norm, rotation_format=self.config.rotation_format, stats=self._observation_stats, dataset_names=self.dataset_names, ) @cached_property def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return translation_norm_bounds( translation_norm=self.config.obs_translation_norm, stats=self._observation_stats, dataset_names=self.dataset_names, ) @cached_property def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return rotation_norm_bounds( rotation_norm=self.config.rotation_norm, rotation_format=self.config.rotation_format, stats=self._control_stats, dataset_names=self.dataset_names, ) @cached_property def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return translation_norm_bounds( translation_norm=self.config.translation_norm, stats=self._control_stats, dataset_names=self.dataset_names, ) @cached_property def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: """ NOTE: - Joint values across all joints and all datasets vary in the range [-2pi; 2pi] - The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi] - It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint """ low = torch.tensor(self.config.joints_norm['low'], dtype=torch.float32) high = torch.tensor(self.config.joints_norm['high'], dtype=torch.float32) results = {'ANY': {'low': low, 'high': high}} return results @cached_property def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: return { 'austin_buds_dataset': { 'euler': { 'max': [3.141589641571045, 0.2279592752456665, 0.10982763767242432], 'mean': [-0.433798223733902, 0.02660757675766945, 0.01057341042906046], 'min': [-3.1415905952453613, -0.17229366302490234, -0.08408939838409424], 'q01': [-3.1401021480560303, -0.15158048272132874, -0.07398375868797302], 'q99': [3.1401915550231934, 0.1699378788471222, 0.08044551312923431], 'std': [3.071873664855957, 0.08403228223323822, 0.04623554274439812], }, 'gripper': { 'max': [0.07999841868877411], 'min': [0.00019240332767367363], 'q01': [0.00760263716802001], 'q99': [0.07997412234544754], }, 'joints': { 'max': [ 0.25105541944503784, 1.0239691734313965, 0.25514841079711914, 0.0, 0.05838121101260185, 3.0727620124816895, 1.1911247968673706, ], 'mean': [ -0.02248559147119522, 0.4224241375923157, 0.011533008888363838, -2.040178060531616, -0.004422259051352739, 2.440391778945923, 0.7626844644546509, ], 'min': [ -0.46276336908340454, -0.2620261609554291, -0.37377235293388367, -3.010026693344116, -0.015008972026407719, 0.0, 0.0, ], 'q01': [ -0.3563080132007599, -0.19165010750293732, -0.29941368103027344, -2.766944408416748, -0.012211786583065987, 1.9946393966674805, 0.21924567222595215, ], 'q99': [ 0.22295619547367096, 0.8841447830200195, 0.20571835339069366, -1.339158296585083, 0.05605502426624298, 2.9369189739227295, 1.1475461721420288, ], 'std': [ 0.1684190183877945, 0.2238602489233017, 0.1185624971985817, 0.3182348608970642, 0.012190691195428371, 0.2673698663711548, 0.2611873149871826, ], }, 'translation': { 'max': [0.7684353590011597, 0.22995209693908691, 0.3893454968929291], 'mean': [0.5589713454246521, -0.0022956086322665215, 0.12593945860862732], 'min': [0.0, -0.31078848242759705, 0.0], 'q01': [0.3499317765235901, -0.2854413390159607, 0.010516085661947727], 'q99': [0.7243335843086243, 0.20652863383293152, 0.3218296766281128], 'std': [0.08953548222780228, 0.1462203562259674, 0.0769098550081253], }, }, 'austin_sailor_dataset': { 'euler': { 'max': [3.1415910720825195, 0.26857471466064453, 2.4386940002441406], 'mean': [0.040416110306978226, 0.013173789717257023, -0.044821564108133316], 'min': [-3.141592264175415, -0.21639788150787354, -1.995558738708496], 'q01': [-3.1401662826538086, -0.16171419620513916, -1.5918874740600586], 'q99': [3.1401755809783936, 0.17527776956558228, 1.3560799360275269], 'std': [3.0995969772338867, 0.08136381208896637, 0.6311237812042236], }, 'gripper': { 'max': [0.07783995568752289], 'min': [-8.864999836077914e-05], 'q01': [0.0005174533580429852], 'q99': [0.07778151333332062], }, 'joints': { 'max': [ 0.18366889655590057, 1.188208818435669, 1.1904715299606323, -1.0253318548202515, 0.06765712797641754, 2.98008131980896, 2.7747786045074463, ], 'mean': [ -0.4481923580169678, 0.428782194852829, 0.29826778173446655, -2.1356770992279053, -0.2064618021249771, 2.5082974433898926, 0.8104547262191772, ], 'min': [ -1.4425580501556396, -0.34988880157470703, -0.4436887800693512, -2.9587178230285645, -0.9116092324256897, 1.7689018249511719, -1.4776484966278076, ], 'q01': [ -1.0579811334609985, -0.13357718288898468, -0.2518821358680725, -2.6593017578125, -0.6571194529533386, 2.072176456451416, -0.5861577391624451, ], 'q99': [ 0.08897759765386581, 0.8743473887443542, 0.8909343481063843, -1.4748132228851318, -0.010495365597307682, 2.859259605407715, 2.319192409515381, ], 'std': [ 0.26852160692214966, 0.22502659261226654, 0.2577133774757385, 0.2963367700576782, 0.1629231870174408, 0.18130357563495636, 0.6348837614059448, ], }, 'translation': { 'max': [0.7488242387771606, 0.2829112708568573, 0.3541720509529114], 'mean': [0.5367430448532104, -0.08604215085506439, 0.11135970056056976], 'min': [0.3208981454372406, -0.3730051815509796, 0.020952222868800163], 'q01': [0.387094110250473, -0.3164229393005371, 0.024492919445037842], 'q99': [0.6869593262672424, 0.2086469978094101, 0.2551962733268738], 'std': [0.08000300824642181, 0.13984571397304535, 0.056377921253442764], }, }, 'austin_sirius_dataset': { 'euler': { 'max': [3.141592502593994, 0.09964871406555176, 0.3406205177307129], 'mean': [1.1875473260879517, -0.018573710694909096, -0.4977009892463684], 'min': [-3.141592502593994, -0.14349305629730225, -1.892538070678711], 'q01': [-3.140658378601074, -0.12321051210165024, -1.743293285369873], 'q99': [3.1407761573791504, 0.07348346710205078, 0.062370821833610535], 'std': [2.8533384799957275, 0.03962606564164162, 0.5864803791046143], }, 'gripper': { 'max': [0.07965432107448578], 'min': [0.00016088332631625235], 'q01': [0.0334978811442852], 'q99': [0.07948359102010727], }, 'joints': { 'max': [ 0.5567107200622559, 0.3814372420310974, 0.36868324875831604, 0.0, 0.021617505699396133, 2.825117588043213, 2.8925509452819824, ], 'mean': [ 0.22241303324699402, 0.03750193864107132, -0.013105375692248344, -2.428316593170166, -0.017613587900996208, 2.4755642414093018, 1.4929485321044922, ], 'min': [ -0.23144784569740295, -0.41377919912338257, -0.3536752760410309, -2.9289300441741943, -0.06330326944589615, 0.0, 0.0, ], 'q01': [ -0.10539760440587997, -0.2651134133338928, -0.21157152950763702, -2.831827402114868, -0.04555036500096321, 0.0, 0.0, ], 'q99': [ 0.48007193207740784, 0.3064860999584198, 0.2438022643327713, 0.0, 0.012844694778323174, 2.71537446975708, 2.51297926902771, ], 'std': [ 0.1297803372144699, 0.132253035902977, 0.08940940350294113, 0.345712274312973, 0.015448619611561298, 0.3493262231349945, 0.632145345211029, ], }, 'translation': { 'max': [0.5655431151390076, 0.3374997079372406, 0.3071175515651703], 'mean': [0.4490734338760376, 0.09406533092260361, 0.17131780087947845], 'min': [0.0, -0.16291481256484985, 0.0], 'q01': [0.0, -0.11814527958631516, 0.0], 'q99': [0.532875120639801, 0.26084619760513306, 0.27225059270858765], 'std': [0.07605387270450592, 0.08447220921516418, 0.04798709973692894], }, }, 'bc_z': { 'euler': { 'max': [3.141583088638963, 1.5360414168274534, 3.141576023430307], 'mean': [-0.32332700342924814, -0.1255861641435509, -0.07202428898041417], 'min': [-3.13973587780783, -1.5322501214383415, -3.1415575429114675], 'q01': [-1.0433973645798003, -1.007553367940694, -2.6477418236555788], 'q99': [0.8467956620806508, 0.9111507554055364, 2.226032625129908], 'std': [0.40307241985577485, 0.4832146677789649, 1.0666829542185445], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.20000000298023224], 'q99': [1.0]}, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.6597589254379272, 0.7259413599967957, 1.1217665672302246], 'mean': [0.012632723897695541, 0.10061875730752945, 0.7831991910934448], 'min': [-0.7181899547576904, -0.3756217360496521, -0.281008243560791], 'q01': [-0.3956047296524048, -0.11924505233764648, 0.601338267326355], 'q99': [0.332028865814209, 0.3088575601577759, 0.98329097032547], 'std': [0.1808762550354004, 0.0960959941148758, 0.08665762096643448], }, }, 'berkeley_autolab_ur5': { 'euler': { 'max': [3.141586857385381, 0.9715026080766722, 1.0041252839605828], 'mean': [0.10885329009674993, -0.005823583245937125, 0.01756067034138813], 'min': [-3.14158698706268, -0.9912158529884818, -0.8782840134752882], 'q01': [-3.1391613317061315, -0.6410961610461836, -0.4513447799408977], 'q99': [3.13947692478744, 0.6738775260636495, 0.535202669480506], 'std': [3.069331637100682, 0.18171227413203145, 0.17512358924250696], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, 'joints': { 'max': [ -2.5115966796875, -0.682390034198761, 2.6830153465270996, -1.1711143255233765, -0.6423047184944153, 4.032349109649658, ], 'mean': [ -3.2676453590393066, -1.3197712898254395, 2.122663736343384, -2.3710596561431885, -1.567111611366272, 3.0015013217926025, ], 'min': [ -4.0775299072265625, -2.2796404361724854, 1.1905627250671387, -3.814586877822876, -2.1223204135894775, 1.5279699563980103, ], 'q01': [ -3.839585781097412, -1.8913453817367554, 1.5813319683074951, -3.3646113872528076, -1.805822730064392, 2.22794246673584, ], 'q99': [ -2.7322237491607666, -0.8280109763145447, 2.5063326358795166, -1.5614854097366333, -1.2497813701629639, 3.7013423442840576, ], 'std': [ 0.2746899724006653, 0.27345725893974304, 0.22793518006801605, 0.34690913558006287, 0.10438559204339981, 0.3388271629810333, ], }, 'translation': { 'max': [0.6610942482948303, 0.38513994216918945, 0.2049914449453354], 'mean': [0.45345133543014526, 0.05642913281917572, -0.0356401689350605], 'min': [0.1970997005701065, -0.27643972635269165, -0.20529526472091675], 'q01': [0.3020566999912262, -0.21297279000282288, -0.18836002051830292], 'q99': [0.6132073998451233, 0.30656182765960693, 0.12212439626455307], 'std': [0.07906801998615265, 0.12497302889823914, 0.08053416013717651], }, }, 'berkeley_cable_routing': { 'euler': { 'max': [3.141592264175415, 0.05923163890838623, 3.1243016719818115], 'mean': [-0.6730493903160095, 0.010155542753636837, 0.9881726503372192], 'min': [-3.141592264175415, -0.10035276412963867, -1.6915032863616943], 'q01': [-3.1412672996520996, -0.02922845631837845, -0.6449583172798157], 'q99': [3.1412830352783203, 0.039870113134384155, 2.4969191551208496], 'std': [3.0588438510894775, 0.01435503363609314, 0.6301024556159973], }, 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.7113381028175354, 0.35645395517349243, 0.23169316351413727], 'mean': [0.5643643736839294, 0.00464651919901371, 0.07635799050331116], 'min': [0.378939151763916, -0.3244381248950958, 0.02837979607284069], 'q01': [0.4641263782978058, -0.2806571424007416, 0.030183622613549232], 'q99': [0.6452807784080505, 0.28204888105392456, 0.1557157188653946], 'std': [0.0380910187959671, 0.16768066585063934, 0.032017387449741364], }, }, 'berkeley_fanuc_manipulation': { 'euler': { 'max': [3.1415822505950928, 1.020048975944519, 1.9211788177490234], 'mean': [0.1944933384656906, -0.03283948823809624, 0.017717985436320305], 'min': [-3.141590118408203, -1.5707060098648071, -3.089141368865967], 'q01': [-3.139866352081299, -1.0165742635726929, -1.6987266540527344], 'q99': [3.140317440032959, 0.4332200288772583, 1.5085773468017578], 'std': [2.9820094108581543, 0.21335174143314362, 0.4836970269680023], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, 'joints': { 'max': [ 0.9483258724212646, 1.4428143501281738, 1.2322325706481934, 2.3726272583007812, 0.0008019324741326272, 3.017007827758789, ], 'mean': [ 0.04932224750518799, 0.36293527483940125, -0.18438231945037842, 0.09242437779903412, -1.0793049335479736, -0.06553139537572861, ], 'min': [ -0.7464086413383484, -0.5774519443511963, -1.1099220514297485, -1.8152672052383423, -2.1454310417175293, -3.168759346008301, ], 'q01': [ -0.5061959028244019, -0.09613651037216187, -0.8047154545783997, -0.9908221364021301, -1.8158636093139648, -2.420135974884033, ], 'q99': [ 0.6388322710990906, 1.0234705209732056, 0.5990921258926392, 1.282780647277832, -0.39091193675994873, 2.340099811553955, ], 'std': [ 0.2452843338251114, 0.2555835247039795, 0.2740932106971741, 0.3769519329071045, 0.3472467362880707, 0.6859157681465149, ], }, 'translation': { 'max': [0.8101964592933655, 0.5117550492286682, 0.8054816722869873], 'mean': [0.5446329116821289, 0.003737417282536626, 0.2439887523651123], 'min': [0.263683944940567, -0.5785022377967834, -0.0025757886469364166], 'q01': [0.3718133866786957, -0.4071895182132721, 0.01847645826637745], 'q99': [0.7200658321380615, 0.3128541111946106, 0.5413243770599365], 'std': [0.07649082690477371, 0.14486227929592133, 0.13899113237857819], }, }, 'bridge': { 'euler': { 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], }, 'gripper': { 'max': [1.0370277166366577], 'min': [0.04637829214334488], 'q01': [0.05192930996417999], 'q99': [1.0118417739868164], }, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], }, }, 'bridge_orig': { 'euler': { 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], }, 'gripper': { 'max': [1.0370277166366577], 'min': [0.04637829214334488], 'q01': [0.05192930996417999], 'q99': [1.0118417739868164], }, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], }, }, 'bridge_action_lang': { 'euler': { 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], }, 'gripper': { 'max': [1.0370277166366577], 'min': [0.04637829214334488], 'q01': [0.05192930996417999], 'q99': [1.0118417739868164], }, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], }, }, 'cmu_stretch': { 'euler': { 'max': [3.141592653589793, 0.0, 0.0], 'mean': [3.1415926535896035, 0.0, 0.0], 'min': [3.141592653589793, 0.0, 0.0], 'q01': [3.141592653589793, 0.0, 0.0], 'q99': [3.141592653589793, 0.0, 0.0], 'std': [1.8962609260597674e-13, 0.0, 0.0], }, 'gripper': { 'max': [3.110913038253784], 'min': [-3.1155149936676025], 'q01': [-3.055689811706543], 'q99': [3.1017091274261475], }, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.3401322066783905, 0.0, 1.1013386249542236], 'mean': [0.1461893916130066, 0.0, 0.7568270564079285], 'min': [0.005320895928889513, 0.0, 0.45218077301979065], 'q01': [0.017430847510695457, 0.0, 0.46050605177879333], 'q99': [0.33094948530197144, 0.0, 1.0952961444854736], 'std': [0.09959954023361206, 0.0, 0.15819725394248962], }, }, 'dlr_edan_shared_control': { 'euler': { 'max': [2.2228052616119385, 0.13346634805202484, 3.1415085792541504], 'mean': [1.6072595119476318, -0.4295055568218231, 1.1549623012542725], 'min': [0.4599944055080414, -1.4871342182159424, -3.1406784057617188], 'q01': [0.8485284447669983, -1.4454149007797241, -3.1238820552825928], 'q99': [1.8301045894622803, 0.07032366842031479, 3.131430149078369], 'std': [0.2785564363002777, 0.540916919708252, 2.4260833263397217], }, 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.044179707765579224, 0.6297282576560974, 0.9278888702392578], 'mean': [-0.512688934803009, 0.3690241575241089, 0.441307932138443], 'min': [-0.7695853114128113, -0.040099501609802246, 0.2583216428756714], 'q01': [-0.729511022567749, 0.077408567070961, 0.2658006250858307], 'q99': [-0.13719859719276428, 0.5719971060752869, 0.7898909449577332], 'std': [0.15213875472545624, 0.08891375362873077, 0.13743770122528076], }, }, 'droid': { 'euler': { 'max': [3.141592502593994, 1.5705928802490234, 3.1415867805480957], 'mean': [0.3140628098409554, -0.09296274023036387, -0.07227215454779846], 'min': [-3.141592502593994, -1.5691150426864624, -3.1415374279022217], 'q01': [-3.1378602981567383, -1.2125312042236327, -2.1614069032669065], 'q99': [3.137854380607605, 0.9200375998020163, 1.9367506909370364], 'std': [2.926265757944871, 0.363273475703332, 0.7576065217938824], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.9911894202232361]}, 'joints': { 'max': [ 2.668445110321045, 1.5691218376159668, 2.666306734085083, -0.3114914000034332, 2.6624162197113037, 4.28157901763916, 2.752457857131958, ], 'mean': [ 0.023137084334640106, 0.2704989977282293, -0.01451389357228282, -2.018709403792315, -0.042720520800030394, 2.350281188152209, 0.12424663946659845, ], 'min': [ -2.6536705493927, -1.547789216041565, -2.6781487464904785, -2.9409868717193604, -2.6705946922302246, 0.24893812835216522, -2.7615714073181152, ], 'q01': [ -0.9026106441020965, -0.8547340619564057, -0.9028875434398651, -2.7698556280136106, -1.6851656341552732, 1.2335169839859008, -1.9587260699272155, ], 'q99': [ 0.9569852340221403, 1.4148830294609054, 0.7693877756595566, -0.4545914208889008, 1.5623322343826267, 3.475611729621887, 2.263479118347167, ], 'std': [ 0.31695080251469465, 0.49522214687158767, 0.27993538230553827, 0.478161574676113, 0.4969961591445458, 0.45101008525403846, 0.7287264344068457, ], }, 'translation': { 'max': [0.8575563430786133, 0.799155592918396, 1.0043904781341553], 'mean': [0.5283099395864883, 0.005363794653877434, 0.3120132207021294], 'min': [-0.15604186058044434, -0.827903687953949, -0.2347021996974945], 'q01': [0.26669957995414734, -0.43774398624897004, -0.048167889714241026], 'q99': [0.7774086785316463, 0.428325751423835, 0.776091011762619], 'std': [0.1148424841779685, 0.17489566608140428, 0.16541062032731538], }, }, 'fmb': { 'euler': { 'max': [3.1415927410125732, 1.044223666191101, 2.6313371658325195], 'mean': [-0.5128238201141357, -0.0326850451529026, 1.2892225980758667], 'min': [-3.141592502593994, -1.0594873428344727, -1.5775980949401855], 'q01': [-3.1404616832733154, -0.9319325089454651, -1.4586873054504395], 'q99': [3.140399932861328, 0.8346238732337952, 2.3081648349761963], 'std': [3.0557258129119873, 0.33158522844314575, 0.6522650122642517], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, 'joints': { 'max': [ 1.633209228515625, 1.216904640197754, 0.6477310061454773, -1.101138949394226, 1.4866814613342285, 3.2880685329437256, 2.8213212490081787, ], 'mean': [ 0.08979735523462296, 0.24736542999744415, -0.18356309831142426, -2.1948094367980957, 0.09010367840528488, 2.3375964164733887, -0.6424615979194641, ], 'min': [ -0.24819369614124298, -0.6836066246032715, -1.852062463760376, -2.950329542160034, -1.360267162322998, 1.1119775772094727, -2.53395938873291, ], 'q01': [ -0.12363661825656891, -0.2893752157688141, -0.8288514018058777, -2.7587249279022217, -0.9787064790725708, 1.5508010387420654, -1.7525910139083862, ], 'q99': [ 0.527353048324585, 0.8305937647819519, 0.3808687925338745, -1.5172499418258667, 1.1656274795532227, 2.9421844482421875, 2.3595635890960693, ], 'std': [ 0.14010190963745117, 0.24683699011802673, 0.29916438460350037, 0.29822880029678345, 0.40809085965156555, 0.3218824863433838, 0.7698287963867188, ], }, 'translation': { 'max': [0.7458599805831909, 0.2993985116481781, 0.39971601963043213], 'mean': [0.5149032473564148, -0.04377440735697746, 0.17520515620708466], 'min': [0.32314637303352356, -0.34641459584236145, 0.029209019616246223], 'q01': [0.3655048608779907, -0.28729698061943054, 0.033201027661561966], 'q99': [0.6782684326171875, 0.209969624876976, 0.3331448435783386], 'std': [0.07067117840051651, 0.13670970499515533, 0.07637079805135727], }, }, 'fractal20220817_data': { 'euler': { 'max': [3.141592264175415, 1.5703786611557007, 3.1415889263153076], 'mean': [-1.5046496391296387, 0.6226330399513245, -1.621827244758606], 'min': [-3.1415927410125732, -1.569378137588501, -3.1415863037109375], 'q01': [-3.130796194076538, -0.23499104380607605, -2.9648191928863525], 'q99': [3.1300582885742188, 1.4907126426696777, 2.8373801708221436], 'std': [2.380099296569824, 0.41548046469688416, 0.803632915019989], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [1.0534898042678833, 0.48018959164619446, 1.6896663904190063], 'mean': [0.5594336986541748, -0.08356647938489914, 0.7760705947875977], 'min': [-0.3139062225818634, -0.9970501065254211, -0.006579156965017319], 'q01': [0.3249714970588684, -0.2818704843521118, 0.1410011649131775], 'q99': [0.8754204511642456, 0.21279653906822205, 1.071526288986206], 'std': [0.12431981414556503, 0.11550336331129074, 0.24583852291107178], }, }, 'furniture_bench_dataset': { 'euler': { 'max': [3.141592502593994, 1.545873999595642, 3.1415679454803467], 'mean': [-1.5673161745071411, 0.028489787131547928, -0.0338752306997776], 'min': [-3.1415927410125732, -1.1365290880203247, -3.141582727432251], 'q01': [-3.139331817626953, -0.6121169328689575, -1.9944958686828613], 'q99': [3.1392478942871094, 0.9993562698364258, 1.7793478965759277], 'std': [2.605621337890625, 0.30199435353279114, 0.924261212348938], }, 'gripper': { 'max': [0.07995442301034927], 'min': [-3.4803331800503656e-05], 'q01': [0.003531553316861391], 'q99': [0.07973115146160126], }, 'joints': { 'max': [ 1.6170321702957153, 1.4905058145523071, 0.8054860234260559, -0.8407800197601318, 2.596961736679077, 3.591698169708252, 2.744664192199707, ], 'mean': [ 0.14828824996948242, 0.513389527797699, -0.07023008167743683, -2.116460084915161, 0.17382267117500305, 2.6314659118652344, 0.7604196667671204, ], 'min': [ -1.0936335325241089, -0.28272193670272827, -1.3713332414627075, -2.9083454608917236, -2.664609432220459, 0.49353715777397156, -2.6659374237060547, ], 'q01': [ -0.3774302303791046, 0.16254600882530212, -0.526292085647583, -2.662245512008667, -0.3341670036315918, 1.4710004329681396, -0.997846245765686, ], 'q99': [ 0.6179460287094116, 0.9334877729415894, 0.3241332471370697, -1.5048880577087402, 0.69161057472229, 3.45759916305542, 2.5480763912200928, ], 'std': [ 0.20775794982910156, 0.14619596302509308, 0.1992885023355484, 0.24350883066654205, 0.2203112691640854, 0.3669370114803314, 0.8288815021514893, ], }, 'translation': { 'max': [0.75081467628479, 0.29695066809654236, 0.35806331038475037], 'mean': [0.5622280836105347, 0.049872349947690964, 0.08108603954315186], 'min': [0.2878606617450714, -0.3141690492630005, -0.00460465531796217], 'q01': [0.36915361881256104, -0.180975541472435, 0.0058300793170928955], 'q99': [0.6652880311012268, 0.1772783100605011, 0.18316447734832764], 'std': [0.07692465931177139, 0.07895787060260773, 0.04069080576300621], }, }, 'iamlab_cmu_pickup_insert': { 'euler': { 'max': [3.14159250270088, 0.04584214946783205, 0.0450923308550184], 'mean': [-1.4122567194164972, -0.03209339030753262, 0.006593786875632054], 'min': [-3.141592327000409, -0.13647014666477264, -0.052753928700043584], 'q01': [-3.141100653025272, -0.1142266801993486, -0.026375220367685838], 'q99': [3.1411030904244917, 0.02986631061111326, 0.033458487885352856], 'std': [2.771965176997775, 0.02781015988983883, 0.015259428603879655], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, 'joints': { 'max': [ 0.24314826726913452, 0.7470589280128479, 0.7795426249504089, -1.6825058460235596, 0.13216856122016907, 2.88216233253479, 1.443279504776001, ], 'mean': [ -0.09283880889415741, 0.16589058935642242, 0.13815303146839142, -2.236156940460205, -0.013342339545488358, 2.4330623149871826, 0.8266588449478149, ], 'min': [ -0.6439031958580017, -0.7895585894584656, -0.017819713801145554, -2.9597983360290527, -0.5142408013343811, 1.82623291015625, 0.24880434572696686, ], 'q01': [ -0.4449005424976349, -0.687881588935852, -0.00037891813553869724, -2.9046988487243652, -0.3344403803348541, 1.9697062969207764, 0.4578770101070404, ], 'q99': [ 0.19939860701560974, 0.6767980456352234, 0.5450605750083923, -1.7509592771530151, 0.07327680289745331, 2.758103370666504, 1.2580491304397583, ], 'std': [ 0.1536070704460144, 0.3142663240432739, 0.1251399964094162, 0.28150704503059387, 0.062141284346580505, 0.17315669357776642, 0.19375956058502197, ], }, 'translation': { 'max': [0.6634980998075433, 0.23428471360823594, 0.4308285234773945], 'mean': [0.5266367430643635, 0.02849007901898653, 0.18708021993948834], 'min': [0.3071657210198696, -0.2975911497764855, 0.06578226212031718], 'q01': [0.314498567139741, -0.20315787858517198, 0.06785127291435837], 'q99': [0.6472027612545221, 0.20840713845170097, 0.37003410655170416], 'std': [0.08150424006275543, 0.11136019021346877, 0.0777212838203498], }, }, 'jaco_play': { 'euler': { 'max': [3.141592653589793, 0.0, 0.0], 'mean': [3.14159265358649, 0.0, 0.0], 'min': [3.141592653589793, 0.0, 0.0], 'q01': [3.141592653589793, 0.0, 0.0], 'q99': [3.141592653589793, 0.0, 0.0], 'std': [3.3031355428647657e-12, 0.0, 0.0], }, 'gripper': { 'max': [1.3890767097473145], 'min': [-0.00123199715744704], 'q01': [0.00061599857872352], 'q99': [1.3835327625274658], }, 'joints': { 'max': [ -0.6973918676376343, 4.516506671905518, 2.732408285140991, 0.27657514810562134, 2.0405426025390625, 0.9062198996543884, ], 'mean': [ -1.5362969636917114, 3.9368598461151123, 1.470100998878479, -0.26214125752449036, 0.8781110644340515, -0.4923328757286072, ], 'min': [ -2.394073486328125, 3.4039254188537598, 0.5968497395515442, -1.0070130825042725, -0.28005343675613403, -1.3207207918167114, ], 'q01': [ -2.2932708263397217, 3.5327682495117188, 0.7838058471679688, -0.5554518699645996, 0.002689492190256715, -1.1670725345611572, ], 'q99': [ -0.8449038863182068, 4.363891124725342, 2.367223024368286, 0.06634082645177841, 1.595384120941162, 0.17280587553977966, ], 'std': [ 0.32358425855636597, 0.18587827682495117, 0.3632128834724426, 0.11627942323684692, 0.3359401226043701, 0.3179064989089966, ], }, 'translation': { 'max': [0.2252020239830017, -0.19358234107494354, 0.4066188633441925], 'mean': [-0.07287009060382843, -0.42861491441726685, 0.2764993906021118], 'min': [-0.4429473876953125, -0.6635459661483765, 0.1568669229745865], 'q01': [-0.3789186179637909, -0.6194459795951843, 0.16865813732147217], 'q99': [0.21203258633613586, -0.26914602518081665, 0.38958534598350525], 'std': [0.138992577791214, 0.0908467248082161, 0.052496980875730515], }, }, 'kuka': { 'euler': { 'max': [3.141592651594966, 0.25521246168753553, 3.1415812671947134], 'mean': [-0.8123818266038756, -0.004947911572775501, 0.1622765703197679], 'min': [-3.141592653256577, -0.22240290857095402, -3.1415339219506544], 'q01': [-3.141539356675989, -0.042093398087956785, -0.6103775416503281], 'q99': [3.1415363480584935, 0.013040864331848548, 1.4661773197655914], 'std': [3.031619905792479, 0.010541045603285457, 0.386240278316171], }, 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.7243871092796326, 0.31309840083122253, 0.613274335861206], 'mean': [0.5510859489440918, 0.04787421599030495, 0.12812286615371704], 'min': [0.40573424100875854, -0.2028520256280899, 0.018512273207306862], 'q01': [0.4765772819519043, -0.14815208315849304, 0.06674224138259888], 'q99': [0.6515637040138245, 0.2447487711906433, 0.28018367290496826], 'std': [0.045206550508737564, 0.10222796350717545, 0.058390580117702484], }, }, 'language_table': { 'euler': { 'max': [3.141592653589793, 0.0, 0.0], 'mean': [3.141592653805758, 0.0, 0.0], 'min': [3.141592653589793, 0.0, 0.0], 'q01': [3.141592653589793, 0.0, 0.0], 'q99': [3.141592653589793, 0.0, 0.0], 'std': [2.1596502364831912e-10, 0.0, 0.0], }, 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, 'joints': { 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, 'translation': { 'max': [0.6191085577011108, 0.345907062292099, 0.0], 'mean': [0.3994608521461487, 0.0045331921428442, 0.0], 'min': [0.18907572329044342, -0.3051564395427704, 0.0], 'q01': [0.19237099587917328, -0.2962527573108673, 0.0], 'q99': [0.6171894669532776, 0.30645298957824707, 0.0], 'std': [0.10587888211011887, 0.14165113866329193, 0.0], }, }, 'nyu_franka_play_dataset': { 'euler': { 'max': [3.141542687188373, 1.570464569214435, 3.1412348640834002], 'mean': [1.2730823611033315, 0.6681926368695384, 1.141188695183881], 'min': [-3.141417902437913, -0.6159773949821361, -3.1401807914238264], 'q01': [-3.1053165771573177, -0.18849691525500664, -2.8903965648285523], 'q99': [3.1156105211359257, 1.5443502891290082, 2.882089687976354], 'std': [1.519365023922235, 0.5398198304496786, 1.099335476495627], }, 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, 'joints': { 'max': [ 0.873094916343689, 1.244848370552063, 2.6604084968566895, -0.46481436491012573, 1.409853458404541, 3.5058159828186035, 2.2937142848968506, ], 'mean': [ -0.7650318741798401, -0.9744310975074768, 1.4849265813827515, -2.0917000770568848, -0.42907631397247314, 2.4454119205474854, -0.04509185999631882, ], 'min': [ -2.1353442668914795, -1.6306616067886353, -0.0167933851480484, -2.8458828926086426, -2.5256054401397705, 0.5101203322410583, -2.1178665161132812, ], 'q01': [ -1.9259787797927856, -1.537291407585144, 0.4664345383644104, -2.7214717864990234, -2.114818811416626, 1.0703176259994507, -1.8231183290481567, ], 'q99': [ 0.19280724227428436, 0.7752904295921326, 2.6163666248321533, -0.8622008562088013, 0.7001188397407532, 3.4607036113739014, 1.7281100749969482, ], 'std': [ 0.4459172785282135, 0.5925542712211609, 0.45783039927482605, 0.4001038372516632, 0.6094018816947937, 0.5889745950698853, 0.7601358294487, ], }, 'translation': { 'max': [0.6258905727651867, 0.8289460146095029, 0.8867425967452327], 'mean': [0.38785427731316113, 0.39551342608233353, 0.45161485937851], 'min': [0.05690113252698781, -0.06171031940925212, 0.1157533700659999], 'q01': [0.1393695951374117, 0.07645522416447147, 0.19364508601310984], 'q99': [0.5920727118835839, 0.6584802280677562, 0.8056891771281188], 'std': [0.10666290139662009, 0.1217948547812865, 0.1707295181003299], }, }, 'roboset': { 'euler': { 'max': [3.1415449294818236, 1.5705575529715636, 3.141527342124582], 'mean': [-0.0398455755412464, 1.0518070390619125, -0.015345692503002759], 'min': [-3.1415813300509536, -1.5222832468962035, -3.141575300866071], 'q01': [-2.9414386317311187, -0.24976770655101155, -2.985256521212579], 'q99': [2.9380437893235993, 1.5403010739503078, 2.9746912523985025], 'std': [1.7866587696177456, 0.40620530263065, 1.7288511340250616], }, 'gripper': { 'max': [0.83056640625], 'min': [0.0001499652862548828], 'q01': [0.0001499652862548828], 'q99': [0.82666015625], }, 'joints': { 'max': [ 0.96240234375, 1.1162109375, 1.1064453125, -0.98095703125, 2.30859375, 1.576171875, 1.7412109375, ], 'mean': [ 0.005913593806326389, 0.1877261847257614, 0.04653879255056381, -2.0529513359069824, -0.011298442259430885, 0.6185526251792908, -0.01701134257018566, ], 'min': [ -0.8330078125, -0.74658203125, -0.8642578125, -2.892578125, -1.390625, -0.24658203125, -2.953125, ], 'q01': [ -0.41015625, -0.5302734375, -0.6455078125, -2.57421875, -0.76416015625, -0.0386962890625, -1.435546875, ], 'q99': [ 0.66455078125, 0.9501953125, 0.7529296875, -1.251953125, 0.75244140625, 1.2314453125, 1.384765625, ], 'std': [ 0.17915399372577667, 0.32234326004981995, 0.26069700717926025, 0.31767210364341736, 0.205329030752182, 0.33385637402534485, 0.6263682842254639, ], }, 'translation': { 'max': [0.5747738480567932, 0.3972920775413513, 0.7443570494651794], 'mean': [0.3331542909145355, 0.019357483834028244, 0.37330344319343567], 'min': [0.09978063404560089, -0.29593944549560547, 0.10065606236457825], 'q01': [0.18437016010284424, -0.25699371099472046, 0.15134164690971375], 'q99': [0.543661892414093, 0.29646238684654236, 0.6682320833206177], 'std': [0.07849054038524628, 0.12241040915250778, 0.1460595279932022], }, }, 'roboturk': { 'euler': { 'max': [3.1415925701602854, 1.5661525056538665, 3.1412803735012553], 'mean': [0.3946147188969296, -0.09657834247810235, -0.1840393277985236], 'min': [-3.141592645423486, -1.568282806779776, -3.141449561492499], 'q01': [-3.1378020570699525, -0.7907563562327732, -1.7347170410441308], 'q99': [3.137841563084971, 0.5773179844053413, 1.455226678618167], 'std': [2.9453717386817755, 0.2466574231368262, 0.6146181899264712], }, 'gripper': { 'max': [0.041667], 'min': [-1.6931189781743683e-05], 'q01': [0.0], 'q99': [0.040625325000000004], }, 'joints': { 'max': [ 3.0592832565307617, 2.9869844913482666, 0.33399999141693115, 2.0810000896453857, 3.874000072479248, 0.21199999749660492, 32.02399826049805, ], 'mean': [ 0.06695421040058136, 0.4899406135082245, -0.0009486731723882258, -0.00033266679383814335, -0.000789206416811794, 0.017417440190911293, -4.154728412628174, ], 'min': [ -2.995668888092041, -2.9882216453552246, -1.3760000467300415, -1.9279999732971191, -4.10099983215332, -0.5040000081062317, -23.58799934387207, ], 'q01': [ -1.1723066568374634, -1.1159032583236694, -0.00800000037997961, -0.3580000102519989, -0.5680000185966492, -0.10000000149011612, -14.303999900817871, ], 'q99': [ 1.1788837909698486, 1.7529590129852295, 0.007000000216066837, 0.3799999952316284, 0.628000020980835, 0.13199999928474426, 7.5279998779296875, ], 'std': [ 0.46230366826057434, 0.5157809257507324, 0.0024148367810994387, 0.11895754188299179, 0.18901629745960236, 0.050090208649635315, 4.478224754333496, ], }, 'translation': { 'max': [1.0403562763478107, 0.8124313436278997, 0.9670890184966487], 'mean': [0.601758638436076, -0.0011977902645611068, 0.061136336184971524], 'min': [-0.375370271681523, -0.9602276639048171, -0.39440951528330676], 'q01': [0.2845426588179575, -0.3288349528691964, -0.0934955147249334], 'q99': [0.8773894592247552, 0.28575230587640144, 0.3286392635131121], 'std': [0.12719404213832913, 0.1476650170975746, 0.09248325352628932], }, }, 'stanford_hydra_dataset': { 'euler': { 'max': [3.141592357591889, 1.5364991593031068, 3.141222461465462], 'mean': [-0.2789889021491306, -0.19247585545209975, 0.34538177027867784], 'min': [-3.1415919781981962, -1.5573291068141568, -3.1413972494043216], 'q01': [-3.138796116403405, -1.107955818954808, -2.271278880707407], 'q99': [3.1386728135975384, 0.915868569686168, 2.4011245578416336], 'std': [2.883189482433683, 0.43362190146795254, 1.199450318648319], }, 'gripper': { 'max': [0.0811397060751915], 'min': [-0.0005391233135014772], 'q01': [1.3133333141013281e-06], 'q99': [0.08100640028715134], }, 'joints': { 'max': [ 1.9801124334335327, 1.5752310752868652, 2.6661651134490967, -0.6274839639663696, 2.3508076667785645, 3.4612932205200195, 2.747223377227783, ], 'mean': [ -0.052988938987255096, -0.19803331792354584, -0.046565085649490356, -2.395282506942749, 0.4740530848503113, 2.119305372238159, 0.031005585566163063, ], 'min': [ -2.5843772888183594, -1.287919044494629, -1.9306836128234863, -2.9673542976379395, -1.9528191089630127, 0.5680276155471802, -2.74257755279541, ], 'q01': [ -1.0905369520187378, -0.9110571146011353, -0.9165626764297485, -2.876581907272339, -0.9029546976089478, 1.3720064163208008, -2.0172016620635986, ], 'q99': [ 0.6443323493003845, 1.0052685737609863, 0.9653106927871704, -1.4030946493148804, 1.738037109375, 2.943267822265625, 2.451101541519165, ], 'std': [ 0.3502423167228699, 0.41914135217666626, 0.39895766973495483, 0.31860536336898804, 0.5207170248031616, 0.35328012704849243, 1.2519474029541016, ], }, 'translation': { 'max': [0.7598075952006731, 0.36410959179102775, 0.597508369211053], 'mean': [0.4348564561896346, 0.024563536690147474, 0.2938664975066919], 'min': [0.18414022283711137, -0.32311685510034516, 0.017172937739480726], 'q01': [0.2373728557200433, -0.26521679456931635, 0.09069013461938308], 'q99': [0.712423814640525, 0.25299056815993204, 0.49505407602925094], 'std': [0.10465658310507645, 0.12745896408828714, 0.10379447506049931], }, }, 'taco_play': { 'euler': { 'max': [3.1415891647338867, 0.30717557668685913, 2.4624788761138916], 'mean': [-0.06622616201639175, -0.17697572708129883, 0.4614097774028778], 'min': [-3.1415886878967285, -0.8683895468711853, -1.789320707321167], 'q01': [-3.139000654220581, -0.6962019205093384, -1.1729414463043213], 'q99': [3.1392500400543213, 0.12436603754758835, 1.8065968751907349], 'std': [3.0384926795959473, 0.18692463636398315, 0.6806972622871399], }, 'gripper': { 'max': [0.08074767142534256], 'min': [9.652999870013446e-05], 'q01': [0.00015234666352625936], 'q99': [0.08067675679922104], }, 'joints': { 'max': [ 0.9704633951187134, 0.4390600025653839, 0.9892600178718567, -1.116769790649414, 1.2800638675689697, 2.894169330596924, 2.627519130706787, ], 'mean': [ 0.14954668283462524, -0.3514331877231598, 0.1295831948518753, -2.205498456954956, 0.11473798751831055, 2.0389058589935303, 0.5516068339347839, ], 'min': [ -0.6495265960693359, -1.2877551317214966, -1.5125423669815063, -2.956650733947754, -1.0317779779434204, 1.1109415292739868, -1.8966686725616455, ], 'q01': [ -0.4596467614173889, -0.9323632121086121, -0.6153853535652161, -2.8417088985443115, -0.3472142219543457, 1.428541898727417, -0.5756277441978455, ], 'q99': [ 0.7262046933174133, 0.21787810325622559, 0.7387099266052246, -1.3517359495162964, 0.7581852674484253, 2.564530849456787, 1.6355704069137573, ], 'std': [ 0.281919926404953, 0.2702403962612152, 0.3334408402442932, 0.412654846906662, 0.21815766394138336, 0.2808808982372284, 0.4732993543148041, ], }, 'translation': { 'max': [0.7124971151351929, 0.6118948459625244, 0.617118775844574], 'mean': [0.37278515100479126, 0.13003520667552948, 0.38341864943504333], 'min': [0.07693906873464584, -0.4944135844707489, 0.20030911266803741], 'q01': [0.1368357390165329, -0.4297449290752411, 0.20516259968280792], 'q99': [0.6700438857078552, 0.5943909883499146, 0.5966404676437378], 'std': [0.11509375274181366, 0.2694733738899231, 0.11266136914491653], }, }, 'toto': { 'euler': { 'max': [3.1414193360696108, 1.570402878730552, 3.1415549880169924], 'mean': [-1.31124856158284, 0.8107744638901545, 1.042033586691101], 'min': [-3.1415814091974665, -1.3210638993411208, -3.141515215393105], 'q01': [-2.9459294143675696, -0.6449196412662773, -2.9790265961739872], 'q99': [2.824635320375007, 1.4897934376557762, 3.0083314049108925], 'std': [1.2205361480020216, 0.45724967725161403, 1.3644152538580079], }, 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, 'joints': { 'max': [ 1.4926575422286987, 0.7715681195259094, 1.9506583213806152, -1.6330622434616089, 2.804656744003296, 2.0850095748901367, 1.9661481380462646, ], 'mean': [ -0.040506213903427124, -0.2537725269794464, 0.05389903113245964, -2.5135483741760254, -0.161987766623497, 0.8815387487411499, 0.0034975146409124136, ], 'min': [ -1.8345723152160645, -1.7133995294570923, -1.40083646774292, -2.985748767852783, -2.814244270324707, -0.5165338516235352, -3.5424435138702393, ], 'q01': [ -1.4044159650802612, -1.3978556394577026, -0.8029389381408691, -2.97342586517334, -1.5402026176452637, 0.0011001524981111288, -2.9709367752075195, ], 'q99': [ 0.7075999975204468, 0.42536962032318115, 1.3345959186553955, -1.8281668424606323, 2.5240790843963623, 2.0835044384002686, 1.212519884109497, ], 'std': [ 0.3405684530735016, 0.3766334354877472, 0.40128207206726074, 0.3202400505542755, 0.6421889066696167, 0.33138400316238403, 0.5801302194595337, ], }, 'translation': { 'max': [0.7777318005382999, 0.5688841397096768, 0.672235403102292], 'mean': [0.15078069344893, -0.03012185632654897, 0.35519823701585035], 'min': [-0.18532184496861265, -0.6282599632853518, 0.06312318086547261], 'q01': [-0.09177927811484686, -0.3571658956454935, 0.21965464356283362], 'q99': [0.6757593680120374, 0.2889021640318381, 0.501109413161318], 'std': [0.14726563054735392, 0.13472763062259546, 0.059123705378892395], }, }, 'ucsd_kitchen_dataset': { 'euler': { 'max': [3.141592644655991, 0.10471752134072587, 2.0420500160887434], 'mean': [-2.349219123990161, -0.5844573008678524, 0.8632842703198627], 'min': [-3.141592637838787, -1.5533486027935484, -1.676143206926509], 'q01': [-3.1415921825947537, -1.535883315130894, -0.6806766312742327], 'q99': [3.1328677560237566, 0.01745363977705911, 1.8326023938995777], 'std': [1.756197370631558, 0.482005280180061, 0.7151569996653978], }, 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, 'joints': { 'max': [ 2.4945411682128906, 1.0603798627853394, 0.5518655776977539, 2.71895432472229, 1.9269907474517822, 2.2231369018554688, 0.7844778895378113, ], 'mean': [ 0.35797572135925293, -0.23473882675170898, -0.34185951948165894, 0.8982798457145691, 0.5780586004257202, 1.0351765155792236, -1.2190868854522705, ], 'min': [ -0.6943671703338623, -1.6351218223571777, -2.0142624378204346, 0.00768359424546361, -0.17713193595409393, 0.011292487382888794, -2.6592507362365723, ], 'q01': [ -0.43497809767723083, -1.4843493700027466, -1.7824348211288452, 0.0955040231347084, -0.07186625152826309, 0.27948257327079773, -2.4006130695343018, ], 'q99': [ 2.25881028175354, 0.8190488815307617, 0.2994879484176636, 2.153681993484497, 1.6126199960708618, 1.8523634672164917, 0.06944511085748672, ], 'std': [ 0.6131777167320251, 0.6322475075721741, 0.5427954792976379, 0.539089024066925, 0.43294647336006165, 0.3588581085205078, 0.6916991472244263, ], }, 'translation': { 'max': [0.681192194250889, 0.2323359966361606, 0.6159547986373596], 'mean': [0.4067338752307897, 0.027551073758015975, 0.313373054289359], 'min': [0.13453314224525192, -0.2553313745227843, 0.03515888153950328], 'q01': [0.18739915591793913, -0.18234310016370514, 0.048970692427190994], 'q99': [0.6410437631246786, 0.20632224240340083, 0.5983893391276117], 'std': [0.12446715882142037, 0.08252308025761342, 0.11421020385340024], }, }, 'utaustin_mutex': { 'euler': { 'max': [3.1415927410125732, 0.5760129690170288, 0.6430304050445557], 'mean': [1.643947720527649, 0.03388536348938942, -0.31948330998420715], 'min': [-3.141592025756836, -0.9390894770622253, -1.8157784938812256], 'q01': [-3.1398682594299316, -0.7223533391952515, -1.5468647480010986], 'q99': [3.140272855758667, 0.3558599054813385, 0.3583984971046448], 'std': [2.2715108394622803, 0.19258549809455872, 0.5275800824165344], }, 'gripper': { 'max': [0.07569462060928345], 'min': [0.00026726332725957036], 'q01': [0.0019378233700990677], 'q99': [0.07565194368362427], }, 'joints': { 'max': [ 0.5831676125526428, 0.8024097681045532, 1.350082516670227, -1.1145747900009155, 0.5345654487609863, 3.2108523845672607, 2.875584363937378, ], 'mean': [ -0.09772995859384537, 0.05997378006577492, -0.009316973388195038, -2.3428714275360107, -0.48651641607284546, 2.2390215396881104, 1.3341773748397827, ], 'min': [ -0.6990927457809448, -0.8311276435852051, -0.5323343276977539, -2.9861886501312256, -2.060809850692749, 0.826747477054596, 0.05972049757838249, ], 'q01': [ -0.532222330570221, -0.5078368782997131, -0.3437865376472473, -2.784912586212158, -1.8685580492019653, 1.3974121809005737, 0.5157660841941833, ], 'q99': [ 0.35958197712898254, 0.6339818835258484, 0.5833610892295837, -1.612648367881775, 0.2780923545360565, 2.8907132148742676, 2.837320327758789, ], 'std': [ 0.2101736217737198, 0.2880096137523651, 0.17824982106685638, 0.23544654250144958, 0.6387099027633667, 0.33533376455307007, 0.5090957880020142, ], }, 'translation': { 'max': [0.581696093082428, 0.4334338903427124, 0.6800903081893921], 'mean': [0.4331520199775696, -0.10096638649702072, 0.22044037282466888], 'min': [0.2454746514558792, -0.502901017665863, 0.008792983368039131], 'q01': [0.3217194080352783, -0.4733337163925171, 0.014122226275503635], 'q99': [0.5321439504623413, 0.3733823001384735, 0.5785381197929382], 'std': [0.04380597919225693, 0.20559938251972198, 0.12614832818508148], }, }, 'viola': { 'euler': { 'max': [3.141592502593994, 0.9555190801620483, 0.40456175804138184], 'mean': [-0.7171992063522339, 0.07086259871721268, -0.67817223072052], 'min': [-3.141592025756836, -0.4305531978607178, -2.212939739227295], 'q01': [-3.1400232315063477, -0.19947049021720886, -1.8703945875167847], 'q99': [3.1401824951171875, 0.7830920219421387, 0.19578109681606293], 'std': [2.902841567993164, 0.20232577621936798, 0.7306717038154602], }, 'gripper': { 'max': [0.07748929411172867], 'min': [-5.3190000471659005e-05], 'q01': [0.00020356666937004775], 'q99': [0.07747221738100052], }, 'joints': { 'max': [ 0.3366159200668335, 0.8983011841773987, 0.330204576253891, 0.0, 1.3329579830169678, 3.1103708744049072, 2.8318748474121094, ], 'mean': [ 0.027196571230888367, 0.2290346622467041, -0.06033482402563095, -2.060992479324341, 0.25734394788742065, 2.2569527626037598, 1.2675725221633911, ], 'min': [ -0.4073199927806854, -0.23913456499576569, -0.5244941115379333, -2.744713068008423, -0.48249971866607666, 0.0, -0.22237232327461243, ], 'q01': [ -0.26899582147598267, -0.189721941947937, -0.3926161825656891, -2.620305061340332, -0.199931338429451, 1.6743353605270386, 0.1440846174955368, ], 'q99': [ 0.23294022679328918, 0.764011561870575, 0.20191603899002075, -1.5629768371582031, 1.039290189743042, 2.9412121772766113, 2.798232078552246, ], 'std': [ 0.10160040855407715, 0.21415969729423523, 0.14444759488105774, 0.27077943086624146, 0.3251941502094269, 0.338652640581131, 0.7529342174530029, ], }, 'translation': { 'max': [0.6868156790733337, 0.20999883115291595, 0.5037735104560852], 'mean': [0.5464076995849609, 0.006265466101467609, 0.23136523365974426], 'min': [0.0, -0.33860740065574646, 0.0], 'q01': [0.40061360597610474, -0.25196850299835205, 0.010269512422382832], 'q99': [0.6458418369293213, 0.17776551842689514, 0.4456312954425812], 'std': [0.062116123735904694, 0.12932434678077698, 0.13205111026763916], }, }, } @cached_property def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm): return {} with open(self.config.control_stats_path, 'r') as file: stats = yaml.safe_load(file) if self.config.delta_controls: if self.control_io_config.future_controls_sequence_stride_sec is None: horizon = 0.0 else: horizon = self.control_io_config.future_controls_sequence_stride_sec elif self.control_io_config.future_controls_sequence_stride_sec is None: if self.control_io_config.future_controls_sequence_length == 1: horizon = 0.0 else: raise NotImplementedError() else: horizon = ( self.control_io_config.future_controls_sequence_length * self.control_io_config.future_controls_sequence_stride_sec ) key = f'horizon_{round(horizon, 2)}s' if key in stats: stats = stats[key] else: raise ValueError( f'Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]' ) return stats @cached_property def dataset_names(self) -> List[str]: if ( is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.obs_rotation_norm) and is_global_norm(self.config.translation_norm) and is_global_norm(self.config.obs_translation_norm) ): return ['ANY'] return list(set(self._control_stats.keys()) | set(self._observation_stats.keys())) def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor: """ Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding w.r.t. the 0-th element preceding the sequence Ex: Sequence of points: T1, T2, T3, T4 `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame, implicitly encoded in T0T1 Output: T0T1, T0T2, T0T3, T0T4 Args: translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S corresponds to the sequence dimension Returns: torch.Tensor of the same shape as translation_sequence, containing delta translations """ assert translation_sequence.ndim >= 3, translation_sequence.shape delta_translations = torch.cumsum(translation_sequence, dim=-2) return delta_translations class RegressionProcessor(VLAMProcessor[RegressionProcessorConfig]): def policy_control_plan_from_model_target( self, target: RoboticsTarget, dataset_name: np.ndarray ) -> RoboticsControlPlan: translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key='translation') rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key='rotation') rotmat = convert_rotation(rotation, RotationFormat.ROTMAT) gripper_prob = target.gripper if self.config.delta_controls: translation_m = delta_to_relative_translations(translation_m) rotmat = delta_to_relative_rotations(rotmat) return RoboticsControlPlan( translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=target.valid_mask, ) def policy_control_plan_from_model_output( self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor ) -> RoboticsControlPlan: """Called during inference to create control plan from model output""" translation_m = self.unnormalize( model_output.translation, dataset_name=dataset_name, key='translation' ) rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key='rotation') rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True) gripper_prob = torch.sigmoid(model_output.gripper) if self.config.delta_controls: translation_m = delta_to_relative_translations(translation_m) rotmat = delta_to_relative_rotations(rotmat) return RoboticsControlPlan( translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=valid_mask ) class VLARMProcessor(RegressionProcessor): def __init__(self, config: VLARMProcessorConfig, vlm_processor: VLMProcessor): Configurable.__init__(self, config) self.vlm_processor = vlm_processor self.control_tokenizer = EmptyTokenizer( config=self.config.control_tokenizer_config, tokenizer=self.tokenizer ) self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { 'obs_translation': self.obs_translation_norm_bounds, 'obs_rotation': self.obs_rotation_norm_bounds, 'translation': self.translation_norm_bounds, 'rotation': self.rotation_norm_bounds, 'joints': self.joints_norm_bounds, } @property def dataset_np_names(self) -> np.ndarray: return self._dataset_np_names def _set_dataset_np_names(self, dataset_np_names: np.ndarray) -> None: self._dataset_np_names = dataset_np_names def preprocess_inputs( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]], ee_pose_translation: np.ndarray, ee_pose_rotation: np.ndarray, gripper: np.ndarray, joints: np.ndarray, dataset_name: np.ndarray, inference_mode: bool, control_target: Optional[RoboticsTarget] = None, ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) images: Dict[str, torch.Tensor] = inputs['images'] input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length] target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length] attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) gripper = preprocess_gripper_observation(gripper, dataset_name) gripper = torch.tensor(gripper, dtype=torch.float32) ee_pose_translation = self.normalize( ee_pose_translation, dataset_name=dataset_name, key='obs_translation' ) ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation') joints = torch.tensor(joints, dtype=torch.float32) if joints.shape[-1] < 7: missing_size = 7 - joints.shape[-1] joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) joints = self.normalize(joints, dataset_name=dataset_name, key='joints') return { 'images': images, 'input_ids': input_ids, 'target_text_tokens_ids': target_text_tokens_ids, 'attn_mask': attn_mask, 'ee_pose_translation': ee_pose_translation, 'ee_pose_rotation': ee_pose_rotation, 'gripper': gripper, 'joints': joints, 'control_tokens_ids': None, 'target_control_tokens_ids': None, } def model_otoi_but_gt_gripper( self, translation_output: torch.Tensor, rotation_output: torch.Tensor, reference_translation: torch.Tensor, reference_rotation: torch.Tensor, gt_gripper: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Called during inference to create control plan from model output""" translation_delta = self.unnormalize( translation_output, dataset_name=self.dataset_np_names, key='translation' ) rotation_delta = self.unnormalize(rotation_output, dataset_name=self.dataset_np_names, key='rotation') translation_reference = self.unnormalize( reference_translation, dataset_name=self.dataset_np_names, key='obs_translation' ) rotation_reference = self.unnormalize( reference_rotation, dataset_name=self.dataset_np_names, key='obs_rotation' ) translation = translation_reference + translation_delta rotation = delta_to_world_rotations(rotation_delta, rotation_reference) translation_next = self.normalize( translation, dataset_name=self.dataset_np_names, key='obs_translation' ) rotation_next = self.normalize(rotation, dataset_name=self.dataset_np_names, key='obs_rotation') gripper_next = gt_gripper return translation_next, rotation_next, gripper_next def make_causal_mask(shape: Sequence[int]) -> torch.Tensor: """ Create a causal attention mask of shape `shape` Args: shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len] Returns: torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend to the corresponding column (i.e. key) Example: shape = (3, 5) -> Mask the upper triangular part [ [ 1, 0, 0, 0, 0], [ 1, 1, 0, 0, 0], [ 1, 1, 1, 0, 0] ] """ return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0) def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor: """ Enable full bi-directional attention in `attn_mask` inside specific blocks Args: attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool where False values indicate disabled attention full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate positions where full bi-directional attention should be enabled Example: 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, """ assert full_attn.dtype == torch.bool, full_attn.dtype assert full_attn.ndim == 1, full_attn.shape assert full_attn.shape[0] == attn_mask.shape[-2], f'{full_attn.shape[0]}, {attn_mask.shape}' if attn_mask.shape[-1] != attn_mask.shape[-2]: raise NotImplementedError('Only self-attention supported right now.') x = full_attn.view(-1, 1) & full_attn.view(1, -1) x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]]) x = torch.cumprod(x, dim=1).to(dtype=torch.bool) x = x & x.permute(1, 0) mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn mask_indices = torch.where(mask_positions)[0] x[mask_indices, mask_indices] = 0 attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1]) return attn_mask IGNORE_INDEX = -100 class PaliGemmaProcessor(VLMProcessor[PaliGemmaProcessorConfig]): def __init__( self, config: PaliGemmaProcessorConfig, hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, **kwargs, ): del kwargs super().__init__(config) self.hf_processor = hf_processor self.hf_processor.image_processor.size = dict(self.config.image_sizes['main'].as_json()) self.hf_processor.image_seq_length = self.config.num_image_tokens['main'] self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens['main'] self.bos_id: int = self.tokenizer.bos_token_id self.eos_id: int = self.tokenizer.eos_token_id self.sep_token = '\n' self.sep_id: int = self.tokenizer( self.sep_token, padding=False, add_special_tokens=False, return_attention_mask=False )['input_ids'][0] self.image_token_id: int = self.tokenizer( self.config.image_token, padding=False, add_special_tokens=False, return_attention_mask=False )['input_ids'][0] self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values()) self.bbox_pattern = re.compile( '\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]' ) def preprocess_inputs( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: """ Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs Chat must be always made of separate messages from user and model, always starting with user ... ... Args: chat: List[str] of even size where each entry corresponds to a different turn in the conversation images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys in the Dict and the List corresponds to history of images """ for key, value in images.items(): if not isinstance(value, list): raise TypeError(f'Camera {key} contains values of type {type(value)} instead of list') (input_ids, target_ids) = ([], []) for i, text in enumerate(chat): text = text.replace(self.sep_token, ' ').replace('', '') text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text) turn_input_ids: List[int] = self.tokenizer( text, padding=False, add_special_tokens=False, return_attention_mask=False )['input_ids'] if i % 2 == 0: turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids) else: turn_target_ids = turn_input_ids if i != len(chat) - 1: turn_input_ids = turn_input_ids + [self.sep_id] turn_target_ids = turn_target_ids + [IGNORE_INDEX] input_ids = input_ids + turn_input_ids target_ids = target_ids + turn_target_ids input_ids = [self.bos_id] + input_ids + [self.eos_id] target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id] image_tokens = self.image_tokens input_ids = image_tokens + input_ids target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids input_ids = torch.tensor(input_ids, dtype=torch.int64) target_ids = torch.tensor(target_ids, dtype=torch.int64) image_tensors: Dict[str, torch.Tensor] = { f'{camera_name}.siglip': self.hf_processor.image_processor( camera_images, size=self.config.image_sizes[camera_name].as_json(), return_tensors='pt' )['pixel_values'] for (camera_name, camera_images) in images.items() } attn_mask = make_causal_mask([len(input_ids), len(input_ids)]) attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX) return { 'input_ids': input_ids, 'target_ids': target_ids, 'images': image_tensors, 'attn_mask': attn_mask, } @property def tokenizer(self) -> transformers.PreTrainedTokenizerBase: return self.hf_processor.tokenizer @staticmethod def _bbox_to_loc_tokens(match: str) -> str: """ https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ """ floats = list(map(float, match.groups())) transformed = [f'' for num in floats] return f"[{', '.join(transformed)}]" @property def image_sizes(self) -> Dict[str, ImageSizeConfig]: return self.config.image_sizes