arvla-bridge / src /processing_vlarm.py
you2who's picture
Duplicate from you2who/paligemma-arvla-bridge
2672775
import collections
import collections.abc
import importlib
import re
import warnings
from abc import abstractmethod
from functools import cached_property
from typing import Dict, List, Optional, Sequence, Tuple, TypeVar
import numpy as np
import PIL.Image
import roma
import torch
import torchvision.transforms.v2
import transformers
import yaml
try:
from .hf_compat import Configurable, Template
except Exception:
_hf_compat = importlib.import_module("src.hf_compat")
Configurable = _hf_compat.Configurable
Template = _hf_compat.Template
from .common_vlarm import (
Normalization,
ResizeMode,
RoboticsControlPlan,
RoboticsInput,
RoboticsOutput,
RoboticsTarget,
RotationFormat,
expand_dims,
)
from .configuration_vlarm import (
ControlDataIOConfig,
ControlTokenizerConfig,
EmptyTokenizerConfig,
ImageSizeConfig,
PaliGemmaProcessorConfig,
RegressionProcessorConfig,
VLAMProcessorConfig,
VLARMProcessorConfig,
VLMProcessorConfig,
)
ControlTokenizerConfigT = TypeVar('ControlTokenizerConfigT', bound=ControlTokenizerConfig)
class ControlTokenizer(Configurable[ControlTokenizerConfigT], Template[ControlTokenizerConfigT]):
@abstractmethod
def __call__(self, *args, **kwargs) -> str:
"""Given GT actions and possibly other information, output text control. Gets appened to the prompt"""
class EmptyTokenizer(ControlTokenizer[EmptyTokenizerConfig]):
"""
Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the
desired result. Includes the hidden states for the image tokens.
"""
def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None:
super().__init__(config)
self.tokenizer = tokenizer
def __call__(self, *_) -> str:
return ''
def np_unique(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute unique elements in data and corresponding indices.
np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply
run np.unique on unsorted data, the indices you will get will be invalid.
"""
(_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True)
(_, indices_of_first_occurence, inverse_indices, counts) = np.unique(
indices[inverse], return_index=True, return_inverse=True, return_counts=True
)
unique_ids = data[indices_of_first_occurence]
return unique_ids, indices_of_first_occurence, inverse_indices, counts
def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor:
"""
Args:
angles: Euler angles in radians in the format 'xyz', shape [..., 3]
Returns:
torch.Tensor of shape [..., 3, 3] containing rotation matrices
"""
return roma.euler_to_rotmat(convention='xyz', angles=angles, degrees=False)
def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor:
"""
Args:
angles: Euler angles in radians in the format 'xyz', shape [..., 3]
Returns:
torch.Tensor of shape [..., 4] containing unit quaternions
"""
return roma.euler_to_unitquat(convention='xyz', angles=angles, degrees=False, normalize=True)
def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor:
"""
Args:
quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4]
eps: Small constant to prevent division by zero
Returns:
torch.Tensor of shape [..., 4] of unit quaternions
"""
return quaternion / (quaternion.norm(dim=-1, keepdim=True) + eps)
def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor:
"""
Args:
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
Returns:
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
"""
unit_quat = normalize_quaternion(quaternion)
rotmat = roma.unitquat_to_euler(convention='xyz', quat=unit_quat, as_tuple=False, degrees=False)
return rotmat
def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor:
"""
Args:
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
Returns:
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
"""
unit_quat = normalize_quaternion(quaternion)
rotmat = roma.unitquat_to_rotmat(unit_quat)
return rotmat
def is_quaternion(quaternion: torch.Tensor) -> bool:
return quaternion.shape[-1] == 4
def is_unit_quaternion(
quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = 'none'
) -> torch.Tensor | bool:
"""
Check if a quternion is normalized or not.
Args:
quaternion: torch.Tensor of shape [..., 4]
tolerance: Tolerance for numerical comparisons
reduction:
'none' - returns torch.Tensor of bools with the same batch shape
'all' - returns a bool, True if ALL quaternions in the batch are normalized
Returns:
torch.Tensor with the same batch shape or bool
"""
assert is_quaternion(quaternion)
is_norm = torch.isclose(
quaternion.norm(dim=-1, keepdim=True),
torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device),
atol=epsilon,
)
if reduction == 'none':
return is_norm
if reduction == 'all':
return bool(torch.all(is_norm).item())
raise ValueError(f'Unknown reduction mode {reduction}')
def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor:
"""
Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion.
If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically,
this doesn't correspond to a single hemisphere of the unit sphere. Follows
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat
"""
assert is_quaternion(quaternion), quaternion.shape
with torch.no_grad():
is_zero = quaternion == 0
flip_condition = (
(quaternion[..., -1:] < 0)
| is_zero[..., -1:] & (quaternion[..., 0:1] < 0)
| is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0)
| is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0)
)
quaternion = torch.where(flip_condition, -quaternion, quaternion)
return quaternion
def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor:
"""Convert any rotmat input to [..., 3, 3] shape"""
if rotmat.shape[-1] == 9:
return rotmat.reshape(*rotmat.shape[:-1], 3, 3)
if rotmat.shape[-2:] == torch.Size([3, 3]):
return rotmat
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor:
"""
Args:
rotmat: Batch of rotation matrices, shape [..., 3, 3]
Returns:
Batch of Euler angles in radiant, shape [..., 3]
"""
rotmat = rotmat_as_3x3(rotmat)
return roma.rotmat_to_euler(convention='xyz', rotmat=rotmat, as_tuple=False, degrees=False)
def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor:
"""
Args:
rotmat: Batch of rotation matrices, shape [..., 3, 3]
Returns:
Batch of unit quaternions, shape [..., 4]
"""
rotmat = rotmat_as_3x3(rotmat)
return roma.rotmat_to_unitquat(rotmat)
def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor:
"""
Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
- Let SVD(M) = U \Sigma V^T
- Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T
- det(UV^T) ensures that det(SVD+(M)) = 1
- The return value is a rotation matrix (ortonormal) with the least-squares distance to M
Args:
x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3]
Returns:
torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3)
"""
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='In CPU autocast, but the target dtype is not supported. Disabling autocast.'
)
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
matrices = x.view(-1, 3, 3)
matrices = matrices.to(dtype=torch.float32)
(u, s, v) = torch.svd(matrices)
vt = torch.transpose(v, 1, 2)
det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1)
diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1)
result = torch.matmul(u, diag_vt)
result = result.view(*x.shape)
result = result.to(dtype=x.dtype)
return result
def is_rotmat_3x3(rotmat: torch.Tensor) -> bool:
return rotmat.shape[-2:] == torch.Size([3, 3])
def is_rotmat_9(rotmat: torch.Tensor) -> bool:
return rotmat.shape[-1] == 9
def is_rotmat(rotmat: torch.Tensor) -> bool:
"""
Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a
valid rotmat. `is_orthonormal_rotmat` performs this additional check.
NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally
`rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution
"""
return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat)
def is_rotmat_orthonormal(
rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = 'none'
) -> torch.Tensor | bool:
"""
Check if a rotation matrix is orthonormal or not.
Args:
rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9]
epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally,
anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not
reduction:
'none' - returns torch.Tensor of bools with the same batch shape
'all' - returns a bool, True is ALL matrices in the batch are orthonormal
Returns:
torch.Tensor with the same batch shape or bool
"""
assert is_rotmat(rotmat)
rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32))
is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon)
if reduction == 'none':
return is_orthonormal
if reduction == 'all':
return bool(torch.all(is_orthonormal).item())
raise ValueError(f'Unknown reduction mode {reduction}')
def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool:
"""
Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3,
also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles
when accidentally `rotmat.shape[-2:] == [3, 3]`
"""
return (
is_rotmat_9(rotmat)
or is_rotmat_3x3(rotmat)
and is_rotmat_orthonormal(rotmat, epsilon=0.02, reduction='all')
)
def is_euler(euler: torch.Tensor) -> bool:
return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler)
def rotation_format_from_tensor(rotation) -> RotationFormat:
if is_quaternion(rotation):
return RotationFormat.QUATERNION
if is_orthonormal_rotmat(rotation):
return RotationFormat.ROTMAT
if is_euler(rotation):
return RotationFormat.EULER
raise ValueError(f'Tensor shape {rotation.shape} is not a valid rotation format')
def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor:
"""Convert any rotmat input to [..., 9] shape"""
if is_rotmat_9(rotmat):
return rotmat
if is_rotmat_3x3(rotmat):
return rotmat.reshape(*rotmat.shape[:-2], 9)
raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")
def convert_rotation(
rotation: torch.Tensor | np.ndarray,
output_format: RotationFormat,
autonorm: bool = True,
half_cover: bool = True,
) -> torch.Tensor | np.ndarray:
is_np = isinstance(rotation, np.ndarray)
if is_np:
rotation = torch.from_numpy(rotation)
if is_quaternion(rotation):
if autonorm and not is_unit_quaternion(rotation, reduction='all'):
rotation = normalize_quaternion(rotation)
if output_format == RotationFormat.QUATERNION:
output = rotation
elif output_format == RotationFormat.ROTMAT:
output = rotmat_as_9(quaternion_to_rotmat(rotation))
elif output_format == RotationFormat.EULER:
output = quaternion_to_euler(rotation)
else:
raise NotImplementedError(f'Unsupported rotation format: {output_format}')
elif is_orthonormal_rotmat(rotation):
if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction='all'):
rotation = symmetric_orthogonalization(rotation)
if output_format == RotationFormat.QUATERNION:
output = rotmat_to_unit_quaternion(rotation)
elif output_format == RotationFormat.ROTMAT:
output = rotmat_as_9(rotation)
elif output_format == RotationFormat.EULER:
output = rotmat_to_euler(rotation)
else:
raise NotImplementedError(f'Unsupported rotation format: {output_format}')
elif is_euler(rotation):
if output_format == RotationFormat.QUATERNION:
output = euler_to_unit_quaternion(rotation)
elif output_format == RotationFormat.ROTMAT:
output = rotmat_as_9(euler_to_rotmat(rotation))
elif output_format == RotationFormat.EULER:
output = rotation
else:
raise NotImplementedError(f'Unsupported rotation format: {output_format}')
else:
raise ValueError(f'Unknown rotation encoding with shape {rotation.shape}')
if output_format == RotationFormat.QUATERNION and half_cover:
output = quaternion_half_cover(output)
if is_np:
output = output.numpy()
return output
def delta_to_world_rotations(rotation_sequence: torch.Tensor, reference_frame: torch.Tensor) -> torch.Tensor:
"""
Transform a sequence of rotation representations encoded w.r.t. `reference_frame` to encoding w.r.t.
WORLD frame, where `reference_frame` is provided w.r.t. WORLD frame
Ex:
Sequence of points (rotations): R_1, R_2, R_3, R_4
`rotation_sequence` contains the rotations: R_01, R_02, R_03, R_04, where R_0 is the reference frame
and R_01 is the pose of R1 frame in reference frame, i.e. R_10 converts from reference frame to R1 frame
Output: R_W1, R_W2, R_W3, R_W4, where W is the world frame, i.e. the rotation poses of
R_1, R_2, R_3, R_4 expressed in world frame
Args:
rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
reference_frame: torch.Tensor, shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] and the SAME number of BATCH
dims as `rotation_sequence`. The reference frame, provided w.r.t. WORLD coordinate frame R_W0,1,2,3
Returns:
torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
(R_W1, R_W2, R_W3, R_W4, ...)
"""
assert rotation_sequence.ndim >= 3, rotation_sequence.shape
rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
reference_frame = rotmat_as_3x3(convert_rotation(reference_frame, RotationFormat.ROTMAT))
rotation_sequence = rotmat_as_3x3(convert_rotation(rotation_sequence, RotationFormat.ROTMAT))
if reference_frame.ndim != rotation_sequence.ndim:
raise ValueError(
f'Cannot broadcast reference_frame of shape {reference_frame.shape} to rotation_sequence of shape {rotation_sequence.shape}. Provide tensors with the same number of batch dimensions'
)
R_W0 = reference_frame
world_rotations = torch.matmul(R_W0, rotation_sequence)
world_rotations = world_rotations.view(*world_rotations.shape[:-2], 9)
world_rotations = convert_rotation(world_rotations, rotation_format)
return world_rotations
def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor:
"""
Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the
sequence to the 0-th element preceding the sequence
Ex:
`rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame,
implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame
Output: R_01, R_02, R_03, R_04
Args:
rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
Returns:
torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
(R_01, R_02, R_03, R_04, ...)
TODO: Can you make it work without for loop
"""
assert rotation_sequence.ndim >= 3, rotation_sequence.shape
rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
batch_dims = np.arange(rotation_sequence.ndim - 2)
delta_rotations = torch.cat(
[rotation_sequence[..., :1, :]]
+ [
roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2))
for i in range(2, rotation_sequence.shape[-2] + 1)
],
dim=-2,
)
delta_rotations = convert_rotation(delta_rotations, rotation_format)
return delta_rotations
def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray:
"""Make sure image is of type np.ndarray and HWC format"""
if isinstance(image, PIL.Image.Image):
image = np.asarray(image)
assert isinstance(image, np.ndarray), type(image)
assert image.ndim in [2, 3], image.shape
if image.ndim == 3:
assert image.shape[-1] <= 4, image.shape
return image
def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]:
if isinstance(image, np.ndarray):
(height, width) = image.shape[:2]
else:
(width, height) = image.size
return height, width
def pad_image(
image: PIL.Image.Image | np.ndarray,
target_size: dict[str, int],
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
"""Pad image adding a symmetric border around the height/width."""
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
(height, width) = hw_from_image(image)
(target_width, target_height) = (target_size['width'], target_size['height'])
if width == target_width and height == target_height:
return image
assert target_width >= width, f"Can't pad image of width {width} to {target_width}"
assert target_height >= height, f"Can't pad image of height {height} to {target_height}"
(horizontal_pad, vertical_pad) = (int((target_width - width) / 2), int((target_height - height) / 2))
if isinstance(image, np.ndarray):
padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * (
image.ndim - 2
)
image = np.pad(image, padding, mode='constant', constant_values=pad_value)
else:
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
image = torchvision.transforms.v2.functional.pad(
image, padding=padding, fill=pad_value, padding_mode='constant'
)
return image
def pad_image_to_ratio(
image: PIL.Image.Image | np.ndarray,
target_wh_ratio: float,
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
"""Pad image to a target aspect ratio."""
(height, width) = hw_from_image(image)
wh_ratio = width / height
if target_wh_ratio >= wh_ratio:
pad_size = {'width': round(height * target_wh_ratio), 'height': height}
else:
pad_size = {'width': width, 'height': round(width / target_wh_ratio)}
image = pad_image(image, target_size=pad_size, pad_value=pad_value)
return image
def crop_image(
image: np.ndarray | PIL.Image.Image,
start_height: int,
start_width: int,
target_height: int,
target_width: int,
) -> np.ndarray | PIL.Image.Image:
np_image = assert_np_hwc_or_hw_image(image)
(height, width) = hw_from_image(image)
assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
(start_height, start_width) = (round(start_height), round(start_width))
(target_height, target_width) = (round(target_height), round(target_width))
np_image = np_image[
start_height : start_height + target_height, start_width : start_width + target_width, ...
]
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
return image
def crop_image_center(
image: np.ndarray | PIL.Image.Image, target_size: dict[str, int]
) -> np.ndarray | PIL.Image.Image:
np_image = assert_np_hwc_or_hw_image(image)
(height, width) = np_image.shape[:2]
(target_height, target_width) = (target_size['height'], target_size['width'])
assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
top = (height - target_height) // 2
left = (width - target_width) // 2
np_image = crop_image(np_image, top, left, target_height, target_width)
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
return image
def crop_image_to_ratio(
image: PIL.Image.Image | np.ndarray, target_wh_ratio: float
) -> PIL.Image.Image | np.ndarray:
"""Pad image to a target aspect ratio."""
(height, width) = hw_from_image(image)
wh_ratio = width / height
if target_wh_ratio >= wh_ratio:
crop_size = {'width': width, 'height': round(width / target_wh_ratio)}
else:
crop_size = {'width': round(height * target_wh_ratio), 'height': height}
image = crop_image_center(image, target_size=crop_size)
return image
def crop_and_pad_image_to_ratio(
image: PIL.Image.Image | np.ndarray,
target_wh_ratio: float,
mode: ResizeMode | str,
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
"""
Crop and pad an image to a target size depending on the mode.
It's expected that the source image and target size have different aspect ratios.
Args:
image: The image to crop and pad.
target_size: The target size to crop and pad the image to.
mode: The mode to use for cropping and padding.
"""
(height, width) = hw_from_image(image)
wh_ratio = width / height
if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001):
return image
if mode == ResizeMode.SMART:
aspect_ratio = max(width, height) / min(width, height)
target_ratio = max(target_wh_ratio, 1 / target_wh_ratio)
if aspect_ratio == 1:
if target_ratio >= 4 / 3 - 0.01:
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
image = crop_image_to_ratio(image, crop_wh_ratio)
else:
pass
elif aspect_ratio <= 4 / 3 + 0.01:
if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
image = crop_image_to_ratio(image, 1.0)
elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
image = crop_image_to_ratio(image, 1.0)
elif target_ratio >= 4 / 3 + 0.01:
pass
else:
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
image = crop_image_to_ratio(image, crop_wh_ratio)
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
elif mode == ResizeMode.PAD:
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
elif mode == ResizeMode.CROP:
image = crop_image_to_ratio(image, target_wh_ratio)
else:
raise ValueError(f'Mode {mode} not supported')
return image
def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool:
if isinstance(image, PIL.Image.Image):
return image.mode in ['1', 'L', 'LA', 'La', 'P', 'PA', 'F', 'I', 'I;16', 'I;16L', 'I;16B', 'I;16N']
if isinstance(image, np.ndarray):
return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1
raise ValueError(f'Unsupported image type: {type(image)}')
def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool:
image = np.asarray(image)
return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1
def resize_image(
image: PIL.Image.Image | np.ndarray,
target_size: dict[str, int],
mode: ResizeMode | str,
resample: PIL.Image.Resampling | str = 'auto',
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
(target_width, target_height) = (target_size['width'], target_size['height'])
(height, width) = hw_from_image(image)
if height == target_height and width == target_width:
return image
if resample == 'auto':
if is_single_channel_image(image):
resample = PIL.Image.Resampling.BILINEAR
else:
resample = PIL.Image.Resampling.LANCZOS
else:
assert isinstance(resample, PIL.Image.Resampling), resample
if is_single_channel_image(image) and resample not in [
PIL.Image.Resampling.BILINEAR,
PIL.Image.Resampling.BICUBIC,
]:
raise ValueError(
f'Single channel images must be resized with bilinear or bicubic, but got {resample}'
)
if is_bin_mask := is_binary_mask(image):
image = np.asarray(image).astype(np.uint8) * 255
if mode == ResizeMode.SMART:
image = crop_and_pad_image_to_ratio(
image, target_wh_ratio=target_width / target_height, mode=mode, pad_value=pad_value
)
pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image
if mode in [ResizeMode.NAIVE, ResizeMode.SMART]:
pil_image = pil_image.resize((target_width, target_height), resample=resample)
else:
raise NotImplementedError(f'Mode {mode} not supported')
image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image
if is_bin_mask:
image = image.astype(np.uint8) > 127
return image
def is_global_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool:
"""Return true if norm is NONE or global for all datasets"""
return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping)
def is_mean_norm(norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list]) -> bool:
"""Return true if norm is based on mean and std"""
return (
norm == Normalization.MEAN
or isinstance(norm, collections.abc.Mapping)
and set(norm.keys()) == {'mean', 'std'}
)
def _broadcast_shapes(
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Broadcast shapes for normalization:
Args:
value: torch.Tensor of shape [..., num_components]. The entire shape might be:
- [num_components]: `value` has no batch dimension
- [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds
contained in `low` and `high`
- [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds
contained in `low` and `high`
- [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high`
must be for a single dataset, i.e. `num_datasets = 1`
low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low`
contains normalization bounds for a single dataset
high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high`
contains normalization bounds for a single dataset
Returns:
Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value`
"""
assert low.ndim == high.ndim == 2, f'{low.shape} != {high.shape} or ndim != 2'
assert value.shape[-1] == low.shape[-1] == high.shape[-1], f'{value.shape} != {low.shape} / {high.shape}'
if value.ndim == low.ndim == high.ndim:
return low, high
if value.ndim < low.ndim:
assert low.ndim == high.ndim == 2, f'{low.shape}, {high.shape}'
assert low.shape[0] == high.shape[0] == 1, f'{low.shape}, {high.shape}'
(low, high) = (low.view(-1), high.view(-1))
return low, high
if low.shape[0] == high.shape[0] == 1:
low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1])
high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1])
else:
assert value.shape[0] == low.shape[0] == high.shape[0], f'{value.shape} != {low.shape} / {high.shape}'
low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1])
high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1])
return low, high
def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
(mean, std) = _broadcast_shapes(value, mean, std)
(mean, std) = (mean.to(device=value.device), std.to(device=value.device))
return value * (std + 1e-08) + mean
def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
(low, high) = _broadcast_shapes(value, low, high)
(low, high) = (low.to(device=value.device), high.to(device=value.device))
return 0.5 * (value + 1) * (high - low) + low
def normalize_gripper_by_bounds(
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True
) -> torch.Tensor:
"""
If binary, normalize to [0, 1], otherwise normalize to [-1, 1]
"""
(low, high) = _broadcast_shapes(value, low, high)
(low, high) = (low.to(device=value.device), high.to(device=value.device))
if binary:
return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0)
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
(mean, std) = _broadcast_shapes(value, mean, std)
(mean, std) = (mean.to(device=value.device), std.to(device=value.device))
return (value - mean) / (std + 1e-08)
def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
(low, high) = _broadcast_shapes(value, low, high)
(low, high) = (low.to(device=value.device), high.to(device=value.device))
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray:
if low < 0.0:
return np.clip(-gripper, low, high)
return high - np.clip(gripper, low, high)
GRIPPER_BOUNDS = {
'austin_buds_dataset': (0.0, 0.08),
'austin_sailor_dataset': (0.0, 0.08),
'austin_sirius_dataset': (0.0, 0.08),
'bc_z': (0.0, 1.0),
'berkeley_autolab_ur5': (0.0, 1.0),
'berkeley_cable_routing': (0.0, 1.0),
'berkeley_fanuc_manipulation': (0.0, 1.0),
'bridge': (0.0, 1.0),
'bridge_orig': (0.0, 1.0),
'bridge_action_lang': (0.0, 1.0),
'cmu_stretch': (-3.0, 3.0),
'dlr_edan_shared_control': (0.0, 1.0),
'droid': (0.0, 1.0),
'fmb': (0.0, 1.0),
'fractal20220817_data': (0.0, 1.0),
'furniture_bench_dataset': (0.0, 0.08),
'iamlab_cmu_pickup_insert': (0.0, 1.0),
'jaco_play': (0.0, 1.4),
'kuka': (0.0, 1.0),
'language_table': (0.0, 1.0),
'nyu_franka_play_dataset': (0.0, 1.0),
'roboset': (0.0, 1.0),
'roboturk': (0.0, 1.0),
'stanford_hydra_dataset': (0.0, 0.08),
'taco_play': (0.0, 0.08),
'toto': (0.0, 1.0),
'ucsd_kitchen_dataset': (0.0, 1.0),
'utaustin_mutex': (0.0, 0.08),
'viola': (0.0, 0.08),
}
def preprocess_gripper_observation(
gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True
) -> np.ndarray:
"""
Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset
or from the robot and output is normalized continuous value.
- if `binary`, output is in [0, 1], with 0 = closed and 1 = open.
- otherwise, output is in [-1, 1], with -1 = closed and 1 = open.
Dataset-specific gripper observations:
austin_buds_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
austin_sailor_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
austin_sirius_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
bc_z: continuous; [0=open; 1=closed]
berkeley_autolab_ur5: binary; [0=open; 1=closed]
berkeley_cable_routing: constant (closed)
berkeley_fanuc_manipulation: binary; [0=open; 1=closed]
bridge: continuous; ~[0=closed; 1=open]
bridge_orig: continuous; ~[0=closed; 1=open]
cmu_stretch: continuous; [-3=closed; 3=open]
dlr_edan_shared_control: missing
droid: continuous; [0=open, 1=closed]
fmb: binary; [0=open; 1=closed]
fractal20220817_data: continuous; [0=open; 1=closed]
furniture_bench_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
iamlab_cmu_pickup_insert: binary; [0=closed; 1=open]
jaco_play: continuous; [0=open; 1.4=closed]
kuka: binary; [0=open; 1=closed]
language_table: constant (no gripper)
nyu_franka_play_dataset: missing
roboset: continuous; [0=open, 1=closed]
roboturk: continuous; [0=closed, 0.04=open]
stanford_hydra_dataset: continuous; ~[0=closed; 0.08=open] (franka gripper)
taco_play: continuous; ~[0=closed; 0.08=open] (franka gripper)
toto: constant (closed)
ucsd_kitchen_dataset: missing
utaustin_mutex: continuous; ~[0=closed; 0.08=open] (franka gripper)
viola: continuous; ~[0=closed; 0.08=open] (franka gripper)
"""
if isinstance(dataset_name, np.ndarray):
assert np.unique(dataset_name).size == 1, dataset_name
dataset_name = str(dataset_name[0])
if dataset_name in [
'berkeley_cable_routing',
'dlr_edan_shared_control',
'language_table',
'nyu_franka_play_dataset',
'toto',
'ucsd_kitchen_dataset',
]:
gripper = normalize_gripper_by_bounds(
torch.from_numpy(gripper),
low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
binary=binary,
).numpy()
elif dataset_name in [
'bc_z',
'berkeley_autolab_ur5',
'berkeley_fanuc_manipulation',
'droid',
'fmb',
'fractal20220817_data',
'jaco_play',
'kuka',
'roboset',
]:
(low, high) = GRIPPER_BOUNDS[dataset_name]
gripper = normalize_gripper_by_bounds(
torch.from_numpy(invert_gripper(gripper, low=low, high=high)),
low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
binary=binary,
).numpy()
elif dataset_name in [
'austin_buds_dataset',
'austin_sailor_dataset',
'austin_sirius_dataset',
'bridge',
'bridge_orig',
'bridge_action_lang',
'cmu_stretch',
'furniture_bench_dataset',
'iamlab_cmu_pickup_insert',
'roboturk',
'stanford_hydra_dataset',
'taco_play',
'utaustin_mutex',
'viola',
]:
(low, high) = GRIPPER_BOUNDS[dataset_name]
gripper = normalize_gripper_by_bounds(
torch.from_numpy(gripper),
low=torch.full(gripper.shape, low, dtype=torch.float32),
high=torch.full(gripper.shape, high, dtype=torch.float32),
binary=binary,
).numpy()
else:
raise NotImplementedError(f'Unknown dataset: {dataset_name}')
return gripper
def rotation_norm_bounds(
rotation_norm: Normalization,
rotation_format: RotationFormat,
stats: Dict[str, Dict[str, Dict[str, List[float]]]],
dataset_names: List[str],
) -> Dict[str, Dict[str, torch.Tensor]]:
if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE:
if rotation_norm == Normalization.BOUNDS:
results = {
dataset_name: {
'low': torch.tensor(dataset_stats['euler']['min']),
'high': torch.tensor(dataset_stats['euler']['max']),
}
for (dataset_name, dataset_stats) in stats.items()
}
elif rotation_norm == Normalization.BOUNDS_Q99:
results = {
dataset_name: {
'low': torch.tensor(dataset_stats['euler']['q01']),
'high': torch.tensor(dataset_stats['euler']['q99']),
}
for (dataset_name, dataset_stats) in stats.items()
}
else:
raise NotImplementedError(f'Normalization type {rotation_norm} not yet implemented')
else:
assert rotation_norm == Normalization.NONE, rotation_norm
if rotation_format == RotationFormat.EULER:
rotation_size = 3
elif rotation_format == RotationFormat.QUATERNION:
rotation_size = 4
else:
rotation_size = 9
results = {
dataset_name: {
'low': -1 * torch.ones(rotation_size, dtype=torch.float32),
'high': 1 * torch.ones(rotation_size, dtype=torch.float32),
}
for dataset_name in dataset_names
}
return results
def translation_norm_bounds(
translation_norm: Normalization | tuple,
stats: Dict[str, Dict[str, Dict[str, List[float]]]],
dataset_names: List[str],
) -> Dict[str, Dict[str, torch.Tensor]]:
if isinstance(translation_norm, Normalization) and translation_norm != Normalization.NONE:
if translation_norm == Normalization.BOUNDS:
results = {
dataset_name: {
'low': torch.tensor(dataset_stats['translation']['min']),
'high': torch.tensor(dataset_stats['translation']['max']),
}
for (dataset_name, dataset_stats) in stats.items()
}
elif translation_norm == Normalization.BOUNDS_Q99:
results = {
dataset_name: {
'low': torch.tensor(dataset_stats['translation']['q01']),
'high': torch.tensor(dataset_stats['translation']['q99']),
}
for (dataset_name, dataset_stats) in stats.items()
}
elif translation_norm == Normalization.MEAN:
results = {
dataset_name: {
'mean': torch.tensor(dataset_stats['translation']['mean']),
'std': torch.tensor(dataset_stats['translation']['std']),
}
for (dataset_name, dataset_stats) in stats.items()
}
else:
raise NotImplementedError(f'Normalization type {translation_norm} not yet implemented')
elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE:
results = {
dataset_name: {
'low': -1 * torch.ones(3, dtype=torch.float32),
'high': 1 * torch.ones(3, dtype=torch.float32),
}
for dataset_name in dataset_names
}
else:
assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm)
assert all((len(value) == 3 for value in translation_norm.values())), translation_norm
assert set(translation_norm.keys()) in ({'low', 'high'}, {'mean', 'std'}), translation_norm
results = {
dataset_name: {
key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items()
}
for dataset_name in dataset_names
}
return results
VLMProcessorConfigT = TypeVar('VLMProcessorConfigT', bound=VLMProcessorConfig)
class VLMProcessor(Configurable[VLMProcessorConfigT], Template[VLMProcessorConfigT]):
@abstractmethod
def preprocess_inputs(
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
...
@property
@abstractmethod
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
pass
@property
@abstractmethod
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
pass
VLAMProcessorConfigT = TypeVar('VLAMProcessorConfigT', bound=VLAMProcessorConfig)
class VLAMProcessor(Configurable[VLAMProcessorConfigT], Template[VLAMProcessorConfigT]):
def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor):
super().__init__(config)
self.vlm_processor = vlm_processor
self.control_tokenizer = EmptyTokenizer(
config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
)
self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {
'obs_translation': self.obs_translation_norm_bounds,
'obs_rotation': self.obs_rotation_norm_bounds,
'translation': self.translation_norm_bounds,
'rotation': self.rotation_norm_bounds,
'joints': self.joints_norm_bounds,
}
@property
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return self.vlm_processor.tokenizer
@property
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
return self.vlm_processor.image_sizes
@property
def camera_names(self) -> List[str]:
return list(self.vlm_processor.image_sizes.keys())
@property
def control_io_config(self) -> ControlDataIOConfig:
return self.config.control_io_config
@cached_property
def rotation_components(self) -> int:
if self.config.rotation_format == RotationFormat.EULER:
return 3
if self.config.rotation_format == RotationFormat.QUATERNION:
return 4
if self.config.rotation_format == RotationFormat.ROTMAT:
return 9
raise NotImplementedError(self.config.rotation_format)
@abstractmethod
def policy_control_plan_from_model_target(
self, target: RoboticsTarget, dataset_name: np.ndarray
) -> RoboticsControlPlan:
pass
@abstractmethod
def policy_control_plan_from_model_output(
self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor
) -> RoboticsControlPlan:
pass
def resize_image(
self, camera_name: str, image: PIL.Image.Image | np.ndarray
) -> PIL.Image.Image | np.ndarray:
return resize_image(
image,
target_size={
'width': self.image_sizes[camera_name].width,
'height': self.image_sizes[camera_name].height,
},
mode=self.config.image_resize,
resample=PIL.Image.Resampling.LANCZOS,
)
def preprocess_inputs(
self,
chat: List[str],
images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]],
ee_pose_translation: np.ndarray,
ee_pose_rotation: np.ndarray,
gripper: np.ndarray,
joints: np.ndarray,
dataset_name: np.ndarray,
inference_mode: bool,
control_target: Optional[RoboticsTarget] = None,
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
"""
Preprocess the inputs for a single example
Args:
instruction: Language instruction
images: History of input images with increasing timestamps
ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3]
ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9]
joints: np.ndarray, shape [..., num_past_scalars, <= 7]
dataset_name: 1D np.ndarray
inference_mode: If True, prepare the input for inference (e.g. don't include target
any tokens in the input if relevant). If control_target is available, it should
still be preprocessed for test dataset comparison
control_target: RoboticsTarget, each component of shape
[..., num_control_steps, num_control_components]. Provided only when available, usually
during training and dataset test
Returns:
Dict containing torch.Tensor with inputs
"""
del control_target
del inference_mode
inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
images: Dict[str, torch.Tensor] = inputs['images']
input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length]
target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length]
attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
gripper = preprocess_gripper_observation(gripper, dataset_name)
gripper = torch.tensor(gripper, dtype=torch.float32)
ee_pose_translation = self.normalize(
ee_pose_translation, dataset_name=dataset_name, key='obs_translation'
)
ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation')
joints = torch.tensor(joints, dtype=torch.float32)
if joints.shape[-1] < 7:
missing_size = 7 - joints.shape[-1]
joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
joints = self.normalize(joints, dataset_name=dataset_name, key='joints')
outputs = {
'images': images,
'input_ids': input_ids,
'target_text_tokens_ids': target_text_tokens_ids,
'attn_mask': attn_mask,
'ee_pose_translation': ee_pose_translation,
'ee_pose_rotation': ee_pose_rotation,
'gripper': gripper,
'joints': joints,
'control_tokens_ids': None,
'target_control_tokens_ids': None,
}
return outputs
def create_input(
self,
chat: List[str],
images: Dict[str, List[PIL.Image.Image]],
ee_pose_translation: np.ndarray,
ee_pose_rotation: np.ndarray,
gripper: np.ndarray,
joints: np.ndarray,
dataset_name: np.ndarray,
inference_mode: bool,
control_target: Optional[RoboticsTarget] = None,
) -> RoboticsInput:
inputs = self.preprocess_inputs(
chat=chat,
images=images,
ee_pose_translation=ee_pose_translation,
ee_pose_rotation=ee_pose_rotation,
gripper=gripper,
joints=joints,
dataset_name=dataset_name,
inference_mode=inference_mode,
control_target=control_target,
)
inputs.pop('target_text_tokens_ids')
inputs.pop('target_control_tokens_ids')
return RoboticsInput(**inputs)
def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
if is_mean_norm(getattr(self.config, f'{key}_norm')):
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = normalize_by_moments(value, mean=mean, std=std)
else:
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = normalize_by_bounds(value, low=low, high=high)
return output
def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
if is_mean_norm(getattr(self.config, f'{key}_norm')):
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = unnormalize_by_moments(value, mean=mean, std=std)
else:
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = unnormalize_by_bounds(value, low=low, high=high)
return output
def _norm_bounds_from_dataset_name(
self, dataset_name: np.ndarray, component_key: str
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create an array of normalization bounds corresponding to dataset names
Args:
dataset_name: Array of shape [B] of dataset names for which to fetch the low and high
normalization bounds. Note the values can be repeating
component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to
compute the normalization bounds
Returns:
Tuple of low and high bounds or norm and std, each of shape [B, -1]
"""
norm = getattr(self.config, f'{component_key}_norm')
if is_mean_norm(norm):
(stats_key_1, stats_key_2) = ('mean', 'std')
else:
(stats_key_1, stats_key_2) = ('low', 'high')
if component_key == 'joints':
if not isinstance(norm, collections.abc.Mapping):
raise NotImplementedError()
stats = {
key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1]))
for (key, value) in self.joints_norm_bounds['ANY'].items()
}
return tuple(stats.values())
component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1]
if self.dataset_names == ['ANY']:
stats_1 = self.norm_bounds[component_key]['ANY'][stats_key_1]
stats_2 = self.norm_bounds[component_key]['ANY'][stats_key_2]
stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0)
stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0)
else:
(unique_names, _, inverse_indices, _) = np_unique(dataset_name)
stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32)
stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32)
for i, ds_name in enumerate(unique_names):
stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1]
stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2]
stats_1 = stats_1[inverse_indices]
stats_2 = stats_2[inverse_indices]
return torch.from_numpy(stats_1), torch.from_numpy(stats_2)
@cached_property
def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return rotation_norm_bounds(
rotation_norm=self.config.obs_rotation_norm,
rotation_format=self.config.rotation_format,
stats=self._observation_stats,
dataset_names=self.dataset_names,
)
@cached_property
def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return translation_norm_bounds(
translation_norm=self.config.obs_translation_norm,
stats=self._observation_stats,
dataset_names=self.dataset_names,
)
@cached_property
def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return rotation_norm_bounds(
rotation_norm=self.config.rotation_norm,
rotation_format=self.config.rotation_format,
stats=self._control_stats,
dataset_names=self.dataset_names,
)
@cached_property
def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return translation_norm_bounds(
translation_norm=self.config.translation_norm,
stats=self._control_stats,
dataset_names=self.dataset_names,
)
@cached_property
def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
NOTE:
- Joint values across all joints and all datasets vary in the range [-2pi; 2pi]
- The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi]
- It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint
"""
low = torch.tensor(self.config.joints_norm['low'], dtype=torch.float32)
high = torch.tensor(self.config.joints_norm['high'], dtype=torch.float32)
results = {'ANY': {'low': low, 'high': high}}
return results
@cached_property
def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
return {
'austin_buds_dataset': {
'euler': {
'max': [3.141589641571045, 0.2279592752456665, 0.10982763767242432],
'mean': [-0.433798223733902, 0.02660757675766945, 0.01057341042906046],
'min': [-3.1415905952453613, -0.17229366302490234, -0.08408939838409424],
'q01': [-3.1401021480560303, -0.15158048272132874, -0.07398375868797302],
'q99': [3.1401915550231934, 0.1699378788471222, 0.08044551312923431],
'std': [3.071873664855957, 0.08403228223323822, 0.04623554274439812],
},
'gripper': {
'max': [0.07999841868877411],
'min': [0.00019240332767367363],
'q01': [0.00760263716802001],
'q99': [0.07997412234544754],
},
'joints': {
'max': [
0.25105541944503784,
1.0239691734313965,
0.25514841079711914,
0.0,
0.05838121101260185,
3.0727620124816895,
1.1911247968673706,
],
'mean': [
-0.02248559147119522,
0.4224241375923157,
0.011533008888363838,
-2.040178060531616,
-0.004422259051352739,
2.440391778945923,
0.7626844644546509,
],
'min': [
-0.46276336908340454,
-0.2620261609554291,
-0.37377235293388367,
-3.010026693344116,
-0.015008972026407719,
0.0,
0.0,
],
'q01': [
-0.3563080132007599,
-0.19165010750293732,
-0.29941368103027344,
-2.766944408416748,
-0.012211786583065987,
1.9946393966674805,
0.21924567222595215,
],
'q99': [
0.22295619547367096,
0.8841447830200195,
0.20571835339069366,
-1.339158296585083,
0.05605502426624298,
2.9369189739227295,
1.1475461721420288,
],
'std': [
0.1684190183877945,
0.2238602489233017,
0.1185624971985817,
0.3182348608970642,
0.012190691195428371,
0.2673698663711548,
0.2611873149871826,
],
},
'translation': {
'max': [0.7684353590011597, 0.22995209693908691, 0.3893454968929291],
'mean': [0.5589713454246521, -0.0022956086322665215, 0.12593945860862732],
'min': [0.0, -0.31078848242759705, 0.0],
'q01': [0.3499317765235901, -0.2854413390159607, 0.010516085661947727],
'q99': [0.7243335843086243, 0.20652863383293152, 0.3218296766281128],
'std': [0.08953548222780228, 0.1462203562259674, 0.0769098550081253],
},
},
'austin_sailor_dataset': {
'euler': {
'max': [3.1415910720825195, 0.26857471466064453, 2.4386940002441406],
'mean': [0.040416110306978226, 0.013173789717257023, -0.044821564108133316],
'min': [-3.141592264175415, -0.21639788150787354, -1.995558738708496],
'q01': [-3.1401662826538086, -0.16171419620513916, -1.5918874740600586],
'q99': [3.1401755809783936, 0.17527776956558228, 1.3560799360275269],
'std': [3.0995969772338867, 0.08136381208896637, 0.6311237812042236],
},
'gripper': {
'max': [0.07783995568752289],
'min': [-8.864999836077914e-05],
'q01': [0.0005174533580429852],
'q99': [0.07778151333332062],
},
'joints': {
'max': [
0.18366889655590057,
1.188208818435669,
1.1904715299606323,
-1.0253318548202515,
0.06765712797641754,
2.98008131980896,
2.7747786045074463,
],
'mean': [
-0.4481923580169678,
0.428782194852829,
0.29826778173446655,
-2.1356770992279053,
-0.2064618021249771,
2.5082974433898926,
0.8104547262191772,
],
'min': [
-1.4425580501556396,
-0.34988880157470703,
-0.4436887800693512,
-2.9587178230285645,
-0.9116092324256897,
1.7689018249511719,
-1.4776484966278076,
],
'q01': [
-1.0579811334609985,
-0.13357718288898468,
-0.2518821358680725,
-2.6593017578125,
-0.6571194529533386,
2.072176456451416,
-0.5861577391624451,
],
'q99': [
0.08897759765386581,
0.8743473887443542,
0.8909343481063843,
-1.4748132228851318,
-0.010495365597307682,
2.859259605407715,
2.319192409515381,
],
'std': [
0.26852160692214966,
0.22502659261226654,
0.2577133774757385,
0.2963367700576782,
0.1629231870174408,
0.18130357563495636,
0.6348837614059448,
],
},
'translation': {
'max': [0.7488242387771606, 0.2829112708568573, 0.3541720509529114],
'mean': [0.5367430448532104, -0.08604215085506439, 0.11135970056056976],
'min': [0.3208981454372406, -0.3730051815509796, 0.020952222868800163],
'q01': [0.387094110250473, -0.3164229393005371, 0.024492919445037842],
'q99': [0.6869593262672424, 0.2086469978094101, 0.2551962733268738],
'std': [0.08000300824642181, 0.13984571397304535, 0.056377921253442764],
},
},
'austin_sirius_dataset': {
'euler': {
'max': [3.141592502593994, 0.09964871406555176, 0.3406205177307129],
'mean': [1.1875473260879517, -0.018573710694909096, -0.4977009892463684],
'min': [-3.141592502593994, -0.14349305629730225, -1.892538070678711],
'q01': [-3.140658378601074, -0.12321051210165024, -1.743293285369873],
'q99': [3.1407761573791504, 0.07348346710205078, 0.062370821833610535],
'std': [2.8533384799957275, 0.03962606564164162, 0.5864803791046143],
},
'gripper': {
'max': [0.07965432107448578],
'min': [0.00016088332631625235],
'q01': [0.0334978811442852],
'q99': [0.07948359102010727],
},
'joints': {
'max': [
0.5567107200622559,
0.3814372420310974,
0.36868324875831604,
0.0,
0.021617505699396133,
2.825117588043213,
2.8925509452819824,
],
'mean': [
0.22241303324699402,
0.03750193864107132,
-0.013105375692248344,
-2.428316593170166,
-0.017613587900996208,
2.4755642414093018,
1.4929485321044922,
],
'min': [
-0.23144784569740295,
-0.41377919912338257,
-0.3536752760410309,
-2.9289300441741943,
-0.06330326944589615,
0.0,
0.0,
],
'q01': [
-0.10539760440587997,
-0.2651134133338928,
-0.21157152950763702,
-2.831827402114868,
-0.04555036500096321,
0.0,
0.0,
],
'q99': [
0.48007193207740784,
0.3064860999584198,
0.2438022643327713,
0.0,
0.012844694778323174,
2.71537446975708,
2.51297926902771,
],
'std': [
0.1297803372144699,
0.132253035902977,
0.08940940350294113,
0.345712274312973,
0.015448619611561298,
0.3493262231349945,
0.632145345211029,
],
},
'translation': {
'max': [0.5655431151390076, 0.3374997079372406, 0.3071175515651703],
'mean': [0.4490734338760376, 0.09406533092260361, 0.17131780087947845],
'min': [0.0, -0.16291481256484985, 0.0],
'q01': [0.0, -0.11814527958631516, 0.0],
'q99': [0.532875120639801, 0.26084619760513306, 0.27225059270858765],
'std': [0.07605387270450592, 0.08447220921516418, 0.04798709973692894],
},
},
'bc_z': {
'euler': {
'max': [3.141583088638963, 1.5360414168274534, 3.141576023430307],
'mean': [-0.32332700342924814, -0.1255861641435509, -0.07202428898041417],
'min': [-3.13973587780783, -1.5322501214383415, -3.1415575429114675],
'q01': [-1.0433973645798003, -1.007553367940694, -2.6477418236555788],
'q99': [0.8467956620806508, 0.9111507554055364, 2.226032625129908],
'std': [0.40307241985577485, 0.4832146677789649, 1.0666829542185445],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.20000000298023224], 'q99': [1.0]},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.6597589254379272, 0.7259413599967957, 1.1217665672302246],
'mean': [0.012632723897695541, 0.10061875730752945, 0.7831991910934448],
'min': [-0.7181899547576904, -0.3756217360496521, -0.281008243560791],
'q01': [-0.3956047296524048, -0.11924505233764648, 0.601338267326355],
'q99': [0.332028865814209, 0.3088575601577759, 0.98329097032547],
'std': [0.1808762550354004, 0.0960959941148758, 0.08665762096643448],
},
},
'berkeley_autolab_ur5': {
'euler': {
'max': [3.141586857385381, 0.9715026080766722, 1.0041252839605828],
'mean': [0.10885329009674993, -0.005823583245937125, 0.01756067034138813],
'min': [-3.14158698706268, -0.9912158529884818, -0.8782840134752882],
'q01': [-3.1391613317061315, -0.6410961610461836, -0.4513447799408977],
'q99': [3.13947692478744, 0.6738775260636495, 0.535202669480506],
'std': [3.069331637100682, 0.18171227413203145, 0.17512358924250696],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]},
'joints': {
'max': [
-2.5115966796875,
-0.682390034198761,
2.6830153465270996,
-1.1711143255233765,
-0.6423047184944153,
4.032349109649658,
],
'mean': [
-3.2676453590393066,
-1.3197712898254395,
2.122663736343384,
-2.3710596561431885,
-1.567111611366272,
3.0015013217926025,
],
'min': [
-4.0775299072265625,
-2.2796404361724854,
1.1905627250671387,
-3.814586877822876,
-2.1223204135894775,
1.5279699563980103,
],
'q01': [
-3.839585781097412,
-1.8913453817367554,
1.5813319683074951,
-3.3646113872528076,
-1.805822730064392,
2.22794246673584,
],
'q99': [
-2.7322237491607666,
-0.8280109763145447,
2.5063326358795166,
-1.5614854097366333,
-1.2497813701629639,
3.7013423442840576,
],
'std': [
0.2746899724006653,
0.27345725893974304,
0.22793518006801605,
0.34690913558006287,
0.10438559204339981,
0.3388271629810333,
],
},
'translation': {
'max': [0.6610942482948303, 0.38513994216918945, 0.2049914449453354],
'mean': [0.45345133543014526, 0.05642913281917572, -0.0356401689350605],
'min': [0.1970997005701065, -0.27643972635269165, -0.20529526472091675],
'q01': [0.3020566999912262, -0.21297279000282288, -0.18836002051830292],
'q99': [0.6132073998451233, 0.30656182765960693, 0.12212439626455307],
'std': [0.07906801998615265, 0.12497302889823914, 0.08053416013717651],
},
},
'berkeley_cable_routing': {
'euler': {
'max': [3.141592264175415, 0.05923163890838623, 3.1243016719818115],
'mean': [-0.6730493903160095, 0.010155542753636837, 0.9881726503372192],
'min': [-3.141592264175415, -0.10035276412963867, -1.6915032863616943],
'q01': [-3.1412672996520996, -0.02922845631837845, -0.6449583172798157],
'q99': [3.1412830352783203, 0.039870113134384155, 2.4969191551208496],
'std': [3.0588438510894775, 0.01435503363609314, 0.6301024556159973],
},
'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.7113381028175354, 0.35645395517349243, 0.23169316351413727],
'mean': [0.5643643736839294, 0.00464651919901371, 0.07635799050331116],
'min': [0.378939151763916, -0.3244381248950958, 0.02837979607284069],
'q01': [0.4641263782978058, -0.2806571424007416, 0.030183622613549232],
'q99': [0.6452807784080505, 0.28204888105392456, 0.1557157188653946],
'std': [0.0380910187959671, 0.16768066585063934, 0.032017387449741364],
},
},
'berkeley_fanuc_manipulation': {
'euler': {
'max': [3.1415822505950928, 1.020048975944519, 1.9211788177490234],
'mean': [0.1944933384656906, -0.03283948823809624, 0.017717985436320305],
'min': [-3.141590118408203, -1.5707060098648071, -3.089141368865967],
'q01': [-3.139866352081299, -1.0165742635726929, -1.6987266540527344],
'q99': [3.140317440032959, 0.4332200288772583, 1.5085773468017578],
'std': [2.9820094108581543, 0.21335174143314362, 0.4836970269680023],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]},
'joints': {
'max': [
0.9483258724212646,
1.4428143501281738,
1.2322325706481934,
2.3726272583007812,
0.0008019324741326272,
3.017007827758789,
],
'mean': [
0.04932224750518799,
0.36293527483940125,
-0.18438231945037842,
0.09242437779903412,
-1.0793049335479736,
-0.06553139537572861,
],
'min': [
-0.7464086413383484,
-0.5774519443511963,
-1.1099220514297485,
-1.8152672052383423,
-2.1454310417175293,
-3.168759346008301,
],
'q01': [
-0.5061959028244019,
-0.09613651037216187,
-0.8047154545783997,
-0.9908221364021301,
-1.8158636093139648,
-2.420135974884033,
],
'q99': [
0.6388322710990906,
1.0234705209732056,
0.5990921258926392,
1.282780647277832,
-0.39091193675994873,
2.340099811553955,
],
'std': [
0.2452843338251114,
0.2555835247039795,
0.2740932106971741,
0.3769519329071045,
0.3472467362880707,
0.6859157681465149,
],
},
'translation': {
'max': [0.8101964592933655, 0.5117550492286682, 0.8054816722869873],
'mean': [0.5446329116821289, 0.003737417282536626, 0.2439887523651123],
'min': [0.263683944940567, -0.5785022377967834, -0.0025757886469364166],
'q01': [0.3718133866786957, -0.4071895182132721, 0.01847645826637745],
'q99': [0.7200658321380615, 0.3128541111946106, 0.5413243770599365],
'std': [0.07649082690477371, 0.14486227929592133, 0.13899113237857819],
},
},
'bridge': {
'euler': {
'max': [3.141592653589793, 1.570796251296997, 3.141204357147217],
'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691],
'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
},
'gripper': {
'max': [1.0370277166366577],
'min': [0.04637829214334488],
'q01': [0.05192930996417999],
'q99': [1.0118417739868164],
},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424],
'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723],
'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685],
},
},
'bridge_orig': {
'euler': {
'max': [3.141592653589793, 1.570796251296997, 3.141204357147217],
'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691],
'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
},
'gripper': {
'max': [1.0370277166366577],
'min': [0.04637829214334488],
'q01': [0.05192930996417999],
'q99': [1.0118417739868164],
},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424],
'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723],
'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685],
},
},
'bridge_action_lang': {
'euler': {
'max': [3.141592653589793, 1.570796251296997, 3.141204357147217],
'mean': [-0.25754162314671525, -0.12370228389510128, 0.1620053749182691],
'min': [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
'q01': [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
'q99': [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
'std': [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
},
'gripper': {
'max': [1.0370277166366577],
'min': [0.04637829214334488],
'q01': [0.05192930996417999],
'q99': [1.0118417739868164],
},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
'mean': [0.309032678604126, 0.03403777256608009, 0.061277542263269424],
'min': [-0.04167502000927925, -0.2889411449432373, -0.13934996724128723],
'q01': [0.1711955964565277, -0.15639324486255646, -0.048255354166030884],
'q99': [0.4604376256465912, 0.24112474918365479, 0.18886254727840424],
'std': [0.0635896623134613, 0.09153717756271362, 0.049334850162267685],
},
},
'cmu_stretch': {
'euler': {
'max': [3.141592653589793, 0.0, 0.0],
'mean': [3.1415926535896035, 0.0, 0.0],
'min': [3.141592653589793, 0.0, 0.0],
'q01': [3.141592653589793, 0.0, 0.0],
'q99': [3.141592653589793, 0.0, 0.0],
'std': [1.8962609260597674e-13, 0.0, 0.0],
},
'gripper': {
'max': [3.110913038253784],
'min': [-3.1155149936676025],
'q01': [-3.055689811706543],
'q99': [3.1017091274261475],
},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.3401322066783905, 0.0, 1.1013386249542236],
'mean': [0.1461893916130066, 0.0, 0.7568270564079285],
'min': [0.005320895928889513, 0.0, 0.45218077301979065],
'q01': [0.017430847510695457, 0.0, 0.46050605177879333],
'q99': [0.33094948530197144, 0.0, 1.0952961444854736],
'std': [0.09959954023361206, 0.0, 0.15819725394248962],
},
},
'dlr_edan_shared_control': {
'euler': {
'max': [2.2228052616119385, 0.13346634805202484, 3.1415085792541504],
'mean': [1.6072595119476318, -0.4295055568218231, 1.1549623012542725],
'min': [0.4599944055080414, -1.4871342182159424, -3.1406784057617188],
'q01': [0.8485284447669983, -1.4454149007797241, -3.1238820552825928],
'q99': [1.8301045894622803, 0.07032366842031479, 3.131430149078369],
'std': [0.2785564363002777, 0.540916919708252, 2.4260833263397217],
},
'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.044179707765579224, 0.6297282576560974, 0.9278888702392578],
'mean': [-0.512688934803009, 0.3690241575241089, 0.441307932138443],
'min': [-0.7695853114128113, -0.040099501609802246, 0.2583216428756714],
'q01': [-0.729511022567749, 0.077408567070961, 0.2658006250858307],
'q99': [-0.13719859719276428, 0.5719971060752869, 0.7898909449577332],
'std': [0.15213875472545624, 0.08891375362873077, 0.13743770122528076],
},
},
'droid': {
'euler': {
'max': [3.141592502593994, 1.5705928802490234, 3.1415867805480957],
'mean': [0.3140628098409554, -0.09296274023036387, -0.07227215454779846],
'min': [-3.141592502593994, -1.5691150426864624, -3.1415374279022217],
'q01': [-3.1378602981567383, -1.2125312042236327, -2.1614069032669065],
'q99': [3.137854380607605, 0.9200375998020163, 1.9367506909370364],
'std': [2.926265757944871, 0.363273475703332, 0.7576065217938824],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.9911894202232361]},
'joints': {
'max': [
2.668445110321045,
1.5691218376159668,
2.666306734085083,
-0.3114914000034332,
2.6624162197113037,
4.28157901763916,
2.752457857131958,
],
'mean': [
0.023137084334640106,
0.2704989977282293,
-0.01451389357228282,
-2.018709403792315,
-0.042720520800030394,
2.350281188152209,
0.12424663946659845,
],
'min': [
-2.6536705493927,
-1.547789216041565,
-2.6781487464904785,
-2.9409868717193604,
-2.6705946922302246,
0.24893812835216522,
-2.7615714073181152,
],
'q01': [
-0.9026106441020965,
-0.8547340619564057,
-0.9028875434398651,
-2.7698556280136106,
-1.6851656341552732,
1.2335169839859008,
-1.9587260699272155,
],
'q99': [
0.9569852340221403,
1.4148830294609054,
0.7693877756595566,
-0.4545914208889008,
1.5623322343826267,
3.475611729621887,
2.263479118347167,
],
'std': [
0.31695080251469465,
0.49522214687158767,
0.27993538230553827,
0.478161574676113,
0.4969961591445458,
0.45101008525403846,
0.7287264344068457,
],
},
'translation': {
'max': [0.8575563430786133, 0.799155592918396, 1.0043904781341553],
'mean': [0.5283099395864883, 0.005363794653877434, 0.3120132207021294],
'min': [-0.15604186058044434, -0.827903687953949, -0.2347021996974945],
'q01': [0.26669957995414734, -0.43774398624897004, -0.048167889714241026],
'q99': [0.7774086785316463, 0.428325751423835, 0.776091011762619],
'std': [0.1148424841779685, 0.17489566608140428, 0.16541062032731538],
},
},
'fmb': {
'euler': {
'max': [3.1415927410125732, 1.044223666191101, 2.6313371658325195],
'mean': [-0.5128238201141357, -0.0326850451529026, 1.2892225980758667],
'min': [-3.141592502593994, -1.0594873428344727, -1.5775980949401855],
'q01': [-3.1404616832733154, -0.9319325089454651, -1.4586873054504395],
'q99': [3.140399932861328, 0.8346238732337952, 2.3081648349761963],
'std': [3.0557258129119873, 0.33158522844314575, 0.6522650122642517],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]},
'joints': {
'max': [
1.633209228515625,
1.216904640197754,
0.6477310061454773,
-1.101138949394226,
1.4866814613342285,
3.2880685329437256,
2.8213212490081787,
],
'mean': [
0.08979735523462296,
0.24736542999744415,
-0.18356309831142426,
-2.1948094367980957,
0.09010367840528488,
2.3375964164733887,
-0.6424615979194641,
],
'min': [
-0.24819369614124298,
-0.6836066246032715,
-1.852062463760376,
-2.950329542160034,
-1.360267162322998,
1.1119775772094727,
-2.53395938873291,
],
'q01': [
-0.12363661825656891,
-0.2893752157688141,
-0.8288514018058777,
-2.7587249279022217,
-0.9787064790725708,
1.5508010387420654,
-1.7525910139083862,
],
'q99': [
0.527353048324585,
0.8305937647819519,
0.3808687925338745,
-1.5172499418258667,
1.1656274795532227,
2.9421844482421875,
2.3595635890960693,
],
'std': [
0.14010190963745117,
0.24683699011802673,
0.29916438460350037,
0.29822880029678345,
0.40809085965156555,
0.3218824863433838,
0.7698287963867188,
],
},
'translation': {
'max': [0.7458599805831909, 0.2993985116481781, 0.39971601963043213],
'mean': [0.5149032473564148, -0.04377440735697746, 0.17520515620708466],
'min': [0.32314637303352356, -0.34641459584236145, 0.029209019616246223],
'q01': [0.3655048608779907, -0.28729698061943054, 0.033201027661561966],
'q99': [0.6782684326171875, 0.209969624876976, 0.3331448435783386],
'std': [0.07067117840051651, 0.13670970499515533, 0.07637079805135727],
},
},
'fractal20220817_data': {
'euler': {
'max': [3.141592264175415, 1.5703786611557007, 3.1415889263153076],
'mean': [-1.5046496391296387, 0.6226330399513245, -1.621827244758606],
'min': [-3.1415927410125732, -1.569378137588501, -3.1415863037109375],
'q01': [-3.130796194076538, -0.23499104380607605, -2.9648191928863525],
'q99': [3.1300582885742188, 1.4907126426696777, 2.8373801708221436],
'std': [2.380099296569824, 0.41548046469688416, 0.803632915019989],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [1.0534898042678833, 0.48018959164619446, 1.6896663904190063],
'mean': [0.5594336986541748, -0.08356647938489914, 0.7760705947875977],
'min': [-0.3139062225818634, -0.9970501065254211, -0.006579156965017319],
'q01': [0.3249714970588684, -0.2818704843521118, 0.1410011649131775],
'q99': [0.8754204511642456, 0.21279653906822205, 1.071526288986206],
'std': [0.12431981414556503, 0.11550336331129074, 0.24583852291107178],
},
},
'furniture_bench_dataset': {
'euler': {
'max': [3.141592502593994, 1.545873999595642, 3.1415679454803467],
'mean': [-1.5673161745071411, 0.028489787131547928, -0.0338752306997776],
'min': [-3.1415927410125732, -1.1365290880203247, -3.141582727432251],
'q01': [-3.139331817626953, -0.6121169328689575, -1.9944958686828613],
'q99': [3.1392478942871094, 0.9993562698364258, 1.7793478965759277],
'std': [2.605621337890625, 0.30199435353279114, 0.924261212348938],
},
'gripper': {
'max': [0.07995442301034927],
'min': [-3.4803331800503656e-05],
'q01': [0.003531553316861391],
'q99': [0.07973115146160126],
},
'joints': {
'max': [
1.6170321702957153,
1.4905058145523071,
0.8054860234260559,
-0.8407800197601318,
2.596961736679077,
3.591698169708252,
2.744664192199707,
],
'mean': [
0.14828824996948242,
0.513389527797699,
-0.07023008167743683,
-2.116460084915161,
0.17382267117500305,
2.6314659118652344,
0.7604196667671204,
],
'min': [
-1.0936335325241089,
-0.28272193670272827,
-1.3713332414627075,
-2.9083454608917236,
-2.664609432220459,
0.49353715777397156,
-2.6659374237060547,
],
'q01': [
-0.3774302303791046,
0.16254600882530212,
-0.526292085647583,
-2.662245512008667,
-0.3341670036315918,
1.4710004329681396,
-0.997846245765686,
],
'q99': [
0.6179460287094116,
0.9334877729415894,
0.3241332471370697,
-1.5048880577087402,
0.69161057472229,
3.45759916305542,
2.5480763912200928,
],
'std': [
0.20775794982910156,
0.14619596302509308,
0.1992885023355484,
0.24350883066654205,
0.2203112691640854,
0.3669370114803314,
0.8288815021514893,
],
},
'translation': {
'max': [0.75081467628479, 0.29695066809654236, 0.35806331038475037],
'mean': [0.5622280836105347, 0.049872349947690964, 0.08108603954315186],
'min': [0.2878606617450714, -0.3141690492630005, -0.00460465531796217],
'q01': [0.36915361881256104, -0.180975541472435, 0.0058300793170928955],
'q99': [0.6652880311012268, 0.1772783100605011, 0.18316447734832764],
'std': [0.07692465931177139, 0.07895787060260773, 0.04069080576300621],
},
},
'iamlab_cmu_pickup_insert': {
'euler': {
'max': [3.14159250270088, 0.04584214946783205, 0.0450923308550184],
'mean': [-1.4122567194164972, -0.03209339030753262, 0.006593786875632054],
'min': [-3.141592327000409, -0.13647014666477264, -0.052753928700043584],
'q01': [-3.141100653025272, -0.1142266801993486, -0.026375220367685838],
'q99': [3.1411030904244917, 0.02986631061111326, 0.033458487885352856],
'std': [2.771965176997775, 0.02781015988983883, 0.015259428603879655],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]},
'joints': {
'max': [
0.24314826726913452,
0.7470589280128479,
0.7795426249504089,
-1.6825058460235596,
0.13216856122016907,
2.88216233253479,
1.443279504776001,
],
'mean': [
-0.09283880889415741,
0.16589058935642242,
0.13815303146839142,
-2.236156940460205,
-0.013342339545488358,
2.4330623149871826,
0.8266588449478149,
],
'min': [
-0.6439031958580017,
-0.7895585894584656,
-0.017819713801145554,
-2.9597983360290527,
-0.5142408013343811,
1.82623291015625,
0.24880434572696686,
],
'q01': [
-0.4449005424976349,
-0.687881588935852,
-0.00037891813553869724,
-2.9046988487243652,
-0.3344403803348541,
1.9697062969207764,
0.4578770101070404,
],
'q99': [
0.19939860701560974,
0.6767980456352234,
0.5450605750083923,
-1.7509592771530151,
0.07327680289745331,
2.758103370666504,
1.2580491304397583,
],
'std': [
0.1536070704460144,
0.3142663240432739,
0.1251399964094162,
0.28150704503059387,
0.062141284346580505,
0.17315669357776642,
0.19375956058502197,
],
},
'translation': {
'max': [0.6634980998075433, 0.23428471360823594, 0.4308285234773945],
'mean': [0.5266367430643635, 0.02849007901898653, 0.18708021993948834],
'min': [0.3071657210198696, -0.2975911497764855, 0.06578226212031718],
'q01': [0.314498567139741, -0.20315787858517198, 0.06785127291435837],
'q99': [0.6472027612545221, 0.20840713845170097, 0.37003410655170416],
'std': [0.08150424006275543, 0.11136019021346877, 0.0777212838203498],
},
},
'jaco_play': {
'euler': {
'max': [3.141592653589793, 0.0, 0.0],
'mean': [3.14159265358649, 0.0, 0.0],
'min': [3.141592653589793, 0.0, 0.0],
'q01': [3.141592653589793, 0.0, 0.0],
'q99': [3.141592653589793, 0.0, 0.0],
'std': [3.3031355428647657e-12, 0.0, 0.0],
},
'gripper': {
'max': [1.3890767097473145],
'min': [-0.00123199715744704],
'q01': [0.00061599857872352],
'q99': [1.3835327625274658],
},
'joints': {
'max': [
-0.6973918676376343,
4.516506671905518,
2.732408285140991,
0.27657514810562134,
2.0405426025390625,
0.9062198996543884,
],
'mean': [
-1.5362969636917114,
3.9368598461151123,
1.470100998878479,
-0.26214125752449036,
0.8781110644340515,
-0.4923328757286072,
],
'min': [
-2.394073486328125,
3.4039254188537598,
0.5968497395515442,
-1.0070130825042725,
-0.28005343675613403,
-1.3207207918167114,
],
'q01': [
-2.2932708263397217,
3.5327682495117188,
0.7838058471679688,
-0.5554518699645996,
0.002689492190256715,
-1.1670725345611572,
],
'q99': [
-0.8449038863182068,
4.363891124725342,
2.367223024368286,
0.06634082645177841,
1.595384120941162,
0.17280587553977966,
],
'std': [
0.32358425855636597,
0.18587827682495117,
0.3632128834724426,
0.11627942323684692,
0.3359401226043701,
0.3179064989089966,
],
},
'translation': {
'max': [0.2252020239830017, -0.19358234107494354, 0.4066188633441925],
'mean': [-0.07287009060382843, -0.42861491441726685, 0.2764993906021118],
'min': [-0.4429473876953125, -0.6635459661483765, 0.1568669229745865],
'q01': [-0.3789186179637909, -0.6194459795951843, 0.16865813732147217],
'q99': [0.21203258633613586, -0.26914602518081665, 0.38958534598350525],
'std': [0.138992577791214, 0.0908467248082161, 0.052496980875730515],
},
},
'kuka': {
'euler': {
'max': [3.141592651594966, 0.25521246168753553, 3.1415812671947134],
'mean': [-0.8123818266038756, -0.004947911572775501, 0.1622765703197679],
'min': [-3.141592653256577, -0.22240290857095402, -3.1415339219506544],
'q01': [-3.141539356675989, -0.042093398087956785, -0.6103775416503281],
'q99': [3.1415363480584935, 0.013040864331848548, 1.4661773197655914],
'std': [3.031619905792479, 0.010541045603285457, 0.386240278316171],
},
'gripper': {'max': [1.0], 'min': [0.0], 'q01': [0.0], 'q99': [1.0]},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.7243871092796326, 0.31309840083122253, 0.613274335861206],
'mean': [0.5510859489440918, 0.04787421599030495, 0.12812286615371704],
'min': [0.40573424100875854, -0.2028520256280899, 0.018512273207306862],
'q01': [0.4765772819519043, -0.14815208315849304, 0.06674224138259888],
'q99': [0.6515637040138245, 0.2447487711906433, 0.28018367290496826],
'std': [0.045206550508737564, 0.10222796350717545, 0.058390580117702484],
},
},
'language_table': {
'euler': {
'max': [3.141592653589793, 0.0, 0.0],
'mean': [3.141592653805758, 0.0, 0.0],
'min': [3.141592653589793, 0.0, 0.0],
'q01': [3.141592653589793, 0.0, 0.0],
'q99': [3.141592653589793, 0.0, 0.0],
'std': [2.1596502364831912e-10, 0.0, 0.0],
},
'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]},
'joints': {
'max': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'mean': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'min': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q01': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'q99': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
'std': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
'translation': {
'max': [0.6191085577011108, 0.345907062292099, 0.0],
'mean': [0.3994608521461487, 0.0045331921428442, 0.0],
'min': [0.18907572329044342, -0.3051564395427704, 0.0],
'q01': [0.19237099587917328, -0.2962527573108673, 0.0],
'q99': [0.6171894669532776, 0.30645298957824707, 0.0],
'std': [0.10587888211011887, 0.14165113866329193, 0.0],
},
},
'nyu_franka_play_dataset': {
'euler': {
'max': [3.141542687188373, 1.570464569214435, 3.1412348640834002],
'mean': [1.2730823611033315, 0.6681926368695384, 1.141188695183881],
'min': [-3.141417902437913, -0.6159773949821361, -3.1401807914238264],
'q01': [-3.1053165771573177, -0.18849691525500664, -2.8903965648285523],
'q99': [3.1156105211359257, 1.5443502891290082, 2.882089687976354],
'std': [1.519365023922235, 0.5398198304496786, 1.099335476495627],
},
'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]},
'joints': {
'max': [
0.873094916343689,
1.244848370552063,
2.6604084968566895,
-0.46481436491012573,
1.409853458404541,
3.5058159828186035,
2.2937142848968506,
],
'mean': [
-0.7650318741798401,
-0.9744310975074768,
1.4849265813827515,
-2.0917000770568848,
-0.42907631397247314,
2.4454119205474854,
-0.04509185999631882,
],
'min': [
-2.1353442668914795,
-1.6306616067886353,
-0.0167933851480484,
-2.8458828926086426,
-2.5256054401397705,
0.5101203322410583,
-2.1178665161132812,
],
'q01': [
-1.9259787797927856,
-1.537291407585144,
0.4664345383644104,
-2.7214717864990234,
-2.114818811416626,
1.0703176259994507,
-1.8231183290481567,
],
'q99': [
0.19280724227428436,
0.7752904295921326,
2.6163666248321533,
-0.8622008562088013,
0.7001188397407532,
3.4607036113739014,
1.7281100749969482,
],
'std': [
0.4459172785282135,
0.5925542712211609,
0.45783039927482605,
0.4001038372516632,
0.6094018816947937,
0.5889745950698853,
0.7601358294487,
],
},
'translation': {
'max': [0.6258905727651867, 0.8289460146095029, 0.8867425967452327],
'mean': [0.38785427731316113, 0.39551342608233353, 0.45161485937851],
'min': [0.05690113252698781, -0.06171031940925212, 0.1157533700659999],
'q01': [0.1393695951374117, 0.07645522416447147, 0.19364508601310984],
'q99': [0.5920727118835839, 0.6584802280677562, 0.8056891771281188],
'std': [0.10666290139662009, 0.1217948547812865, 0.1707295181003299],
},
},
'roboset': {
'euler': {
'max': [3.1415449294818236, 1.5705575529715636, 3.141527342124582],
'mean': [-0.0398455755412464, 1.0518070390619125, -0.015345692503002759],
'min': [-3.1415813300509536, -1.5222832468962035, -3.141575300866071],
'q01': [-2.9414386317311187, -0.24976770655101155, -2.985256521212579],
'q99': [2.9380437893235993, 1.5403010739503078, 2.9746912523985025],
'std': [1.7866587696177456, 0.40620530263065, 1.7288511340250616],
},
'gripper': {
'max': [0.83056640625],
'min': [0.0001499652862548828],
'q01': [0.0001499652862548828],
'q99': [0.82666015625],
},
'joints': {
'max': [
0.96240234375,
1.1162109375,
1.1064453125,
-0.98095703125,
2.30859375,
1.576171875,
1.7412109375,
],
'mean': [
0.005913593806326389,
0.1877261847257614,
0.04653879255056381,
-2.0529513359069824,
-0.011298442259430885,
0.6185526251792908,
-0.01701134257018566,
],
'min': [
-0.8330078125,
-0.74658203125,
-0.8642578125,
-2.892578125,
-1.390625,
-0.24658203125,
-2.953125,
],
'q01': [
-0.41015625,
-0.5302734375,
-0.6455078125,
-2.57421875,
-0.76416015625,
-0.0386962890625,
-1.435546875,
],
'q99': [
0.66455078125,
0.9501953125,
0.7529296875,
-1.251953125,
0.75244140625,
1.2314453125,
1.384765625,
],
'std': [
0.17915399372577667,
0.32234326004981995,
0.26069700717926025,
0.31767210364341736,
0.205329030752182,
0.33385637402534485,
0.6263682842254639,
],
},
'translation': {
'max': [0.5747738480567932, 0.3972920775413513, 0.7443570494651794],
'mean': [0.3331542909145355, 0.019357483834028244, 0.37330344319343567],
'min': [0.09978063404560089, -0.29593944549560547, 0.10065606236457825],
'q01': [0.18437016010284424, -0.25699371099472046, 0.15134164690971375],
'q99': [0.543661892414093, 0.29646238684654236, 0.6682320833206177],
'std': [0.07849054038524628, 0.12241040915250778, 0.1460595279932022],
},
},
'roboturk': {
'euler': {
'max': [3.1415925701602854, 1.5661525056538665, 3.1412803735012553],
'mean': [0.3946147188969296, -0.09657834247810235, -0.1840393277985236],
'min': [-3.141592645423486, -1.568282806779776, -3.141449561492499],
'q01': [-3.1378020570699525, -0.7907563562327732, -1.7347170410441308],
'q99': [3.137841563084971, 0.5773179844053413, 1.455226678618167],
'std': [2.9453717386817755, 0.2466574231368262, 0.6146181899264712],
},
'gripper': {
'max': [0.041667],
'min': [-1.6931189781743683e-05],
'q01': [0.0],
'q99': [0.040625325000000004],
},
'joints': {
'max': [
3.0592832565307617,
2.9869844913482666,
0.33399999141693115,
2.0810000896453857,
3.874000072479248,
0.21199999749660492,
32.02399826049805,
],
'mean': [
0.06695421040058136,
0.4899406135082245,
-0.0009486731723882258,
-0.00033266679383814335,
-0.000789206416811794,
0.017417440190911293,
-4.154728412628174,
],
'min': [
-2.995668888092041,
-2.9882216453552246,
-1.3760000467300415,
-1.9279999732971191,
-4.10099983215332,
-0.5040000081062317,
-23.58799934387207,
],
'q01': [
-1.1723066568374634,
-1.1159032583236694,
-0.00800000037997961,
-0.3580000102519989,
-0.5680000185966492,
-0.10000000149011612,
-14.303999900817871,
],
'q99': [
1.1788837909698486,
1.7529590129852295,
0.007000000216066837,
0.3799999952316284,
0.628000020980835,
0.13199999928474426,
7.5279998779296875,
],
'std': [
0.46230366826057434,
0.5157809257507324,
0.0024148367810994387,
0.11895754188299179,
0.18901629745960236,
0.050090208649635315,
4.478224754333496,
],
},
'translation': {
'max': [1.0403562763478107, 0.8124313436278997, 0.9670890184966487],
'mean': [0.601758638436076, -0.0011977902645611068, 0.061136336184971524],
'min': [-0.375370271681523, -0.9602276639048171, -0.39440951528330676],
'q01': [0.2845426588179575, -0.3288349528691964, -0.0934955147249334],
'q99': [0.8773894592247552, 0.28575230587640144, 0.3286392635131121],
'std': [0.12719404213832913, 0.1476650170975746, 0.09248325352628932],
},
},
'stanford_hydra_dataset': {
'euler': {
'max': [3.141592357591889, 1.5364991593031068, 3.141222461465462],
'mean': [-0.2789889021491306, -0.19247585545209975, 0.34538177027867784],
'min': [-3.1415919781981962, -1.5573291068141568, -3.1413972494043216],
'q01': [-3.138796116403405, -1.107955818954808, -2.271278880707407],
'q99': [3.1386728135975384, 0.915868569686168, 2.4011245578416336],
'std': [2.883189482433683, 0.43362190146795254, 1.199450318648319],
},
'gripper': {
'max': [0.0811397060751915],
'min': [-0.0005391233135014772],
'q01': [1.3133333141013281e-06],
'q99': [0.08100640028715134],
},
'joints': {
'max': [
1.9801124334335327,
1.5752310752868652,
2.6661651134490967,
-0.6274839639663696,
2.3508076667785645,
3.4612932205200195,
2.747223377227783,
],
'mean': [
-0.052988938987255096,
-0.19803331792354584,
-0.046565085649490356,
-2.395282506942749,
0.4740530848503113,
2.119305372238159,
0.031005585566163063,
],
'min': [
-2.5843772888183594,
-1.287919044494629,
-1.9306836128234863,
-2.9673542976379395,
-1.9528191089630127,
0.5680276155471802,
-2.74257755279541,
],
'q01': [
-1.0905369520187378,
-0.9110571146011353,
-0.9165626764297485,
-2.876581907272339,
-0.9029546976089478,
1.3720064163208008,
-2.0172016620635986,
],
'q99': [
0.6443323493003845,
1.0052685737609863,
0.9653106927871704,
-1.4030946493148804,
1.738037109375,
2.943267822265625,
2.451101541519165,
],
'std': [
0.3502423167228699,
0.41914135217666626,
0.39895766973495483,
0.31860536336898804,
0.5207170248031616,
0.35328012704849243,
1.2519474029541016,
],
},
'translation': {
'max': [0.7598075952006731, 0.36410959179102775, 0.597508369211053],
'mean': [0.4348564561896346, 0.024563536690147474, 0.2938664975066919],
'min': [0.18414022283711137, -0.32311685510034516, 0.017172937739480726],
'q01': [0.2373728557200433, -0.26521679456931635, 0.09069013461938308],
'q99': [0.712423814640525, 0.25299056815993204, 0.49505407602925094],
'std': [0.10465658310507645, 0.12745896408828714, 0.10379447506049931],
},
},
'taco_play': {
'euler': {
'max': [3.1415891647338867, 0.30717557668685913, 2.4624788761138916],
'mean': [-0.06622616201639175, -0.17697572708129883, 0.4614097774028778],
'min': [-3.1415886878967285, -0.8683895468711853, -1.789320707321167],
'q01': [-3.139000654220581, -0.6962019205093384, -1.1729414463043213],
'q99': [3.1392500400543213, 0.12436603754758835, 1.8065968751907349],
'std': [3.0384926795959473, 0.18692463636398315, 0.6806972622871399],
},
'gripper': {
'max': [0.08074767142534256],
'min': [9.652999870013446e-05],
'q01': [0.00015234666352625936],
'q99': [0.08067675679922104],
},
'joints': {
'max': [
0.9704633951187134,
0.4390600025653839,
0.9892600178718567,
-1.116769790649414,
1.2800638675689697,
2.894169330596924,
2.627519130706787,
],
'mean': [
0.14954668283462524,
-0.3514331877231598,
0.1295831948518753,
-2.205498456954956,
0.11473798751831055,
2.0389058589935303,
0.5516068339347839,
],
'min': [
-0.6495265960693359,
-1.2877551317214966,
-1.5125423669815063,
-2.956650733947754,
-1.0317779779434204,
1.1109415292739868,
-1.8966686725616455,
],
'q01': [
-0.4596467614173889,
-0.9323632121086121,
-0.6153853535652161,
-2.8417088985443115,
-0.3472142219543457,
1.428541898727417,
-0.5756277441978455,
],
'q99': [
0.7262046933174133,
0.21787810325622559,
0.7387099266052246,
-1.3517359495162964,
0.7581852674484253,
2.564530849456787,
1.6355704069137573,
],
'std': [
0.281919926404953,
0.2702403962612152,
0.3334408402442932,
0.412654846906662,
0.21815766394138336,
0.2808808982372284,
0.4732993543148041,
],
},
'translation': {
'max': [0.7124971151351929, 0.6118948459625244, 0.617118775844574],
'mean': [0.37278515100479126, 0.13003520667552948, 0.38341864943504333],
'min': [0.07693906873464584, -0.4944135844707489, 0.20030911266803741],
'q01': [0.1368357390165329, -0.4297449290752411, 0.20516259968280792],
'q99': [0.6700438857078552, 0.5943909883499146, 0.5966404676437378],
'std': [0.11509375274181366, 0.2694733738899231, 0.11266136914491653],
},
},
'toto': {
'euler': {
'max': [3.1414193360696108, 1.570402878730552, 3.1415549880169924],
'mean': [-1.31124856158284, 0.8107744638901545, 1.042033586691101],
'min': [-3.1415814091974665, -1.3210638993411208, -3.141515215393105],
'q01': [-2.9459294143675696, -0.6449196412662773, -2.9790265961739872],
'q99': [2.824635320375007, 1.4897934376557762, 3.0083314049108925],
'std': [1.2205361480020216, 0.45724967725161403, 1.3644152538580079],
},
'gripper': {'max': [-1.0], 'min': [-1.0], 'q01': [-1.0], 'q99': [-1.0]},
'joints': {
'max': [
1.4926575422286987,
0.7715681195259094,
1.9506583213806152,
-1.6330622434616089,
2.804656744003296,
2.0850095748901367,
1.9661481380462646,
],
'mean': [
-0.040506213903427124,
-0.2537725269794464,
0.05389903113245964,
-2.5135483741760254,
-0.161987766623497,
0.8815387487411499,
0.0034975146409124136,
],
'min': [
-1.8345723152160645,
-1.7133995294570923,
-1.40083646774292,
-2.985748767852783,
-2.814244270324707,
-0.5165338516235352,
-3.5424435138702393,
],
'q01': [
-1.4044159650802612,
-1.3978556394577026,
-0.8029389381408691,
-2.97342586517334,
-1.5402026176452637,
0.0011001524981111288,
-2.9709367752075195,
],
'q99': [
0.7075999975204468,
0.42536962032318115,
1.3345959186553955,
-1.8281668424606323,
2.5240790843963623,
2.0835044384002686,
1.212519884109497,
],
'std': [
0.3405684530735016,
0.3766334354877472,
0.40128207206726074,
0.3202400505542755,
0.6421889066696167,
0.33138400316238403,
0.5801302194595337,
],
},
'translation': {
'max': [0.7777318005382999, 0.5688841397096768, 0.672235403102292],
'mean': [0.15078069344893, -0.03012185632654897, 0.35519823701585035],
'min': [-0.18532184496861265, -0.6282599632853518, 0.06312318086547261],
'q01': [-0.09177927811484686, -0.3571658956454935, 0.21965464356283362],
'q99': [0.6757593680120374, 0.2889021640318381, 0.501109413161318],
'std': [0.14726563054735392, 0.13472763062259546, 0.059123705378892395],
},
},
'ucsd_kitchen_dataset': {
'euler': {
'max': [3.141592644655991, 0.10471752134072587, 2.0420500160887434],
'mean': [-2.349219123990161, -0.5844573008678524, 0.8632842703198627],
'min': [-3.141592637838787, -1.5533486027935484, -1.676143206926509],
'q01': [-3.1415921825947537, -1.535883315130894, -0.6806766312742327],
'q99': [3.1328677560237566, 0.01745363977705911, 1.8326023938995777],
'std': [1.756197370631558, 0.482005280180061, 0.7151569996653978],
},
'gripper': {'max': [0.0], 'min': [0.0], 'q01': [0.0], 'q99': [0.0]},
'joints': {
'max': [
2.4945411682128906,
1.0603798627853394,
0.5518655776977539,
2.71895432472229,
1.9269907474517822,
2.2231369018554688,
0.7844778895378113,
],
'mean': [
0.35797572135925293,
-0.23473882675170898,
-0.34185951948165894,
0.8982798457145691,
0.5780586004257202,
1.0351765155792236,
-1.2190868854522705,
],
'min': [
-0.6943671703338623,
-1.6351218223571777,
-2.0142624378204346,
0.00768359424546361,
-0.17713193595409393,
0.011292487382888794,
-2.6592507362365723,
],
'q01': [
-0.43497809767723083,
-1.4843493700027466,
-1.7824348211288452,
0.0955040231347084,
-0.07186625152826309,
0.27948257327079773,
-2.4006130695343018,
],
'q99': [
2.25881028175354,
0.8190488815307617,
0.2994879484176636,
2.153681993484497,
1.6126199960708618,
1.8523634672164917,
0.06944511085748672,
],
'std': [
0.6131777167320251,
0.6322475075721741,
0.5427954792976379,
0.539089024066925,
0.43294647336006165,
0.3588581085205078,
0.6916991472244263,
],
},
'translation': {
'max': [0.681192194250889, 0.2323359966361606, 0.6159547986373596],
'mean': [0.4067338752307897, 0.027551073758015975, 0.313373054289359],
'min': [0.13453314224525192, -0.2553313745227843, 0.03515888153950328],
'q01': [0.18739915591793913, -0.18234310016370514, 0.048970692427190994],
'q99': [0.6410437631246786, 0.20632224240340083, 0.5983893391276117],
'std': [0.12446715882142037, 0.08252308025761342, 0.11421020385340024],
},
},
'utaustin_mutex': {
'euler': {
'max': [3.1415927410125732, 0.5760129690170288, 0.6430304050445557],
'mean': [1.643947720527649, 0.03388536348938942, -0.31948330998420715],
'min': [-3.141592025756836, -0.9390894770622253, -1.8157784938812256],
'q01': [-3.1398682594299316, -0.7223533391952515, -1.5468647480010986],
'q99': [3.140272855758667, 0.3558599054813385, 0.3583984971046448],
'std': [2.2715108394622803, 0.19258549809455872, 0.5275800824165344],
},
'gripper': {
'max': [0.07569462060928345],
'min': [0.00026726332725957036],
'q01': [0.0019378233700990677],
'q99': [0.07565194368362427],
},
'joints': {
'max': [
0.5831676125526428,
0.8024097681045532,
1.350082516670227,
-1.1145747900009155,
0.5345654487609863,
3.2108523845672607,
2.875584363937378,
],
'mean': [
-0.09772995859384537,
0.05997378006577492,
-0.009316973388195038,
-2.3428714275360107,
-0.48651641607284546,
2.2390215396881104,
1.3341773748397827,
],
'min': [
-0.6990927457809448,
-0.8311276435852051,
-0.5323343276977539,
-2.9861886501312256,
-2.060809850692749,
0.826747477054596,
0.05972049757838249,
],
'q01': [
-0.532222330570221,
-0.5078368782997131,
-0.3437865376472473,
-2.784912586212158,
-1.8685580492019653,
1.3974121809005737,
0.5157660841941833,
],
'q99': [
0.35958197712898254,
0.6339818835258484,
0.5833610892295837,
-1.612648367881775,
0.2780923545360565,
2.8907132148742676,
2.837320327758789,
],
'std': [
0.2101736217737198,
0.2880096137523651,
0.17824982106685638,
0.23544654250144958,
0.6387099027633667,
0.33533376455307007,
0.5090957880020142,
],
},
'translation': {
'max': [0.581696093082428, 0.4334338903427124, 0.6800903081893921],
'mean': [0.4331520199775696, -0.10096638649702072, 0.22044037282466888],
'min': [0.2454746514558792, -0.502901017665863, 0.008792983368039131],
'q01': [0.3217194080352783, -0.4733337163925171, 0.014122226275503635],
'q99': [0.5321439504623413, 0.3733823001384735, 0.5785381197929382],
'std': [0.04380597919225693, 0.20559938251972198, 0.12614832818508148],
},
},
'viola': {
'euler': {
'max': [3.141592502593994, 0.9555190801620483, 0.40456175804138184],
'mean': [-0.7171992063522339, 0.07086259871721268, -0.67817223072052],
'min': [-3.141592025756836, -0.4305531978607178, -2.212939739227295],
'q01': [-3.1400232315063477, -0.19947049021720886, -1.8703945875167847],
'q99': [3.1401824951171875, 0.7830920219421387, 0.19578109681606293],
'std': [2.902841567993164, 0.20232577621936798, 0.7306717038154602],
},
'gripper': {
'max': [0.07748929411172867],
'min': [-5.3190000471659005e-05],
'q01': [0.00020356666937004775],
'q99': [0.07747221738100052],
},
'joints': {
'max': [
0.3366159200668335,
0.8983011841773987,
0.330204576253891,
0.0,
1.3329579830169678,
3.1103708744049072,
2.8318748474121094,
],
'mean': [
0.027196571230888367,
0.2290346622467041,
-0.06033482402563095,
-2.060992479324341,
0.25734394788742065,
2.2569527626037598,
1.2675725221633911,
],
'min': [
-0.4073199927806854,
-0.23913456499576569,
-0.5244941115379333,
-2.744713068008423,
-0.48249971866607666,
0.0,
-0.22237232327461243,
],
'q01': [
-0.26899582147598267,
-0.189721941947937,
-0.3926161825656891,
-2.620305061340332,
-0.199931338429451,
1.6743353605270386,
0.1440846174955368,
],
'q99': [
0.23294022679328918,
0.764011561870575,
0.20191603899002075,
-1.5629768371582031,
1.039290189743042,
2.9412121772766113,
2.798232078552246,
],
'std': [
0.10160040855407715,
0.21415969729423523,
0.14444759488105774,
0.27077943086624146,
0.3251941502094269,
0.338652640581131,
0.7529342174530029,
],
},
'translation': {
'max': [0.6868156790733337, 0.20999883115291595, 0.5037735104560852],
'mean': [0.5464076995849609, 0.006265466101467609, 0.23136523365974426],
'min': [0.0, -0.33860740065574646, 0.0],
'q01': [0.40061360597610474, -0.25196850299835205, 0.010269512422382832],
'q99': [0.6458418369293213, 0.17776551842689514, 0.4456312954425812],
'std': [0.062116123735904694, 0.12932434678077698, 0.13205111026763916],
},
},
}
@cached_property
def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm):
return {}
with open(self.config.control_stats_path, 'r') as file:
stats = yaml.safe_load(file)
if self.config.delta_controls:
if self.control_io_config.future_controls_sequence_stride_sec is None:
horizon = 0.0
else:
horizon = self.control_io_config.future_controls_sequence_stride_sec
elif self.control_io_config.future_controls_sequence_stride_sec is None:
if self.control_io_config.future_controls_sequence_length == 1:
horizon = 0.0
else:
raise NotImplementedError()
else:
horizon = (
self.control_io_config.future_controls_sequence_length
* self.control_io_config.future_controls_sequence_stride_sec
)
key = f'horizon_{round(horizon, 2)}s'
if key in stats:
stats = stats[key]
else:
raise ValueError(
f'Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]'
)
return stats
@cached_property
def dataset_names(self) -> List[str]:
if (
is_global_norm(self.config.rotation_norm)
and is_global_norm(self.config.obs_rotation_norm)
and is_global_norm(self.config.translation_norm)
and is_global_norm(self.config.obs_translation_norm)
):
return ['ANY']
return list(set(self._control_stats.keys()) | set(self._observation_stats.keys()))
def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
"""
Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding
w.r.t. the 0-th element preceding the sequence
Ex:
Sequence of points: T1, T2, T3, T4
`translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame,
implicitly encoded in T0T1
Output: T0T1, T0T2, T0T3, T0T4
Args:
translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
corresponds to the sequence dimension
Returns:
torch.Tensor of the same shape as translation_sequence, containing delta translations
"""
assert translation_sequence.ndim >= 3, translation_sequence.shape
delta_translations = torch.cumsum(translation_sequence, dim=-2)
return delta_translations
class RegressionProcessor(VLAMProcessor[RegressionProcessorConfig]):
def policy_control_plan_from_model_target(
self, target: RoboticsTarget, dataset_name: np.ndarray
) -> RoboticsControlPlan:
translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key='translation')
rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key='rotation')
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT)
gripper_prob = target.gripper
if self.config.delta_controls:
translation_m = delta_to_relative_translations(translation_m)
rotmat = delta_to_relative_rotations(rotmat)
return RoboticsControlPlan(
translation_m=translation_m,
rotmat=rotmat,
gripper_prob=gripper_prob,
valid_mask=target.valid_mask,
)
def policy_control_plan_from_model_output(
self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor
) -> RoboticsControlPlan:
"""Called during inference to create control plan from model output"""
translation_m = self.unnormalize(
model_output.translation, dataset_name=dataset_name, key='translation'
)
rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key='rotation')
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True)
gripper_prob = torch.sigmoid(model_output.gripper)
if self.config.delta_controls:
translation_m = delta_to_relative_translations(translation_m)
rotmat = delta_to_relative_rotations(rotmat)
return RoboticsControlPlan(
translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=valid_mask
)
class VLARMProcessor(RegressionProcessor):
def __init__(self, config: VLARMProcessorConfig, vlm_processor: VLMProcessor):
Configurable.__init__(self, config)
self.vlm_processor = vlm_processor
self.control_tokenizer = EmptyTokenizer(
config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
)
self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {
'obs_translation': self.obs_translation_norm_bounds,
'obs_rotation': self.obs_rotation_norm_bounds,
'translation': self.translation_norm_bounds,
'rotation': self.rotation_norm_bounds,
'joints': self.joints_norm_bounds,
}
@property
def dataset_np_names(self) -> np.ndarray:
return self._dataset_np_names
def _set_dataset_np_names(self, dataset_np_names: np.ndarray) -> None:
self._dataset_np_names = dataset_np_names
def preprocess_inputs(
self,
chat: List[str],
images: Dict[str, List[PIL.Image.Image]],
ee_pose_translation: np.ndarray,
ee_pose_rotation: np.ndarray,
gripper: np.ndarray,
joints: np.ndarray,
dataset_name: np.ndarray,
inference_mode: bool,
control_target: Optional[RoboticsTarget] = None,
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
images: Dict[str, torch.Tensor] = inputs['images']
input_ids: torch.Tensor = inputs['input_ids'][..., : self.tokenizer.model_max_length]
target_text_tokens_ids: torch.Tensor = inputs['target_ids'][..., : self.tokenizer.model_max_length]
attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
gripper = preprocess_gripper_observation(gripper, dataset_name)
gripper = torch.tensor(gripper, dtype=torch.float32)
ee_pose_translation = self.normalize(
ee_pose_translation, dataset_name=dataset_name, key='obs_translation'
)
ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key='obs_rotation')
joints = torch.tensor(joints, dtype=torch.float32)
if joints.shape[-1] < 7:
missing_size = 7 - joints.shape[-1]
joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
joints = self.normalize(joints, dataset_name=dataset_name, key='joints')
return {
'images': images,
'input_ids': input_ids,
'target_text_tokens_ids': target_text_tokens_ids,
'attn_mask': attn_mask,
'ee_pose_translation': ee_pose_translation,
'ee_pose_rotation': ee_pose_rotation,
'gripper': gripper,
'joints': joints,
'control_tokens_ids': None,
'target_control_tokens_ids': None,
}
def model_otoi_but_gt_gripper(
self,
translation_output: torch.Tensor,
rotation_output: torch.Tensor,
reference_translation: torch.Tensor,
reference_rotation: torch.Tensor,
gt_gripper: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Called during inference to create control plan from model output"""
translation_delta = self.unnormalize(
translation_output, dataset_name=self.dataset_np_names, key='translation'
)
rotation_delta = self.unnormalize(rotation_output, dataset_name=self.dataset_np_names, key='rotation')
translation_reference = self.unnormalize(
reference_translation, dataset_name=self.dataset_np_names, key='obs_translation'
)
rotation_reference = self.unnormalize(
reference_rotation, dataset_name=self.dataset_np_names, key='obs_rotation'
)
translation = translation_reference + translation_delta
rotation = delta_to_world_rotations(rotation_delta, rotation_reference)
translation_next = self.normalize(
translation, dataset_name=self.dataset_np_names, key='obs_translation'
)
rotation_next = self.normalize(rotation, dataset_name=self.dataset_np_names, key='obs_rotation')
gripper_next = gt_gripper
return translation_next, rotation_next, gripper_next
def make_causal_mask(shape: Sequence[int]) -> torch.Tensor:
"""
Create a causal attention mask of shape `shape`
Args:
shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len]
Returns:
torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend
to the corresponding column (i.e. key)
Example:
shape = (3, 5) -> Mask the upper triangular part
[
[ 1, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0],
[ 1, 1, 1, 0, 0]
]
"""
return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0)
def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor:
"""
Enable full bi-directional attention in `attn_mask` inside specific blocks
Args:
attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool
where False values indicate disabled attention
full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate
positions where full bi-directional attention should be enabled
Example:
1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0,
1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
"""
assert full_attn.dtype == torch.bool, full_attn.dtype
assert full_attn.ndim == 1, full_attn.shape
assert full_attn.shape[0] == attn_mask.shape[-2], f'{full_attn.shape[0]}, {attn_mask.shape}'
if attn_mask.shape[-1] != attn_mask.shape[-2]:
raise NotImplementedError('Only self-attention supported right now.')
x = full_attn.view(-1, 1) & full_attn.view(1, -1)
x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]])
x = torch.cumprod(x, dim=1).to(dtype=torch.bool)
x = x & x.permute(1, 0)
mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn
mask_indices = torch.where(mask_positions)[0]
x[mask_indices, mask_indices] = 0
attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1])
return attn_mask
IGNORE_INDEX = -100
class PaliGemmaProcessor(VLMProcessor[PaliGemmaProcessorConfig]):
def __init__(
self,
config: PaliGemmaProcessorConfig,
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
**kwargs,
):
del kwargs
super().__init__(config)
self.hf_processor = hf_processor
self.hf_processor.image_processor.size = dict(self.config.image_sizes['main'].as_json())
self.hf_processor.image_seq_length = self.config.num_image_tokens['main']
self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens['main']
self.bos_id: int = self.tokenizer.bos_token_id
self.eos_id: int = self.tokenizer.eos_token_id
self.sep_token = '\n'
self.sep_id: int = self.tokenizer(
self.sep_token, padding=False, add_special_tokens=False, return_attention_mask=False
)['input_ids'][0]
self.image_token_id: int = self.tokenizer(
self.config.image_token, padding=False, add_special_tokens=False, return_attention_mask=False
)['input_ids'][0]
self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values())
self.bbox_pattern = re.compile(
'\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]'
)
def preprocess_inputs(
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
"""
Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at
https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs
Chat must be always made of separate messages from user and model, always starting with user
<image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos>
Args:
chat: List[str] of even size where each entry corresponds to a different turn in the conversation
images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys
in the Dict and the List corresponds to history of images
"""
for key, value in images.items():
if not isinstance(value, list):
raise TypeError(f'Camera {key} contains values of type {type(value)} instead of list')
(input_ids, target_ids) = ([], [])
for i, text in enumerate(chat):
text = text.replace(self.sep_token, ' ').replace('<image>', '')
text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text)
turn_input_ids: List[int] = self.tokenizer(
text, padding=False, add_special_tokens=False, return_attention_mask=False
)['input_ids']
if i % 2 == 0:
turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids)
else:
turn_target_ids = turn_input_ids
if i != len(chat) - 1:
turn_input_ids = turn_input_ids + [self.sep_id]
turn_target_ids = turn_target_ids + [IGNORE_INDEX]
input_ids = input_ids + turn_input_ids
target_ids = target_ids + turn_target_ids
input_ids = [self.bos_id] + input_ids + [self.eos_id]
target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id]
image_tokens = self.image_tokens
input_ids = image_tokens + input_ids
target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids
input_ids = torch.tensor(input_ids, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
image_tensors: Dict[str, torch.Tensor] = {
f'{camera_name}.siglip': self.hf_processor.image_processor(
camera_images, size=self.config.image_sizes[camera_name].as_json(), return_tensors='pt'
)['pixel_values']
for (camera_name, camera_images) in images.items()
}
attn_mask = make_causal_mask([len(input_ids), len(input_ids)])
attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX)
return {
'input_ids': input_ids,
'target_ids': target_ids,
'images': image_tensors,
'attn_mask': attn_mask,
}
@property
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return self.hf_processor.tokenizer
@staticmethod
def _bbox_to_loc_tokens(match: str) -> str:
"""
https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
"""
floats = list(map(float, match.groups()))
transformed = [f'<loc{np.clip(round(num * 1024), 0, 1023):04d}>' for num in floats]
return f"[{', '.join(transformed)}]"
@property
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
return self.config.image_sizes