Instructions to use INSAIT-Institute/arvla-bridge with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use INSAIT-Institute/arvla-bridge with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("INSAIT-Institute/arvla-bridge", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import collections | |
| import collections.abc | |
| import importlib | |
| import re | |
| import warnings | |
| from abc import abstractmethod | |
| from functools import cached_property | |
| from typing import Dict, List, Optional, Sequence, Tuple, TypeVar | |
| import numpy as np | |
| import PIL.Image | |
| import roma | |
| import torch | |
| import torchvision.transforms.v2 | |
| import transformers | |
| import yaml | |
| try: | |
| from .hf_compat import Configurable, Template | |
| except Exception: | |
| _hf_compat = importlib.import_module("src.hf_compat") | |
| Configurable = _hf_compat.Configurable | |
| Template = _hf_compat.Template | |
| from .common_vlarm import ( | |
| Normalization, | |
| ResizeMode, | |
| RoboticsControlPlan, | |
| RoboticsInput, | |
| RoboticsOutput, | |
| RoboticsTarget, | |
| RotationFormat, | |
| expand_dims, | |
| ) | |
| from .configuration_vlarm import ( | |
| ControlDataIOConfig, | |
| ControlTokenizerConfig, | |
| EmptyTokenizerConfig, | |
| ImageSizeConfig, | |
| PaliGemmaProcessorConfig, | |
| RegressionProcessorConfig, | |
| VLAMProcessorConfig, | |
| VLARMProcessorConfig, | |
| VLMProcessorConfig, | |
| ) | |
| ControlTokenizerConfigT = TypeVar('ControlTokenizerConfigT', bound=ControlTokenizerConfig) | |
| class ControlTokenizer(Configurable[ControlTokenizerConfigT], Template[ControlTokenizerConfigT]): | |
| def __call__(self, *args, **kwargs) -> str: | |
| """Given GT actions and possibly other information, output text control. Gets appened to the prompt""" | |
| class EmptyTokenizer(ControlTokenizer[EmptyTokenizerConfig]): | |
| """ | |
| Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the | |
| desired result. Includes the hidden states for the image tokens. | |
| """ | |
| def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None: | |
| super().__init__(config) | |
| self.tokenizer = tokenizer | |
| def __call__(self, *_) -> str: | |
| return '' | |
| def np_unique(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Compute unique elements in data and corresponding indices. | |
| np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply | |
| run np.unique on unsorted data, the indices you will get will be invalid. | |
| """ | |
| (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True) | |
| (_, indices_of_first_occurence, inverse_indices, counts) = np.unique( | |
| indices[inverse], return_index=True, return_inverse=True, return_counts=True | |
| ) | |
| unique_ids = data[indices_of_first_occurence] | |
| return unique_ids, indices_of_first_occurence, inverse_indices, counts | |
| def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| angles: Euler angles in radians in the format 'xyz', shape [..., 3] | |
| Returns: | |
| torch.Tensor of shape [..., 3, 3] containing rotation matrices | |
| """ | |
| return roma.euler_to_rotmat(convention='xyz', angles=angles, degrees=False) | |
| def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| angles: Euler angles in radians in the format 'xyz', shape [..., 3] | |
| Returns: | |
| torch.Tensor of shape [..., 4] containing unit quaternions | |
| """ | |
| return roma.euler_to_unitquat(convention='xyz', angles=angles, degrees=False, normalize=True) | |
| def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor: | |
| """ | |
| Args: | |
| quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4] | |
| eps: Small constant to prevent division by zero | |
| Returns: | |
| torch.Tensor of shape [..., 4] of unit quaternions | |
| """ | |
| return quaternion / (quaternion.norm(dim=-1, keepdim=True) + eps) | |
| def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized | |
| Returns: | |
| torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) | |
| """ | |
| unit_quat = normalize_quaternion(quaternion) | |
| rotmat = roma.unitquat_to_euler(convention='xyz', quat=unit_quat, as_tuple=False, degrees=False) | |
| return rotmat | |
| def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized | |
| Returns: | |
| torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) | |
| """ | |
| unit_quat = normalize_quaternion(quaternion) | |
| rotmat = roma.unitquat_to_rotmat(unit_quat) | |
| return rotmat | |
| def is_quaternion(quaternion: torch.Tensor) -> bool: | |
| return quaternion.shape[-1] == 4 | |
| def is_unit_quaternion( | |
| quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = 'none' | |
| ) -> torch.Tensor | bool: | |
| """ | |
| Check if a quternion is normalized or not. | |
| Args: | |
| quaternion: torch.Tensor of shape [..., 4] | |
| tolerance: Tolerance for numerical comparisons | |
| reduction: | |
| 'none' - returns torch.Tensor of bools with the same batch shape | |
| 'all' - returns a bool, True if ALL quaternions in the batch are normalized | |
| Returns: | |
| torch.Tensor with the same batch shape or bool | |
| """ | |
| assert is_quaternion(quaternion) | |
| is_norm = torch.isclose( | |
| quaternion.norm(dim=-1, keepdim=True), | |
| torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device), | |
| atol=epsilon, | |
| ) | |
| if reduction == 'none': | |
| return is_norm | |
| if reduction == 'all': | |
| return bool(torch.all(is_norm).item()) | |
| raise ValueError(f'Unknown reduction mode {reduction}') | |
| def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion. | |
| If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically, | |
| this doesn't correspond to a single hemisphere of the unit sphere. Follows | |
| https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat | |
| """ | |
| assert is_quaternion(quaternion), quaternion.shape | |
| with torch.no_grad(): | |
| is_zero = quaternion == 0 | |
| flip_condition = ( | |
| (quaternion[..., -1:] < 0) | |
| | is_zero[..., -1:] & (quaternion[..., 0:1] < 0) | |
| | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0) | |
| | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0) | |
| ) | |
| quaternion = torch.where(flip_condition, -quaternion, quaternion) | |
| return quaternion | |
| def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor: | |
| """Convert any rotmat input to [..., 3, 3] shape""" | |
| if rotmat.shape[-1] == 9: | |
| return rotmat.reshape(*rotmat.shape[:-1], 3, 3) | |
| if rotmat.shape[-2:] == torch.Size([3, 3]): | |
| return rotmat | |
| raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") | |
| def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| rotmat: Batch of rotation matrices, shape [..., 3, 3] | |
| Returns: | |
| Batch of Euler angles in radiant, shape [..., 3] | |
| """ | |
| rotmat = rotmat_as_3x3(rotmat) | |
| return roma.rotmat_to_euler(convention='xyz', rotmat=rotmat, as_tuple=False, degrees=False) | |
| def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| rotmat: Batch of rotation matrices, shape [..., 3, 3] | |
| Returns: | |
| Batch of unit quaternions, shape [..., 4] | |
| """ | |
| rotmat = rotmat_as_3x3(rotmat) | |
| return roma.rotmat_to_unitquat(rotmat) | |
| def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Maps 9D input vectors onto SO(3) via symmetric orthogonalization. | |
| - Let SVD(M) = U \Sigma V^T | |
| - Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T | |
| - det(UV^T) ensures that det(SVD+(M)) = 1 | |
| - The return value is a rotation matrix (ortonormal) with the least-squares distance to M | |
| Args: | |
| x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3] | |
| Returns: | |
| torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3) | |
| """ | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings( | |
| 'ignore', message='In CPU autocast, but the target dtype is not supported. Disabling autocast.' | |
| ) | |
| with torch.autocast(device_type=x.device.type, dtype=torch.float32): | |
| matrices = x.view(-1, 3, 3) | |
| matrices = matrices.to(dtype=torch.float32) | |
| (u, s, v) = torch.svd(matrices) | |
| vt = torch.transpose(v, 1, 2) | |
| det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1) | |
| diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1) | |
| result = torch.matmul(u, diag_vt) | |
| result = result.view(*x.shape) | |
| result = result.to(dtype=x.dtype) | |
| return result | |
| def is_rotmat_3x3(rotmat: torch.Tensor) -> bool: | |
| return rotmat.shape[-2:] == torch.Size([3, 3]) | |
| def is_rotmat_9(rotmat: torch.Tensor) -> bool: | |
| return rotmat.shape[-1] == 9 | |
| def is_rotmat(rotmat: torch.Tensor) -> bool: | |
| """ | |
| Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a | |
| valid rotmat. `is_orthonormal_rotmat` performs this additional check. | |
| NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally | |
| `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution | |
| """ | |
| return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat) | |
| def is_rotmat_orthonormal( | |
| rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = 'none' | |
| ) -> torch.Tensor | bool: | |
| """ | |
| Check if a rotation matrix is orthonormal or not. | |
| Args: | |
| rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9] | |
| epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally, | |
| anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not | |
| reduction: | |
| 'none' - returns torch.Tensor of bools with the same batch shape | |
| 'all' - returns a bool, True is ALL matrices in the batch are orthonormal | |
| Returns: | |
| torch.Tensor with the same batch shape or bool | |
| """ | |
| assert is_rotmat(rotmat) | |
| rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32)) | |
| is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon) | |
| if reduction == 'none': | |
| return is_orthonormal | |
| if reduction == 'all': | |
| return bool(torch.all(is_orthonormal).item()) | |
| raise ValueError(f'Unknown reduction mode {reduction}') | |
| def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool: | |
| """ | |
| Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3, | |
| also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles | |
| when accidentally `rotmat.shape[-2:] == [3, 3]` | |
| """ | |
| return ( | |
| is_rotmat_9(rotmat) | |
| or is_rotmat_3x3(rotmat) | |
| and is_rotmat_orthonormal(rotmat, epsilon=0.02, reduction='all') | |
| ) | |
| def is_euler(euler: torch.Tensor) -> bool: | |
| return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler) | |
| def rotation_format_from_tensor(rotation) -> RotationFormat: | |
| if is_quaternion(rotation): | |
| return RotationFormat.QUATERNION | |
| if is_orthonormal_rotmat(rotation): | |
| return RotationFormat.ROTMAT | |
| if is_euler(rotation): | |
| return RotationFormat.EULER | |
| raise ValueError(f'Tensor shape {rotation.shape} is not a valid rotation format') | |
| def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor: | |
| """Convert any rotmat input to [..., 9] shape""" | |
| if is_rotmat_9(rotmat): | |
| return rotmat | |
| if is_rotmat_3x3(rotmat): | |
| return rotmat.reshape(*rotmat.shape[:-2], 9) | |
| raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") | |
| def convert_rotation( | |
| rotation: torch.Tensor | np.ndarray, | |
| output_format: RotationFormat, | |
| autonorm: bool = True, | |
| half_cover: bool = True, | |
| ) -> torch.Tensor | np.ndarray: | |
| is_np = isinstance(rotation, np.ndarray) | |
| if is_np: | |
| rotation = torch.from_numpy(rotation) | |
| if is_quaternion(rotation): | |
| if autonorm and not is_unit_quaternion(rotation, reduction='all'): | |
| rotation = normalize_quaternion(rotation) | |
| if output_format == RotationFormat.QUATERNION: | |
| output = rotation | |
| elif output_format == RotationFormat.ROTMAT: | |
| output = rotmat_as_9(quaternion_to_rotmat(rotation)) | |
| elif output_format == RotationFormat.EULER: | |
| output = quaternion_to_euler(rotation) | |
| else: | |
| raise NotImplementedError(f'Unsupported rotation format: {output_format}') | |
| elif is_orthonormal_rotmat(rotation): | |
| if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction='all'): | |
| rotation = symmetric_orthogonalization(rotation) | |
| if output_format == RotationFormat.QUATERNION: | |
| output = rotmat_to_unit_quaternion(rotation) | |
| elif output_format == RotationFormat.ROTMAT: | |
| output = rotmat_as_9(rotation) | |
| elif output_format == RotationFormat.EULER: | |
| output = rotmat_to_euler(rotation) | |
| else: | |
| raise NotImplementedError(f'Unsupported rotation format: {output_format}') | |
| elif is_euler(rotation): | |
| if output_format == RotationFormat.QUATERNION: | |
| output = euler_to_unit_quaternion(rotation) | |
| elif output_format == RotationFormat.ROTMAT: | |
| output = rotmat_as_9(euler_to_rotmat(rotation)) | |
| elif output_format == RotationFormat.EULER: | |
| output = rotation | |
| else: | |
| raise NotImplementedError(f'Unsupported rotation format: {output_format}') | |
| else: | |
| raise ValueError(f'Unknown rotation encoding with shape {rotation.shape}') | |
| if output_format == RotationFormat.QUATERNION and half_cover: | |
| output = quaternion_half_cover(output) | |
| if is_np: | |
| output = output.numpy() | |
| return output | |
| def delta_to_world_rotations(rotation_sequence: torch.Tensor, reference_frame: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Transform a sequence of rotation representations encoded w.r.t. `reference_frame` to encoding w.r.t. | |
| WORLD frame, where `reference_frame` is provided w.r.t. WORLD frame | |
| Ex: | |
| Sequence of points (rotations): R_1, R_2, R_3, R_4 | |
| `rotation_sequence` contains the rotations: R_01, R_02, R_03, R_04, where R_0 is the reference frame | |
| and R_01 is the pose of R1 frame in reference frame, i.e. R_10 converts from reference frame to R1 frame | |
| Output: R_W1, R_W2, R_W3, R_W4, where W is the world frame, i.e. the rotation poses of | |
| R_1, R_2, R_3, R_4 expressed in world frame | |
| Args: | |
| rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing | |
| either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions | |
| reference_frame: torch.Tensor, shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] and the SAME number of BATCH | |
| dims as `rotation_sequence`. The reference frame, provided w.r.t. WORLD coordinate frame R_W0,1,2,3 | |
| Returns: | |
| torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations | |
| (R_W1, R_W2, R_W3, R_W4, ...) | |
| """ | |
| assert rotation_sequence.ndim >= 3, rotation_sequence.shape | |
| rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) | |
| reference_frame = rotmat_as_3x3(convert_rotation(reference_frame, RotationFormat.ROTMAT)) | |
| rotation_sequence = rotmat_as_3x3(convert_rotation(rotation_sequence, RotationFormat.ROTMAT)) | |
| if reference_frame.ndim != rotation_sequence.ndim: | |
| raise ValueError( | |
| f'Cannot broadcast reference_frame of shape {reference_frame.shape} to rotation_sequence of shape {rotation_sequence.shape}. Provide tensors with the same number of batch dimensions' | |
| ) | |
| R_W0 = reference_frame | |
| world_rotations = torch.matmul(R_W0, rotation_sequence) | |
| world_rotations = world_rotations.view(*world_rotations.shape[:-2], 9) | |
| world_rotations = convert_rotation(world_rotations, rotation_format) | |
| return world_rotations | |
| def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the | |
| sequence to the 0-th element preceding the sequence | |
| Ex: | |
| `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame, | |
| implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame | |
| Output: R_01, R_02, R_03, R_04 | |
| Args: | |
| rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing | |
| either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions | |
| Returns: | |
| torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations | |
| (R_01, R_02, R_03, R_04, ...) | |
| TODO: Can you make it work without for loop | |
| """ | |
| assert rotation_sequence.ndim >= 3, rotation_sequence.shape | |
| rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) | |
| rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION) | |
| batch_dims = np.arange(rotation_sequence.ndim - 2) | |
| delta_rotations = torch.cat( | |
| [rotation_sequence[..., :1, :]] | |
| + [ | |
| roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2)) | |
| for i in range(2, rotation_sequence.shape[-2] + 1) | |
| ], | |
| dim=-2, | |
| ) | |
| delta_rotations = convert_rotation(delta_rotations, rotation_format) | |
| return delta_rotations | |
| def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray: | |
| """Make sure image is of type np.ndarray and HWC format""" | |
| if isinstance(image, PIL.Image.Image): | |
| image = np.asarray(image) | |
| assert isinstance(image, np.ndarray), type(image) | |
| assert image.ndim in [2, 3], image.shape | |
| if image.ndim == 3: | |
| assert image.shape[-1] <= 4, image.shape | |
| return image | |
| def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]: | |
| if isinstance(image, np.ndarray): | |
| (height, width) = image.shape[:2] | |
| else: | |
| (width, height) = image.size | |
| return height, width | |
| def pad_image( | |
| image: PIL.Image.Image | np.ndarray, | |
| target_size: dict[str, int], | |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, | |
| ) -> PIL.Image.Image | np.ndarray: | |
| """Pad image adding a symmetric border around the height/width.""" | |
| assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) | |
| (height, width) = hw_from_image(image) | |
| (target_width, target_height) = (target_size['width'], target_size['height']) | |
| if width == target_width and height == target_height: | |
| return image | |
| assert target_width >= width, f"Can't pad image of width {width} to {target_width}" | |
| assert target_height >= height, f"Can't pad image of height {height} to {target_height}" | |
| (horizontal_pad, vertical_pad) = (int((target_width - width) / 2), int((target_height - height) / 2)) | |
| if isinstance(image, np.ndarray): | |
| padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * ( | |
| image.ndim - 2 | |
| ) | |
| image = np.pad(image, padding, mode='constant', constant_values=pad_value) | |
| else: | |
| padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) | |
| image = torchvision.transforms.v2.functional.pad( | |
| image, padding=padding, fill=pad_value, padding_mode='constant' | |
| ) | |
| return image | |
| def pad_image_to_ratio( | |
| image: PIL.Image.Image | np.ndarray, | |
| target_wh_ratio: float, | |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, | |
| ) -> PIL.Image.Image | np.ndarray: | |
| """Pad image to a target aspect ratio.""" | |
| (height, width) = hw_from_image(image) | |
| wh_ratio = width / height | |
| if target_wh_ratio >= wh_ratio: | |
| pad_size = {'width': round(height * target_wh_ratio), 'height': height} | |
| else: | |
| pad_size = {'width': width, 'height': round(width / target_wh_ratio)} | |
| image = pad_image(image, target_size=pad_size, pad_value=pad_value) | |
| return image | |
| def crop_image( | |
| image: np.ndarray | PIL.Image.Image, | |
| start_height: int, | |
| start_width: int, | |
| target_height: int, | |
| target_width: int, | |
| ) -> np.ndarray | PIL.Image.Image: | |
| np_image = assert_np_hwc_or_hw_image(image) | |
| (height, width) = hw_from_image(image) | |
| assert target_width <= width, f"Can't crop image of width {width} to {target_width}" | |
| assert target_height <= height, f"Can't crop image of width {height} to {target_height}" | |
| (start_height, start_width) = (round(start_height), round(start_width)) | |
| (target_height, target_width) = (round(target_height), round(target_width)) | |
| np_image = np_image[ | |
| start_height : start_height + target_height, start_width : start_width + target_width, ... | |
| ] | |
| image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image | |
| return image | |
| def crop_image_center( | |
| image: np.ndarray | PIL.Image.Image, target_size: dict[str, int] | |
| ) -> np.ndarray | PIL.Image.Image: | |
| np_image = assert_np_hwc_or_hw_image(image) | |
| (height, width) = np_image.shape[:2] | |
| (target_height, target_width) = (target_size['height'], target_size['width']) | |
| assert target_width <= width, f"Can't crop image of width {width} to {target_width}" | |
| assert target_height <= height, f"Can't crop image of width {height} to {target_height}" | |
| top = (height - target_height) // 2 | |
| left = (width - target_width) // 2 | |
| np_image = crop_image(np_image, top, left, target_height, target_width) | |
| image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image | |
| return image | |
| def crop_image_to_ratio( | |
| image: PIL.Image.Image | np.ndarray, target_wh_ratio: float | |
| ) -> PIL.Image.Image | np.ndarray: | |
| """Pad image to a target aspect ratio.""" | |
| (height, width) = hw_from_image(image) | |
| wh_ratio = width / height | |
| if target_wh_ratio >= wh_ratio: | |
| crop_size = {'width': width, 'height': round(width / target_wh_ratio)} | |
| else: | |
| crop_size = {'width': round(height * target_wh_ratio), 'height': height} | |
| image = crop_image_center(image, target_size=crop_size) | |
| return image | |
| def crop_and_pad_image_to_ratio( | |
| image: PIL.Image.Image | np.ndarray, | |
| target_wh_ratio: float, | |
| mode: ResizeMode | str, | |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, | |
| ) -> PIL.Image.Image | np.ndarray: | |
| """ | |
| Crop and pad an image to a target size depending on the mode. | |
| It's expected that the source image and target size have different aspect ratios. | |
| Args: | |
| image: The image to crop and pad. | |
| target_size: The target size to crop and pad the image to. | |
| mode: The mode to use for cropping and padding. | |
| """ | |
| (height, width) = hw_from_image(image) | |
| wh_ratio = width / height | |
| if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001): | |
| return image | |
| if mode == ResizeMode.SMART: | |
| aspect_ratio = max(width, height) / min(width, height) | |
| target_ratio = max(target_wh_ratio, 1 / target_wh_ratio) | |
| if aspect_ratio == 1: | |
| if target_ratio >= 4 / 3 - 0.01: | |
| crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 | |
| image = crop_image_to_ratio(image, crop_wh_ratio) | |
| else: | |
| pass | |
| elif aspect_ratio <= 4 / 3 + 0.01: | |
| if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): | |
| image = crop_image_to_ratio(image, 1.0) | |
| elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): | |
| image = crop_image_to_ratio(image, 1.0) | |
| elif target_ratio >= 4 / 3 + 0.01: | |
| pass | |
| else: | |
| crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 | |
| image = crop_image_to_ratio(image, crop_wh_ratio) | |
| image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) | |
| elif mode == ResizeMode.PAD: | |
| image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) | |
| elif mode == ResizeMode.CROP: | |
| image = crop_image_to_ratio(image, target_wh_ratio) | |
| else: | |
| raise ValueError(f'Mode {mode} not supported') | |
| return image | |
| def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool: | |
| if isinstance(image, PIL.Image.Image): | |
| return image.mode in ['1', 'L', 'LA', 'La', 'P', 'PA', 'F', 'I', 'I;16', 'I;16L', 'I;16B', 'I;16N'] | |
| if isinstance(image, np.ndarray): | |
| return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1 | |
| raise ValueError(f'Unsupported image type: {type(image)}') | |
| def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool: | |
| image = np.asarray(image) | |
| return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1 | |
| def resize_image( | |
| image: PIL.Image.Image | np.ndarray, | |
| target_size: dict[str, int], | |
| mode: ResizeMode | str, | |
| resample: PIL.Image.Resampling | str = 'auto', | |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, | |
| ) -> PIL.Image.Image | np.ndarray: | |
| (target_width, target_height) = (target_size['width'], target_size['height']) | |
| (height, width) = hw_from_image(image) | |
| if height == target_height and width == target_width: | |
| return image | |
| if resample == 'auto': | |
| if is_single_channel_image(image): | |
| resample = PIL.Image.Resampling.BILINEAR | |
| else: | |
| resample = PIL.Image.Resampling.LANCZOS | |
| else: | |
| assert isinstance(resample, PIL.Image.Resampling), resample | |
| if is_single_channel_image(image) and resample not in [ | |
| PIL.Image.Resampling.BILINEAR, | |
| PIL.Image.Resampling.BICUBIC, | |
| ]: | |
| raise ValueError( | |
| f'Single channel images must be resized with bilinear or bicubic, but got {resample}' | |
| ) | |
| if is_bin_mask := is_binary_mask(image): | |
| image = np.asarray(image).astype(np.uint8) * 255 | |
| if mode == ResizeMode.SMART: | |
| image = crop_and_pad_image_to_ratio( | |
| image, target_wh_ratio=target_width / target_height, mode=mode, pad_value=pad_value | |
| ) | |
| pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image | |
| if mode in [ResizeMode.NAIVE, ResizeMode.SMART]: | |
| pil_image = pil_image.resize((target_width, target_height), resample=resample) | |
| else: | |
| raise NotImplementedError(f'Mode {mode} not supported') | |
| image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image | |
| if is_bin_mask: | |
| image = image.astype(np.uint8) > 127 | |
| return image | |
| def is_global_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool: | |
| """Return true if norm is NONE or global for all datasets""" | |
| return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping) | |
| def is_mean_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool: | |
| """Return true if norm is based on mean and std""" | |
| return ( | |
| norm == Normalization.MEAN | |
| or isinstance(norm, collections.abc.Mapping) | |
| and set(norm.keys()) == {'mean', 'std'} | |
| ) | |
| def _broadcast_shapes( | |
| value: torch.Tensor, low: torch.Tensor, high: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Broadcast shapes for normalization: | |
| Args: | |
| value: torch.Tensor of shape [..., num_components]. The entire shape might be: | |
| - [num_components]: `value` has no batch dimension | |
| - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds | |
| contained in `low` and `high` | |
| - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds | |
| contained in `low` and `high` | |
| - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high` | |
| must be for a single dataset, i.e. `num_datasets = 1` | |
| low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low` | |
| contains normalization bounds for a single dataset | |
| high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high` | |
| contains normalization bounds for a single dataset | |
| Returns: | |
| Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value` | |
| """ | |
| assert low.ndim == high.ndim == 2, f'{low.shape} != {high.shape} or ndim != 2' | |
| assert value.shape[-1] == low.shape[-1] == high.shape[-1], f'{value.shape} != {low.shape} / {high.shape}' | |
| if value.ndim == low.ndim == high.ndim: | |
| return low, high | |
| if value.ndim < low.ndim: | |
| assert low.ndim == high.ndim == 2, f'{low.shape}, {high.shape}' | |
| assert low.shape[0] == high.shape[0] == 1, f'{low.shape}, {high.shape}' | |
| (low, high) = (low.view(-1), high.view(-1)) | |
| return low, high | |
| if low.shape[0] == high.shape[0] == 1: | |
| low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1]) | |
| high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1]) | |
| else: | |
| assert value.shape[0] == low.shape[0] == high.shape[0], f'{value.shape} != {low.shape} / {high.shape}' | |
| low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1]) | |
| high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1]) | |
| return low, high | |
| def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: | |
| (mean, std) = _broadcast_shapes(value, mean, std) | |
| (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) | |
| return value * (std + 1e-08) + mean | |
| def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: | |
| (low, high) = _broadcast_shapes(value, low, high) | |
| (low, high) = (low.to(device=value.device), high.to(device=value.device)) | |
| return 0.5 * (value + 1) * (high - low) + low | |
| def normalize_gripper_by_bounds( | |
| value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True | |
| ) -> torch.Tensor: | |
| """ | |
| If binary, normalize to [0, 1], otherwise normalize to [-1, 1] | |
| """ | |
| (low, high) = _broadcast_shapes(value, low, high) | |
| (low, high) = (low.to(device=value.device), high.to(device=value.device)) | |
| if binary: | |
| return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0) | |
| return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) | |
| def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: | |
| (mean, std) = _broadcast_shapes(value, mean, std) | |
| (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) | |
| return (value - mean) / (std + 1e-08) | |
| def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: | |
| (low, high) = _broadcast_shapes(value, low, high) | |
| (low, high) = (low.to(device=value.device), high.to(device=value.device)) | |
| return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) | |
| def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray: | |
| if low < 0.0: | |
| return np.clip(-gripper, low, high) | |
| return high - np.clip(gripper, low, high) | |
| GRIPPER_BOUNDS = { | |
| 'austin_buds_dataset': (0.0, 0.08), | |
| 'austin_sailor_dataset': (0.0, 0.08), | |
| 'austin_sirius_dataset': (0.0, 0.08), | |
| 'bc_z': (0.0, 1.0), | |
| 'berkeley_autolab_ur5': (0.0, 1.0), | |
| 'berkeley_cable_routing': (0.0, 1.0), | |
| 'berkeley_fanuc_manipulation': (0.0, 1.0), | |
| 'bridge': (0.0, 1.0), | |
| 'bridge_orig': (0.0, 1.0), | |
| 'bridge_action_lang': (0.0, 1.0), | |
| 'cmu_stretch': (-3.0, 3.0), | |
| 'dlr_edan_shared_control': (0.0, 1.0), | |
| 'droid': (0.0, 1.0), | |
| 'fmb': (0.0, 1.0), | |
| 'fractal20220817_data': (0.0, 1.0), | |
| 'furniture_bench_dataset': (0.0, 0.08), | |
| 'iamlab_cmu_pickup_insert': (0.0, 1.0), | |
| 'jaco_play': (0.0, 1.4), | |
| 'kuka': (0.0, 1.0), | |
| 'language_table': (0.0, 1.0), | |
| 'nyu_franka_play_dataset': (0.0, 1.0), | |
| 'roboset': (0.0, 1.0), | |
| 'roboturk': (0.0, 1.0), | |
| 'stanford_hydra_dataset': (0.0, 0.08), | |
| 'taco_play': (0.0, 0.08), | |
| 'toto': (0.0, 1.0), | |
| 'ucsd_kitchen_dataset': (0.0, 1.0), | |
| 'utaustin_mutex': (0.0, 0.08), | |
| 'viola': (0.0, 0.08), | |
| } | |
| def preprocess_gripper_observation( | |
| gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True | |
| ) -> np.ndarray: | |
| """ | |
| Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset | |
| or from the robot and output is normalized continuous value. | |
| - if `binary`, output is in [0, 1], with 0 = closed and 1 = open. | |
| - otherwise, output is in [-1, 1], with -1 = closed and 1 = open. | |
| Dataset-specific gripper observations: | |
| austin_buds_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| austin_sailor_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| austin_sirius_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| bc_z: continuous; [0=open; 1=closed] | |
| berkeley_autolab_ur5: binary; [0=open; 1=closed] | |
| berkeley_cable_routing: constant (closed) | |
| berkeley_fanuc_manipulation: binary; [0=open; 1=closed] | |
| bridge: continuous; ~[0=closed; 1=open] | |
| bridge_orig: continuous; ~[0=closed; 1=open] | |
| cmu_stretch: continuous; [-3=closed; 3=open] | |
| dlr_edan_shared_control: missing | |
| droid: continuous; [0=open, 1=closed] | |
| fmb: binary; [0=open; 1=closed] | |
| fractal20220817_data: continuous; [0=open; 1=closed] | |
| furniture_bench_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| iamlab_cmu_pickup_insert: binary; [0=closed; 1=open] | |
| jaco_play: continuous; [0=open; 1.4=closed] | |
| kuka: binary; [0=open; 1=closed] | |
| language_table: constant (no gripper) | |
| nyu_franka_play_dataset: missing | |
| roboset: continuous; [0=open, 1=closed] | |
| roboturk: continuous; [0=closed, 0.04=open] | |
| stanford_hydra_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| taco_play: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| toto: constant (closed) | |
| ucsd_kitchen_dataset: missing | |
| utaustin_mutex: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| viola: continuous; ~[0=closed; 0.08=open] (franka gripper) | |
| """ | |
| if isinstance(dataset_name, np.ndarray): | |
| assert np.unique(dataset_name).size == 1, dataset_name | |
| dataset_name = str(dataset_name[0]) | |
| if dataset_name in [ | |
| 'berkeley_cable_routing', | |
| 'dlr_edan_shared_control', | |
| 'language_table', | |
| 'nyu_franka_play_dataset', | |
| 'toto', | |
| 'ucsd_kitchen_dataset', | |
| ]: | |
| gripper = normalize_gripper_by_bounds( | |
| torch.from_numpy(gripper), | |
| low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), | |
| high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), | |
| binary=binary, | |
| ).numpy() | |
| elif dataset_name in [ | |
| 'bc_z', | |
| 'berkeley_autolab_ur5', | |
| 'berkeley_fanuc_manipulation', | |
| 'droid', | |
| 'fmb', | |
| 'fractal20220817_data', | |
| 'jaco_play', | |
| 'kuka', | |
| 'roboset', | |
| ]: | |
| (low, high) = GRIPPER_BOUNDS[dataset_name] | |
| gripper = normalize_gripper_by_bounds( | |
| torch.from_numpy(invert_gripper(gripper, low=low, high=high)), | |
| low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), | |
| high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), | |
| binary=binary, | |
| ).numpy() | |
| elif dataset_name in [ | |
| 'austin_buds_dataset', | |
| 'austin_sailor_dataset', | |
| 'austin_sirius_dataset', | |
| 'bridge', | |
| 'bridge_orig', | |
| 'bridge_action_lang', | |
| 'cmu_stretch', | |
| 'furniture_bench_dataset', | |
| 'iamlab_cmu_pickup_insert', | |
| 'roboturk', | |
| 'stanford_hydra_dataset', | |
| 'taco_play', | |
| 'utaustin_mutex', | |
| 'viola', | |
| ]: | |
| (low, high) = GRIPPER_BOUNDS[dataset_name] | |
| gripper = normalize_gripper_by_bounds( | |
| torch.from_numpy(gripper), | |
| low=torch.full(gripper.shape, low, dtype=torch.float32), | |
| high=torch.full(gripper.shape, high, dtype=torch.float32), | |
| binary=binary, | |
| ).numpy() | |
| else: | |
| raise NotImplementedError(f'Unknown dataset: {dataset_name}') | |
| return gripper | |
| def rotation_norm_bounds( | |
| rotation_norm: Normalization, | |
| rotation_format: RotationFormat, | |
| stats: Dict[str, Dict[str, Dict[str, List[float]]]], | |
| dataset_names: List[str], | |
| ) -> Dict[str, Dict[str, torch.Tensor]]: | |
| if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE: | |
| if rotation_norm == Normalization.BOUNDS: | |
| results = { | |
| dataset_name: { | |
| 'low': torch.tensor(dataset_stats['euler']['min']), | |
| 'high': torch.tensor(dataset_stats['euler']['max']), | |
| } | |
| for (dataset_name, dataset_stats) in stats.items() | |
| } | |
| elif rotation_norm == Normalization.BOUNDS_Q99: | |
| results = { | |
| dataset_name: { | |
| 'low': torch.tensor(dataset_stats['euler']['q01']), | |
| 'high': torch.tensor(dataset_stats['euler']['q99']), | |
| } | |
| for (dataset_name, dataset_stats) in stats.items() | |
| } | |
| else: | |
| raise NotImplementedError(f'Normalization type {rotation_norm} not yet implemented') | |
| else: | |
| assert rotation_norm == Normalization.NONE, rotation_norm | |
| if rotation_format == RotationFormat.EULER: | |
| rotation_size = 3 | |
| elif rotation_format == RotationFormat.QUATERNION: | |
| rotation_size = 4 | |
| else: | |
| rotation_size = 9 | |
| results = { | |
| dataset_name: { | |
| 'low': -1 * torch.ones(rotation_size, dtype=torch.float32), | |
| 'high': 1 * torch.ones(rotation_size, dtype=torch.float32), | |
| } | |
| for dataset_name in dataset_names | |
| } | |
| return results | |
| def translation_norm_bounds( | |
| translation_norm: Normalization | tuple, | |
| stats: Dict[str, Dict[str, Dict[str, List[float]]]], | |
| dataset_names: List[str], | |
| ) -> Dict[str, Dict[str, torch.Tensor]]: | |
| if isinstance(translation_norm, Normalization) and translation_norm != Normalization.NONE: | |
| if translation_norm == Normalization.BOUNDS: | |
| results = { | |
| dataset_name: { | |
| 'low': torch.tensor(dataset_stats['translation']['min']), | |
| 'high': torch.tensor(dataset_stats['translation']['max']), | |
| } | |
| for (dataset_name, dataset_stats) in stats.items() | |
| } | |
| elif translation_norm == Normalization.BOUNDS_Q99: | |
| results = { | |
| dataset_name: { | |
| 'low': torch.tensor(dataset_stats['translation']['q01']), | |
| 'high': torch.tensor(dataset_stats['translation']['q99']), | |
| } | |
| for (dataset_name, dataset_stats) in stats.items() | |
| } | |
| elif translation_norm == Normalization.MEAN: | |
| results = { | |
| dataset_name: { | |
| 'mean': torch.tensor(dataset_stats['translation']['mean']), | |
| 'std': torch.tensor(dataset_stats['translation']['std']), | |
| } | |
| for (dataset_name, dataset_stats) in stats.items() | |
| } | |
| else: | |
| raise NotImplementedError(f'Normalization type {translation_norm} not yet implemented') | |
| elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE: | |
| results = { | |
| dataset_name: { | |
| 'low': -1 * torch.ones(3, dtype=torch.float32), | |
| 'high': 1 * torch.ones(3, dtype=torch.float32), | |
| } | |
| for dataset_name in dataset_names | |
| } | |
| else: | |
| assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm) | |
| assert all((len(value) == 3 for value in translation_norm.values())), translation_norm | |
| assert set(translation_norm.keys()) in ({'low', 'high'}, {'mean', 'std'}), translation_norm | |
| results = { | |
| dataset_name: { | |
| key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items() | |
| } | |
| for dataset_name in dataset_names | |
| } | |
| return results | |
| VLMProcessorConfigT = TypeVar('VLMProcessorConfigT', bound=VLMProcessorConfig) | |
| class VLMProcessor(Configurable[VLMProcessorConfigT], Template[VLMProcessorConfigT]): | |
| def preprocess_inputs( | |
| self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] | |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: | |
| ... | |
| def tokenizer(self) -> transformers.PreTrainedTokenizerBase: | |
| pass | |
| def image_sizes(self) -> Dict[str, ImageSizeConfig]: | |
| pass | |
| VLAMProcessorConfigT = TypeVar('VLAMProcessorConfigT', bound=VLAMProcessorConfig) | |
| class VLAMProcessor(Configurable[VLAMProcessorConfigT], Template[VLAMProcessorConfigT]): | |
| def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor): | |
| super().__init__(config) | |
| self.vlm_processor = vlm_processor | |
| self.control_tokenizer = EmptyTokenizer( | |
| config=self.config.control_tokenizer_config, tokenizer=self.tokenizer | |
| ) | |
| self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { | |
| 'obs_translation': self.obs_translation_norm_bounds, | |
| 'obs_rotation': self.obs_rotation_norm_bounds, | |
| 'translation': self.translation_norm_bounds, | |
| 'rotation': self.rotation_norm_bounds, | |
| 'joints': self.joints_norm_bounds, | |
| } | |
| def tokenizer(self) -> transformers.PreTrainedTokenizerBase: | |
| return self.vlm_processor.tokenizer | |
| def image_sizes(self) -> Dict[str, ImageSizeConfig]: | |
| return self.vlm_processor.image_sizes | |
| def camera_names(self) -> List[str]: | |
| return list(self.vlm_processor.image_sizes.keys()) | |
| def control_io_config(self) -> ControlDataIOConfig: | |
| return self.config.control_io_config | |
| def rotation_components(self) -> int: | |
| if self.config.rotation_format == RotationFormat.EULER: | |
| return 3 | |
| if self.config.rotation_format == RotationFormat.QUATERNION: | |
| return 4 | |
| if self.config.rotation_format == RotationFormat.ROTMAT: | |
| return 9 | |
| raise NotImplementedError(self.config.rotation_format) | |
| def policy_control_plan_from_model_target( | |
| self, target: RoboticsTarget, dataset_name: np.ndarray | |
| ) -> RoboticsControlPlan: | |
| pass | |
| def policy_control_plan_from_model_output( | |
| self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor | |
| ) -> RoboticsControlPlan: | |
| pass | |
| def resize_image( | |
| self, camera_name: str, image: PIL.Image.Image | np.ndarray | |
| ) -> PIL.Image.Image | np.ndarray: | |
| return resize_image( | |
| image, | |
| target_size={ | |
| 'width': self.image_sizes[camera_name].width, | |
| 'height': self.image_sizes[camera_name].height, | |
| }, | |
| mode=self.config.image_resize, | |
| resample=PIL.Image.Resampling.LANCZOS, | |
| ) | |
| def preprocess_inputs( | |
| self, | |
| chat: List[str], | |
| images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]], | |
| ee_pose_translation: np.ndarray, | |
| ee_pose_rotation: np.ndarray, | |
| gripper: np.ndarray, | |
| joints: np.ndarray, | |
| dataset_name: np.ndarray, | |
| inference_mode: bool, | |
| control_target: Optional[RoboticsTarget] = None, | |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: | |
| """ | |
| Preprocess the inputs for a single example | |
| Args: | |
| instruction: Language instruction | |
| images: History of input images with increasing timestamps | |
| ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3] | |
| ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9] | |
| joints: np.ndarray, shape [..., num_past_scalars, <= 7] | |
| dataset_name: 1D np.ndarray | |
| inference_mode: If True, prepare the input for inference (e.g. don't include target | |
| any tokens in the input if relevant). If control_target is available, it should | |
| still be preprocessed for test dataset comparison | |
| control_target: RoboticsTarget, each component of shape | |
| [..., num_control_steps, num_control_components]. Provided only when available, usually | |
| during training and dataset test | |
| Returns: | |
| Dict containing torch.Tensor with inputs | |
| """ | |
| del control_target | |
| del inference_mode | |
| inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) | |
| images: Dict[str, torch.Tensor] = inputs['images'] | |
| input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length] | |
| target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length] | |
| attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) | |
| ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) | |
| ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) | |
| ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) | |
| gripper = preprocess_gripper_observation(gripper, dataset_name) | |
| gripper = torch.tensor(gripper, dtype=torch.float32) | |
| ee_pose_translation = self.normalize( | |
| ee_pose_translation, dataset_name=dataset_name, key='obs_translation' | |
| ) | |
| ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation') | |
| joints = torch.tensor(joints, dtype=torch.float32) | |
| if joints.shape[-1] < 7: | |
| missing_size = 7 - joints.shape[-1] | |
| joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) | |
| joints = self.normalize(joints, dataset_name=dataset_name, key='joints') | |
| outputs = { | |
| 'images': images, | |
| 'input_ids': input_ids, | |
| 'target_text_tokens_ids': target_text_tokens_ids, | |
| 'attn_mask': attn_mask, | |
| 'ee_pose_translation': ee_pose_translation, | |
| 'ee_pose_rotation': ee_pose_rotation, | |
| 'gripper': gripper, | |
| 'joints': joints, | |
| 'control_tokens_ids': None, | |
| 'target_control_tokens_ids': None, | |
| } | |
| return outputs | |
| def create_input( | |
| self, | |
| chat: List[str], | |
| images: Dict[str, List[PIL.Image.Image]], | |
| ee_pose_translation: np.ndarray, | |
| ee_pose_rotation: np.ndarray, | |
| gripper: np.ndarray, | |
| joints: np.ndarray, | |
| dataset_name: np.ndarray, | |
| inference_mode: bool, | |
| control_target: Optional[RoboticsTarget] = None, | |
| ) -> RoboticsInput: | |
| inputs = self.preprocess_inputs( | |
| chat=chat, | |
| images=images, | |
| ee_pose_translation=ee_pose_translation, | |
| ee_pose_rotation=ee_pose_rotation, | |
| gripper=gripper, | |
| joints=joints, | |
| dataset_name=dataset_name, | |
| inference_mode=inference_mode, | |
| control_target=control_target, | |
| ) | |
| inputs.pop('target_text_tokens_ids') | |
| inputs.pop('target_control_tokens_ids') | |
| return RoboticsInput(**inputs) | |
| def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: | |
| if is_mean_norm(getattr(self.config, f'{key}_norm')): | |
| (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) | |
| output = normalize_by_moments(value, mean=mean, std=std) | |
| else: | |
| (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) | |
| output = normalize_by_bounds(value, low=low, high=high) | |
| return output | |
| def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: | |
| if is_mean_norm(getattr(self.config, f'{key}_norm')): | |
| (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) | |
| output = unnormalize_by_moments(value, mean=mean, std=std) | |
| else: | |
| (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) | |
| output = unnormalize_by_bounds(value, low=low, high=high) | |
| return output | |
| def _norm_bounds_from_dataset_name( | |
| self, dataset_name: np.ndarray, component_key: str | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Create an array of normalization bounds corresponding to dataset names | |
| Args: | |
| dataset_name: Array of shape [B] of dataset names for which to fetch the low and high | |
| normalization bounds. Note the values can be repeating | |
| component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to | |
| compute the normalization bounds | |
| Returns: | |
| Tuple of low and high bounds or norm and std, each of shape [B, -1] | |
| """ | |
| norm = getattr(self.config, f'{component_key}_norm') | |
| if is_mean_norm(norm): | |
| (stats_key_1, stats_key_2) = ('mean', 'std') | |
| else: | |
| (stats_key_1, stats_key_2) = ('low', 'high') | |
| if component_key == 'joints': | |
| if not isinstance(norm, collections.abc.Mapping): | |
| raise NotImplementedError() | |
| stats = { | |
| key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1])) | |
| for (key, value) in self.joints_norm_bounds['ANY'].items() | |
| } | |
| return tuple(stats.values()) | |
| component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1] | |
| if self.dataset_names == ['ANY']: | |
| stats_1 = self.norm_bounds[component_key]['ANY'][stats_key_1] | |
| stats_2 = self.norm_bounds[component_key]['ANY'][stats_key_2] | |
| stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0) | |
| stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0) | |
| else: | |
| (unique_names, _, inverse_indices, _) = np_unique(dataset_name) | |
| stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32) | |
| stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32) | |
| for i, ds_name in enumerate(unique_names): | |
| stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1] | |
| stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2] | |
| stats_1 = stats_1[inverse_indices] | |
| stats_2 = stats_2[inverse_indices] | |
| return torch.from_numpy(stats_1), torch.from_numpy(stats_2) | |
| def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: | |
| return rotation_norm_bounds( | |
| rotation_norm=self.config.obs_rotation_norm, | |
| rotation_format=self.config.rotation_format, | |
| stats=self._observation_stats, | |
| dataset_names=self.dataset_names, | |
| ) | |
| def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: | |
| return translation_norm_bounds( | |
| translation_norm=self.config.obs_translation_norm, | |
| stats=self._observation_stats, | |
| dataset_names=self.dataset_names, | |
| ) | |
| def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: | |
| return rotation_norm_bounds( | |
| rotation_norm=self.config.rotation_norm, | |
| rotation_format=self.config.rotation_format, | |
| stats=self._control_stats, | |
| dataset_names=self.dataset_names, | |
| ) | |
| def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: | |
| return translation_norm_bounds( | |
| translation_norm=self.config.translation_norm, | |
| stats=self._control_stats, | |
| dataset_names=self.dataset_names, | |
| ) | |
| def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: | |
| """ | |
| NOTE: | |
| - Joint values across all joints and all datasets vary in the range [-2pi; 2pi] | |
| - The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi] | |
| - It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint | |
| """ | |
| low = torch.tensor(self.config.joints_norm['low'], dtype=torch.float32) | |
| high = torch.tensor(self.config.joints_norm['high'], dtype=torch.float32) | |
| results = {'ANY': {'low': low, 'high': high}} | |
| return results | |
| def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: | |
| return { | |
| 'austin_buds_dataset': { | |
| 'euler': { | |
| 'max': [3.141589641571045, 0.2279592752456665, 0.10982763767242432], | |
| 'mean': [-0.433798223733902, 0.02660757675766945, 0.01057341042906046], | |
| 'min': [-3.1415905952453613, -0.17229366302490234, -0.08408939838409424], | |
| 'q01': [-3.1401021480560303, -0.15158048272132874, -0.07398375868797302], | |
| 'q99': [3.1401915550231934, 0.1699378788471222, 0.08044551312923431], | |
| 'std': [3.071873664855957, 0.08403228223323822, 0.04623554274439812], | |
| }, | |
| 'gripper': { | |
| 'max': [0.07999841868877411], | |
| 'min': [0.00019240332767367363], | |
| 'q01': [0.00760263716802001], | |
| 'q99': [0.07997412234544754], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.25105541944503784, | |
| 1.0239691734313965, | |
| 0.25514841079711914, | |
| 0.0, | |
| 0.05838121101260185, | |
| 3.0727620124816895, | |
| 1.1911247968673706, | |
| ], | |
| 'mean': [ | |
| -0.02248559147119522, | |
| 0.4224241375923157, | |
| 0.011533008888363838, | |
| -2.040178060531616, | |
| -0.004422259051352739, | |
| 2.440391778945923, | |
| 0.7626844644546509, | |
| ], | |
| 'min': [ | |
| -0.46276336908340454, | |
| -0.2620261609554291, | |
| -0.37377235293388367, | |
| -3.010026693344116, | |
| -0.015008972026407719, | |
| 0.0, | |
| 0.0, | |
| ], | |
| 'q01': [ | |
| -0.3563080132007599, | |
| -0.19165010750293732, | |
| -0.29941368103027344, | |
| -2.766944408416748, | |
| -0.012211786583065987, | |
| 1.9946393966674805, | |
| 0.21924567222595215, | |
| ], | |
| 'q99': [ | |
| 0.22295619547367096, | |
| 0.8841447830200195, | |
| 0.20571835339069366, | |
| -1.339158296585083, | |
| 0.05605502426624298, | |
| 2.9369189739227295, | |
| 1.1475461721420288, | |
| ], | |
| 'std': [ | |
| 0.1684190183877945, | |
| 0.2238602489233017, | |
| 0.1185624971985817, | |
| 0.3182348608970642, | |
| 0.012190691195428371, | |
| 0.2673698663711548, | |
| 0.2611873149871826, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.7684353590011597, 0.22995209693908691, 0.3893454968929291], | |
| 'mean': [0.5589713454246521, -0.0022956086322665215, 0.12593945860862732], | |
| 'min': [0.0, -0.31078848242759705, 0.0], | |
| 'q01': [0.3499317765235901, -0.2854413390159607, 0.010516085661947727], | |
| 'q99': [0.7243335843086243, 0.20652863383293152, 0.3218296766281128], | |
| 'std': [0.08953548222780228, 0.1462203562259674, 0.0769098550081253], | |
| }, | |
| }, | |
| 'austin_sailor_dataset': { | |
| 'euler': { | |
| 'max': [3.1415910720825195, 0.26857471466064453, 2.4386940002441406], | |
| 'mean': [0.040416110306978226, 0.013173789717257023, -0.044821564108133316], | |
| 'min': [-3.141592264175415, -0.21639788150787354, -1.995558738708496], | |
| 'q01': [-3.1401662826538086, -0.16171419620513916, -1.5918874740600586], | |
| 'q99': [3.1401755809783936, 0.17527776956558228, 1.3560799360275269], | |
| 'std': [3.0995969772338867, 0.08136381208896637, 0.6311237812042236], | |
| }, | |
| 'gripper': { | |
| 'max': [0.07783995568752289], | |
| 'min': [-8.864999836077914e-05], | |
| 'q01': [0.0005174533580429852], | |
| 'q99': [0.07778151333332062], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.18366889655590057, | |
| 1.188208818435669, | |
| 1.1904715299606323, | |
| -1.0253318548202515, | |
| 0.06765712797641754, | |
| 2.98008131980896, | |
| 2.7747786045074463, | |
| ], | |
| 'mean': [ | |
| -0.4481923580169678, | |
| 0.428782194852829, | |
| 0.29826778173446655, | |
| -2.1356770992279053, | |
| -0.2064618021249771, | |
| 2.5082974433898926, | |
| 0.8104547262191772, | |
| ], | |
| 'min': [ | |
| -1.4425580501556396, | |
| -0.34988880157470703, | |
| -0.4436887800693512, | |
| -2.9587178230285645, | |
| -0.9116092324256897, | |
| 1.7689018249511719, | |
| -1.4776484966278076, | |
| ], | |
| 'q01': [ | |
| -1.0579811334609985, | |
| -0.13357718288898468, | |
| -0.2518821358680725, | |
| -2.6593017578125, | |
| -0.6571194529533386, | |
| 2.072176456451416, | |
| -0.5861577391624451, | |
| ], | |
| 'q99': [ | |
| 0.08897759765386581, | |
| 0.8743473887443542, | |
| 0.8909343481063843, | |
| -1.4748132228851318, | |
| -0.010495365597307682, | |
| 2.859259605407715, | |
| 2.319192409515381, | |
| ], | |
| 'std': [ | |
| 0.26852160692214966, | |
| 0.22502659261226654, | |
| 0.2577133774757385, | |
| 0.2963367700576782, | |
| 0.1629231870174408, | |
| 0.18130357563495636, | |
| 0.6348837614059448, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.7488242387771606, 0.2829112708568573, 0.3541720509529114], | |
| 'mean': [0.5367430448532104, -0.08604215085506439, 0.11135970056056976], | |
| 'min': [0.3208981454372406, -0.3730051815509796, 0.020952222868800163], | |
| 'q01': [0.387094110250473, -0.3164229393005371, 0.024492919445037842], | |
| 'q99': [0.6869593262672424, 0.2086469978094101, 0.2551962733268738], | |
| 'std': [0.08000300824642181, 0.13984571397304535, 0.056377921253442764], | |
| }, | |
| }, | |
| 'austin_sirius_dataset': { | |
| 'euler': { | |
| 'max': [3.141592502593994, 0.09964871406555176, 0.3406205177307129], | |
| 'mean': [1.1875473260879517, -0.018573710694909096, -0.4977009892463684], | |
| 'min': [-3.141592502593994, -0.14349305629730225, -1.892538070678711], | |
| 'q01': [-3.140658378601074, -0.12321051210165024, -1.743293285369873], | |
| 'q99': [3.1407761573791504, 0.07348346710205078, 0.062370821833610535], | |
| 'std': [2.8533384799957275, 0.03962606564164162, 0.5864803791046143], | |
| }, | |
| 'gripper': { | |
| 'max': [0.07965432107448578], | |
| 'min': [0.00016088332631625235], | |
| 'q01': [0.0334978811442852], | |
| 'q99': [0.07948359102010727], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.5567107200622559, | |
| 0.3814372420310974, | |
| 0.36868324875831604, | |
| 0.0, | |
| 0.021617505699396133, | |
| 2.825117588043213, | |
| 2.8925509452819824, | |
| ], | |
| 'mean': [ | |
| 0.22241303324699402, | |
| 0.03750193864107132, | |
| -0.013105375692248344, | |
| -2.428316593170166, | |
| -0.017613587900996208, | |
| 2.4755642414093018, | |
| 1.4929485321044922, | |
| ], | |
| 'min': [ | |
| -0.23144784569740295, | |
| -0.41377919912338257, | |
| -0.3536752760410309, | |
| -2.9289300441741943, | |
| -0.06330326944589615, | |
| 0.0, | |
| 0.0, | |
| ], | |
| 'q01': [ | |
| -0.10539760440587997, | |
| -0.2651134133338928, | |
| -0.21157152950763702, | |
| -2.831827402114868, | |
| -0.04555036500096321, | |
| 0.0, | |
| 0.0, | |
| ], | |
| 'q99': [ | |
| 0.48007193207740784, | |
| 0.3064860999584198, | |
| 0.2438022643327713, | |
| 0.0, | |
| 0.012844694778323174, | |
| 2.71537446975708, | |
| 2.51297926902771, | |
| ], | |
| 'std': [ | |
| 0.1297803372144699, | |
| 0.132253035902977, | |
| 0.08940940350294113, | |
| 0.345712274312973, | |
| 0.015448619611561298, | |
| 0.3493262231349945, | |
| 0.632145345211029, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.5655431151390076, 0.3374997079372406, 0.3071175515651703], | |
| 'mean': [0.4490734338760376, 0.09406533092260361, 0.17131780087947845], | |
| 'min': [0.0, -0.16291481256484985, 0.0], | |
| 'q01': [0.0, -0.11814527958631516, 0.0], | |
| 'q99': [0.532875120639801, 0.26084619760513306, 0.27225059270858765], | |
| 'std': [0.07605387270450592, 0.08447220921516418, 0.04798709973692894], | |
| }, | |
| }, | |
| 'bc_z': { | |
| 'euler': { | |
| 'max': [3.141583088638963, 1.5360414168274534, 3.141576023430307], | |
| 'mean': [-0.32332700342924814, -0.1255861641435509, -0.07202428898041417], | |
| 'min': [-3.13973587780783, -1.5322501214383415, -3.1415575429114675], | |
| 'q01': [-1.0433973645798003, -1.007553367940694, -2.6477418236555788], | |
| 'q99': [0.8467956620806508, 0.9111507554055364, 2.226032625129908], | |
| 'std': [0.40307241985577485, 0.4832146677789649, 1.0666829542185445], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.20000000298023224], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.6597589254379272, 0.7259413599967957, 1.1217665672302246], | |
| 'mean': [0.012632723897695541, 0.10061875730752945, 0.7831991910934448], | |
| 'min': [-0.7181899547576904, -0.3756217360496521, -0.281008243560791], | |
| 'q01': [-0.3956047296524048, -0.11924505233764648, 0.601338267326355], | |
| 'q99': [0.332028865814209, 0.3088575601577759, 0.98329097032547], | |
| 'std': [0.1808762550354004, 0.0960959941148758, 0.08665762096643448], | |
| }, | |
| }, | |
| 'berkeley_autolab_ur5': { | |
| 'euler': { | |
| 'max': [3.141586857385381, 0.9715026080766722, 1.0041252839605828], | |
| 'mean': [0.10885329009674993, -0.005823583245937125, 0.01756067034138813], | |
| 'min': [-3.14158698706268, -0.9912158529884818, -0.8782840134752882], | |
| 'q01': [-3.1391613317061315, -0.6410961610461836, -0.4513447799408977], | |
| 'q99': [3.13947692478744, 0.6738775260636495, 0.535202669480506], | |
| 'std': [3.069331637100682, 0.18171227413203145, 0.17512358924250696], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [ | |
| -2.5115966796875, | |
| -0.682390034198761, | |
| 2.6830153465270996, | |
| -1.1711143255233765, | |
| -0.6423047184944153, | |
| 4.032349109649658, | |
| ], | |
| 'mean': [ | |
| -3.2676453590393066, | |
| -1.3197712898254395, | |
| 2.122663736343384, | |
| -2.3710596561431885, | |
| -1.567111611366272, | |
| 3.0015013217926025, | |
| ], | |
| 'min': [ | |
| -4.0775299072265625, | |
| -2.2796404361724854, | |
| 1.1905627250671387, | |
| -3.814586877822876, | |
| -2.1223204135894775, | |
| 1.5279699563980103, | |
| ], | |
| 'q01': [ | |
| -3.839585781097412, | |
| -1.8913453817367554, | |
| 1.5813319683074951, | |
| -3.3646113872528076, | |
| -1.805822730064392, | |
| 2.22794246673584, | |
| ], | |
| 'q99': [ | |
| -2.7322237491607666, | |
| -0.8280109763145447, | |
| 2.5063326358795166, | |
| -1.5614854097366333, | |
| -1.2497813701629639, | |
| 3.7013423442840576, | |
| ], | |
| 'std': [ | |
| 0.2746899724006653, | |
| 0.27345725893974304, | |
| 0.22793518006801605, | |
| 0.34690913558006287, | |
| 0.10438559204339981, | |
| 0.3388271629810333, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.6610942482948303, 0.38513994216918945, 0.2049914449453354], | |
| 'mean': [0.45345133543014526, 0.05642913281917572, -0.0356401689350605], | |
| 'min': [0.1970997005701065, -0.27643972635269165, -0.20529526472091675], | |
| 'q01': [0.3020566999912262, -0.21297279000282288, -0.18836002051830292], | |
| 'q99': [0.6132073998451233, 0.30656182765960693, 0.12212439626455307], | |
| 'std': [0.07906801998615265, 0.12497302889823914, 0.08053416013717651], | |
| }, | |
| }, | |
| 'berkeley_cable_routing': { | |
| 'euler': { | |
| 'max': [3.141592264175415, 0.05923163890838623, 3.1243016719818115], | |
| 'mean': [-0.6730493903160095, 0.010155542753636837, 0.9881726503372192], | |
| 'min': [-3.141592264175415, -0.10035276412963867, -1.6915032863616943], | |
| 'q01': [-3.1412672996520996, -0.02922845631837845, -0.6449583172798157], | |
| 'q99': [3.1412830352783203, 0.039870113134384155, 2.4969191551208496], | |
| 'std': [3.0588438510894775, 0.01435503363609314, 0.6301024556159973], | |
| }, | |
| 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.7113381028175354, 0.35645395517349243, 0.23169316351413727], | |
| 'mean': [0.5643643736839294, 0.00464651919901371, 0.07635799050331116], | |
| 'min': [0.378939151763916, -0.3244381248950958, 0.02837979607284069], | |
| 'q01': [0.4641263782978058, -0.2806571424007416, 0.030183622613549232], | |
| 'q99': [0.6452807784080505, 0.28204888105392456, 0.1557157188653946], | |
| 'std': [0.0380910187959671, 0.16768066585063934, 0.032017387449741364], | |
| }, | |
| }, | |
| 'berkeley_fanuc_manipulation': { | |
| 'euler': { | |
| 'max': [3.1415822505950928, 1.020048975944519, 1.9211788177490234], | |
| 'mean': [0.1944933384656906, -0.03283948823809624, 0.017717985436320305], | |
| 'min': [-3.141590118408203, -1.5707060098648071, -3.089141368865967], | |
| 'q01': [-3.139866352081299, -1.0165742635726929, -1.6987266540527344], | |
| 'q99': [3.140317440032959, 0.4332200288772583, 1.5085773468017578], | |
| 'std': [2.9820094108581543, 0.21335174143314362, 0.4836970269680023], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [ | |
| 0.9483258724212646, | |
| 1.4428143501281738, | |
| 1.2322325706481934, | |
| 2.3726272583007812, | |
| 0.0008019324741326272, | |
| 3.017007827758789, | |
| ], | |
| 'mean': [ | |
| 0.04932224750518799, | |
| 0.36293527483940125, | |
| -0.18438231945037842, | |
| 0.09242437779903412, | |
| -1.0793049335479736, | |
| -0.06553139537572861, | |
| ], | |
| 'min': [ | |
| -0.7464086413383484, | |
| -0.5774519443511963, | |
| -1.1099220514297485, | |
| -1.8152672052383423, | |
| -2.1454310417175293, | |
| -3.168759346008301, | |
| ], | |
| 'q01': [ | |
| -0.5061959028244019, | |
| -0.09613651037216187, | |
| -0.8047154545783997, | |
| -0.9908221364021301, | |
| -1.8158636093139648, | |
| -2.420135974884033, | |
| ], | |
| 'q99': [ | |
| 0.6388322710990906, | |
| 1.0234705209732056, | |
| 0.5990921258926392, | |
| 1.282780647277832, | |
| -0.39091193675994873, | |
| 2.340099811553955, | |
| ], | |
| 'std': [ | |
| 0.2452843338251114, | |
| 0.2555835247039795, | |
| 0.2740932106971741, | |
| 0.3769519329071045, | |
| 0.3472467362880707, | |
| 0.6859157681465149, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.8101964592933655, 0.5117550492286682, 0.8054816722869873], | |
| 'mean': [0.5446329116821289, 0.003737417282536626, 0.2439887523651123], | |
| 'min': [0.263683944940567, -0.5785022377967834, -0.0025757886469364166], | |
| 'q01': [0.3718133866786957, -0.4071895182132721, 0.01847645826637745], | |
| 'q99': [0.7200658321380615, 0.3128541111946106, 0.5413243770599365], | |
| 'std': [0.07649082690477371, 0.14486227929592133, 0.13899113237857819], | |
| }, | |
| }, | |
| 'bridge': { | |
| 'euler': { | |
| 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], | |
| 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], | |
| 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], | |
| 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], | |
| 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], | |
| 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], | |
| }, | |
| 'gripper': { | |
| 'max': [1.0370277166366577], | |
| 'min': [0.04637829214334488], | |
| 'q01': [0.05192930996417999], | |
| 'q99': [1.0118417739868164], | |
| }, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], | |
| 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], | |
| 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], | |
| 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], | |
| 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], | |
| 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], | |
| }, | |
| }, | |
| 'bridge_orig': { | |
| 'euler': { | |
| 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], | |
| 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], | |
| 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], | |
| 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], | |
| 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], | |
| 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], | |
| }, | |
| 'gripper': { | |
| 'max': [1.0370277166366577], | |
| 'min': [0.04637829214334488], | |
| 'q01': [0.05192930996417999], | |
| 'q99': [1.0118417739868164], | |
| }, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], | |
| 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], | |
| 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], | |
| 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], | |
| 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], | |
| 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], | |
| }, | |
| }, | |
| 'bridge_action_lang': { | |
| 'euler': { | |
| 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], | |
| 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], | |
| 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], | |
| 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], | |
| 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], | |
| 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], | |
| }, | |
| 'gripper': { | |
| 'max': [1.0370277166366577], | |
| 'min': [0.04637829214334488], | |
| 'q01': [0.05192930996417999], | |
| 'q99': [1.0118417739868164], | |
| }, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], | |
| 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], | |
| 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], | |
| 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], | |
| 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], | |
| 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], | |
| }, | |
| }, | |
| 'cmu_stretch': { | |
| 'euler': { | |
| 'max': [3.141592653589793, 0.0, 0.0], | |
| 'mean': [3.1415926535896035, 0.0, 0.0], | |
| 'min': [3.141592653589793, 0.0, 0.0], | |
| 'q01': [3.141592653589793, 0.0, 0.0], | |
| 'q99': [3.141592653589793, 0.0, 0.0], | |
| 'std': [1.8962609260597674e-13, 0.0, 0.0], | |
| }, | |
| 'gripper': { | |
| 'max': [3.110913038253784], | |
| 'min': [-3.1155149936676025], | |
| 'q01': [-3.055689811706543], | |
| 'q99': [3.1017091274261475], | |
| }, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.3401322066783905, 0.0, 1.1013386249542236], | |
| 'mean': [0.1461893916130066, 0.0, 0.7568270564079285], | |
| 'min': [0.005320895928889513, 0.0, 0.45218077301979065], | |
| 'q01': [0.017430847510695457, 0.0, 0.46050605177879333], | |
| 'q99': [0.33094948530197144, 0.0, 1.0952961444854736], | |
| 'std': [0.09959954023361206, 0.0, 0.15819725394248962], | |
| }, | |
| }, | |
| 'dlr_edan_shared_control': { | |
| 'euler': { | |
| 'max': [2.2228052616119385, 0.13346634805202484, 3.1415085792541504], | |
| 'mean': [1.6072595119476318, -0.4295055568218231, 1.1549623012542725], | |
| 'min': [0.4599944055080414, -1.4871342182159424, -3.1406784057617188], | |
| 'q01': [0.8485284447669983, -1.4454149007797241, -3.1238820552825928], | |
| 'q99': [1.8301045894622803, 0.07032366842031479, 3.131430149078369], | |
| 'std': [0.2785564363002777, 0.540916919708252, 2.4260833263397217], | |
| }, | |
| 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.044179707765579224, 0.6297282576560974, 0.9278888702392578], | |
| 'mean': [-0.512688934803009, 0.3690241575241089, 0.441307932138443], | |
| 'min': [-0.7695853114128113, -0.040099501609802246, 0.2583216428756714], | |
| 'q01': [-0.729511022567749, 0.077408567070961, 0.2658006250858307], | |
| 'q99': [-0.13719859719276428, 0.5719971060752869, 0.7898909449577332], | |
| 'std': [0.15213875472545624, 0.08891375362873077, 0.13743770122528076], | |
| }, | |
| }, | |
| 'droid': { | |
| 'euler': { | |
| 'max': [3.141592502593994, 1.5705928802490234, 3.1415867805480957], | |
| 'mean': [0.3140628098409554, -0.09296274023036387, -0.07227215454779846], | |
| 'min': [-3.141592502593994, -1.5691150426864624, -3.1415374279022217], | |
| 'q01': [-3.1378602981567383, -1.2125312042236327, -2.1614069032669065], | |
| 'q99': [3.137854380607605, 0.9200375998020163, 1.9367506909370364], | |
| 'std': [2.926265757944871, 0.363273475703332, 0.7576065217938824], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.9911894202232361]}, | |
| 'joints': { | |
| 'max': [ | |
| 2.668445110321045, | |
| 1.5691218376159668, | |
| 2.666306734085083, | |
| -0.3114914000034332, | |
| 2.6624162197113037, | |
| 4.28157901763916, | |
| 2.752457857131958, | |
| ], | |
| 'mean': [ | |
| 0.023137084334640106, | |
| 0.2704989977282293, | |
| -0.01451389357228282, | |
| -2.018709403792315, | |
| -0.042720520800030394, | |
| 2.350281188152209, | |
| 0.12424663946659845, | |
| ], | |
| 'min': [ | |
| -2.6536705493927, | |
| -1.547789216041565, | |
| -2.6781487464904785, | |
| -2.9409868717193604, | |
| -2.6705946922302246, | |
| 0.24893812835216522, | |
| -2.7615714073181152, | |
| ], | |
| 'q01': [ | |
| -0.9026106441020965, | |
| -0.8547340619564057, | |
| -0.9028875434398651, | |
| -2.7698556280136106, | |
| -1.6851656341552732, | |
| 1.2335169839859008, | |
| -1.9587260699272155, | |
| ], | |
| 'q99': [ | |
| 0.9569852340221403, | |
| 1.4148830294609054, | |
| 0.7693877756595566, | |
| -0.4545914208889008, | |
| 1.5623322343826267, | |
| 3.475611729621887, | |
| 2.263479118347167, | |
| ], | |
| 'std': [ | |
| 0.31695080251469465, | |
| 0.49522214687158767, | |
| 0.27993538230553827, | |
| 0.478161574676113, | |
| 0.4969961591445458, | |
| 0.45101008525403846, | |
| 0.7287264344068457, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.8575563430786133, 0.799155592918396, 1.0043904781341553], | |
| 'mean': [0.5283099395864883, 0.005363794653877434, 0.3120132207021294], | |
| 'min': [-0.15604186058044434, -0.827903687953949, -0.2347021996974945], | |
| 'q01': [0.26669957995414734, -0.43774398624897004, -0.048167889714241026], | |
| 'q99': [0.7774086785316463, 0.428325751423835, 0.776091011762619], | |
| 'std': [0.1148424841779685, 0.17489566608140428, 0.16541062032731538], | |
| }, | |
| }, | |
| 'fmb': { | |
| 'euler': { | |
| 'max': [3.1415927410125732, 1.044223666191101, 2.6313371658325195], | |
| 'mean': [-0.5128238201141357, -0.0326850451529026, 1.2892225980758667], | |
| 'min': [-3.141592502593994, -1.0594873428344727, -1.5775980949401855], | |
| 'q01': [-3.1404616832733154, -0.9319325089454651, -1.4586873054504395], | |
| 'q99': [3.140399932861328, 0.8346238732337952, 2.3081648349761963], | |
| 'std': [3.0557258129119873, 0.33158522844314575, 0.6522650122642517], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [ | |
| 1.633209228515625, | |
| 1.216904640197754, | |
| 0.6477310061454773, | |
| -1.101138949394226, | |
| 1.4866814613342285, | |
| 3.2880685329437256, | |
| 2.8213212490081787, | |
| ], | |
| 'mean': [ | |
| 0.08979735523462296, | |
| 0.24736542999744415, | |
| -0.18356309831142426, | |
| -2.1948094367980957, | |
| 0.09010367840528488, | |
| 2.3375964164733887, | |
| -0.6424615979194641, | |
| ], | |
| 'min': [ | |
| -0.24819369614124298, | |
| -0.6836066246032715, | |
| -1.852062463760376, | |
| -2.950329542160034, | |
| -1.360267162322998, | |
| 1.1119775772094727, | |
| -2.53395938873291, | |
| ], | |
| 'q01': [ | |
| -0.12363661825656891, | |
| -0.2893752157688141, | |
| -0.8288514018058777, | |
| -2.7587249279022217, | |
| -0.9787064790725708, | |
| 1.5508010387420654, | |
| -1.7525910139083862, | |
| ], | |
| 'q99': [ | |
| 0.527353048324585, | |
| 0.8305937647819519, | |
| 0.3808687925338745, | |
| -1.5172499418258667, | |
| 1.1656274795532227, | |
| 2.9421844482421875, | |
| 2.3595635890960693, | |
| ], | |
| 'std': [ | |
| 0.14010190963745117, | |
| 0.24683699011802673, | |
| 0.29916438460350037, | |
| 0.29822880029678345, | |
| 0.40809085965156555, | |
| 0.3218824863433838, | |
| 0.7698287963867188, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.7458599805831909, 0.2993985116481781, 0.39971601963043213], | |
| 'mean': [0.5149032473564148, -0.04377440735697746, 0.17520515620708466], | |
| 'min': [0.32314637303352356, -0.34641459584236145, 0.029209019616246223], | |
| 'q01': [0.3655048608779907, -0.28729698061943054, 0.033201027661561966], | |
| 'q99': [0.6782684326171875, 0.209969624876976, 0.3331448435783386], | |
| 'std': [0.07067117840051651, 0.13670970499515533, 0.07637079805135727], | |
| }, | |
| }, | |
| 'fractal20220817_data': { | |
| 'euler': { | |
| 'max': [3.141592264175415, 1.5703786611557007, 3.1415889263153076], | |
| 'mean': [-1.5046496391296387, 0.6226330399513245, -1.621827244758606], | |
| 'min': [-3.1415927410125732, -1.569378137588501, -3.1415863037109375], | |
| 'q01': [-3.130796194076538, -0.23499104380607605, -2.9648191928863525], | |
| 'q99': [3.1300582885742188, 1.4907126426696777, 2.8373801708221436], | |
| 'std': [2.380099296569824, 0.41548046469688416, 0.803632915019989], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [1.0534898042678833, 0.48018959164619446, 1.6896663904190063], | |
| 'mean': [0.5594336986541748, -0.08356647938489914, 0.7760705947875977], | |
| 'min': [-0.3139062225818634, -0.9970501065254211, -0.006579156965017319], | |
| 'q01': [0.3249714970588684, -0.2818704843521118, 0.1410011649131775], | |
| 'q99': [0.8754204511642456, 0.21279653906822205, 1.071526288986206], | |
| 'std': [0.12431981414556503, 0.11550336331129074, 0.24583852291107178], | |
| }, | |
| }, | |
| 'furniture_bench_dataset': { | |
| 'euler': { | |
| 'max': [3.141592502593994, 1.545873999595642, 3.1415679454803467], | |
| 'mean': [-1.5673161745071411, 0.028489787131547928, -0.0338752306997776], | |
| 'min': [-3.1415927410125732, -1.1365290880203247, -3.141582727432251], | |
| 'q01': [-3.139331817626953, -0.6121169328689575, -1.9944958686828613], | |
| 'q99': [3.1392478942871094, 0.9993562698364258, 1.7793478965759277], | |
| 'std': [2.605621337890625, 0.30199435353279114, 0.924261212348938], | |
| }, | |
| 'gripper': { | |
| 'max': [0.07995442301034927], | |
| 'min': [-3.4803331800503656e-05], | |
| 'q01': [0.003531553316861391], | |
| 'q99': [0.07973115146160126], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 1.6170321702957153, | |
| 1.4905058145523071, | |
| 0.8054860234260559, | |
| -0.8407800197601318, | |
| 2.596961736679077, | |
| 3.591698169708252, | |
| 2.744664192199707, | |
| ], | |
| 'mean': [ | |
| 0.14828824996948242, | |
| 0.513389527797699, | |
| -0.07023008167743683, | |
| -2.116460084915161, | |
| 0.17382267117500305, | |
| 2.6314659118652344, | |
| 0.7604196667671204, | |
| ], | |
| 'min': [ | |
| -1.0936335325241089, | |
| -0.28272193670272827, | |
| -1.3713332414627075, | |
| -2.9083454608917236, | |
| -2.664609432220459, | |
| 0.49353715777397156, | |
| -2.6659374237060547, | |
| ], | |
| 'q01': [ | |
| -0.3774302303791046, | |
| 0.16254600882530212, | |
| -0.526292085647583, | |
| -2.662245512008667, | |
| -0.3341670036315918, | |
| 1.4710004329681396, | |
| -0.997846245765686, | |
| ], | |
| 'q99': [ | |
| 0.6179460287094116, | |
| 0.9334877729415894, | |
| 0.3241332471370697, | |
| -1.5048880577087402, | |
| 0.69161057472229, | |
| 3.45759916305542, | |
| 2.5480763912200928, | |
| ], | |
| 'std': [ | |
| 0.20775794982910156, | |
| 0.14619596302509308, | |
| 0.1992885023355484, | |
| 0.24350883066654205, | |
| 0.2203112691640854, | |
| 0.3669370114803314, | |
| 0.8288815021514893, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.75081467628479, 0.29695066809654236, 0.35806331038475037], | |
| 'mean': [0.5622280836105347, 0.049872349947690964, 0.08108603954315186], | |
| 'min': [0.2878606617450714, -0.3141690492630005, -0.00460465531796217], | |
| 'q01': [0.36915361881256104, -0.180975541472435, 0.0058300793170928955], | |
| 'q99': [0.6652880311012268, 0.1772783100605011, 0.18316447734832764], | |
| 'std': [0.07692465931177139, 0.07895787060260773, 0.04069080576300621], | |
| }, | |
| }, | |
| 'iamlab_cmu_pickup_insert': { | |
| 'euler': { | |
| 'max': [3.14159250270088, 0.04584214946783205, 0.0450923308550184], | |
| 'mean': [-1.4122567194164972, -0.03209339030753262, 0.006593786875632054], | |
| 'min': [-3.141592327000409, -0.13647014666477264, -0.052753928700043584], | |
| 'q01': [-3.141100653025272, -0.1142266801993486, -0.026375220367685838], | |
| 'q99': [3.1411030904244917, 0.02986631061111326, 0.033458487885352856], | |
| 'std': [2.771965176997775, 0.02781015988983883, 0.015259428603879655], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [ | |
| 0.24314826726913452, | |
| 0.7470589280128479, | |
| 0.7795426249504089, | |
| -1.6825058460235596, | |
| 0.13216856122016907, | |
| 2.88216233253479, | |
| 1.443279504776001, | |
| ], | |
| 'mean': [ | |
| -0.09283880889415741, | |
| 0.16589058935642242, | |
| 0.13815303146839142, | |
| -2.236156940460205, | |
| -0.013342339545488358, | |
| 2.4330623149871826, | |
| 0.8266588449478149, | |
| ], | |
| 'min': [ | |
| -0.6439031958580017, | |
| -0.7895585894584656, | |
| -0.017819713801145554, | |
| -2.9597983360290527, | |
| -0.5142408013343811, | |
| 1.82623291015625, | |
| 0.24880434572696686, | |
| ], | |
| 'q01': [ | |
| -0.4449005424976349, | |
| -0.687881588935852, | |
| -0.00037891813553869724, | |
| -2.9046988487243652, | |
| -0.3344403803348541, | |
| 1.9697062969207764, | |
| 0.4578770101070404, | |
| ], | |
| 'q99': [ | |
| 0.19939860701560974, | |
| 0.6767980456352234, | |
| 0.5450605750083923, | |
| -1.7509592771530151, | |
| 0.07327680289745331, | |
| 2.758103370666504, | |
| 1.2580491304397583, | |
| ], | |
| 'std': [ | |
| 0.1536070704460144, | |
| 0.3142663240432739, | |
| 0.1251399964094162, | |
| 0.28150704503059387, | |
| 0.062141284346580505, | |
| 0.17315669357776642, | |
| 0.19375956058502197, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.6634980998075433, 0.23428471360823594, 0.4308285234773945], | |
| 'mean': [0.5266367430643635, 0.02849007901898653, 0.18708021993948834], | |
| 'min': [0.3071657210198696, -0.2975911497764855, 0.06578226212031718], | |
| 'q01': [0.314498567139741, -0.20315787858517198, 0.06785127291435837], | |
| 'q99': [0.6472027612545221, 0.20840713845170097, 0.37003410655170416], | |
| 'std': [0.08150424006275543, 0.11136019021346877, 0.0777212838203498], | |
| }, | |
| }, | |
| 'jaco_play': { | |
| 'euler': { | |
| 'max': [3.141592653589793, 0.0, 0.0], | |
| 'mean': [3.14159265358649, 0.0, 0.0], | |
| 'min': [3.141592653589793, 0.0, 0.0], | |
| 'q01': [3.141592653589793, 0.0, 0.0], | |
| 'q99': [3.141592653589793, 0.0, 0.0], | |
| 'std': [3.3031355428647657e-12, 0.0, 0.0], | |
| }, | |
| 'gripper': { | |
| 'max': [1.3890767097473145], | |
| 'min': [-0.00123199715744704], | |
| 'q01': [0.00061599857872352], | |
| 'q99': [1.3835327625274658], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| -0.6973918676376343, | |
| 4.516506671905518, | |
| 2.732408285140991, | |
| 0.27657514810562134, | |
| 2.0405426025390625, | |
| 0.9062198996543884, | |
| ], | |
| 'mean': [ | |
| -1.5362969636917114, | |
| 3.9368598461151123, | |
| 1.470100998878479, | |
| -0.26214125752449036, | |
| 0.8781110644340515, | |
| -0.4923328757286072, | |
| ], | |
| 'min': [ | |
| -2.394073486328125, | |
| 3.4039254188537598, | |
| 0.5968497395515442, | |
| -1.0070130825042725, | |
| -0.28005343675613403, | |
| -1.3207207918167114, | |
| ], | |
| 'q01': [ | |
| -2.2932708263397217, | |
| 3.5327682495117188, | |
| 0.7838058471679688, | |
| -0.5554518699645996, | |
| 0.002689492190256715, | |
| -1.1670725345611572, | |
| ], | |
| 'q99': [ | |
| -0.8449038863182068, | |
| 4.363891124725342, | |
| 2.367223024368286, | |
| 0.06634082645177841, | |
| 1.595384120941162, | |
| 0.17280587553977966, | |
| ], | |
| 'std': [ | |
| 0.32358425855636597, | |
| 0.18587827682495117, | |
| 0.3632128834724426, | |
| 0.11627942323684692, | |
| 0.3359401226043701, | |
| 0.3179064989089966, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.2252020239830017, -0.19358234107494354, 0.4066188633441925], | |
| 'mean': [-0.07287009060382843, -0.42861491441726685, 0.2764993906021118], | |
| 'min': [-0.4429473876953125, -0.6635459661483765, 0.1568669229745865], | |
| 'q01': [-0.3789186179637909, -0.6194459795951843, 0.16865813732147217], | |
| 'q99': [0.21203258633613586, -0.26914602518081665, 0.38958534598350525], | |
| 'std': [0.138992577791214, 0.0908467248082161, 0.052496980875730515], | |
| }, | |
| }, | |
| 'kuka': { | |
| 'euler': { | |
| 'max': [3.141592651594966, 0.25521246168753553, 3.1415812671947134], | |
| 'mean': [-0.8123818266038756, -0.004947911572775501, 0.1622765703197679], | |
| 'min': [-3.141592653256577, -0.22240290857095402, -3.1415339219506544], | |
| 'q01': [-3.141539356675989, -0.042093398087956785, -0.6103775416503281], | |
| 'q99': [3.1415363480584935, 0.013040864331848548, 1.4661773197655914], | |
| 'std': [3.031619905792479, 0.010541045603285457, 0.386240278316171], | |
| }, | |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.7243871092796326, 0.31309840083122253, 0.613274335861206], | |
| 'mean': [0.5510859489440918, 0.04787421599030495, 0.12812286615371704], | |
| 'min': [0.40573424100875854, -0.2028520256280899, 0.018512273207306862], | |
| 'q01': [0.4765772819519043, -0.14815208315849304, 0.06674224138259888], | |
| 'q99': [0.6515637040138245, 0.2447487711906433, 0.28018367290496826], | |
| 'std': [0.045206550508737564, 0.10222796350717545, 0.058390580117702484], | |
| }, | |
| }, | |
| 'language_table': { | |
| 'euler': { | |
| 'max': [3.141592653589793, 0.0, 0.0], | |
| 'mean': [3.141592653805758, 0.0, 0.0], | |
| 'min': [3.141592653589793, 0.0, 0.0], | |
| 'q01': [3.141592653589793, 0.0, 0.0], | |
| 'q99': [3.141592653589793, 0.0, 0.0], | |
| 'std': [2.1596502364831912e-10, 0.0, 0.0], | |
| }, | |
| 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, | |
| 'joints': { | |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| }, | |
| 'translation': { | |
| 'max': [0.6191085577011108, 0.345907062292099, 0.0], | |
| 'mean': [0.3994608521461487, 0.0045331921428442, 0.0], | |
| 'min': [0.18907572329044342, -0.3051564395427704, 0.0], | |
| 'q01': [0.19237099587917328, -0.2962527573108673, 0.0], | |
| 'q99': [0.6171894669532776, 0.30645298957824707, 0.0], | |
| 'std': [0.10587888211011887, 0.14165113866329193, 0.0], | |
| }, | |
| }, | |
| 'nyu_franka_play_dataset': { | |
| 'euler': { | |
| 'max': [3.141542687188373, 1.570464569214435, 3.1412348640834002], | |
| 'mean': [1.2730823611033315, 0.6681926368695384, 1.141188695183881], | |
| 'min': [-3.141417902437913, -0.6159773949821361, -3.1401807914238264], | |
| 'q01': [-3.1053165771573177, -0.18849691525500664, -2.8903965648285523], | |
| 'q99': [3.1156105211359257, 1.5443502891290082, 2.882089687976354], | |
| 'std': [1.519365023922235, 0.5398198304496786, 1.099335476495627], | |
| }, | |
| 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, | |
| 'joints': { | |
| 'max': [ | |
| 0.873094916343689, | |
| 1.244848370552063, | |
| 2.6604084968566895, | |
| -0.46481436491012573, | |
| 1.409853458404541, | |
| 3.5058159828186035, | |
| 2.2937142848968506, | |
| ], | |
| 'mean': [ | |
| -0.7650318741798401, | |
| -0.9744310975074768, | |
| 1.4849265813827515, | |
| -2.0917000770568848, | |
| -0.42907631397247314, | |
| 2.4454119205474854, | |
| -0.04509185999631882, | |
| ], | |
| 'min': [ | |
| -2.1353442668914795, | |
| -1.6306616067886353, | |
| -0.0167933851480484, | |
| -2.8458828926086426, | |
| -2.5256054401397705, | |
| 0.5101203322410583, | |
| -2.1178665161132812, | |
| ], | |
| 'q01': [ | |
| -1.9259787797927856, | |
| -1.537291407585144, | |
| 0.4664345383644104, | |
| -2.7214717864990234, | |
| -2.114818811416626, | |
| 1.0703176259994507, | |
| -1.8231183290481567, | |
| ], | |
| 'q99': [ | |
| 0.19280724227428436, | |
| 0.7752904295921326, | |
| 2.6163666248321533, | |
| -0.8622008562088013, | |
| 0.7001188397407532, | |
| 3.4607036113739014, | |
| 1.7281100749969482, | |
| ], | |
| 'std': [ | |
| 0.4459172785282135, | |
| 0.5925542712211609, | |
| 0.45783039927482605, | |
| 0.4001038372516632, | |
| 0.6094018816947937, | |
| 0.5889745950698853, | |
| 0.7601358294487, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.6258905727651867, 0.8289460146095029, 0.8867425967452327], | |
| 'mean': [0.38785427731316113, 0.39551342608233353, 0.45161485937851], | |
| 'min': [0.05690113252698781, -0.06171031940925212, 0.1157533700659999], | |
| 'q01': [0.1393695951374117, 0.07645522416447147, 0.19364508601310984], | |
| 'q99': [0.5920727118835839, 0.6584802280677562, 0.8056891771281188], | |
| 'std': [0.10666290139662009, 0.1217948547812865, 0.1707295181003299], | |
| }, | |
| }, | |
| 'roboset': { | |
| 'euler': { | |
| 'max': [3.1415449294818236, 1.5705575529715636, 3.141527342124582], | |
| 'mean': [-0.0398455755412464, 1.0518070390619125, -0.015345692503002759], | |
| 'min': [-3.1415813300509536, -1.5222832468962035, -3.141575300866071], | |
| 'q01': [-2.9414386317311187, -0.24976770655101155, -2.985256521212579], | |
| 'q99': [2.9380437893235993, 1.5403010739503078, 2.9746912523985025], | |
| 'std': [1.7866587696177456, 0.40620530263065, 1.7288511340250616], | |
| }, | |
| 'gripper': { | |
| 'max': [0.83056640625], | |
| 'min': [0.0001499652862548828], | |
| 'q01': [0.0001499652862548828], | |
| 'q99': [0.82666015625], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.96240234375, | |
| 1.1162109375, | |
| 1.1064453125, | |
| -0.98095703125, | |
| 2.30859375, | |
| 1.576171875, | |
| 1.7412109375, | |
| ], | |
| 'mean': [ | |
| 0.005913593806326389, | |
| 0.1877261847257614, | |
| 0.04653879255056381, | |
| -2.0529513359069824, | |
| -0.011298442259430885, | |
| 0.6185526251792908, | |
| -0.01701134257018566, | |
| ], | |
| 'min': [ | |
| -0.8330078125, | |
| -0.74658203125, | |
| -0.8642578125, | |
| -2.892578125, | |
| -1.390625, | |
| -0.24658203125, | |
| -2.953125, | |
| ], | |
| 'q01': [ | |
| -0.41015625, | |
| -0.5302734375, | |
| -0.6455078125, | |
| -2.57421875, | |
| -0.76416015625, | |
| -0.0386962890625, | |
| -1.435546875, | |
| ], | |
| 'q99': [ | |
| 0.66455078125, | |
| 0.9501953125, | |
| 0.7529296875, | |
| -1.251953125, | |
| 0.75244140625, | |
| 1.2314453125, | |
| 1.384765625, | |
| ], | |
| 'std': [ | |
| 0.17915399372577667, | |
| 0.32234326004981995, | |
| 0.26069700717926025, | |
| 0.31767210364341736, | |
| 0.205329030752182, | |
| 0.33385637402534485, | |
| 0.6263682842254639, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.5747738480567932, 0.3972920775413513, 0.7443570494651794], | |
| 'mean': [0.3331542909145355, 0.019357483834028244, 0.37330344319343567], | |
| 'min': [0.09978063404560089, -0.29593944549560547, 0.10065606236457825], | |
| 'q01': [0.18437016010284424, -0.25699371099472046, 0.15134164690971375], | |
| 'q99': [0.543661892414093, 0.29646238684654236, 0.6682320833206177], | |
| 'std': [0.07849054038524628, 0.12241040915250778, 0.1460595279932022], | |
| }, | |
| }, | |
| 'roboturk': { | |
| 'euler': { | |
| 'max': [3.1415925701602854, 1.5661525056538665, 3.1412803735012553], | |
| 'mean': [0.3946147188969296, -0.09657834247810235, -0.1840393277985236], | |
| 'min': [-3.141592645423486, -1.568282806779776, -3.141449561492499], | |
| 'q01': [-3.1378020570699525, -0.7907563562327732, -1.7347170410441308], | |
| 'q99': [3.137841563084971, 0.5773179844053413, 1.455226678618167], | |
| 'std': [2.9453717386817755, 0.2466574231368262, 0.6146181899264712], | |
| }, | |
| 'gripper': { | |
| 'max': [0.041667], | |
| 'min': [-1.6931189781743683e-05], | |
| 'q01': [0.0], | |
| 'q99': [0.040625325000000004], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 3.0592832565307617, | |
| 2.9869844913482666, | |
| 0.33399999141693115, | |
| 2.0810000896453857, | |
| 3.874000072479248, | |
| 0.21199999749660492, | |
| 32.02399826049805, | |
| ], | |
| 'mean': [ | |
| 0.06695421040058136, | |
| 0.4899406135082245, | |
| -0.0009486731723882258, | |
| -0.00033266679383814335, | |
| -0.000789206416811794, | |
| 0.017417440190911293, | |
| -4.154728412628174, | |
| ], | |
| 'min': [ | |
| -2.995668888092041, | |
| -2.9882216453552246, | |
| -1.3760000467300415, | |
| -1.9279999732971191, | |
| -4.10099983215332, | |
| -0.5040000081062317, | |
| -23.58799934387207, | |
| ], | |
| 'q01': [ | |
| -1.1723066568374634, | |
| -1.1159032583236694, | |
| -0.00800000037997961, | |
| -0.3580000102519989, | |
| -0.5680000185966492, | |
| -0.10000000149011612, | |
| -14.303999900817871, | |
| ], | |
| 'q99': [ | |
| 1.1788837909698486, | |
| 1.7529590129852295, | |
| 0.007000000216066837, | |
| 0.3799999952316284, | |
| 0.628000020980835, | |
| 0.13199999928474426, | |
| 7.5279998779296875, | |
| ], | |
| 'std': [ | |
| 0.46230366826057434, | |
| 0.5157809257507324, | |
| 0.0024148367810994387, | |
| 0.11895754188299179, | |
| 0.18901629745960236, | |
| 0.050090208649635315, | |
| 4.478224754333496, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [1.0403562763478107, 0.8124313436278997, 0.9670890184966487], | |
| 'mean': [0.601758638436076, -0.0011977902645611068, 0.061136336184971524], | |
| 'min': [-0.375370271681523, -0.9602276639048171, -0.39440951528330676], | |
| 'q01': [0.2845426588179575, -0.3288349528691964, -0.0934955147249334], | |
| 'q99': [0.8773894592247552, 0.28575230587640144, 0.3286392635131121], | |
| 'std': [0.12719404213832913, 0.1476650170975746, 0.09248325352628932], | |
| }, | |
| }, | |
| 'stanford_hydra_dataset': { | |
| 'euler': { | |
| 'max': [3.141592357591889, 1.5364991593031068, 3.141222461465462], | |
| 'mean': [-0.2789889021491306, -0.19247585545209975, 0.34538177027867784], | |
| 'min': [-3.1415919781981962, -1.5573291068141568, -3.1413972494043216], | |
| 'q01': [-3.138796116403405, -1.107955818954808, -2.271278880707407], | |
| 'q99': [3.1386728135975384, 0.915868569686168, 2.4011245578416336], | |
| 'std': [2.883189482433683, 0.43362190146795254, 1.199450318648319], | |
| }, | |
| 'gripper': { | |
| 'max': [0.0811397060751915], | |
| 'min': [-0.0005391233135014772], | |
| 'q01': [1.3133333141013281e-06], | |
| 'q99': [0.08100640028715134], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 1.9801124334335327, | |
| 1.5752310752868652, | |
| 2.6661651134490967, | |
| -0.6274839639663696, | |
| 2.3508076667785645, | |
| 3.4612932205200195, | |
| 2.747223377227783, | |
| ], | |
| 'mean': [ | |
| -0.052988938987255096, | |
| -0.19803331792354584, | |
| -0.046565085649490356, | |
| -2.395282506942749, | |
| 0.4740530848503113, | |
| 2.119305372238159, | |
| 0.031005585566163063, | |
| ], | |
| 'min': [ | |
| -2.5843772888183594, | |
| -1.287919044494629, | |
| -1.9306836128234863, | |
| -2.9673542976379395, | |
| -1.9528191089630127, | |
| 0.5680276155471802, | |
| -2.74257755279541, | |
| ], | |
| 'q01': [ | |
| -1.0905369520187378, | |
| -0.9110571146011353, | |
| -0.9165626764297485, | |
| -2.876581907272339, | |
| -0.9029546976089478, | |
| 1.3720064163208008, | |
| -2.0172016620635986, | |
| ], | |
| 'q99': [ | |
| 0.6443323493003845, | |
| 1.0052685737609863, | |
| 0.9653106927871704, | |
| -1.4030946493148804, | |
| 1.738037109375, | |
| 2.943267822265625, | |
| 2.451101541519165, | |
| ], | |
| 'std': [ | |
| 0.3502423167228699, | |
| 0.41914135217666626, | |
| 0.39895766973495483, | |
| 0.31860536336898804, | |
| 0.5207170248031616, | |
| 0.35328012704849243, | |
| 1.2519474029541016, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.7598075952006731, 0.36410959179102775, 0.597508369211053], | |
| 'mean': [0.4348564561896346, 0.024563536690147474, 0.2938664975066919], | |
| 'min': [0.18414022283711137, -0.32311685510034516, 0.017172937739480726], | |
| 'q01': [0.2373728557200433, -0.26521679456931635, 0.09069013461938308], | |
| 'q99': [0.712423814640525, 0.25299056815993204, 0.49505407602925094], | |
| 'std': [0.10465658310507645, 0.12745896408828714, 0.10379447506049931], | |
| }, | |
| }, | |
| 'taco_play': { | |
| 'euler': { | |
| 'max': [3.1415891647338867, 0.30717557668685913, 2.4624788761138916], | |
| 'mean': [-0.06622616201639175, -0.17697572708129883, 0.4614097774028778], | |
| 'min': [-3.1415886878967285, -0.8683895468711853, -1.789320707321167], | |
| 'q01': [-3.139000654220581, -0.6962019205093384, -1.1729414463043213], | |
| 'q99': [3.1392500400543213, 0.12436603754758835, 1.8065968751907349], | |
| 'std': [3.0384926795959473, 0.18692463636398315, 0.6806972622871399], | |
| }, | |
| 'gripper': { | |
| 'max': [0.08074767142534256], | |
| 'min': [9.652999870013446e-05], | |
| 'q01': [0.00015234666352625936], | |
| 'q99': [0.08067675679922104], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.9704633951187134, | |
| 0.4390600025653839, | |
| 0.9892600178718567, | |
| -1.116769790649414, | |
| 1.2800638675689697, | |
| 2.894169330596924, | |
| 2.627519130706787, | |
| ], | |
| 'mean': [ | |
| 0.14954668283462524, | |
| -0.3514331877231598, | |
| 0.1295831948518753, | |
| -2.205498456954956, | |
| 0.11473798751831055, | |
| 2.0389058589935303, | |
| 0.5516068339347839, | |
| ], | |
| 'min': [ | |
| -0.6495265960693359, | |
| -1.2877551317214966, | |
| -1.5125423669815063, | |
| -2.956650733947754, | |
| -1.0317779779434204, | |
| 1.1109415292739868, | |
| -1.8966686725616455, | |
| ], | |
| 'q01': [ | |
| -0.4596467614173889, | |
| -0.9323632121086121, | |
| -0.6153853535652161, | |
| -2.8417088985443115, | |
| -0.3472142219543457, | |
| 1.428541898727417, | |
| -0.5756277441978455, | |
| ], | |
| 'q99': [ | |
| 0.7262046933174133, | |
| 0.21787810325622559, | |
| 0.7387099266052246, | |
| -1.3517359495162964, | |
| 0.7581852674484253, | |
| 2.564530849456787, | |
| 1.6355704069137573, | |
| ], | |
| 'std': [ | |
| 0.281919926404953, | |
| 0.2702403962612152, | |
| 0.3334408402442932, | |
| 0.412654846906662, | |
| 0.21815766394138336, | |
| 0.2808808982372284, | |
| 0.4732993543148041, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.7124971151351929, 0.6118948459625244, 0.617118775844574], | |
| 'mean': [0.37278515100479126, 0.13003520667552948, 0.38341864943504333], | |
| 'min': [0.07693906873464584, -0.4944135844707489, 0.20030911266803741], | |
| 'q01': [0.1368357390165329, -0.4297449290752411, 0.20516259968280792], | |
| 'q99': [0.6700438857078552, 0.5943909883499146, 0.5966404676437378], | |
| 'std': [0.11509375274181366, 0.2694733738899231, 0.11266136914491653], | |
| }, | |
| }, | |
| 'toto': { | |
| 'euler': { | |
| 'max': [3.1414193360696108, 1.570402878730552, 3.1415549880169924], | |
| 'mean': [-1.31124856158284, 0.8107744638901545, 1.042033586691101], | |
| 'min': [-3.1415814091974665, -1.3210638993411208, -3.141515215393105], | |
| 'q01': [-2.9459294143675696, -0.6449196412662773, -2.9790265961739872], | |
| 'q99': [2.824635320375007, 1.4897934376557762, 3.0083314049108925], | |
| 'std': [1.2205361480020216, 0.45724967725161403, 1.3644152538580079], | |
| }, | |
| 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, | |
| 'joints': { | |
| 'max': [ | |
| 1.4926575422286987, | |
| 0.7715681195259094, | |
| 1.9506583213806152, | |
| -1.6330622434616089, | |
| 2.804656744003296, | |
| 2.0850095748901367, | |
| 1.9661481380462646, | |
| ], | |
| 'mean': [ | |
| -0.040506213903427124, | |
| -0.2537725269794464, | |
| 0.05389903113245964, | |
| -2.5135483741760254, | |
| -0.161987766623497, | |
| 0.8815387487411499, | |
| 0.0034975146409124136, | |
| ], | |
| 'min': [ | |
| -1.8345723152160645, | |
| -1.7133995294570923, | |
| -1.40083646774292, | |
| -2.985748767852783, | |
| -2.814244270324707, | |
| -0.5165338516235352, | |
| -3.5424435138702393, | |
| ], | |
| 'q01': [ | |
| -1.4044159650802612, | |
| -1.3978556394577026, | |
| -0.8029389381408691, | |
| -2.97342586517334, | |
| -1.5402026176452637, | |
| 0.0011001524981111288, | |
| -2.9709367752075195, | |
| ], | |
| 'q99': [ | |
| 0.7075999975204468, | |
| 0.42536962032318115, | |
| 1.3345959186553955, | |
| -1.8281668424606323, | |
| 2.5240790843963623, | |
| 2.0835044384002686, | |
| 1.212519884109497, | |
| ], | |
| 'std': [ | |
| 0.3405684530735016, | |
| 0.3766334354877472, | |
| 0.40128207206726074, | |
| 0.3202400505542755, | |
| 0.6421889066696167, | |
| 0.33138400316238403, | |
| 0.5801302194595337, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.7777318005382999, 0.5688841397096768, 0.672235403102292], | |
| 'mean': [0.15078069344893, -0.03012185632654897, 0.35519823701585035], | |
| 'min': [-0.18532184496861265, -0.6282599632853518, 0.06312318086547261], | |
| 'q01': [-0.09177927811484686, -0.3571658956454935, 0.21965464356283362], | |
| 'q99': [0.6757593680120374, 0.2889021640318381, 0.501109413161318], | |
| 'std': [0.14726563054735392, 0.13472763062259546, 0.059123705378892395], | |
| }, | |
| }, | |
| 'ucsd_kitchen_dataset': { | |
| 'euler': { | |
| 'max': [3.141592644655991, 0.10471752134072587, 2.0420500160887434], | |
| 'mean': [-2.349219123990161, -0.5844573008678524, 0.8632842703198627], | |
| 'min': [-3.141592637838787, -1.5533486027935484, -1.676143206926509], | |
| 'q01': [-3.1415921825947537, -1.535883315130894, -0.6806766312742327], | |
| 'q99': [3.1328677560237566, 0.01745363977705911, 1.8326023938995777], | |
| 'std': [1.756197370631558, 0.482005280180061, 0.7151569996653978], | |
| }, | |
| 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, | |
| 'joints': { | |
| 'max': [ | |
| 2.4945411682128906, | |
| 1.0603798627853394, | |
| 0.5518655776977539, | |
| 2.71895432472229, | |
| 1.9269907474517822, | |
| 2.2231369018554688, | |
| 0.7844778895378113, | |
| ], | |
| 'mean': [ | |
| 0.35797572135925293, | |
| -0.23473882675170898, | |
| -0.34185951948165894, | |
| 0.8982798457145691, | |
| 0.5780586004257202, | |
| 1.0351765155792236, | |
| -1.2190868854522705, | |
| ], | |
| 'min': [ | |
| -0.6943671703338623, | |
| -1.6351218223571777, | |
| -2.0142624378204346, | |
| 0.00768359424546361, | |
| -0.17713193595409393, | |
| 0.011292487382888794, | |
| -2.6592507362365723, | |
| ], | |
| 'q01': [ | |
| -0.43497809767723083, | |
| -1.4843493700027466, | |
| -1.7824348211288452, | |
| 0.0955040231347084, | |
| -0.07186625152826309, | |
| 0.27948257327079773, | |
| -2.4006130695343018, | |
| ], | |
| 'q99': [ | |
| 2.25881028175354, | |
| 0.8190488815307617, | |
| 0.2994879484176636, | |
| 2.153681993484497, | |
| 1.6126199960708618, | |
| 1.8523634672164917, | |
| 0.06944511085748672, | |
| ], | |
| 'std': [ | |
| 0.6131777167320251, | |
| 0.6322475075721741, | |
| 0.5427954792976379, | |
| 0.539089024066925, | |
| 0.43294647336006165, | |
| 0.3588581085205078, | |
| 0.6916991472244263, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.681192194250889, 0.2323359966361606, 0.6159547986373596], | |
| 'mean': [0.4067338752307897, 0.027551073758015975, 0.313373054289359], | |
| 'min': [0.13453314224525192, -0.2553313745227843, 0.03515888153950328], | |
| 'q01': [0.18739915591793913, -0.18234310016370514, 0.048970692427190994], | |
| 'q99': [0.6410437631246786, 0.20632224240340083, 0.5983893391276117], | |
| 'std': [0.12446715882142037, 0.08252308025761342, 0.11421020385340024], | |
| }, | |
| }, | |
| 'utaustin_mutex': { | |
| 'euler': { | |
| 'max': [3.1415927410125732, 0.5760129690170288, 0.6430304050445557], | |
| 'mean': [1.643947720527649, 0.03388536348938942, -0.31948330998420715], | |
| 'min': [-3.141592025756836, -0.9390894770622253, -1.8157784938812256], | |
| 'q01': [-3.1398682594299316, -0.7223533391952515, -1.5468647480010986], | |
| 'q99': [3.140272855758667, 0.3558599054813385, 0.3583984971046448], | |
| 'std': [2.2715108394622803, 0.19258549809455872, 0.5275800824165344], | |
| }, | |
| 'gripper': { | |
| 'max': [0.07569462060928345], | |
| 'min': [0.00026726332725957036], | |
| 'q01': [0.0019378233700990677], | |
| 'q99': [0.07565194368362427], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.5831676125526428, | |
| 0.8024097681045532, | |
| 1.350082516670227, | |
| -1.1145747900009155, | |
| 0.5345654487609863, | |
| 3.2108523845672607, | |
| 2.875584363937378, | |
| ], | |
| 'mean': [ | |
| -0.09772995859384537, | |
| 0.05997378006577492, | |
| -0.009316973388195038, | |
| -2.3428714275360107, | |
| -0.48651641607284546, | |
| 2.2390215396881104, | |
| 1.3341773748397827, | |
| ], | |
| 'min': [ | |
| -0.6990927457809448, | |
| -0.8311276435852051, | |
| -0.5323343276977539, | |
| -2.9861886501312256, | |
| -2.060809850692749, | |
| 0.826747477054596, | |
| 0.05972049757838249, | |
| ], | |
| 'q01': [ | |
| -0.532222330570221, | |
| -0.5078368782997131, | |
| -0.3437865376472473, | |
| -2.784912586212158, | |
| -1.8685580492019653, | |
| 1.3974121809005737, | |
| 0.5157660841941833, | |
| ], | |
| 'q99': [ | |
| 0.35958197712898254, | |
| 0.6339818835258484, | |
| 0.5833610892295837, | |
| -1.612648367881775, | |
| 0.2780923545360565, | |
| 2.8907132148742676, | |
| 2.837320327758789, | |
| ], | |
| 'std': [ | |
| 0.2101736217737198, | |
| 0.2880096137523651, | |
| 0.17824982106685638, | |
| 0.23544654250144958, | |
| 0.6387099027633667, | |
| 0.33533376455307007, | |
| 0.5090957880020142, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.581696093082428, 0.4334338903427124, 0.6800903081893921], | |
| 'mean': [0.4331520199775696, -0.10096638649702072, 0.22044037282466888], | |
| 'min': [0.2454746514558792, -0.502901017665863, 0.008792983368039131], | |
| 'q01': [0.3217194080352783, -0.4733337163925171, 0.014122226275503635], | |
| 'q99': [0.5321439504623413, 0.3733823001384735, 0.5785381197929382], | |
| 'std': [0.04380597919225693, 0.20559938251972198, 0.12614832818508148], | |
| }, | |
| }, | |
| 'viola': { | |
| 'euler': { | |
| 'max': [3.141592502593994, 0.9555190801620483, 0.40456175804138184], | |
| 'mean': [-0.7171992063522339, 0.07086259871721268, -0.67817223072052], | |
| 'min': [-3.141592025756836, -0.4305531978607178, -2.212939739227295], | |
| 'q01': [-3.1400232315063477, -0.19947049021720886, -1.8703945875167847], | |
| 'q99': [3.1401824951171875, 0.7830920219421387, 0.19578109681606293], | |
| 'std': [2.902841567993164, 0.20232577621936798, 0.7306717038154602], | |
| }, | |
| 'gripper': { | |
| 'max': [0.07748929411172867], | |
| 'min': [-5.3190000471659005e-05], | |
| 'q01': [0.00020356666937004775], | |
| 'q99': [0.07747221738100052], | |
| }, | |
| 'joints': { | |
| 'max': [ | |
| 0.3366159200668335, | |
| 0.8983011841773987, | |
| 0.330204576253891, | |
| 0.0, | |
| 1.3329579830169678, | |
| 3.1103708744049072, | |
| 2.8318748474121094, | |
| ], | |
| 'mean': [ | |
| 0.027196571230888367, | |
| 0.2290346622467041, | |
| -0.06033482402563095, | |
| -2.060992479324341, | |
| 0.25734394788742065, | |
| 2.2569527626037598, | |
| 1.2675725221633911, | |
| ], | |
| 'min': [ | |
| -0.4073199927806854, | |
| -0.23913456499576569, | |
| -0.5244941115379333, | |
| -2.744713068008423, | |
| -0.48249971866607666, | |
| 0.0, | |
| -0.22237232327461243, | |
| ], | |
| 'q01': [ | |
| -0.26899582147598267, | |
| -0.189721941947937, | |
| -0.3926161825656891, | |
| -2.620305061340332, | |
| -0.199931338429451, | |
| 1.6743353605270386, | |
| 0.1440846174955368, | |
| ], | |
| 'q99': [ | |
| 0.23294022679328918, | |
| 0.764011561870575, | |
| 0.20191603899002075, | |
| -1.5629768371582031, | |
| 1.039290189743042, | |
| 2.9412121772766113, | |
| 2.798232078552246, | |
| ], | |
| 'std': [ | |
| 0.10160040855407715, | |
| 0.21415969729423523, | |
| 0.14444759488105774, | |
| 0.27077943086624146, | |
| 0.3251941502094269, | |
| 0.338652640581131, | |
| 0.7529342174530029, | |
| ], | |
| }, | |
| 'translation': { | |
| 'max': [0.6868156790733337, 0.20999883115291595, 0.5037735104560852], | |
| 'mean': [0.5464076995849609, 0.006265466101467609, 0.23136523365974426], | |
| 'min': [0.0, -0.33860740065574646, 0.0], | |
| 'q01': [0.40061360597610474, -0.25196850299835205, 0.010269512422382832], | |
| 'q99': [0.6458418369293213, 0.17776551842689514, 0.4456312954425812], | |
| 'std': [0.062116123735904694, 0.12932434678077698, 0.13205111026763916], | |
| }, | |
| }, | |
| } | |
| def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: | |
| if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm): | |
| return {} | |
| with open(self.config.control_stats_path, 'r') as file: | |
| stats = yaml.safe_load(file) | |
| if self.config.delta_controls: | |
| if self.control_io_config.future_controls_sequence_stride_sec is None: | |
| horizon = 0.0 | |
| else: | |
| horizon = self.control_io_config.future_controls_sequence_stride_sec | |
| elif self.control_io_config.future_controls_sequence_stride_sec is None: | |
| if self.control_io_config.future_controls_sequence_length == 1: | |
| horizon = 0.0 | |
| else: | |
| raise NotImplementedError() | |
| else: | |
| horizon = ( | |
| self.control_io_config.future_controls_sequence_length | |
| * self.control_io_config.future_controls_sequence_stride_sec | |
| ) | |
| key = f'horizon_{round(horizon, 2)}s' | |
| if key in stats: | |
| stats = stats[key] | |
| else: | |
| raise ValueError( | |
| f'Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]' | |
| ) | |
| return stats | |
| def dataset_names(self) -> List[str]: | |
| if ( | |
| is_global_norm(self.config.rotation_norm) | |
| and is_global_norm(self.config.obs_rotation_norm) | |
| and is_global_norm(self.config.translation_norm) | |
| and is_global_norm(self.config.obs_translation_norm) | |
| ): | |
| return ['ANY'] | |
| return list(set(self._control_stats.keys()) | set(self._observation_stats.keys())) | |
| def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding | |
| w.r.t. the 0-th element preceding the sequence | |
| Ex: | |
| Sequence of points: T1, T2, T3, T4 | |
| `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame, | |
| implicitly encoded in T0T1 | |
| Output: T0T1, T0T2, T0T3, T0T4 | |
| Args: | |
| translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S | |
| corresponds to the sequence dimension | |
| Returns: | |
| torch.Tensor of the same shape as translation_sequence, containing delta translations | |
| """ | |
| assert translation_sequence.ndim >= 3, translation_sequence.shape | |
| delta_translations = torch.cumsum(translation_sequence, dim=-2) | |
| return delta_translations | |
| class RegressionProcessor(VLAMProcessor[RegressionProcessorConfig]): | |
| def policy_control_plan_from_model_target( | |
| self, target: RoboticsTarget, dataset_name: np.ndarray | |
| ) -> RoboticsControlPlan: | |
| translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key='translation') | |
| rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key='rotation') | |
| rotmat = convert_rotation(rotation, RotationFormat.ROTMAT) | |
| gripper_prob = target.gripper | |
| if self.config.delta_controls: | |
| translation_m = delta_to_relative_translations(translation_m) | |
| rotmat = delta_to_relative_rotations(rotmat) | |
| return RoboticsControlPlan( | |
| translation_m=translation_m, | |
| rotmat=rotmat, | |
| gripper_prob=gripper_prob, | |
| valid_mask=target.valid_mask, | |
| ) | |
| def policy_control_plan_from_model_output( | |
| self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor | |
| ) -> RoboticsControlPlan: | |
| """Called during inference to create control plan from model output""" | |
| translation_m = self.unnormalize( | |
| model_output.translation, dataset_name=dataset_name, key='translation' | |
| ) | |
| rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key='rotation') | |
| rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True) | |
| gripper_prob = torch.sigmoid(model_output.gripper) | |
| if self.config.delta_controls: | |
| translation_m = delta_to_relative_translations(translation_m) | |
| rotmat = delta_to_relative_rotations(rotmat) | |
| return RoboticsControlPlan( | |
| translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=valid_mask | |
| ) | |
| class VLARMProcessor(RegressionProcessor): | |
| def __init__(self, config: VLARMProcessorConfig, vlm_processor: VLMProcessor): | |
| Configurable.__init__(self, config) | |
| self.vlm_processor = vlm_processor | |
| self.control_tokenizer = EmptyTokenizer( | |
| config=self.config.control_tokenizer_config, tokenizer=self.tokenizer | |
| ) | |
| self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { | |
| 'obs_translation': self.obs_translation_norm_bounds, | |
| 'obs_rotation': self.obs_rotation_norm_bounds, | |
| 'translation': self.translation_norm_bounds, | |
| 'rotation': self.rotation_norm_bounds, | |
| 'joints': self.joints_norm_bounds, | |
| } | |
| def dataset_np_names(self) -> np.ndarray: | |
| return self._dataset_np_names | |
| def _set_dataset_np_names(self, dataset_np_names: np.ndarray) -> None: | |
| self._dataset_np_names = dataset_np_names | |
| def preprocess_inputs( | |
| self, | |
| chat: List[str], | |
| images: Dict[str, List[PIL.Image.Image]], | |
| ee_pose_translation: np.ndarray, | |
| ee_pose_rotation: np.ndarray, | |
| gripper: np.ndarray, | |
| joints: np.ndarray, | |
| dataset_name: np.ndarray, | |
| inference_mode: bool, | |
| control_target: Optional[RoboticsTarget] = None, | |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: | |
| inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) | |
| images: Dict[str, torch.Tensor] = inputs['images'] | |
| input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length] | |
| target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length] | |
| attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) | |
| ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) | |
| ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) | |
| ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) | |
| gripper = preprocess_gripper_observation(gripper, dataset_name) | |
| gripper = torch.tensor(gripper, dtype=torch.float32) | |
| ee_pose_translation = self.normalize( | |
| ee_pose_translation, dataset_name=dataset_name, key='obs_translation' | |
| ) | |
| ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation') | |
| joints = torch.tensor(joints, dtype=torch.float32) | |
| if joints.shape[-1] < 7: | |
| missing_size = 7 - joints.shape[-1] | |
| joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) | |
| joints = self.normalize(joints, dataset_name=dataset_name, key='joints') | |
| return { | |
| 'images': images, | |
| 'input_ids': input_ids, | |
| 'target_text_tokens_ids': target_text_tokens_ids, | |
| 'attn_mask': attn_mask, | |
| 'ee_pose_translation': ee_pose_translation, | |
| 'ee_pose_rotation': ee_pose_rotation, | |
| 'gripper': gripper, | |
| 'joints': joints, | |
| 'control_tokens_ids': None, | |
| 'target_control_tokens_ids': None, | |
| } | |
| def model_otoi_but_gt_gripper( | |
| self, | |
| translation_output: torch.Tensor, | |
| rotation_output: torch.Tensor, | |
| reference_translation: torch.Tensor, | |
| reference_rotation: torch.Tensor, | |
| gt_gripper: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Called during inference to create control plan from model output""" | |
| translation_delta = self.unnormalize( | |
| translation_output, dataset_name=self.dataset_np_names, key='translation' | |
| ) | |
| rotation_delta = self.unnormalize(rotation_output, dataset_name=self.dataset_np_names, key='rotation') | |
| translation_reference = self.unnormalize( | |
| reference_translation, dataset_name=self.dataset_np_names, key='obs_translation' | |
| ) | |
| rotation_reference = self.unnormalize( | |
| reference_rotation, dataset_name=self.dataset_np_names, key='obs_rotation' | |
| ) | |
| translation = translation_reference + translation_delta | |
| rotation = delta_to_world_rotations(rotation_delta, rotation_reference) | |
| translation_next = self.normalize( | |
| translation, dataset_name=self.dataset_np_names, key='obs_translation' | |
| ) | |
| rotation_next = self.normalize(rotation, dataset_name=self.dataset_np_names, key='obs_rotation') | |
| gripper_next = gt_gripper | |
| return translation_next, rotation_next, gripper_next | |
| def make_causal_mask(shape: Sequence[int]) -> torch.Tensor: | |
| """ | |
| Create a causal attention mask of shape `shape` | |
| Args: | |
| shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len] | |
| Returns: | |
| torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend | |
| to the corresponding column (i.e. key) | |
| Example: | |
| shape = (3, 5) -> Mask the upper triangular part | |
| [ | |
| [ 1, 0, 0, 0, 0], | |
| [ 1, 1, 0, 0, 0], | |
| [ 1, 1, 1, 0, 0] | |
| ] | |
| """ | |
| return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0) | |
| def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Enable full bi-directional attention in `attn_mask` inside specific blocks | |
| Args: | |
| attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool | |
| where False values indicate disabled attention | |
| full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate | |
| positions where full bi-directional attention should be enabled | |
| Example: | |
| 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, | |
| 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, | |
| 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, | |
| 1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0, | |
| 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, | |
| 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| """ | |
| assert full_attn.dtype == torch.bool, full_attn.dtype | |
| assert full_attn.ndim == 1, full_attn.shape | |
| assert full_attn.shape[0] == attn_mask.shape[-2], f'{full_attn.shape[0]}, {attn_mask.shape}' | |
| if attn_mask.shape[-1] != attn_mask.shape[-2]: | |
| raise NotImplementedError('Only self-attention supported right now.') | |
| x = full_attn.view(-1, 1) & full_attn.view(1, -1) | |
| x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]]) | |
| x = torch.cumprod(x, dim=1).to(dtype=torch.bool) | |
| x = x & x.permute(1, 0) | |
| mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn | |
| mask_indices = torch.where(mask_positions)[0] | |
| x[mask_indices, mask_indices] = 0 | |
| attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1]) | |
| return attn_mask | |
| IGNORE_INDEX = -100 | |
| class PaliGemmaProcessor(VLMProcessor[PaliGemmaProcessorConfig]): | |
| def __init__( | |
| self, | |
| config: PaliGemmaProcessorConfig, | |
| hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, | |
| **kwargs, | |
| ): | |
| del kwargs | |
| super().__init__(config) | |
| self.hf_processor = hf_processor | |
| self.hf_processor.image_processor.size = dict(self.config.image_sizes['main'].as_json()) | |
| self.hf_processor.image_seq_length = self.config.num_image_tokens['main'] | |
| self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens['main'] | |
| self.bos_id: int = self.tokenizer.bos_token_id | |
| self.eos_id: int = self.tokenizer.eos_token_id | |
| self.sep_token = '\n' | |
| self.sep_id: int = self.tokenizer( | |
| self.sep_token, padding=False, add_special_tokens=False, return_attention_mask=False | |
| )['input_ids'][0] | |
| self.image_token_id: int = self.tokenizer( | |
| self.config.image_token, padding=False, add_special_tokens=False, return_attention_mask=False | |
| )['input_ids'][0] | |
| self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values()) | |
| self.bbox_pattern = re.compile( | |
| '\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]' | |
| ) | |
| def preprocess_inputs( | |
| self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] | |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: | |
| """ | |
| Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at | |
| https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs | |
| Chat must be always made of separate messages from user and model, always starting with user | |
| <image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos> | |
| Args: | |
| chat: List[str] of even size where each entry corresponds to a different turn in the conversation | |
| images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys | |
| in the Dict and the List corresponds to history of images | |
| """ | |
| for key, value in images.items(): | |
| if not isinstance(value, list): | |
| raise TypeError(f'Camera {key} contains values of type {type(value)} instead of list') | |
| (input_ids, target_ids) = ([], []) | |
| for i, text in enumerate(chat): | |
| text = text.replace(self.sep_token, ' ').replace('<image>', '') | |
| text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text) | |
| turn_input_ids: List[int] = self.tokenizer( | |
| text, padding=False, add_special_tokens=False, return_attention_mask=False | |
| )['input_ids'] | |
| if i % 2 == 0: | |
| turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids) | |
| else: | |
| turn_target_ids = turn_input_ids | |
| if i != len(chat) - 1: | |
| turn_input_ids = turn_input_ids + [self.sep_id] | |
| turn_target_ids = turn_target_ids + [IGNORE_INDEX] | |
| input_ids = input_ids + turn_input_ids | |
| target_ids = target_ids + turn_target_ids | |
| input_ids = [self.bos_id] + input_ids + [self.eos_id] | |
| target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id] | |
| image_tokens = self.image_tokens | |
| input_ids = image_tokens + input_ids | |
| target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids | |
| input_ids = torch.tensor(input_ids, dtype=torch.int64) | |
| target_ids = torch.tensor(target_ids, dtype=torch.int64) | |
| image_tensors: Dict[str, torch.Tensor] = { | |
| f'{camera_name}.siglip': self.hf_processor.image_processor( | |
| camera_images, size=self.config.image_sizes[camera_name].as_json(), return_tensors='pt' | |
| )['pixel_values'] | |
| for (camera_name, camera_images) in images.items() | |
| } | |
| attn_mask = make_causal_mask([len(input_ids), len(input_ids)]) | |
| attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX) | |
| return { | |
| 'input_ids': input_ids, | |
| 'target_ids': target_ids, | |
| 'images': image_tensors, | |
| 'attn_mask': attn_mask, | |
| } | |
| def tokenizer(self) -> transformers.PreTrainedTokenizerBase: | |
| return self.hf_processor.tokenizer | |
| def _bbox_to_loc_tokens(match: str) -> str: | |
| """ | |
| https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ | |
| """ | |
| floats = list(map(float, match.groups())) | |
| transformed = [f'<loc{np.clip(round(num * 1024), 0, 1023):04d}>' for num in floats] | |
| return f"[{', '.join(transformed)}]" | |
| def image_sizes(self) -> Dict[str, ImageSizeConfig]: | |
| return self.config.image_sizes | |