| import collections |
| import collections.abc |
| import importlib |
| import re |
| import warnings |
| from abc import abstractmethod |
| from functools import cached_property |
| from typing import Dict, List, Optional, Sequence, Tuple, TypeVar |
|
|
| import numpy as np |
| import PIL.Image |
| import roma |
| import torch |
| import torchvision.transforms.v2 |
| import transformers |
| import yaml |
|
|
| try: |
| from .hf_compat import Configurable, Template |
| except Exception: |
| _hf_compat = importlib.import_module("src.hf_compat") |
| Configurable = _hf_compat.Configurable |
| Template = _hf_compat.Template |
|
|
| from .common_vlarm import ( |
| Normalization, |
| ResizeMode, |
| RoboticsControlPlan, |
| RoboticsInput, |
| RoboticsOutput, |
| RoboticsTarget, |
| RotationFormat, |
| expand_dims, |
| ) |
| from .configuration_vlarm import ( |
| ControlDataIOConfig, |
| ControlTokenizerConfig, |
| EmptyTokenizerConfig, |
| ImageSizeConfig, |
| PaliGemmaProcessorConfig, |
| RegressionProcessorConfig, |
| VLAMProcessorConfig, |
| VLARMProcessorConfig, |
| VLMProcessorConfig, |
| ) |
|
|
| ControlTokenizerConfigT = TypeVar('ControlTokenizerConfigT', bound=ControlTokenizerConfig) |
|
|
|
|
| class ControlTokenizer(Configurable[ControlTokenizerConfigT], Template[ControlTokenizerConfigT]): |
| @abstractmethod |
| def __call__(self, *args, **kwargs) -> str: |
| """Given GT actions and possibly other information, output text control. Gets appened to the prompt""" |
|
|
|
|
| class EmptyTokenizer(ControlTokenizer[EmptyTokenizerConfig]): |
| """ |
| Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the |
| desired result. Includes the hidden states for the image tokens. |
| """ |
|
|
| def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None: |
| super().__init__(config) |
| self.tokenizer = tokenizer |
|
|
| def __call__(self, *_) -> str: |
| return '' |
|
|
|
|
| def np_unique(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| Compute unique elements in data and corresponding indices. |
| |
| np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply |
| run np.unique on unsorted data, the indices you will get will be invalid. |
| |
| """ |
| (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True) |
| (_, indices_of_first_occurence, inverse_indices, counts) = np.unique( |
| indices[inverse], return_index=True, return_inverse=True, return_counts=True |
| ) |
| unique_ids = data[indices_of_first_occurence] |
| return unique_ids, indices_of_first_occurence, inverse_indices, counts |
|
|
|
|
| def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| angles: Euler angles in radians in the format 'xyz', shape [..., 3] |
| Returns: |
| torch.Tensor of shape [..., 3, 3] containing rotation matrices |
| """ |
| return roma.euler_to_rotmat(convention='xyz', angles=angles, degrees=False) |
|
|
|
|
| def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| angles: Euler angles in radians in the format 'xyz', shape [..., 3] |
| Returns: |
| torch.Tensor of shape [..., 4] containing unit quaternions |
| """ |
| return roma.euler_to_unitquat(convention='xyz', angles=angles, degrees=False, normalize=True) |
|
|
|
|
| def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor: |
| """ |
| Args: |
| quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4] |
| eps: Small constant to prevent division by zero |
| Returns: |
| torch.Tensor of shape [..., 4] of unit quaternions |
| """ |
| return quaternion / (quaternion.norm(dim=-1, keepdim=True) + eps) |
|
|
|
|
| def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized |
| Returns: |
| torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) |
| """ |
| unit_quat = normalize_quaternion(quaternion) |
| rotmat = roma.unitquat_to_euler(convention='xyz', quat=unit_quat, as_tuple=False, degrees=False) |
| return rotmat |
|
|
|
|
| def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized |
| Returns: |
| torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) |
| """ |
| unit_quat = normalize_quaternion(quaternion) |
| rotmat = roma.unitquat_to_rotmat(unit_quat) |
| return rotmat |
|
|
|
|
| def is_quaternion(quaternion: torch.Tensor) -> bool: |
| return quaternion.shape[-1] == 4 |
|
|
|
|
| def is_unit_quaternion( |
| quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = 'none' |
| ) -> torch.Tensor | bool: |
| """ |
| Check if a quternion is normalized or not. |
| Args: |
| quaternion: torch.Tensor of shape [..., 4] |
| tolerance: Tolerance for numerical comparisons |
| reduction: |
| 'none' - returns torch.Tensor of bools with the same batch shape |
| 'all' - returns a bool, True if ALL quaternions in the batch are normalized |
| Returns: |
| torch.Tensor with the same batch shape or bool |
| """ |
| assert is_quaternion(quaternion) |
| is_norm = torch.isclose( |
| quaternion.norm(dim=-1, keepdim=True), |
| torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device), |
| atol=epsilon, |
| ) |
| if reduction == 'none': |
| return is_norm |
| if reduction == 'all': |
| return bool(torch.all(is_norm).item()) |
| raise ValueError(f'Unknown reduction mode {reduction}') |
|
|
|
|
| def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor: |
| """ |
| Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion. |
| If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically, |
| this doesn't correspond to a single hemisphere of the unit sphere. Follows |
| https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat |
| """ |
| assert is_quaternion(quaternion), quaternion.shape |
| with torch.no_grad(): |
| is_zero = quaternion == 0 |
| flip_condition = ( |
| (quaternion[..., -1:] < 0) |
| | is_zero[..., -1:] & (quaternion[..., 0:1] < 0) |
| | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0) |
| | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0) |
| ) |
| quaternion = torch.where(flip_condition, -quaternion, quaternion) |
| return quaternion |
|
|
|
|
| def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor: |
| """Convert any rotmat input to [..., 3, 3] shape""" |
| if rotmat.shape[-1] == 9: |
| return rotmat.reshape(*rotmat.shape[:-1], 3, 3) |
| if rotmat.shape[-2:] == torch.Size([3, 3]): |
| return rotmat |
| raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") |
|
|
|
|
| def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| rotmat: Batch of rotation matrices, shape [..., 3, 3] |
| Returns: |
| Batch of Euler angles in radiant, shape [..., 3] |
| """ |
| rotmat = rotmat_as_3x3(rotmat) |
| return roma.rotmat_to_euler(convention='xyz', rotmat=rotmat, as_tuple=False, degrees=False) |
|
|
|
|
| def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| rotmat: Batch of rotation matrices, shape [..., 3, 3] |
| Returns: |
| Batch of unit quaternions, shape [..., 4] |
| """ |
| rotmat = rotmat_as_3x3(rotmat) |
| return roma.rotmat_to_unitquat(rotmat) |
|
|
|
|
| def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Maps 9D input vectors onto SO(3) via symmetric orthogonalization. |
| - Let SVD(M) = U \Sigma V^T |
| - Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T |
| - det(UV^T) ensures that det(SVD+(M)) = 1 |
| - The return value is a rotation matrix (ortonormal) with the least-squares distance to M |
| |
| Args: |
| x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3] |
| Returns: |
| torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3) |
| """ |
| with warnings.catch_warnings(): |
| warnings.filterwarnings( |
| 'ignore', message='In CPU autocast, but the target dtype is not supported. Disabling autocast.' |
| ) |
| with torch.autocast(device_type=x.device.type, dtype=torch.float32): |
| matrices = x.view(-1, 3, 3) |
| matrices = matrices.to(dtype=torch.float32) |
| (u, s, v) = torch.svd(matrices) |
| vt = torch.transpose(v, 1, 2) |
| det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1) |
| diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1) |
| result = torch.matmul(u, diag_vt) |
| result = result.view(*x.shape) |
| result = result.to(dtype=x.dtype) |
| return result |
|
|
|
|
| def is_rotmat_3x3(rotmat: torch.Tensor) -> bool: |
| return rotmat.shape[-2:] == torch.Size([3, 3]) |
|
|
|
|
| def is_rotmat_9(rotmat: torch.Tensor) -> bool: |
| return rotmat.shape[-1] == 9 |
|
|
|
|
| def is_rotmat(rotmat: torch.Tensor) -> bool: |
| """ |
| Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a |
| valid rotmat. `is_orthonormal_rotmat` performs this additional check. |
| NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally |
| `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution |
| """ |
| return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat) |
|
|
|
|
| def is_rotmat_orthonormal( |
| rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = 'none' |
| ) -> torch.Tensor | bool: |
| """ |
| Check if a rotation matrix is orthonormal or not. |
| Args: |
| rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9] |
| epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally, |
| anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not |
| reduction: |
| 'none' - returns torch.Tensor of bools with the same batch shape |
| 'all' - returns a bool, True is ALL matrices in the batch are orthonormal |
| Returns: |
| torch.Tensor with the same batch shape or bool |
| """ |
| assert is_rotmat(rotmat) |
| rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32)) |
| is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon) |
| if reduction == 'none': |
| return is_orthonormal |
| if reduction == 'all': |
| return bool(torch.all(is_orthonormal).item()) |
| raise ValueError(f'Unknown reduction mode {reduction}') |
|
|
|
|
| def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool: |
| """ |
| Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3, |
| also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles |
| when accidentally `rotmat.shape[-2:] == [3, 3]` |
| """ |
| return ( |
| is_rotmat_9(rotmat) |
| or is_rotmat_3x3(rotmat) |
| and is_rotmat_orthonormal(rotmat, epsilon=0.02, reduction='all') |
| ) |
|
|
|
|
| def is_euler(euler: torch.Tensor) -> bool: |
| return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler) |
|
|
|
|
| def rotation_format_from_tensor(rotation) -> RotationFormat: |
| if is_quaternion(rotation): |
| return RotationFormat.QUATERNION |
| if is_orthonormal_rotmat(rotation): |
| return RotationFormat.ROTMAT |
| if is_euler(rotation): |
| return RotationFormat.EULER |
| raise ValueError(f'Tensor shape {rotation.shape} is not a valid rotation format') |
|
|
|
|
| def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor: |
| """Convert any rotmat input to [..., 9] shape""" |
| if is_rotmat_9(rotmat): |
| return rotmat |
| if is_rotmat_3x3(rotmat): |
| return rotmat.reshape(*rotmat.shape[:-2], 9) |
| raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") |
|
|
|
|
| def convert_rotation( |
| rotation: torch.Tensor | np.ndarray, |
| output_format: RotationFormat, |
| autonorm: bool = True, |
| half_cover: bool = True, |
| ) -> torch.Tensor | np.ndarray: |
| is_np = isinstance(rotation, np.ndarray) |
| if is_np: |
| rotation = torch.from_numpy(rotation) |
| if is_quaternion(rotation): |
| if autonorm and not is_unit_quaternion(rotation, reduction='all'): |
| rotation = normalize_quaternion(rotation) |
| if output_format == RotationFormat.QUATERNION: |
| output = rotation |
| elif output_format == RotationFormat.ROTMAT: |
| output = rotmat_as_9(quaternion_to_rotmat(rotation)) |
| elif output_format == RotationFormat.EULER: |
| output = quaternion_to_euler(rotation) |
| else: |
| raise NotImplementedError(f'Unsupported rotation format: {output_format}') |
| elif is_orthonormal_rotmat(rotation): |
| if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction='all'): |
| rotation = symmetric_orthogonalization(rotation) |
| if output_format == RotationFormat.QUATERNION: |
| output = rotmat_to_unit_quaternion(rotation) |
| elif output_format == RotationFormat.ROTMAT: |
| output = rotmat_as_9(rotation) |
| elif output_format == RotationFormat.EULER: |
| output = rotmat_to_euler(rotation) |
| else: |
| raise NotImplementedError(f'Unsupported rotation format: {output_format}') |
| elif is_euler(rotation): |
| if output_format == RotationFormat.QUATERNION: |
| output = euler_to_unit_quaternion(rotation) |
| elif output_format == RotationFormat.ROTMAT: |
| output = rotmat_as_9(euler_to_rotmat(rotation)) |
| elif output_format == RotationFormat.EULER: |
| output = rotation |
| else: |
| raise NotImplementedError(f'Unsupported rotation format: {output_format}') |
| else: |
| raise ValueError(f'Unknown rotation encoding with shape {rotation.shape}') |
| if output_format == RotationFormat.QUATERNION and half_cover: |
| output = quaternion_half_cover(output) |
| if is_np: |
| output = output.numpy() |
| return output |
|
|
|
|
| def delta_to_world_rotations(rotation_sequence: torch.Tensor, reference_frame: torch.Tensor) -> torch.Tensor: |
| """ |
| Transform a sequence of rotation representations encoded w.r.t. `reference_frame` to encoding w.r.t. |
| WORLD frame, where `reference_frame` is provided w.r.t. WORLD frame |
| |
| Ex: |
| Sequence of points (rotations): R_1, R_2, R_3, R_4 |
| `rotation_sequence` contains the rotations: R_01, R_02, R_03, R_04, where R_0 is the reference frame |
| and R_01 is the pose of R1 frame in reference frame, i.e. R_10 converts from reference frame to R1 frame |
| Output: R_W1, R_W2, R_W3, R_W4, where W is the world frame, i.e. the rotation poses of |
| R_1, R_2, R_3, R_4 expressed in world frame |
| |
| Args: |
| rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing |
| either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions |
| reference_frame: torch.Tensor, shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] and the SAME number of BATCH |
| dims as `rotation_sequence`. The reference frame, provided w.r.t. WORLD coordinate frame R_W0,1,2,3 |
| Returns: |
| torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations |
| (R_W1, R_W2, R_W3, R_W4, ...) |
| """ |
| assert rotation_sequence.ndim >= 3, rotation_sequence.shape |
| rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) |
| reference_frame = rotmat_as_3x3(convert_rotation(reference_frame, RotationFormat.ROTMAT)) |
| rotation_sequence = rotmat_as_3x3(convert_rotation(rotation_sequence, RotationFormat.ROTMAT)) |
| if reference_frame.ndim != rotation_sequence.ndim: |
| raise ValueError( |
| f'Cannot broadcast reference_frame of shape {reference_frame.shape} to rotation_sequence of shape {rotation_sequence.shape}. Provide tensors with the same number of batch dimensions' |
| ) |
| R_W0 = reference_frame |
| world_rotations = torch.matmul(R_W0, rotation_sequence) |
| world_rotations = world_rotations.view(*world_rotations.shape[:-2], 9) |
| world_rotations = convert_rotation(world_rotations, rotation_format) |
| return world_rotations |
|
|
|
|
| def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor: |
| """ |
| Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the |
| sequence to the 0-th element preceding the sequence |
| |
| Ex: |
| `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame, |
| implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame |
| Output: R_01, R_02, R_03, R_04 |
| |
| Args: |
| rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing |
| either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions |
| Returns: |
| torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations |
| (R_01, R_02, R_03, R_04, ...) |
| |
| TODO: Can you make it work without for loop |
| """ |
| assert rotation_sequence.ndim >= 3, rotation_sequence.shape |
| rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) |
| rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION) |
| batch_dims = np.arange(rotation_sequence.ndim - 2) |
| delta_rotations = torch.cat( |
| [rotation_sequence[..., :1, :]] |
| + [ |
| roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2)) |
| for i in range(2, rotation_sequence.shape[-2] + 1) |
| ], |
| dim=-2, |
| ) |
| delta_rotations = convert_rotation(delta_rotations, rotation_format) |
| return delta_rotations |
|
|
|
|
| def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray: |
| """Make sure image is of type np.ndarray and HWC format""" |
| if isinstance(image, PIL.Image.Image): |
| image = np.asarray(image) |
| assert isinstance(image, np.ndarray), type(image) |
| assert image.ndim in [2, 3], image.shape |
| if image.ndim == 3: |
| assert image.shape[-1] <= 4, image.shape |
| return image |
|
|
|
|
| def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]: |
| if isinstance(image, np.ndarray): |
| (height, width) = image.shape[:2] |
| else: |
| (width, height) = image.size |
| return height, width |
|
|
|
|
| def pad_image( |
| image: PIL.Image.Image | np.ndarray, |
| target_size: dict[str, int], |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
| ) -> PIL.Image.Image | np.ndarray: |
| """Pad image adding a symmetric border around the height/width.""" |
| assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) |
| (height, width) = hw_from_image(image) |
| (target_width, target_height) = (target_size['width'], target_size['height']) |
| if width == target_width and height == target_height: |
| return image |
| assert target_width >= width, f"Can't pad image of width {width} to {target_width}" |
| assert target_height >= height, f"Can't pad image of height {height} to {target_height}" |
| (horizontal_pad, vertical_pad) = (int((target_width - width) / 2), int((target_height - height) / 2)) |
| if isinstance(image, np.ndarray): |
| padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * ( |
| image.ndim - 2 |
| ) |
| image = np.pad(image, padding, mode='constant', constant_values=pad_value) |
| else: |
| padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) |
| image = torchvision.transforms.v2.functional.pad( |
| image, padding=padding, fill=pad_value, padding_mode='constant' |
| ) |
| return image |
|
|
|
|
| def pad_image_to_ratio( |
| image: PIL.Image.Image | np.ndarray, |
| target_wh_ratio: float, |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
| ) -> PIL.Image.Image | np.ndarray: |
| """Pad image to a target aspect ratio.""" |
| (height, width) = hw_from_image(image) |
| wh_ratio = width / height |
| if target_wh_ratio >= wh_ratio: |
| pad_size = {'width': round(height * target_wh_ratio), 'height': height} |
| else: |
| pad_size = {'width': width, 'height': round(width / target_wh_ratio)} |
| image = pad_image(image, target_size=pad_size, pad_value=pad_value) |
| return image |
|
|
|
|
| def crop_image( |
| image: np.ndarray | PIL.Image.Image, |
| start_height: int, |
| start_width: int, |
| target_height: int, |
| target_width: int, |
| ) -> np.ndarray | PIL.Image.Image: |
| np_image = assert_np_hwc_or_hw_image(image) |
| (height, width) = hw_from_image(image) |
| assert target_width <= width, f"Can't crop image of width {width} to {target_width}" |
| assert target_height <= height, f"Can't crop image of width {height} to {target_height}" |
| (start_height, start_width) = (round(start_height), round(start_width)) |
| (target_height, target_width) = (round(target_height), round(target_width)) |
| np_image = np_image[ |
| start_height : start_height + target_height, start_width : start_width + target_width, ... |
| ] |
| image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image |
| return image |
|
|
|
|
| def crop_image_center( |
| image: np.ndarray | PIL.Image.Image, target_size: dict[str, int] |
| ) -> np.ndarray | PIL.Image.Image: |
| np_image = assert_np_hwc_or_hw_image(image) |
| (height, width) = np_image.shape[:2] |
| (target_height, target_width) = (target_size['height'], target_size['width']) |
| assert target_width <= width, f"Can't crop image of width {width} to {target_width}" |
| assert target_height <= height, f"Can't crop image of width {height} to {target_height}" |
| top = (height - target_height) // 2 |
| left = (width - target_width) // 2 |
| np_image = crop_image(np_image, top, left, target_height, target_width) |
| image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image |
| return image |
|
|
|
|
| def crop_image_to_ratio( |
| image: PIL.Image.Image | np.ndarray, target_wh_ratio: float |
| ) -> PIL.Image.Image | np.ndarray: |
| """Pad image to a target aspect ratio.""" |
| (height, width) = hw_from_image(image) |
| wh_ratio = width / height |
| if target_wh_ratio >= wh_ratio: |
| crop_size = {'width': width, 'height': round(width / target_wh_ratio)} |
| else: |
| crop_size = {'width': round(height * target_wh_ratio), 'height': height} |
| image = crop_image_center(image, target_size=crop_size) |
| return image |
|
|
|
|
| def crop_and_pad_image_to_ratio( |
| image: PIL.Image.Image | np.ndarray, |
| target_wh_ratio: float, |
| mode: ResizeMode | str, |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
| ) -> PIL.Image.Image | np.ndarray: |
| """ |
| Crop and pad an image to a target size depending on the mode. |
| It's expected that the source image and target size have different aspect ratios. |
| |
| Args: |
| image: The image to crop and pad. |
| target_size: The target size to crop and pad the image to. |
| mode: The mode to use for cropping and padding. |
| """ |
| (height, width) = hw_from_image(image) |
| wh_ratio = width / height |
| if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001): |
| return image |
| if mode == ResizeMode.SMART: |
| aspect_ratio = max(width, height) / min(width, height) |
| target_ratio = max(target_wh_ratio, 1 / target_wh_ratio) |
| if aspect_ratio == 1: |
| if target_ratio >= 4 / 3 - 0.01: |
| crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 |
| image = crop_image_to_ratio(image, crop_wh_ratio) |
| else: |
| pass |
| elif aspect_ratio <= 4 / 3 + 0.01: |
| if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): |
| image = crop_image_to_ratio(image, 1.0) |
| elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): |
| image = crop_image_to_ratio(image, 1.0) |
| elif target_ratio >= 4 / 3 + 0.01: |
| pass |
| else: |
| crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 |
| image = crop_image_to_ratio(image, crop_wh_ratio) |
| image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) |
| elif mode == ResizeMode.PAD: |
| image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) |
| elif mode == ResizeMode.CROP: |
| image = crop_image_to_ratio(image, target_wh_ratio) |
| else: |
| raise ValueError(f'Mode {mode} not supported') |
| return image |
|
|
|
|
| def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool: |
| if isinstance(image, PIL.Image.Image): |
| return image.mode in ['1', 'L', 'LA', 'La', 'P', 'PA', 'F', 'I', 'I;16', 'I;16L', 'I;16B', 'I;16N'] |
| if isinstance(image, np.ndarray): |
| return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1 |
| raise ValueError(f'Unsupported image type: {type(image)}') |
|
|
|
|
| def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool: |
| image = np.asarray(image) |
| return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1 |
|
|
|
|
| def resize_image( |
| image: PIL.Image.Image | np.ndarray, |
| target_size: dict[str, int], |
| mode: ResizeMode | str, |
| resample: PIL.Image.Resampling | str = 'auto', |
| pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
| ) -> PIL.Image.Image | np.ndarray: |
| (target_width, target_height) = (target_size['width'], target_size['height']) |
| (height, width) = hw_from_image(image) |
| if height == target_height and width == target_width: |
| return image |
| if resample == 'auto': |
| if is_single_channel_image(image): |
| resample = PIL.Image.Resampling.BILINEAR |
| else: |
| resample = PIL.Image.Resampling.LANCZOS |
| else: |
| assert isinstance(resample, PIL.Image.Resampling), resample |
| if is_single_channel_image(image) and resample not in [ |
| PIL.Image.Resampling.BILINEAR, |
| PIL.Image.Resampling.BICUBIC, |
| ]: |
| raise ValueError( |
| f'Single channel images must be resized with bilinear or bicubic, but got {resample}' |
| ) |
| if is_bin_mask := is_binary_mask(image): |
| image = np.asarray(image).astype(np.uint8) * 255 |
| if mode == ResizeMode.SMART: |
| image = crop_and_pad_image_to_ratio( |
| image, target_wh_ratio=target_width / target_height, mode=mode, pad_value=pad_value |
| ) |
| pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image |
| if mode in [ResizeMode.NAIVE, ResizeMode.SMART]: |
| pil_image = pil_image.resize((target_width, target_height), resample=resample) |
| else: |
| raise NotImplementedError(f'Mode {mode} not supported') |
| image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image |
| if is_bin_mask: |
| image = image.astype(np.uint8) > 127 |
| return image |
|
|
|
|
| def is_global_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool: |
| """Return true if norm is NONE or global for all datasets""" |
| return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping) |
|
|
|
|
| def is_mean_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool: |
| """Return true if norm is based on mean and std""" |
| return ( |
| norm == Normalization.MEAN |
| or isinstance(norm, collections.abc.Mapping) |
| and set(norm.keys()) == {'mean', 'std'} |
| ) |
|
|
|
|
| def _broadcast_shapes( |
| value: torch.Tensor, low: torch.Tensor, high: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Broadcast shapes for normalization: |
| Args: |
| value: torch.Tensor of shape [..., num_components]. The entire shape might be: |
| - [num_components]: `value` has no batch dimension |
| - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds |
| contained in `low` and `high` |
| - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds |
| contained in `low` and `high` |
| - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high` |
| must be for a single dataset, i.e. `num_datasets = 1` |
| |
| low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low` |
| contains normalization bounds for a single dataset |
| high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high` |
| contains normalization bounds for a single dataset |
| Returns: |
| Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value` |
| """ |
| assert low.ndim == high.ndim == 2, f'{low.shape} != {high.shape} or ndim != 2' |
| assert value.shape[-1] == low.shape[-1] == high.shape[-1], f'{value.shape} != {low.shape} / {high.shape}' |
| if value.ndim == low.ndim == high.ndim: |
| return low, high |
| if value.ndim < low.ndim: |
| assert low.ndim == high.ndim == 2, f'{low.shape}, {high.shape}' |
| assert low.shape[0] == high.shape[0] == 1, f'{low.shape}, {high.shape}' |
| (low, high) = (low.view(-1), high.view(-1)) |
| return low, high |
| if low.shape[0] == high.shape[0] == 1: |
| low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1]) |
| high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1]) |
| else: |
| assert value.shape[0] == low.shape[0] == high.shape[0], f'{value.shape} != {low.shape} / {high.shape}' |
| low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1]) |
| high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1]) |
| return low, high |
|
|
|
|
| def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
| (mean, std) = _broadcast_shapes(value, mean, std) |
| (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) |
| return value * (std + 1e-08) + mean |
|
|
|
|
| def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: |
| (low, high) = _broadcast_shapes(value, low, high) |
| (low, high) = (low.to(device=value.device), high.to(device=value.device)) |
| return 0.5 * (value + 1) * (high - low) + low |
|
|
|
|
| def normalize_gripper_by_bounds( |
| value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True |
| ) -> torch.Tensor: |
| """ |
| If binary, normalize to [0, 1], otherwise normalize to [-1, 1] |
| """ |
| (low, high) = _broadcast_shapes(value, low, high) |
| (low, high) = (low.to(device=value.device), high.to(device=value.device)) |
| if binary: |
| return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0) |
| return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) |
|
|
|
|
| def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
| (mean, std) = _broadcast_shapes(value, mean, std) |
| (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) |
| return (value - mean) / (std + 1e-08) |
|
|
|
|
| def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: |
| (low, high) = _broadcast_shapes(value, low, high) |
| (low, high) = (low.to(device=value.device), high.to(device=value.device)) |
| return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) |
|
|
|
|
| def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray: |
| if low < 0.0: |
| return np.clip(-gripper, low, high) |
| return high - np.clip(gripper, low, high) |
|
|
|
|
| GRIPPER_BOUNDS = { |
| 'austin_buds_dataset': (0.0, 0.08), |
| 'austin_sailor_dataset': (0.0, 0.08), |
| 'austin_sirius_dataset': (0.0, 0.08), |
| 'bc_z': (0.0, 1.0), |
| 'berkeley_autolab_ur5': (0.0, 1.0), |
| 'berkeley_cable_routing': (0.0, 1.0), |
| 'berkeley_fanuc_manipulation': (0.0, 1.0), |
| 'bridge': (0.0, 1.0), |
| 'bridge_orig': (0.0, 1.0), |
| 'bridge_action_lang': (0.0, 1.0), |
| 'cmu_stretch': (-3.0, 3.0), |
| 'dlr_edan_shared_control': (0.0, 1.0), |
| 'droid': (0.0, 1.0), |
| 'fmb': (0.0, 1.0), |
| 'fractal20220817_data': (0.0, 1.0), |
| 'furniture_bench_dataset': (0.0, 0.08), |
| 'iamlab_cmu_pickup_insert': (0.0, 1.0), |
| 'jaco_play': (0.0, 1.4), |
| 'kuka': (0.0, 1.0), |
| 'language_table': (0.0, 1.0), |
| 'nyu_franka_play_dataset': (0.0, 1.0), |
| 'roboset': (0.0, 1.0), |
| 'roboturk': (0.0, 1.0), |
| 'stanford_hydra_dataset': (0.0, 0.08), |
| 'taco_play': (0.0, 0.08), |
| 'toto': (0.0, 1.0), |
| 'ucsd_kitchen_dataset': (0.0, 1.0), |
| 'utaustin_mutex': (0.0, 0.08), |
| 'viola': (0.0, 0.08), |
| } |
|
|
|
|
| def preprocess_gripper_observation( |
| gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True |
| ) -> np.ndarray: |
| """ |
| Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset |
| or from the robot and output is normalized continuous value. |
| - if `binary`, output is in [0, 1], with 0 = closed and 1 = open. |
| - otherwise, output is in [-1, 1], with -1 = closed and 1 = open. |
| |
| Dataset-specific gripper observations: |
| austin_buds_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| austin_sailor_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| austin_sirius_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| bc_z: continuous; [0=open; 1=closed] |
| berkeley_autolab_ur5: binary; [0=open; 1=closed] |
| berkeley_cable_routing: constant (closed) |
| berkeley_fanuc_manipulation: binary; [0=open; 1=closed] |
| bridge: continuous; ~[0=closed; 1=open] |
| bridge_orig: continuous; ~[0=closed; 1=open] |
| cmu_stretch: continuous; [-3=closed; 3=open] |
| dlr_edan_shared_control: missing |
| droid: continuous; [0=open, 1=closed] |
| fmb: binary; [0=open; 1=closed] |
| fractal20220817_data: continuous; [0=open; 1=closed] |
| furniture_bench_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| iamlab_cmu_pickup_insert: binary; [0=closed; 1=open] |
| jaco_play: continuous; [0=open; 1.4=closed] |
| kuka: binary; [0=open; 1=closed] |
| language_table: constant (no gripper) |
| nyu_franka_play_dataset: missing |
| roboset: continuous; [0=open, 1=closed] |
| roboturk: continuous; [0=closed, 0.04=open] |
| stanford_hydra_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| taco_play: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| toto: constant (closed) |
| ucsd_kitchen_dataset: missing |
| utaustin_mutex: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| viola: continuous; ~[0=closed; 0.08=open] (franka gripper) |
| |
| """ |
| if isinstance(dataset_name, np.ndarray): |
| assert np.unique(dataset_name).size == 1, dataset_name |
| dataset_name = str(dataset_name[0]) |
| if dataset_name in [ |
| 'berkeley_cable_routing', |
| 'dlr_edan_shared_control', |
| 'language_table', |
| 'nyu_franka_play_dataset', |
| 'toto', |
| 'ucsd_kitchen_dataset', |
| ]: |
| gripper = normalize_gripper_by_bounds( |
| torch.from_numpy(gripper), |
| low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), |
| high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), |
| binary=binary, |
| ).numpy() |
| elif dataset_name in [ |
| 'bc_z', |
| 'berkeley_autolab_ur5', |
| 'berkeley_fanuc_manipulation', |
| 'droid', |
| 'fmb', |
| 'fractal20220817_data', |
| 'jaco_play', |
| 'kuka', |
| 'roboset', |
| ]: |
| (low, high) = GRIPPER_BOUNDS[dataset_name] |
| gripper = normalize_gripper_by_bounds( |
| torch.from_numpy(invert_gripper(gripper, low=low, high=high)), |
| low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), |
| high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), |
| binary=binary, |
| ).numpy() |
| elif dataset_name in [ |
| 'austin_buds_dataset', |
| 'austin_sailor_dataset', |
| 'austin_sirius_dataset', |
| 'bridge', |
| 'bridge_orig', |
| 'bridge_action_lang', |
| 'cmu_stretch', |
| 'furniture_bench_dataset', |
| 'iamlab_cmu_pickup_insert', |
| 'roboturk', |
| 'stanford_hydra_dataset', |
| 'taco_play', |
| 'utaustin_mutex', |
| 'viola', |
| ]: |
| (low, high) = GRIPPER_BOUNDS[dataset_name] |
| gripper = normalize_gripper_by_bounds( |
| torch.from_numpy(gripper), |
| low=torch.full(gripper.shape, low, dtype=torch.float32), |
| high=torch.full(gripper.shape, high, dtype=torch.float32), |
| binary=binary, |
| ).numpy() |
| else: |
| raise NotImplementedError(f'Unknown dataset: {dataset_name}') |
| return gripper |
|
|
|
|
| def rotation_norm_bounds( |
| rotation_norm: Normalization, |
| rotation_format: RotationFormat, |
| stats: Dict[str, Dict[str, Dict[str, List[float]]]], |
| dataset_names: List[str], |
| ) -> Dict[str, Dict[str, torch.Tensor]]: |
| if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE: |
| if rotation_norm == Normalization.BOUNDS: |
| results = { |
| dataset_name: { |
| 'low': torch.tensor(dataset_stats['euler']['min']), |
| 'high': torch.tensor(dataset_stats['euler']['max']), |
| } |
| for (dataset_name, dataset_stats) in stats.items() |
| } |
| elif rotation_norm == Normalization.BOUNDS_Q99: |
| results = { |
| dataset_name: { |
| 'low': torch.tensor(dataset_stats['euler']['q01']), |
| 'high': torch.tensor(dataset_stats['euler']['q99']), |
| } |
| for (dataset_name, dataset_stats) in stats.items() |
| } |
| else: |
| raise NotImplementedError(f'Normalization type {rotation_norm} not yet implemented') |
| else: |
| assert rotation_norm == Normalization.NONE, rotation_norm |
| if rotation_format == RotationFormat.EULER: |
| rotation_size = 3 |
| elif rotation_format == RotationFormat.QUATERNION: |
| rotation_size = 4 |
| else: |
| rotation_size = 9 |
| results = { |
| dataset_name: { |
| 'low': -1 * torch.ones(rotation_size, dtype=torch.float32), |
| 'high': 1 * torch.ones(rotation_size, dtype=torch.float32), |
| } |
| for dataset_name in dataset_names |
| } |
| return results |
|
|
|
|
| def translation_norm_bounds( |
| translation_norm: Normalization | tuple, |
| stats: Dict[str, Dict[str, Dict[str, List[float]]]], |
| dataset_names: List[str], |
| ) -> Dict[str, Dict[str, torch.Tensor]]: |
| if isinstance(translation_norm, Normalization) and translation_norm != Normalization.NONE: |
| if translation_norm == Normalization.BOUNDS: |
| results = { |
| dataset_name: { |
| 'low': torch.tensor(dataset_stats['translation']['min']), |
| 'high': torch.tensor(dataset_stats['translation']['max']), |
| } |
| for (dataset_name, dataset_stats) in stats.items() |
| } |
| elif translation_norm == Normalization.BOUNDS_Q99: |
| results = { |
| dataset_name: { |
| 'low': torch.tensor(dataset_stats['translation']['q01']), |
| 'high': torch.tensor(dataset_stats['translation']['q99']), |
| } |
| for (dataset_name, dataset_stats) in stats.items() |
| } |
| elif translation_norm == Normalization.MEAN: |
| results = { |
| dataset_name: { |
| 'mean': torch.tensor(dataset_stats['translation']['mean']), |
| 'std': torch.tensor(dataset_stats['translation']['std']), |
| } |
| for (dataset_name, dataset_stats) in stats.items() |
| } |
| else: |
| raise NotImplementedError(f'Normalization type {translation_norm} not yet implemented') |
| elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE: |
| results = { |
| dataset_name: { |
| 'low': -1 * torch.ones(3, dtype=torch.float32), |
| 'high': 1 * torch.ones(3, dtype=torch.float32), |
| } |
| for dataset_name in dataset_names |
| } |
| else: |
| assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm) |
| assert all((len(value) == 3 for value in translation_norm.values())), translation_norm |
| assert set(translation_norm.keys()) in ({'low', 'high'}, {'mean', 'std'}), translation_norm |
| results = { |
| dataset_name: { |
| key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items() |
| } |
| for dataset_name in dataset_names |
| } |
| return results |
|
|
|
|
| VLMProcessorConfigT = TypeVar('VLMProcessorConfigT', bound=VLMProcessorConfig) |
|
|
|
|
| class VLMProcessor(Configurable[VLMProcessorConfigT], Template[VLMProcessorConfigT]): |
| @abstractmethod |
| def preprocess_inputs( |
| self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
| ... |
|
|
| @property |
| @abstractmethod |
| def tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
| pass |
|
|
| @property |
| @abstractmethod |
| def image_sizes(self) -> Dict[str, ImageSizeConfig]: |
| pass |
|
|
|
|
| VLAMProcessorConfigT = TypeVar('VLAMProcessorConfigT', bound=VLAMProcessorConfig) |
|
|
|
|
| class VLAMProcessor(Configurable[VLAMProcessorConfigT], Template[VLAMProcessorConfigT]): |
| def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor): |
| super().__init__(config) |
| self.vlm_processor = vlm_processor |
| self.control_tokenizer = EmptyTokenizer( |
| config=self.config.control_tokenizer_config, tokenizer=self.tokenizer |
| ) |
| self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { |
| 'obs_translation': self.obs_translation_norm_bounds, |
| 'obs_rotation': self.obs_rotation_norm_bounds, |
| 'translation': self.translation_norm_bounds, |
| 'rotation': self.rotation_norm_bounds, |
| 'joints': self.joints_norm_bounds, |
| } |
|
|
| @property |
| def tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
| return self.vlm_processor.tokenizer |
|
|
| @property |
| def image_sizes(self) -> Dict[str, ImageSizeConfig]: |
| return self.vlm_processor.image_sizes |
|
|
| @property |
| def camera_names(self) -> List[str]: |
| return list(self.vlm_processor.image_sizes.keys()) |
|
|
| @property |
| def control_io_config(self) -> ControlDataIOConfig: |
| return self.config.control_io_config |
|
|
| @cached_property |
| def rotation_components(self) -> int: |
| if self.config.rotation_format == RotationFormat.EULER: |
| return 3 |
| if self.config.rotation_format == RotationFormat.QUATERNION: |
| return 4 |
| if self.config.rotation_format == RotationFormat.ROTMAT: |
| return 9 |
| raise NotImplementedError(self.config.rotation_format) |
|
|
| @abstractmethod |
| def policy_control_plan_from_model_target( |
| self, target: RoboticsTarget, dataset_name: np.ndarray |
| ) -> RoboticsControlPlan: |
| pass |
|
|
| @abstractmethod |
| def policy_control_plan_from_model_output( |
| self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor |
| ) -> RoboticsControlPlan: |
| pass |
|
|
| def resize_image( |
| self, camera_name: str, image: PIL.Image.Image | np.ndarray |
| ) -> PIL.Image.Image | np.ndarray: |
| return resize_image( |
| image, |
| target_size={ |
| 'width': self.image_sizes[camera_name].width, |
| 'height': self.image_sizes[camera_name].height, |
| }, |
| mode=self.config.image_resize, |
| resample=PIL.Image.Resampling.LANCZOS, |
| ) |
|
|
| def preprocess_inputs( |
| self, |
| chat: List[str], |
| images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]], |
| ee_pose_translation: np.ndarray, |
| ee_pose_rotation: np.ndarray, |
| gripper: np.ndarray, |
| joints: np.ndarray, |
| dataset_name: np.ndarray, |
| inference_mode: bool, |
| control_target: Optional[RoboticsTarget] = None, |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
| """ |
| Preprocess the inputs for a single example |
| Args: |
| instruction: Language instruction |
| images: History of input images with increasing timestamps |
| ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3] |
| ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9] |
| joints: np.ndarray, shape [..., num_past_scalars, <= 7] |
| dataset_name: 1D np.ndarray |
| inference_mode: If True, prepare the input for inference (e.g. don't include target |
| any tokens in the input if relevant). If control_target is available, it should |
| still be preprocessed for test dataset comparison |
| control_target: RoboticsTarget, each component of shape |
| [..., num_control_steps, num_control_components]. Provided only when available, usually |
| during training and dataset test |
| Returns: |
| Dict containing torch.Tensor with inputs |
| """ |
| del control_target |
| del inference_mode |
| inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) |
| images: Dict[str, torch.Tensor] = inputs['images'] |
| input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length] |
| target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length] |
| attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) |
| ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) |
| ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) |
| ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) |
| gripper = preprocess_gripper_observation(gripper, dataset_name) |
| gripper = torch.tensor(gripper, dtype=torch.float32) |
| ee_pose_translation = self.normalize( |
| ee_pose_translation, dataset_name=dataset_name, key='obs_translation' |
| ) |
| ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation') |
| joints = torch.tensor(joints, dtype=torch.float32) |
| if joints.shape[-1] < 7: |
| missing_size = 7 - joints.shape[-1] |
| joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) |
| joints = self.normalize(joints, dataset_name=dataset_name, key='joints') |
| outputs = { |
| 'images': images, |
| 'input_ids': input_ids, |
| 'target_text_tokens_ids': target_text_tokens_ids, |
| 'attn_mask': attn_mask, |
| 'ee_pose_translation': ee_pose_translation, |
| 'ee_pose_rotation': ee_pose_rotation, |
| 'gripper': gripper, |
| 'joints': joints, |
| 'control_tokens_ids': None, |
| 'target_control_tokens_ids': None, |
| } |
| return outputs |
|
|
| def create_input( |
| self, |
| chat: List[str], |
| images: Dict[str, List[PIL.Image.Image]], |
| ee_pose_translation: np.ndarray, |
| ee_pose_rotation: np.ndarray, |
| gripper: np.ndarray, |
| joints: np.ndarray, |
| dataset_name: np.ndarray, |
| inference_mode: bool, |
| control_target: Optional[RoboticsTarget] = None, |
| ) -> RoboticsInput: |
| inputs = self.preprocess_inputs( |
| chat=chat, |
| images=images, |
| ee_pose_translation=ee_pose_translation, |
| ee_pose_rotation=ee_pose_rotation, |
| gripper=gripper, |
| joints=joints, |
| dataset_name=dataset_name, |
| inference_mode=inference_mode, |
| control_target=control_target, |
| ) |
| inputs.pop('target_text_tokens_ids') |
| inputs.pop('target_control_tokens_ids') |
| return RoboticsInput(**inputs) |
|
|
| def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: |
| if is_mean_norm(getattr(self.config, f'{key}_norm')): |
| (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
| output = normalize_by_moments(value, mean=mean, std=std) |
| else: |
| (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
| output = normalize_by_bounds(value, low=low, high=high) |
| return output |
|
|
| def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: |
| if is_mean_norm(getattr(self.config, f'{key}_norm')): |
| (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
| output = unnormalize_by_moments(value, mean=mean, std=std) |
| else: |
| (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
| output = unnormalize_by_bounds(value, low=low, high=high) |
| return output |
|
|
| def _norm_bounds_from_dataset_name( |
| self, dataset_name: np.ndarray, component_key: str |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Create an array of normalization bounds corresponding to dataset names |
| Args: |
| dataset_name: Array of shape [B] of dataset names for which to fetch the low and high |
| normalization bounds. Note the values can be repeating |
| component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to |
| compute the normalization bounds |
| Returns: |
| Tuple of low and high bounds or norm and std, each of shape [B, -1] |
| """ |
| norm = getattr(self.config, f'{component_key}_norm') |
| if is_mean_norm(norm): |
| (stats_key_1, stats_key_2) = ('mean', 'std') |
| else: |
| (stats_key_1, stats_key_2) = ('low', 'high') |
| if component_key == 'joints': |
| if not isinstance(norm, collections.abc.Mapping): |
| raise NotImplementedError() |
| stats = { |
| key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1])) |
| for (key, value) in self.joints_norm_bounds['ANY'].items() |
| } |
| return tuple(stats.values()) |
| component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1] |
| if self.dataset_names == ['ANY']: |
| stats_1 = self.norm_bounds[component_key]['ANY'][stats_key_1] |
| stats_2 = self.norm_bounds[component_key]['ANY'][stats_key_2] |
| stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0) |
| stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0) |
| else: |
| (unique_names, _, inverse_indices, _) = np_unique(dataset_name) |
| stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32) |
| stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32) |
| for i, ds_name in enumerate(unique_names): |
| stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1] |
| stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2] |
| stats_1 = stats_1[inverse_indices] |
| stats_2 = stats_2[inverse_indices] |
| return torch.from_numpy(stats_1), torch.from_numpy(stats_2) |
|
|
| @cached_property |
| def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
| return rotation_norm_bounds( |
| rotation_norm=self.config.obs_rotation_norm, |
| rotation_format=self.config.rotation_format, |
| stats=self._observation_stats, |
| dataset_names=self.dataset_names, |
| ) |
|
|
| @cached_property |
| def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
| return translation_norm_bounds( |
| translation_norm=self.config.obs_translation_norm, |
| stats=self._observation_stats, |
| dataset_names=self.dataset_names, |
| ) |
|
|
| @cached_property |
| def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
| return rotation_norm_bounds( |
| rotation_norm=self.config.rotation_norm, |
| rotation_format=self.config.rotation_format, |
| stats=self._control_stats, |
| dataset_names=self.dataset_names, |
| ) |
|
|
| @cached_property |
| def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
| return translation_norm_bounds( |
| translation_norm=self.config.translation_norm, |
| stats=self._control_stats, |
| dataset_names=self.dataset_names, |
| ) |
|
|
| @cached_property |
| def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
| """ |
| NOTE: |
| - Joint values across all joints and all datasets vary in the range [-2pi; 2pi] |
| - The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi] |
| - It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint |
| """ |
| low = torch.tensor(self.config.joints_norm['low'], dtype=torch.float32) |
| high = torch.tensor(self.config.joints_norm['high'], dtype=torch.float32) |
| results = {'ANY': {'low': low, 'high': high}} |
| return results |
|
|
| @cached_property |
| def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: |
| return { |
| 'austin_buds_dataset': { |
| 'euler': { |
| 'max': [3.141589641571045, 0.2279592752456665, 0.10982763767242432], |
| 'mean': [-0.433798223733902, 0.02660757675766945, 0.01057341042906046], |
| 'min': [-3.1415905952453613, -0.17229366302490234, -0.08408939838409424], |
| 'q01': [-3.1401021480560303, -0.15158048272132874, -0.07398375868797302], |
| 'q99': [3.1401915550231934, 0.1699378788471222, 0.08044551312923431], |
| 'std': [3.071873664855957, 0.08403228223323822, 0.04623554274439812], |
| }, |
| 'gripper': { |
| 'max': [0.07999841868877411], |
| 'min': [0.00019240332767367363], |
| 'q01': [0.00760263716802001], |
| 'q99': [0.07997412234544754], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.25105541944503784, |
| 1.0239691734313965, |
| 0.25514841079711914, |
| 0.0, |
| 0.05838121101260185, |
| 3.0727620124816895, |
| 1.1911247968673706, |
| ], |
| 'mean': [ |
| -0.02248559147119522, |
| 0.4224241375923157, |
| 0.011533008888363838, |
| -2.040178060531616, |
| -0.004422259051352739, |
| 2.440391778945923, |
| 0.7626844644546509, |
| ], |
| 'min': [ |
| -0.46276336908340454, |
| -0.2620261609554291, |
| -0.37377235293388367, |
| -3.010026693344116, |
| -0.015008972026407719, |
| 0.0, |
| 0.0, |
| ], |
| 'q01': [ |
| -0.3563080132007599, |
| -0.19165010750293732, |
| -0.29941368103027344, |
| -2.766944408416748, |
| -0.012211786583065987, |
| 1.9946393966674805, |
| 0.21924567222595215, |
| ], |
| 'q99': [ |
| 0.22295619547367096, |
| 0.8841447830200195, |
| 0.20571835339069366, |
| -1.339158296585083, |
| 0.05605502426624298, |
| 2.9369189739227295, |
| 1.1475461721420288, |
| ], |
| 'std': [ |
| 0.1684190183877945, |
| 0.2238602489233017, |
| 0.1185624971985817, |
| 0.3182348608970642, |
| 0.012190691195428371, |
| 0.2673698663711548, |
| 0.2611873149871826, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.7684353590011597, 0.22995209693908691, 0.3893454968929291], |
| 'mean': [0.5589713454246521, -0.0022956086322665215, 0.12593945860862732], |
| 'min': [0.0, -0.31078848242759705, 0.0], |
| 'q01': [0.3499317765235901, -0.2854413390159607, 0.010516085661947727], |
| 'q99': [0.7243335843086243, 0.20652863383293152, 0.3218296766281128], |
| 'std': [0.08953548222780228, 0.1462203562259674, 0.0769098550081253], |
| }, |
| }, |
| 'austin_sailor_dataset': { |
| 'euler': { |
| 'max': [3.1415910720825195, 0.26857471466064453, 2.4386940002441406], |
| 'mean': [0.040416110306978226, 0.013173789717257023, -0.044821564108133316], |
| 'min': [-3.141592264175415, -0.21639788150787354, -1.995558738708496], |
| 'q01': [-3.1401662826538086, -0.16171419620513916, -1.5918874740600586], |
| 'q99': [3.1401755809783936, 0.17527776956558228, 1.3560799360275269], |
| 'std': [3.0995969772338867, 0.08136381208896637, 0.6311237812042236], |
| }, |
| 'gripper': { |
| 'max': [0.07783995568752289], |
| 'min': [-8.864999836077914e-05], |
| 'q01': [0.0005174533580429852], |
| 'q99': [0.07778151333332062], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.18366889655590057, |
| 1.188208818435669, |
| 1.1904715299606323, |
| -1.0253318548202515, |
| 0.06765712797641754, |
| 2.98008131980896, |
| 2.7747786045074463, |
| ], |
| 'mean': [ |
| -0.4481923580169678, |
| 0.428782194852829, |
| 0.29826778173446655, |
| -2.1356770992279053, |
| -0.2064618021249771, |
| 2.5082974433898926, |
| 0.8104547262191772, |
| ], |
| 'min': [ |
| -1.4425580501556396, |
| -0.34988880157470703, |
| -0.4436887800693512, |
| -2.9587178230285645, |
| -0.9116092324256897, |
| 1.7689018249511719, |
| -1.4776484966278076, |
| ], |
| 'q01': [ |
| -1.0579811334609985, |
| -0.13357718288898468, |
| -0.2518821358680725, |
| -2.6593017578125, |
| -0.6571194529533386, |
| 2.072176456451416, |
| -0.5861577391624451, |
| ], |
| 'q99': [ |
| 0.08897759765386581, |
| 0.8743473887443542, |
| 0.8909343481063843, |
| -1.4748132228851318, |
| -0.010495365597307682, |
| 2.859259605407715, |
| 2.319192409515381, |
| ], |
| 'std': [ |
| 0.26852160692214966, |
| 0.22502659261226654, |
| 0.2577133774757385, |
| 0.2963367700576782, |
| 0.1629231870174408, |
| 0.18130357563495636, |
| 0.6348837614059448, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.7488242387771606, 0.2829112708568573, 0.3541720509529114], |
| 'mean': [0.5367430448532104, -0.08604215085506439, 0.11135970056056976], |
| 'min': [0.3208981454372406, -0.3730051815509796, 0.020952222868800163], |
| 'q01': [0.387094110250473, -0.3164229393005371, 0.024492919445037842], |
| 'q99': [0.6869593262672424, 0.2086469978094101, 0.2551962733268738], |
| 'std': [0.08000300824642181, 0.13984571397304535, 0.056377921253442764], |
| }, |
| }, |
| 'austin_sirius_dataset': { |
| 'euler': { |
| 'max': [3.141592502593994, 0.09964871406555176, 0.3406205177307129], |
| 'mean': [1.1875473260879517, -0.018573710694909096, -0.4977009892463684], |
| 'min': [-3.141592502593994, -0.14349305629730225, -1.892538070678711], |
| 'q01': [-3.140658378601074, -0.12321051210165024, -1.743293285369873], |
| 'q99': [3.1407761573791504, 0.07348346710205078, 0.062370821833610535], |
| 'std': [2.8533384799957275, 0.03962606564164162, 0.5864803791046143], |
| }, |
| 'gripper': { |
| 'max': [0.07965432107448578], |
| 'min': [0.00016088332631625235], |
| 'q01': [0.0334978811442852], |
| 'q99': [0.07948359102010727], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.5567107200622559, |
| 0.3814372420310974, |
| 0.36868324875831604, |
| 0.0, |
| 0.021617505699396133, |
| 2.825117588043213, |
| 2.8925509452819824, |
| ], |
| 'mean': [ |
| 0.22241303324699402, |
| 0.03750193864107132, |
| -0.013105375692248344, |
| -2.428316593170166, |
| -0.017613587900996208, |
| 2.4755642414093018, |
| 1.4929485321044922, |
| ], |
| 'min': [ |
| -0.23144784569740295, |
| -0.41377919912338257, |
| -0.3536752760410309, |
| -2.9289300441741943, |
| -0.06330326944589615, |
| 0.0, |
| 0.0, |
| ], |
| 'q01': [ |
| -0.10539760440587997, |
| -0.2651134133338928, |
| -0.21157152950763702, |
| -2.831827402114868, |
| -0.04555036500096321, |
| 0.0, |
| 0.0, |
| ], |
| 'q99': [ |
| 0.48007193207740784, |
| 0.3064860999584198, |
| 0.2438022643327713, |
| 0.0, |
| 0.012844694778323174, |
| 2.71537446975708, |
| 2.51297926902771, |
| ], |
| 'std': [ |
| 0.1297803372144699, |
| 0.132253035902977, |
| 0.08940940350294113, |
| 0.345712274312973, |
| 0.015448619611561298, |
| 0.3493262231349945, |
| 0.632145345211029, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.5655431151390076, 0.3374997079372406, 0.3071175515651703], |
| 'mean': [0.4490734338760376, 0.09406533092260361, 0.17131780087947845], |
| 'min': [0.0, -0.16291481256484985, 0.0], |
| 'q01': [0.0, -0.11814527958631516, 0.0], |
| 'q99': [0.532875120639801, 0.26084619760513306, 0.27225059270858765], |
| 'std': [0.07605387270450592, 0.08447220921516418, 0.04798709973692894], |
| }, |
| }, |
| 'bc_z': { |
| 'euler': { |
| 'max': [3.141583088638963, 1.5360414168274534, 3.141576023430307], |
| 'mean': [-0.32332700342924814, -0.1255861641435509, -0.07202428898041417], |
| 'min': [-3.13973587780783, -1.5322501214383415, -3.1415575429114675], |
| 'q01': [-1.0433973645798003, -1.007553367940694, -2.6477418236555788], |
| 'q99': [0.8467956620806508, 0.9111507554055364, 2.226032625129908], |
| 'std': [0.40307241985577485, 0.4832146677789649, 1.0666829542185445], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.20000000298023224], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.6597589254379272, 0.7259413599967957, 1.1217665672302246], |
| 'mean': [0.012632723897695541, 0.10061875730752945, 0.7831991910934448], |
| 'min': [-0.7181899547576904, -0.3756217360496521, -0.281008243560791], |
| 'q01': [-0.3956047296524048, -0.11924505233764648, 0.601338267326355], |
| 'q99': [0.332028865814209, 0.3088575601577759, 0.98329097032547], |
| 'std': [0.1808762550354004, 0.0960959941148758, 0.08665762096643448], |
| }, |
| }, |
| 'berkeley_autolab_ur5': { |
| 'euler': { |
| 'max': [3.141586857385381, 0.9715026080766722, 1.0041252839605828], |
| 'mean': [0.10885329009674993, -0.005823583245937125, 0.01756067034138813], |
| 'min': [-3.14158698706268, -0.9912158529884818, -0.8782840134752882], |
| 'q01': [-3.1391613317061315, -0.6410961610461836, -0.4513447799408977], |
| 'q99': [3.13947692478744, 0.6738775260636495, 0.535202669480506], |
| 'std': [3.069331637100682, 0.18171227413203145, 0.17512358924250696], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [ |
| -2.5115966796875, |
| -0.682390034198761, |
| 2.6830153465270996, |
| -1.1711143255233765, |
| -0.6423047184944153, |
| 4.032349109649658, |
| ], |
| 'mean': [ |
| -3.2676453590393066, |
| -1.3197712898254395, |
| 2.122663736343384, |
| -2.3710596561431885, |
| -1.567111611366272, |
| 3.0015013217926025, |
| ], |
| 'min': [ |
| -4.0775299072265625, |
| -2.2796404361724854, |
| 1.1905627250671387, |
| -3.814586877822876, |
| -2.1223204135894775, |
| 1.5279699563980103, |
| ], |
| 'q01': [ |
| -3.839585781097412, |
| -1.8913453817367554, |
| 1.5813319683074951, |
| -3.3646113872528076, |
| -1.805822730064392, |
| 2.22794246673584, |
| ], |
| 'q99': [ |
| -2.7322237491607666, |
| -0.8280109763145447, |
| 2.5063326358795166, |
| -1.5614854097366333, |
| -1.2497813701629639, |
| 3.7013423442840576, |
| ], |
| 'std': [ |
| 0.2746899724006653, |
| 0.27345725893974304, |
| 0.22793518006801605, |
| 0.34690913558006287, |
| 0.10438559204339981, |
| 0.3388271629810333, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.6610942482948303, 0.38513994216918945, 0.2049914449453354], |
| 'mean': [0.45345133543014526, 0.05642913281917572, -0.0356401689350605], |
| 'min': [0.1970997005701065, -0.27643972635269165, -0.20529526472091675], |
| 'q01': [0.3020566999912262, -0.21297279000282288, -0.18836002051830292], |
| 'q99': [0.6132073998451233, 0.30656182765960693, 0.12212439626455307], |
| 'std': [0.07906801998615265, 0.12497302889823914, 0.08053416013717651], |
| }, |
| }, |
| 'berkeley_cable_routing': { |
| 'euler': { |
| 'max': [3.141592264175415, 0.05923163890838623, 3.1243016719818115], |
| 'mean': [-0.6730493903160095, 0.010155542753636837, 0.9881726503372192], |
| 'min': [-3.141592264175415, -0.10035276412963867, -1.6915032863616943], |
| 'q01': [-3.1412672996520996, -0.02922845631837845, -0.6449583172798157], |
| 'q99': [3.1412830352783203, 0.039870113134384155, 2.4969191551208496], |
| 'std': [3.0588438510894775, 0.01435503363609314, 0.6301024556159973], |
| }, |
| 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.7113381028175354, 0.35645395517349243, 0.23169316351413727], |
| 'mean': [0.5643643736839294, 0.00464651919901371, 0.07635799050331116], |
| 'min': [0.378939151763916, -0.3244381248950958, 0.02837979607284069], |
| 'q01': [0.4641263782978058, -0.2806571424007416, 0.030183622613549232], |
| 'q99': [0.6452807784080505, 0.28204888105392456, 0.1557157188653946], |
| 'std': [0.0380910187959671, 0.16768066585063934, 0.032017387449741364], |
| }, |
| }, |
| 'berkeley_fanuc_manipulation': { |
| 'euler': { |
| 'max': [3.1415822505950928, 1.020048975944519, 1.9211788177490234], |
| 'mean': [0.1944933384656906, -0.03283948823809624, 0.017717985436320305], |
| 'min': [-3.141590118408203, -1.5707060098648071, -3.089141368865967], |
| 'q01': [-3.139866352081299, -1.0165742635726929, -1.6987266540527344], |
| 'q99': [3.140317440032959, 0.4332200288772583, 1.5085773468017578], |
| 'std': [2.9820094108581543, 0.21335174143314362, 0.4836970269680023], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [ |
| 0.9483258724212646, |
| 1.4428143501281738, |
| 1.2322325706481934, |
| 2.3726272583007812, |
| 0.0008019324741326272, |
| 3.017007827758789, |
| ], |
| 'mean': [ |
| 0.04932224750518799, |
| 0.36293527483940125, |
| -0.18438231945037842, |
| 0.09242437779903412, |
| -1.0793049335479736, |
| -0.06553139537572861, |
| ], |
| 'min': [ |
| -0.7464086413383484, |
| -0.5774519443511963, |
| -1.1099220514297485, |
| -1.8152672052383423, |
| -2.1454310417175293, |
| -3.168759346008301, |
| ], |
| 'q01': [ |
| -0.5061959028244019, |
| -0.09613651037216187, |
| -0.8047154545783997, |
| -0.9908221364021301, |
| -1.8158636093139648, |
| -2.420135974884033, |
| ], |
| 'q99': [ |
| 0.6388322710990906, |
| 1.0234705209732056, |
| 0.5990921258926392, |
| 1.282780647277832, |
| -0.39091193675994873, |
| 2.340099811553955, |
| ], |
| 'std': [ |
| 0.2452843338251114, |
| 0.2555835247039795, |
| 0.2740932106971741, |
| 0.3769519329071045, |
| 0.3472467362880707, |
| 0.6859157681465149, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.8101964592933655, 0.5117550492286682, 0.8054816722869873], |
| 'mean': [0.5446329116821289, 0.003737417282536626, 0.2439887523651123], |
| 'min': [0.263683944940567, -0.5785022377967834, -0.0025757886469364166], |
| 'q01': [0.3718133866786957, -0.4071895182132721, 0.01847645826637745], |
| 'q99': [0.7200658321380615, 0.3128541111946106, 0.5413243770599365], |
| 'std': [0.07649082690477371, 0.14486227929592133, 0.13899113237857819], |
| }, |
| }, |
| 'bridge': { |
| 'euler': { |
| 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], |
| 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], |
| 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], |
| 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], |
| 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], |
| 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], |
| }, |
| 'gripper': { |
| 'max': [1.0370277166366577], |
| 'min': [0.04637829214334488], |
| 'q01': [0.05192930996417999], |
| 'q99': [1.0118417739868164], |
| }, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], |
| 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], |
| 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], |
| 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], |
| 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], |
| 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], |
| }, |
| }, |
| 'bridge_orig': { |
| 'euler': { |
| 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], |
| 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], |
| 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], |
| 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], |
| 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], |
| 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], |
| }, |
| 'gripper': { |
| 'max': [1.0370277166366577], |
| 'min': [0.04637829214334488], |
| 'q01': [0.05192930996417999], |
| 'q99': [1.0118417739868164], |
| }, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], |
| 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], |
| 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], |
| 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], |
| 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], |
| 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], |
| }, |
| }, |
| 'bridge_action_lang': { |
| 'euler': { |
| 'max': [3.141592653589793, 1.570796251296997, 3.141204357147217], |
| 'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691], |
| 'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938], |
| 'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896], |
| 'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236], |
| 'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], |
| }, |
| 'gripper': { |
| 'max': [1.0370277166366577], |
| 'min': [0.04637829214334488], |
| 'q01': [0.05192930996417999], |
| 'q99': [1.0118417739868164], |
| }, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], |
| 'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424], |
| 'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723], |
| 'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884], |
| 'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424], |
| 'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685], |
| }, |
| }, |
| 'cmu_stretch': { |
| 'euler': { |
| 'max': [3.141592653589793, 0.0, 0.0], |
| 'mean': [3.1415926535896035, 0.0, 0.0], |
| 'min': [3.141592653589793, 0.0, 0.0], |
| 'q01': [3.141592653589793, 0.0, 0.0], |
| 'q99': [3.141592653589793, 0.0, 0.0], |
| 'std': [1.8962609260597674e-13, 0.0, 0.0], |
| }, |
| 'gripper': { |
| 'max': [3.110913038253784], |
| 'min': [-3.1155149936676025], |
| 'q01': [-3.055689811706543], |
| 'q99': [3.1017091274261475], |
| }, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.3401322066783905, 0.0, 1.1013386249542236], |
| 'mean': [0.1461893916130066, 0.0, 0.7568270564079285], |
| 'min': [0.005320895928889513, 0.0, 0.45218077301979065], |
| 'q01': [0.017430847510695457, 0.0, 0.46050605177879333], |
| 'q99': [0.33094948530197144, 0.0, 1.0952961444854736], |
| 'std': [0.09959954023361206, 0.0, 0.15819725394248962], |
| }, |
| }, |
| 'dlr_edan_shared_control': { |
| 'euler': { |
| 'max': [2.2228052616119385, 0.13346634805202484, 3.1415085792541504], |
| 'mean': [1.6072595119476318, -0.4295055568218231, 1.1549623012542725], |
| 'min': [0.4599944055080414, -1.4871342182159424, -3.1406784057617188], |
| 'q01': [0.8485284447669983, -1.4454149007797241, -3.1238820552825928], |
| 'q99': [1.8301045894622803, 0.07032366842031479, 3.131430149078369], |
| 'std': [0.2785564363002777, 0.540916919708252, 2.4260833263397217], |
| }, |
| 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.044179707765579224, 0.6297282576560974, 0.9278888702392578], |
| 'mean': [-0.512688934803009, 0.3690241575241089, 0.441307932138443], |
| 'min': [-0.7695853114128113, -0.040099501609802246, 0.2583216428756714], |
| 'q01': [-0.729511022567749, 0.077408567070961, 0.2658006250858307], |
| 'q99': [-0.13719859719276428, 0.5719971060752869, 0.7898909449577332], |
| 'std': [0.15213875472545624, 0.08891375362873077, 0.13743770122528076], |
| }, |
| }, |
| 'droid': { |
| 'euler': { |
| 'max': [3.141592502593994, 1.5705928802490234, 3.1415867805480957], |
| 'mean': [0.3140628098409554, -0.09296274023036387, -0.07227215454779846], |
| 'min': [-3.141592502593994, -1.5691150426864624, -3.1415374279022217], |
| 'q01': [-3.1378602981567383, -1.2125312042236327, -2.1614069032669065], |
| 'q99': [3.137854380607605, 0.9200375998020163, 1.9367506909370364], |
| 'std': [2.926265757944871, 0.363273475703332, 0.7576065217938824], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.9911894202232361]}, |
| 'joints': { |
| 'max': [ |
| 2.668445110321045, |
| 1.5691218376159668, |
| 2.666306734085083, |
| -0.3114914000034332, |
| 2.6624162197113037, |
| 4.28157901763916, |
| 2.752457857131958, |
| ], |
| 'mean': [ |
| 0.023137084334640106, |
| 0.2704989977282293, |
| -0.01451389357228282, |
| -2.018709403792315, |
| -0.042720520800030394, |
| 2.350281188152209, |
| 0.12424663946659845, |
| ], |
| 'min': [ |
| -2.6536705493927, |
| -1.547789216041565, |
| -2.6781487464904785, |
| -2.9409868717193604, |
| -2.6705946922302246, |
| 0.24893812835216522, |
| -2.7615714073181152, |
| ], |
| 'q01': [ |
| -0.9026106441020965, |
| -0.8547340619564057, |
| -0.9028875434398651, |
| -2.7698556280136106, |
| -1.6851656341552732, |
| 1.2335169839859008, |
| -1.9587260699272155, |
| ], |
| 'q99': [ |
| 0.9569852340221403, |
| 1.4148830294609054, |
| 0.7693877756595566, |
| -0.4545914208889008, |
| 1.5623322343826267, |
| 3.475611729621887, |
| 2.263479118347167, |
| ], |
| 'std': [ |
| 0.31695080251469465, |
| 0.49522214687158767, |
| 0.27993538230553827, |
| 0.478161574676113, |
| 0.4969961591445458, |
| 0.45101008525403846, |
| 0.7287264344068457, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.8575563430786133, 0.799155592918396, 1.0043904781341553], |
| 'mean': [0.5283099395864883, 0.005363794653877434, 0.3120132207021294], |
| 'min': [-0.15604186058044434, -0.827903687953949, -0.2347021996974945], |
| 'q01': [0.26669957995414734, -0.43774398624897004, -0.048167889714241026], |
| 'q99': [0.7774086785316463, 0.428325751423835, 0.776091011762619], |
| 'std': [0.1148424841779685, 0.17489566608140428, 0.16541062032731538], |
| }, |
| }, |
| 'fmb': { |
| 'euler': { |
| 'max': [3.1415927410125732, 1.044223666191101, 2.6313371658325195], |
| 'mean': [-0.5128238201141357, -0.0326850451529026, 1.2892225980758667], |
| 'min': [-3.141592502593994, -1.0594873428344727, -1.5775980949401855], |
| 'q01': [-3.1404616832733154, -0.9319325089454651, -1.4586873054504395], |
| 'q99': [3.140399932861328, 0.8346238732337952, 2.3081648349761963], |
| 'std': [3.0557258129119873, 0.33158522844314575, 0.6522650122642517], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [ |
| 1.633209228515625, |
| 1.216904640197754, |
| 0.6477310061454773, |
| -1.101138949394226, |
| 1.4866814613342285, |
| 3.2880685329437256, |
| 2.8213212490081787, |
| ], |
| 'mean': [ |
| 0.08979735523462296, |
| 0.24736542999744415, |
| -0.18356309831142426, |
| -2.1948094367980957, |
| 0.09010367840528488, |
| 2.3375964164733887, |
| -0.6424615979194641, |
| ], |
| 'min': [ |
| -0.24819369614124298, |
| -0.6836066246032715, |
| -1.852062463760376, |
| -2.950329542160034, |
| -1.360267162322998, |
| 1.1119775772094727, |
| -2.53395938873291, |
| ], |
| 'q01': [ |
| -0.12363661825656891, |
| -0.2893752157688141, |
| -0.8288514018058777, |
| -2.7587249279022217, |
| -0.9787064790725708, |
| 1.5508010387420654, |
| -1.7525910139083862, |
| ], |
| 'q99': [ |
| 0.527353048324585, |
| 0.8305937647819519, |
| 0.3808687925338745, |
| -1.5172499418258667, |
| 1.1656274795532227, |
| 2.9421844482421875, |
| 2.3595635890960693, |
| ], |
| 'std': [ |
| 0.14010190963745117, |
| 0.24683699011802673, |
| 0.29916438460350037, |
| 0.29822880029678345, |
| 0.40809085965156555, |
| 0.3218824863433838, |
| 0.7698287963867188, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.7458599805831909, 0.2993985116481781, 0.39971601963043213], |
| 'mean': [0.5149032473564148, -0.04377440735697746, 0.17520515620708466], |
| 'min': [0.32314637303352356, -0.34641459584236145, 0.029209019616246223], |
| 'q01': [0.3655048608779907, -0.28729698061943054, 0.033201027661561966], |
| 'q99': [0.6782684326171875, 0.209969624876976, 0.3331448435783386], |
| 'std': [0.07067117840051651, 0.13670970499515533, 0.07637079805135727], |
| }, |
| }, |
| 'fractal20220817_data': { |
| 'euler': { |
| 'max': [3.141592264175415, 1.5703786611557007, 3.1415889263153076], |
| 'mean': [-1.5046496391296387, 0.6226330399513245, -1.621827244758606], |
| 'min': [-3.1415927410125732, -1.569378137588501, -3.1415863037109375], |
| 'q01': [-3.130796194076538, -0.23499104380607605, -2.9648191928863525], |
| 'q99': [3.1300582885742188, 1.4907126426696777, 2.8373801708221436], |
| 'std': [2.380099296569824, 0.41548046469688416, 0.803632915019989], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [1.0534898042678833, 0.48018959164619446, 1.6896663904190063], |
| 'mean': [0.5594336986541748, -0.08356647938489914, 0.7760705947875977], |
| 'min': [-0.3139062225818634, -0.9970501065254211, -0.006579156965017319], |
| 'q01': [0.3249714970588684, -0.2818704843521118, 0.1410011649131775], |
| 'q99': [0.8754204511642456, 0.21279653906822205, 1.071526288986206], |
| 'std': [0.12431981414556503, 0.11550336331129074, 0.24583852291107178], |
| }, |
| }, |
| 'furniture_bench_dataset': { |
| 'euler': { |
| 'max': [3.141592502593994, 1.545873999595642, 3.1415679454803467], |
| 'mean': [-1.5673161745071411, 0.028489787131547928, -0.0338752306997776], |
| 'min': [-3.1415927410125732, -1.1365290880203247, -3.141582727432251], |
| 'q01': [-3.139331817626953, -0.6121169328689575, -1.9944958686828613], |
| 'q99': [3.1392478942871094, 0.9993562698364258, 1.7793478965759277], |
| 'std': [2.605621337890625, 0.30199435353279114, 0.924261212348938], |
| }, |
| 'gripper': { |
| 'max': [0.07995442301034927], |
| 'min': [-3.4803331800503656e-05], |
| 'q01': [0.003531553316861391], |
| 'q99': [0.07973115146160126], |
| }, |
| 'joints': { |
| 'max': [ |
| 1.6170321702957153, |
| 1.4905058145523071, |
| 0.8054860234260559, |
| -0.8407800197601318, |
| 2.596961736679077, |
| 3.591698169708252, |
| 2.744664192199707, |
| ], |
| 'mean': [ |
| 0.14828824996948242, |
| 0.513389527797699, |
| -0.07023008167743683, |
| -2.116460084915161, |
| 0.17382267117500305, |
| 2.6314659118652344, |
| 0.7604196667671204, |
| ], |
| 'min': [ |
| -1.0936335325241089, |
| -0.28272193670272827, |
| -1.3713332414627075, |
| -2.9083454608917236, |
| -2.664609432220459, |
| 0.49353715777397156, |
| -2.6659374237060547, |
| ], |
| 'q01': [ |
| -0.3774302303791046, |
| 0.16254600882530212, |
| -0.526292085647583, |
| -2.662245512008667, |
| -0.3341670036315918, |
| 1.4710004329681396, |
| -0.997846245765686, |
| ], |
| 'q99': [ |
| 0.6179460287094116, |
| 0.9334877729415894, |
| 0.3241332471370697, |
| -1.5048880577087402, |
| 0.69161057472229, |
| 3.45759916305542, |
| 2.5480763912200928, |
| ], |
| 'std': [ |
| 0.20775794982910156, |
| 0.14619596302509308, |
| 0.1992885023355484, |
| 0.24350883066654205, |
| 0.2203112691640854, |
| 0.3669370114803314, |
| 0.8288815021514893, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.75081467628479, 0.29695066809654236, 0.35806331038475037], |
| 'mean': [0.5622280836105347, 0.049872349947690964, 0.08108603954315186], |
| 'min': [0.2878606617450714, -0.3141690492630005, -0.00460465531796217], |
| 'q01': [0.36915361881256104, -0.180975541472435, 0.0058300793170928955], |
| 'q99': [0.6652880311012268, 0.1772783100605011, 0.18316447734832764], |
| 'std': [0.07692465931177139, 0.07895787060260773, 0.04069080576300621], |
| }, |
| }, |
| 'iamlab_cmu_pickup_insert': { |
| 'euler': { |
| 'max': [3.14159250270088, 0.04584214946783205, 0.0450923308550184], |
| 'mean': [-1.4122567194164972, -0.03209339030753262, 0.006593786875632054], |
| 'min': [-3.141592327000409, -0.13647014666477264, -0.052753928700043584], |
| 'q01': [-3.141100653025272, -0.1142266801993486, -0.026375220367685838], |
| 'q99': [3.1411030904244917, 0.02986631061111326, 0.033458487885352856], |
| 'std': [2.771965176997775, 0.02781015988983883, 0.015259428603879655], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [ |
| 0.24314826726913452, |
| 0.7470589280128479, |
| 0.7795426249504089, |
| -1.6825058460235596, |
| 0.13216856122016907, |
| 2.88216233253479, |
| 1.443279504776001, |
| ], |
| 'mean': [ |
| -0.09283880889415741, |
| 0.16589058935642242, |
| 0.13815303146839142, |
| -2.236156940460205, |
| -0.013342339545488358, |
| 2.4330623149871826, |
| 0.8266588449478149, |
| ], |
| 'min': [ |
| -0.6439031958580017, |
| -0.7895585894584656, |
| -0.017819713801145554, |
| -2.9597983360290527, |
| -0.5142408013343811, |
| 1.82623291015625, |
| 0.24880434572696686, |
| ], |
| 'q01': [ |
| -0.4449005424976349, |
| -0.687881588935852, |
| -0.00037891813553869724, |
| -2.9046988487243652, |
| -0.3344403803348541, |
| 1.9697062969207764, |
| 0.4578770101070404, |
| ], |
| 'q99': [ |
| 0.19939860701560974, |
| 0.6767980456352234, |
| 0.5450605750083923, |
| -1.7509592771530151, |
| 0.07327680289745331, |
| 2.758103370666504, |
| 1.2580491304397583, |
| ], |
| 'std': [ |
| 0.1536070704460144, |
| 0.3142663240432739, |
| 0.1251399964094162, |
| 0.28150704503059387, |
| 0.062141284346580505, |
| 0.17315669357776642, |
| 0.19375956058502197, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.6634980998075433, 0.23428471360823594, 0.4308285234773945], |
| 'mean': [0.5266367430643635, 0.02849007901898653, 0.18708021993948834], |
| 'min': [0.3071657210198696, -0.2975911497764855, 0.06578226212031718], |
| 'q01': [0.314498567139741, -0.20315787858517198, 0.06785127291435837], |
| 'q99': [0.6472027612545221, 0.20840713845170097, 0.37003410655170416], |
| 'std': [0.08150424006275543, 0.11136019021346877, 0.0777212838203498], |
| }, |
| }, |
| 'jaco_play': { |
| 'euler': { |
| 'max': [3.141592653589793, 0.0, 0.0], |
| 'mean': [3.14159265358649, 0.0, 0.0], |
| 'min': [3.141592653589793, 0.0, 0.0], |
| 'q01': [3.141592653589793, 0.0, 0.0], |
| 'q99': [3.141592653589793, 0.0, 0.0], |
| 'std': [3.3031355428647657e-12, 0.0, 0.0], |
| }, |
| 'gripper': { |
| 'max': [1.3890767097473145], |
| 'min': [-0.00123199715744704], |
| 'q01': [0.00061599857872352], |
| 'q99': [1.3835327625274658], |
| }, |
| 'joints': { |
| 'max': [ |
| -0.6973918676376343, |
| 4.516506671905518, |
| 2.732408285140991, |
| 0.27657514810562134, |
| 2.0405426025390625, |
| 0.9062198996543884, |
| ], |
| 'mean': [ |
| -1.5362969636917114, |
| 3.9368598461151123, |
| 1.470100998878479, |
| -0.26214125752449036, |
| 0.8781110644340515, |
| -0.4923328757286072, |
| ], |
| 'min': [ |
| -2.394073486328125, |
| 3.4039254188537598, |
| 0.5968497395515442, |
| -1.0070130825042725, |
| -0.28005343675613403, |
| -1.3207207918167114, |
| ], |
| 'q01': [ |
| -2.2932708263397217, |
| 3.5327682495117188, |
| 0.7838058471679688, |
| -0.5554518699645996, |
| 0.002689492190256715, |
| -1.1670725345611572, |
| ], |
| 'q99': [ |
| -0.8449038863182068, |
| 4.363891124725342, |
| 2.367223024368286, |
| 0.06634082645177841, |
| 1.595384120941162, |
| 0.17280587553977966, |
| ], |
| 'std': [ |
| 0.32358425855636597, |
| 0.18587827682495117, |
| 0.3632128834724426, |
| 0.11627942323684692, |
| 0.3359401226043701, |
| 0.3179064989089966, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.2252020239830017, -0.19358234107494354, 0.4066188633441925], |
| 'mean': [-0.07287009060382843, -0.42861491441726685, 0.2764993906021118], |
| 'min': [-0.4429473876953125, -0.6635459661483765, 0.1568669229745865], |
| 'q01': [-0.3789186179637909, -0.6194459795951843, 0.16865813732147217], |
| 'q99': [0.21203258633613586, -0.26914602518081665, 0.38958534598350525], |
| 'std': [0.138992577791214, 0.0908467248082161, 0.052496980875730515], |
| }, |
| }, |
| 'kuka': { |
| 'euler': { |
| 'max': [3.141592651594966, 0.25521246168753553, 3.1415812671947134], |
| 'mean': [-0.8123818266038756, -0.004947911572775501, 0.1622765703197679], |
| 'min': [-3.141592653256577, -0.22240290857095402, -3.1415339219506544], |
| 'q01': [-3.141539356675989, -0.042093398087956785, -0.6103775416503281], |
| 'q99': [3.1415363480584935, 0.013040864331848548, 1.4661773197655914], |
| 'std': [3.031619905792479, 0.010541045603285457, 0.386240278316171], |
| }, |
| 'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]}, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.7243871092796326, 0.31309840083122253, 0.613274335861206], |
| 'mean': [0.5510859489440918, 0.04787421599030495, 0.12812286615371704], |
| 'min': [0.40573424100875854, -0.2028520256280899, 0.018512273207306862], |
| 'q01': [0.4765772819519043, -0.14815208315849304, 0.06674224138259888], |
| 'q99': [0.6515637040138245, 0.2447487711906433, 0.28018367290496826], |
| 'std': [0.045206550508737564, 0.10222796350717545, 0.058390580117702484], |
| }, |
| }, |
| 'language_table': { |
| 'euler': { |
| 'max': [3.141592653589793, 0.0, 0.0], |
| 'mean': [3.141592653805758, 0.0, 0.0], |
| 'min': [3.141592653589793, 0.0, 0.0], |
| 'q01': [3.141592653589793, 0.0, 0.0], |
| 'q99': [3.141592653589793, 0.0, 0.0], |
| 'std': [2.1596502364831912e-10, 0.0, 0.0], |
| }, |
| 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, |
| 'joints': { |
| 'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| 'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
| }, |
| 'translation': { |
| 'max': [0.6191085577011108, 0.345907062292099, 0.0], |
| 'mean': [0.3994608521461487, 0.0045331921428442, 0.0], |
| 'min': [0.18907572329044342, -0.3051564395427704, 0.0], |
| 'q01': [0.19237099587917328, -0.2962527573108673, 0.0], |
| 'q99': [0.6171894669532776, 0.30645298957824707, 0.0], |
| 'std': [0.10587888211011887, 0.14165113866329193, 0.0], |
| }, |
| }, |
| 'nyu_franka_play_dataset': { |
| 'euler': { |
| 'max': [3.141542687188373, 1.570464569214435, 3.1412348640834002], |
| 'mean': [1.2730823611033315, 0.6681926368695384, 1.141188695183881], |
| 'min': [-3.141417902437913, -0.6159773949821361, -3.1401807914238264], |
| 'q01': [-3.1053165771573177, -0.18849691525500664, -2.8903965648285523], |
| 'q99': [3.1156105211359257, 1.5443502891290082, 2.882089687976354], |
| 'std': [1.519365023922235, 0.5398198304496786, 1.099335476495627], |
| }, |
| 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, |
| 'joints': { |
| 'max': [ |
| 0.873094916343689, |
| 1.244848370552063, |
| 2.6604084968566895, |
| -0.46481436491012573, |
| 1.409853458404541, |
| 3.5058159828186035, |
| 2.2937142848968506, |
| ], |
| 'mean': [ |
| -0.7650318741798401, |
| -0.9744310975074768, |
| 1.4849265813827515, |
| -2.0917000770568848, |
| -0.42907631397247314, |
| 2.4454119205474854, |
| -0.04509185999631882, |
| ], |
| 'min': [ |
| -2.1353442668914795, |
| -1.6306616067886353, |
| -0.0167933851480484, |
| -2.8458828926086426, |
| -2.5256054401397705, |
| 0.5101203322410583, |
| -2.1178665161132812, |
| ], |
| 'q01': [ |
| -1.9259787797927856, |
| -1.537291407585144, |
| 0.4664345383644104, |
| -2.7214717864990234, |
| -2.114818811416626, |
| 1.0703176259994507, |
| -1.8231183290481567, |
| ], |
| 'q99': [ |
| 0.19280724227428436, |
| 0.7752904295921326, |
| 2.6163666248321533, |
| -0.8622008562088013, |
| 0.7001188397407532, |
| 3.4607036113739014, |
| 1.7281100749969482, |
| ], |
| 'std': [ |
| 0.4459172785282135, |
| 0.5925542712211609, |
| 0.45783039927482605, |
| 0.4001038372516632, |
| 0.6094018816947937, |
| 0.5889745950698853, |
| 0.7601358294487, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.6258905727651867, 0.8289460146095029, 0.8867425967452327], |
| 'mean': [0.38785427731316113, 0.39551342608233353, 0.45161485937851], |
| 'min': [0.05690113252698781, -0.06171031940925212, 0.1157533700659999], |
| 'q01': [0.1393695951374117, 0.07645522416447147, 0.19364508601310984], |
| 'q99': [0.5920727118835839, 0.6584802280677562, 0.8056891771281188], |
| 'std': [0.10666290139662009, 0.1217948547812865, 0.1707295181003299], |
| }, |
| }, |
| 'roboset': { |
| 'euler': { |
| 'max': [3.1415449294818236, 1.5705575529715636, 3.141527342124582], |
| 'mean': [-0.0398455755412464, 1.0518070390619125, -0.015345692503002759], |
| 'min': [-3.1415813300509536, -1.5222832468962035, -3.141575300866071], |
| 'q01': [-2.9414386317311187, -0.24976770655101155, -2.985256521212579], |
| 'q99': [2.9380437893235993, 1.5403010739503078, 2.9746912523985025], |
| 'std': [1.7866587696177456, 0.40620530263065, 1.7288511340250616], |
| }, |
| 'gripper': { |
| 'max': [0.83056640625], |
| 'min': [0.0001499652862548828], |
| 'q01': [0.0001499652862548828], |
| 'q99': [0.82666015625], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.96240234375, |
| 1.1162109375, |
| 1.1064453125, |
| -0.98095703125, |
| 2.30859375, |
| 1.576171875, |
| 1.7412109375, |
| ], |
| 'mean': [ |
| 0.005913593806326389, |
| 0.1877261847257614, |
| 0.04653879255056381, |
| -2.0529513359069824, |
| -0.011298442259430885, |
| 0.6185526251792908, |
| -0.01701134257018566, |
| ], |
| 'min': [ |
| -0.8330078125, |
| -0.74658203125, |
| -0.8642578125, |
| -2.892578125, |
| -1.390625, |
| -0.24658203125, |
| -2.953125, |
| ], |
| 'q01': [ |
| -0.41015625, |
| -0.5302734375, |
| -0.6455078125, |
| -2.57421875, |
| -0.76416015625, |
| -0.0386962890625, |
| -1.435546875, |
| ], |
| 'q99': [ |
| 0.66455078125, |
| 0.9501953125, |
| 0.7529296875, |
| -1.251953125, |
| 0.75244140625, |
| 1.2314453125, |
| 1.384765625, |
| ], |
| 'std': [ |
| 0.17915399372577667, |
| 0.32234326004981995, |
| 0.26069700717926025, |
| 0.31767210364341736, |
| 0.205329030752182, |
| 0.33385637402534485, |
| 0.6263682842254639, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.5747738480567932, 0.3972920775413513, 0.7443570494651794], |
| 'mean': [0.3331542909145355, 0.019357483834028244, 0.37330344319343567], |
| 'min': [0.09978063404560089, -0.29593944549560547, 0.10065606236457825], |
| 'q01': [0.18437016010284424, -0.25699371099472046, 0.15134164690971375], |
| 'q99': [0.543661892414093, 0.29646238684654236, 0.6682320833206177], |
| 'std': [0.07849054038524628, 0.12241040915250778, 0.1460595279932022], |
| }, |
| }, |
| 'roboturk': { |
| 'euler': { |
| 'max': [3.1415925701602854, 1.5661525056538665, 3.1412803735012553], |
| 'mean': [0.3946147188969296, -0.09657834247810235, -0.1840393277985236], |
| 'min': [-3.141592645423486, -1.568282806779776, -3.141449561492499], |
| 'q01': [-3.1378020570699525, -0.7907563562327732, -1.7347170410441308], |
| 'q99': [3.137841563084971, 0.5773179844053413, 1.455226678618167], |
| 'std': [2.9453717386817755, 0.2466574231368262, 0.6146181899264712], |
| }, |
| 'gripper': { |
| 'max': [0.041667], |
| 'min': [-1.6931189781743683e-05], |
| 'q01': [0.0], |
| 'q99': [0.040625325000000004], |
| }, |
| 'joints': { |
| 'max': [ |
| 3.0592832565307617, |
| 2.9869844913482666, |
| 0.33399999141693115, |
| 2.0810000896453857, |
| 3.874000072479248, |
| 0.21199999749660492, |
| 32.02399826049805, |
| ], |
| 'mean': [ |
| 0.06695421040058136, |
| 0.4899406135082245, |
| -0.0009486731723882258, |
| -0.00033266679383814335, |
| -0.000789206416811794, |
| 0.017417440190911293, |
| -4.154728412628174, |
| ], |
| 'min': [ |
| -2.995668888092041, |
| -2.9882216453552246, |
| -1.3760000467300415, |
| -1.9279999732971191, |
| -4.10099983215332, |
| -0.5040000081062317, |
| -23.58799934387207, |
| ], |
| 'q01': [ |
| -1.1723066568374634, |
| -1.1159032583236694, |
| -0.00800000037997961, |
| -0.3580000102519989, |
| -0.5680000185966492, |
| -0.10000000149011612, |
| -14.303999900817871, |
| ], |
| 'q99': [ |
| 1.1788837909698486, |
| 1.7529590129852295, |
| 0.007000000216066837, |
| 0.3799999952316284, |
| 0.628000020980835, |
| 0.13199999928474426, |
| 7.5279998779296875, |
| ], |
| 'std': [ |
| 0.46230366826057434, |
| 0.5157809257507324, |
| 0.0024148367810994387, |
| 0.11895754188299179, |
| 0.18901629745960236, |
| 0.050090208649635315, |
| 4.478224754333496, |
| ], |
| }, |
| 'translation': { |
| 'max': [1.0403562763478107, 0.8124313436278997, 0.9670890184966487], |
| 'mean': [0.601758638436076, -0.0011977902645611068, 0.061136336184971524], |
| 'min': [-0.375370271681523, -0.9602276639048171, -0.39440951528330676], |
| 'q01': [0.2845426588179575, -0.3288349528691964, -0.0934955147249334], |
| 'q99': [0.8773894592247552, 0.28575230587640144, 0.3286392635131121], |
| 'std': [0.12719404213832913, 0.1476650170975746, 0.09248325352628932], |
| }, |
| }, |
| 'stanford_hydra_dataset': { |
| 'euler': { |
| 'max': [3.141592357591889, 1.5364991593031068, 3.141222461465462], |
| 'mean': [-0.2789889021491306, -0.19247585545209975, 0.34538177027867784], |
| 'min': [-3.1415919781981962, -1.5573291068141568, -3.1413972494043216], |
| 'q01': [-3.138796116403405, -1.107955818954808, -2.271278880707407], |
| 'q99': [3.1386728135975384, 0.915868569686168, 2.4011245578416336], |
| 'std': [2.883189482433683, 0.43362190146795254, 1.199450318648319], |
| }, |
| 'gripper': { |
| 'max': [0.0811397060751915], |
| 'min': [-0.0005391233135014772], |
| 'q01': [1.3133333141013281e-06], |
| 'q99': [0.08100640028715134], |
| }, |
| 'joints': { |
| 'max': [ |
| 1.9801124334335327, |
| 1.5752310752868652, |
| 2.6661651134490967, |
| -0.6274839639663696, |
| 2.3508076667785645, |
| 3.4612932205200195, |
| 2.747223377227783, |
| ], |
| 'mean': [ |
| -0.052988938987255096, |
| -0.19803331792354584, |
| -0.046565085649490356, |
| -2.395282506942749, |
| 0.4740530848503113, |
| 2.119305372238159, |
| 0.031005585566163063, |
| ], |
| 'min': [ |
| -2.5843772888183594, |
| -1.287919044494629, |
| -1.9306836128234863, |
| -2.9673542976379395, |
| -1.9528191089630127, |
| 0.5680276155471802, |
| -2.74257755279541, |
| ], |
| 'q01': [ |
| -1.0905369520187378, |
| -0.9110571146011353, |
| -0.9165626764297485, |
| -2.876581907272339, |
| -0.9029546976089478, |
| 1.3720064163208008, |
| -2.0172016620635986, |
| ], |
| 'q99': [ |
| 0.6443323493003845, |
| 1.0052685737609863, |
| 0.9653106927871704, |
| -1.4030946493148804, |
| 1.738037109375, |
| 2.943267822265625, |
| 2.451101541519165, |
| ], |
| 'std': [ |
| 0.3502423167228699, |
| 0.41914135217666626, |
| 0.39895766973495483, |
| 0.31860536336898804, |
| 0.5207170248031616, |
| 0.35328012704849243, |
| 1.2519474029541016, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.7598075952006731, 0.36410959179102775, 0.597508369211053], |
| 'mean': [0.4348564561896346, 0.024563536690147474, 0.2938664975066919], |
| 'min': [0.18414022283711137, -0.32311685510034516, 0.017172937739480726], |
| 'q01': [0.2373728557200433, -0.26521679456931635, 0.09069013461938308], |
| 'q99': [0.712423814640525, 0.25299056815993204, 0.49505407602925094], |
| 'std': [0.10465658310507645, 0.12745896408828714, 0.10379447506049931], |
| }, |
| }, |
| 'taco_play': { |
| 'euler': { |
| 'max': [3.1415891647338867, 0.30717557668685913, 2.4624788761138916], |
| 'mean': [-0.06622616201639175, -0.17697572708129883, 0.4614097774028778], |
| 'min': [-3.1415886878967285, -0.8683895468711853, -1.789320707321167], |
| 'q01': [-3.139000654220581, -0.6962019205093384, -1.1729414463043213], |
| 'q99': [3.1392500400543213, 0.12436603754758835, 1.8065968751907349], |
| 'std': [3.0384926795959473, 0.18692463636398315, 0.6806972622871399], |
| }, |
| 'gripper': { |
| 'max': [0.08074767142534256], |
| 'min': [9.652999870013446e-05], |
| 'q01': [0.00015234666352625936], |
| 'q99': [0.08067675679922104], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.9704633951187134, |
| 0.4390600025653839, |
| 0.9892600178718567, |
| -1.116769790649414, |
| 1.2800638675689697, |
| 2.894169330596924, |
| 2.627519130706787, |
| ], |
| 'mean': [ |
| 0.14954668283462524, |
| -0.3514331877231598, |
| 0.1295831948518753, |
| -2.205498456954956, |
| 0.11473798751831055, |
| 2.0389058589935303, |
| 0.5516068339347839, |
| ], |
| 'min': [ |
| -0.6495265960693359, |
| -1.2877551317214966, |
| -1.5125423669815063, |
| -2.956650733947754, |
| -1.0317779779434204, |
| 1.1109415292739868, |
| -1.8966686725616455, |
| ], |
| 'q01': [ |
| -0.4596467614173889, |
| -0.9323632121086121, |
| -0.6153853535652161, |
| -2.8417088985443115, |
| -0.3472142219543457, |
| 1.428541898727417, |
| -0.5756277441978455, |
| ], |
| 'q99': [ |
| 0.7262046933174133, |
| 0.21787810325622559, |
| 0.7387099266052246, |
| -1.3517359495162964, |
| 0.7581852674484253, |
| 2.564530849456787, |
| 1.6355704069137573, |
| ], |
| 'std': [ |
| 0.281919926404953, |
| 0.2702403962612152, |
| 0.3334408402442932, |
| 0.412654846906662, |
| 0.21815766394138336, |
| 0.2808808982372284, |
| 0.4732993543148041, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.7124971151351929, 0.6118948459625244, 0.617118775844574], |
| 'mean': [0.37278515100479126, 0.13003520667552948, 0.38341864943504333], |
| 'min': [0.07693906873464584, -0.4944135844707489, 0.20030911266803741], |
| 'q01': [0.1368357390165329, -0.4297449290752411, 0.20516259968280792], |
| 'q99': [0.6700438857078552, 0.5943909883499146, 0.5966404676437378], |
| 'std': [0.11509375274181366, 0.2694733738899231, 0.11266136914491653], |
| }, |
| }, |
| 'toto': { |
| 'euler': { |
| 'max': [3.1414193360696108, 1.570402878730552, 3.1415549880169924], |
| 'mean': [-1.31124856158284, 0.8107744638901545, 1.042033586691101], |
| 'min': [-3.1415814091974665, -1.3210638993411208, -3.141515215393105], |
| 'q01': [-2.9459294143675696, -0.6449196412662773, -2.9790265961739872], |
| 'q99': [2.824635320375007, 1.4897934376557762, 3.0083314049108925], |
| 'std': [1.2205361480020216, 0.45724967725161403, 1.3644152538580079], |
| }, |
| 'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]}, |
| 'joints': { |
| 'max': [ |
| 1.4926575422286987, |
| 0.7715681195259094, |
| 1.9506583213806152, |
| -1.6330622434616089, |
| 2.804656744003296, |
| 2.0850095748901367, |
| 1.9661481380462646, |
| ], |
| 'mean': [ |
| -0.040506213903427124, |
| -0.2537725269794464, |
| 0.05389903113245964, |
| -2.5135483741760254, |
| -0.161987766623497, |
| 0.8815387487411499, |
| 0.0034975146409124136, |
| ], |
| 'min': [ |
| -1.8345723152160645, |
| -1.7133995294570923, |
| -1.40083646774292, |
| -2.985748767852783, |
| -2.814244270324707, |
| -0.5165338516235352, |
| -3.5424435138702393, |
| ], |
| 'q01': [ |
| -1.4044159650802612, |
| -1.3978556394577026, |
| -0.8029389381408691, |
| -2.97342586517334, |
| -1.5402026176452637, |
| 0.0011001524981111288, |
| -2.9709367752075195, |
| ], |
| 'q99': [ |
| 0.7075999975204468, |
| 0.42536962032318115, |
| 1.3345959186553955, |
| -1.8281668424606323, |
| 2.5240790843963623, |
| 2.0835044384002686, |
| 1.212519884109497, |
| ], |
| 'std': [ |
| 0.3405684530735016, |
| 0.3766334354877472, |
| 0.40128207206726074, |
| 0.3202400505542755, |
| 0.6421889066696167, |
| 0.33138400316238403, |
| 0.5801302194595337, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.7777318005382999, 0.5688841397096768, 0.672235403102292], |
| 'mean': [0.15078069344893, -0.03012185632654897, 0.35519823701585035], |
| 'min': [-0.18532184496861265, -0.6282599632853518, 0.06312318086547261], |
| 'q01': [-0.09177927811484686, -0.3571658956454935, 0.21965464356283362], |
| 'q99': [0.6757593680120374, 0.2889021640318381, 0.501109413161318], |
| 'std': [0.14726563054735392, 0.13472763062259546, 0.059123705378892395], |
| }, |
| }, |
| 'ucsd_kitchen_dataset': { |
| 'euler': { |
| 'max': [3.141592644655991, 0.10471752134072587, 2.0420500160887434], |
| 'mean': [-2.349219123990161, -0.5844573008678524, 0.8632842703198627], |
| 'min': [-3.141592637838787, -1.5533486027935484, -1.676143206926509], |
| 'q01': [-3.1415921825947537, -1.535883315130894, -0.6806766312742327], |
| 'q99': [3.1328677560237566, 0.01745363977705911, 1.8326023938995777], |
| 'std': [1.756197370631558, 0.482005280180061, 0.7151569996653978], |
| }, |
| 'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]}, |
| 'joints': { |
| 'max': [ |
| 2.4945411682128906, |
| 1.0603798627853394, |
| 0.5518655776977539, |
| 2.71895432472229, |
| 1.9269907474517822, |
| 2.2231369018554688, |
| 0.7844778895378113, |
| ], |
| 'mean': [ |
| 0.35797572135925293, |
| -0.23473882675170898, |
| -0.34185951948165894, |
| 0.8982798457145691, |
| 0.5780586004257202, |
| 1.0351765155792236, |
| -1.2190868854522705, |
| ], |
| 'min': [ |
| -0.6943671703338623, |
| -1.6351218223571777, |
| -2.0142624378204346, |
| 0.00768359424546361, |
| -0.17713193595409393, |
| 0.011292487382888794, |
| -2.6592507362365723, |
| ], |
| 'q01': [ |
| -0.43497809767723083, |
| -1.4843493700027466, |
| -1.7824348211288452, |
| 0.0955040231347084, |
| -0.07186625152826309, |
| 0.27948257327079773, |
| -2.4006130695343018, |
| ], |
| 'q99': [ |
| 2.25881028175354, |
| 0.8190488815307617, |
| 0.2994879484176636, |
| 2.153681993484497, |
| 1.6126199960708618, |
| 1.8523634672164917, |
| 0.06944511085748672, |
| ], |
| 'std': [ |
| 0.6131777167320251, |
| 0.6322475075721741, |
| 0.5427954792976379, |
| 0.539089024066925, |
| 0.43294647336006165, |
| 0.3588581085205078, |
| 0.6916991472244263, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.681192194250889, 0.2323359966361606, 0.6159547986373596], |
| 'mean': [0.4067338752307897, 0.027551073758015975, 0.313373054289359], |
| 'min': [0.13453314224525192, -0.2553313745227843, 0.03515888153950328], |
| 'q01': [0.18739915591793913, -0.18234310016370514, 0.048970692427190994], |
| 'q99': [0.6410437631246786, 0.20632224240340083, 0.5983893391276117], |
| 'std': [0.12446715882142037, 0.08252308025761342, 0.11421020385340024], |
| }, |
| }, |
| 'utaustin_mutex': { |
| 'euler': { |
| 'max': [3.1415927410125732, 0.5760129690170288, 0.6430304050445557], |
| 'mean': [1.643947720527649, 0.03388536348938942, -0.31948330998420715], |
| 'min': [-3.141592025756836, -0.9390894770622253, -1.8157784938812256], |
| 'q01': [-3.1398682594299316, -0.7223533391952515, -1.5468647480010986], |
| 'q99': [3.140272855758667, 0.3558599054813385, 0.3583984971046448], |
| 'std': [2.2715108394622803, 0.19258549809455872, 0.5275800824165344], |
| }, |
| 'gripper': { |
| 'max': [0.07569462060928345], |
| 'min': [0.00026726332725957036], |
| 'q01': [0.0019378233700990677], |
| 'q99': [0.07565194368362427], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.5831676125526428, |
| 0.8024097681045532, |
| 1.350082516670227, |
| -1.1145747900009155, |
| 0.5345654487609863, |
| 3.2108523845672607, |
| 2.875584363937378, |
| ], |
| 'mean': [ |
| -0.09772995859384537, |
| 0.05997378006577492, |
| -0.009316973388195038, |
| -2.3428714275360107, |
| -0.48651641607284546, |
| 2.2390215396881104, |
| 1.3341773748397827, |
| ], |
| 'min': [ |
| -0.6990927457809448, |
| -0.8311276435852051, |
| -0.5323343276977539, |
| -2.9861886501312256, |
| -2.060809850692749, |
| 0.826747477054596, |
| 0.05972049757838249, |
| ], |
| 'q01': [ |
| -0.532222330570221, |
| -0.5078368782997131, |
| -0.3437865376472473, |
| -2.784912586212158, |
| -1.8685580492019653, |
| 1.3974121809005737, |
| 0.5157660841941833, |
| ], |
| 'q99': [ |
| 0.35958197712898254, |
| 0.6339818835258484, |
| 0.5833610892295837, |
| -1.612648367881775, |
| 0.2780923545360565, |
| 2.8907132148742676, |
| 2.837320327758789, |
| ], |
| 'std': [ |
| 0.2101736217737198, |
| 0.2880096137523651, |
| 0.17824982106685638, |
| 0.23544654250144958, |
| 0.6387099027633667, |
| 0.33533376455307007, |
| 0.5090957880020142, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.581696093082428, 0.4334338903427124, 0.6800903081893921], |
| 'mean': [0.4331520199775696, -0.10096638649702072, 0.22044037282466888], |
| 'min': [0.2454746514558792, -0.502901017665863, 0.008792983368039131], |
| 'q01': [0.3217194080352783, -0.4733337163925171, 0.014122226275503635], |
| 'q99': [0.5321439504623413, 0.3733823001384735, 0.5785381197929382], |
| 'std': [0.04380597919225693, 0.20559938251972198, 0.12614832818508148], |
| }, |
| }, |
| 'viola': { |
| 'euler': { |
| 'max': [3.141592502593994, 0.9555190801620483, 0.40456175804138184], |
| 'mean': [-0.7171992063522339, 0.07086259871721268, -0.67817223072052], |
| 'min': [-3.141592025756836, -0.4305531978607178, -2.212939739227295], |
| 'q01': [-3.1400232315063477, -0.19947049021720886, -1.8703945875167847], |
| 'q99': [3.1401824951171875, 0.7830920219421387, 0.19578109681606293], |
| 'std': [2.902841567993164, 0.20232577621936798, 0.7306717038154602], |
| }, |
| 'gripper': { |
| 'max': [0.07748929411172867], |
| 'min': [-5.3190000471659005e-05], |
| 'q01': [0.00020356666937004775], |
| 'q99': [0.07747221738100052], |
| }, |
| 'joints': { |
| 'max': [ |
| 0.3366159200668335, |
| 0.8983011841773987, |
| 0.330204576253891, |
| 0.0, |
| 1.3329579830169678, |
| 3.1103708744049072, |
| 2.8318748474121094, |
| ], |
| 'mean': [ |
| 0.027196571230888367, |
| 0.2290346622467041, |
| -0.06033482402563095, |
| -2.060992479324341, |
| 0.25734394788742065, |
| 2.2569527626037598, |
| 1.2675725221633911, |
| ], |
| 'min': [ |
| -0.4073199927806854, |
| -0.23913456499576569, |
| -0.5244941115379333, |
| -2.744713068008423, |
| -0.48249971866607666, |
| 0.0, |
| -0.22237232327461243, |
| ], |
| 'q01': [ |
| -0.26899582147598267, |
| -0.189721941947937, |
| -0.3926161825656891, |
| -2.620305061340332, |
| -0.199931338429451, |
| 1.6743353605270386, |
| 0.1440846174955368, |
| ], |
| 'q99': [ |
| 0.23294022679328918, |
| 0.764011561870575, |
| 0.20191603899002075, |
| -1.5629768371582031, |
| 1.039290189743042, |
| 2.9412121772766113, |
| 2.798232078552246, |
| ], |
| 'std': [ |
| 0.10160040855407715, |
| 0.21415969729423523, |
| 0.14444759488105774, |
| 0.27077943086624146, |
| 0.3251941502094269, |
| 0.338652640581131, |
| 0.7529342174530029, |
| ], |
| }, |
| 'translation': { |
| 'max': [0.6868156790733337, 0.20999883115291595, 0.5037735104560852], |
| 'mean': [0.5464076995849609, 0.006265466101467609, 0.23136523365974426], |
| 'min': [0.0, -0.33860740065574646, 0.0], |
| 'q01': [0.40061360597610474, -0.25196850299835205, 0.010269512422382832], |
| 'q99': [0.6458418369293213, 0.17776551842689514, 0.4456312954425812], |
| 'std': [0.062116123735904694, 0.12932434678077698, 0.13205111026763916], |
| }, |
| }, |
| } |
|
|
| @cached_property |
| def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: |
| if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm): |
| return {} |
| with open(self.config.control_stats_path, 'r') as file: |
| stats = yaml.safe_load(file) |
| if self.config.delta_controls: |
| if self.control_io_config.future_controls_sequence_stride_sec is None: |
| horizon = 0.0 |
| else: |
| horizon = self.control_io_config.future_controls_sequence_stride_sec |
| elif self.control_io_config.future_controls_sequence_stride_sec is None: |
| if self.control_io_config.future_controls_sequence_length == 1: |
| horizon = 0.0 |
| else: |
| raise NotImplementedError() |
| else: |
| horizon = ( |
| self.control_io_config.future_controls_sequence_length |
| * self.control_io_config.future_controls_sequence_stride_sec |
| ) |
| key = f'horizon_{round(horizon, 2)}s' |
| if key in stats: |
| stats = stats[key] |
| else: |
| raise ValueError( |
| f'Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]' |
| ) |
| return stats |
|
|
| @cached_property |
| def dataset_names(self) -> List[str]: |
| if ( |
| is_global_norm(self.config.rotation_norm) |
| and is_global_norm(self.config.obs_rotation_norm) |
| and is_global_norm(self.config.translation_norm) |
| and is_global_norm(self.config.obs_translation_norm) |
| ): |
| return ['ANY'] |
| return list(set(self._control_stats.keys()) | set(self._observation_stats.keys())) |
|
|
|
|
| def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor: |
| """ |
| Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding |
| w.r.t. the 0-th element preceding the sequence |
| Ex: |
| Sequence of points: T1, T2, T3, T4 |
| `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame, |
| implicitly encoded in T0T1 |
| Output: T0T1, T0T2, T0T3, T0T4 |
| |
| Args: |
| translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S |
| corresponds to the sequence dimension |
| Returns: |
| torch.Tensor of the same shape as translation_sequence, containing delta translations |
| """ |
| assert translation_sequence.ndim >= 3, translation_sequence.shape |
| delta_translations = torch.cumsum(translation_sequence, dim=-2) |
| return delta_translations |
|
|
|
|
| class RegressionProcessor(VLAMProcessor[RegressionProcessorConfig]): |
| def policy_control_plan_from_model_target( |
| self, target: RoboticsTarget, dataset_name: np.ndarray |
| ) -> RoboticsControlPlan: |
| translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key='translation') |
| rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key='rotation') |
| rotmat = convert_rotation(rotation, RotationFormat.ROTMAT) |
| gripper_prob = target.gripper |
| if self.config.delta_controls: |
| translation_m = delta_to_relative_translations(translation_m) |
| rotmat = delta_to_relative_rotations(rotmat) |
| return RoboticsControlPlan( |
| translation_m=translation_m, |
| rotmat=rotmat, |
| gripper_prob=gripper_prob, |
| valid_mask=target.valid_mask, |
| ) |
|
|
| def policy_control_plan_from_model_output( |
| self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor |
| ) -> RoboticsControlPlan: |
| """Called during inference to create control plan from model output""" |
| translation_m = self.unnormalize( |
| model_output.translation, dataset_name=dataset_name, key='translation' |
| ) |
| rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key='rotation') |
| rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True) |
| gripper_prob = torch.sigmoid(model_output.gripper) |
| if self.config.delta_controls: |
| translation_m = delta_to_relative_translations(translation_m) |
| rotmat = delta_to_relative_rotations(rotmat) |
| return RoboticsControlPlan( |
| translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=valid_mask |
| ) |
|
|
|
|
| class VLARMProcessor(RegressionProcessor): |
| def __init__(self, config: VLARMProcessorConfig, vlm_processor: VLMProcessor): |
| Configurable.__init__(self, config) |
| self.vlm_processor = vlm_processor |
| self.control_tokenizer = EmptyTokenizer( |
| config=self.config.control_tokenizer_config, tokenizer=self.tokenizer |
| ) |
| self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { |
| 'obs_translation': self.obs_translation_norm_bounds, |
| 'obs_rotation': self.obs_rotation_norm_bounds, |
| 'translation': self.translation_norm_bounds, |
| 'rotation': self.rotation_norm_bounds, |
| 'joints': self.joints_norm_bounds, |
| } |
|
|
| @property |
| def dataset_np_names(self) -> np.ndarray: |
| return self._dataset_np_names |
|
|
| def _set_dataset_np_names(self, dataset_np_names: np.ndarray) -> None: |
| self._dataset_np_names = dataset_np_names |
|
|
| def preprocess_inputs( |
| self, |
| chat: List[str], |
| images: Dict[str, List[PIL.Image.Image]], |
| ee_pose_translation: np.ndarray, |
| ee_pose_rotation: np.ndarray, |
| gripper: np.ndarray, |
| joints: np.ndarray, |
| dataset_name: np.ndarray, |
| inference_mode: bool, |
| control_target: Optional[RoboticsTarget] = None, |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
| inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) |
| images: Dict[str, torch.Tensor] = inputs['images'] |
| input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length] |
| target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length] |
| attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) |
| ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) |
| ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) |
| ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) |
| gripper = preprocess_gripper_observation(gripper, dataset_name) |
| gripper = torch.tensor(gripper, dtype=torch.float32) |
| ee_pose_translation = self.normalize( |
| ee_pose_translation, dataset_name=dataset_name, key='obs_translation' |
| ) |
| ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation') |
| joints = torch.tensor(joints, dtype=torch.float32) |
| if joints.shape[-1] < 7: |
| missing_size = 7 - joints.shape[-1] |
| joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) |
| joints = self.normalize(joints, dataset_name=dataset_name, key='joints') |
| return { |
| 'images': images, |
| 'input_ids': input_ids, |
| 'target_text_tokens_ids': target_text_tokens_ids, |
| 'attn_mask': attn_mask, |
| 'ee_pose_translation': ee_pose_translation, |
| 'ee_pose_rotation': ee_pose_rotation, |
| 'gripper': gripper, |
| 'joints': joints, |
| 'control_tokens_ids': None, |
| 'target_control_tokens_ids': None, |
| } |
|
|
| def model_otoi_but_gt_gripper( |
| self, |
| translation_output: torch.Tensor, |
| rotation_output: torch.Tensor, |
| reference_translation: torch.Tensor, |
| reference_rotation: torch.Tensor, |
| gt_gripper: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Called during inference to create control plan from model output""" |
| translation_delta = self.unnormalize( |
| translation_output, dataset_name=self.dataset_np_names, key='translation' |
| ) |
| rotation_delta = self.unnormalize(rotation_output, dataset_name=self.dataset_np_names, key='rotation') |
| translation_reference = self.unnormalize( |
| reference_translation, dataset_name=self.dataset_np_names, key='obs_translation' |
| ) |
| rotation_reference = self.unnormalize( |
| reference_rotation, dataset_name=self.dataset_np_names, key='obs_rotation' |
| ) |
| translation = translation_reference + translation_delta |
| rotation = delta_to_world_rotations(rotation_delta, rotation_reference) |
| translation_next = self.normalize( |
| translation, dataset_name=self.dataset_np_names, key='obs_translation' |
| ) |
| rotation_next = self.normalize(rotation, dataset_name=self.dataset_np_names, key='obs_rotation') |
| gripper_next = gt_gripper |
| return translation_next, rotation_next, gripper_next |
|
|
|
|
| def make_causal_mask(shape: Sequence[int]) -> torch.Tensor: |
| """ |
| Create a causal attention mask of shape `shape` |
| Args: |
| shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len] |
| Returns: |
| torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend |
| to the corresponding column (i.e. key) |
| |
| Example: |
| shape = (3, 5) -> Mask the upper triangular part |
| [ |
| [ 1, 0, 0, 0, 0], |
| [ 1, 1, 0, 0, 0], |
| [ 1, 1, 1, 0, 0] |
| ] |
| """ |
| return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0) |
|
|
|
|
| def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor: |
| """ |
| Enable full bi-directional attention in `attn_mask` inside specific blocks |
| Args: |
| attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool |
| where False values indicate disabled attention |
| full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate |
| positions where full bi-directional attention should be enabled |
| |
| Example: |
| 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, |
| 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, |
| 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, |
| 1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0, |
| 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, |
| 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
| |
| """ |
| assert full_attn.dtype == torch.bool, full_attn.dtype |
| assert full_attn.ndim == 1, full_attn.shape |
| assert full_attn.shape[0] == attn_mask.shape[-2], f'{full_attn.shape[0]}, {attn_mask.shape}' |
| if attn_mask.shape[-1] != attn_mask.shape[-2]: |
| raise NotImplementedError('Only self-attention supported right now.') |
| x = full_attn.view(-1, 1) & full_attn.view(1, -1) |
| x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]]) |
| x = torch.cumprod(x, dim=1).to(dtype=torch.bool) |
| x = x & x.permute(1, 0) |
| mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn |
| mask_indices = torch.where(mask_positions)[0] |
| x[mask_indices, mask_indices] = 0 |
| attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1]) |
| return attn_mask |
|
|
|
|
| IGNORE_INDEX = -100 |
|
|
|
|
| class PaliGemmaProcessor(VLMProcessor[PaliGemmaProcessorConfig]): |
| def __init__( |
| self, |
| config: PaliGemmaProcessorConfig, |
| hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, |
| **kwargs, |
| ): |
| del kwargs |
| super().__init__(config) |
| self.hf_processor = hf_processor |
| self.hf_processor.image_processor.size = dict(self.config.image_sizes['main'].as_json()) |
| self.hf_processor.image_seq_length = self.config.num_image_tokens['main'] |
| self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens['main'] |
| self.bos_id: int = self.tokenizer.bos_token_id |
| self.eos_id: int = self.tokenizer.eos_token_id |
| self.sep_token = '\n' |
| self.sep_id: int = self.tokenizer( |
| self.sep_token, padding=False, add_special_tokens=False, return_attention_mask=False |
| )['input_ids'][0] |
| self.image_token_id: int = self.tokenizer( |
| self.config.image_token, padding=False, add_special_tokens=False, return_attention_mask=False |
| )['input_ids'][0] |
| self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values()) |
| self.bbox_pattern = re.compile( |
| '\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]' |
| ) |
|
|
| def preprocess_inputs( |
| self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] |
| ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
| """ |
| Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at |
| https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs |
| Chat must be always made of separate messages from user and model, always starting with user |
| |
| <image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos> |
| |
| Args: |
| chat: List[str] of even size where each entry corresponds to a different turn in the conversation |
| images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys |
| in the Dict and the List corresponds to history of images |
| """ |
| for key, value in images.items(): |
| if not isinstance(value, list): |
| raise TypeError(f'Camera {key} contains values of type {type(value)} instead of list') |
| (input_ids, target_ids) = ([], []) |
| for i, text in enumerate(chat): |
| text = text.replace(self.sep_token, ' ').replace('<image>', '') |
| text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text) |
| turn_input_ids: List[int] = self.tokenizer( |
| text, padding=False, add_special_tokens=False, return_attention_mask=False |
| )['input_ids'] |
| if i % 2 == 0: |
| turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids) |
| else: |
| turn_target_ids = turn_input_ids |
| if i != len(chat) - 1: |
| turn_input_ids = turn_input_ids + [self.sep_id] |
| turn_target_ids = turn_target_ids + [IGNORE_INDEX] |
| input_ids = input_ids + turn_input_ids |
| target_ids = target_ids + turn_target_ids |
| input_ids = [self.bos_id] + input_ids + [self.eos_id] |
| target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id] |
| image_tokens = self.image_tokens |
| input_ids = image_tokens + input_ids |
| target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids |
| input_ids = torch.tensor(input_ids, dtype=torch.int64) |
| target_ids = torch.tensor(target_ids, dtype=torch.int64) |
| image_tensors: Dict[str, torch.Tensor] = { |
| f'{camera_name}.siglip': self.hf_processor.image_processor( |
| camera_images, size=self.config.image_sizes[camera_name].as_json(), return_tensors='pt' |
| )['pixel_values'] |
| for (camera_name, camera_images) in images.items() |
| } |
| attn_mask = make_causal_mask([len(input_ids), len(input_ids)]) |
| attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX) |
| return { |
| 'input_ids': input_ids, |
| 'target_ids': target_ids, |
| 'images': image_tensors, |
| 'attn_mask': attn_mask, |
| } |
|
|
| @property |
| def tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
| return self.hf_processor.tokenizer |
|
|
| @staticmethod |
| def _bbox_to_loc_tokens(match: str) -> str: |
| """ |
| https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ |
| """ |
| floats = list(map(float, match.groups())) |
| transformed = [f'<loc{np.clip(round(num * 1024), 0, 1023):04d}>' for num in floats] |
| return f"[{', '.join(transformed)}]" |
|
|
| @property |
| def image_sizes(self) -> Dict[str, ImageSizeConfig]: |
| return self.config.image_sizes |
|
|