spear1-franka / processing_spear.py
giu-alb's picture
Super-squash branch 'main' using huggingface_hub
a8bf2f3 verified
import collections
import collections.abc
import re
import warnings
from abc import abstractmethod
from functools import cached_property
from typing import Dict, List, Optional, Sequence, Tuple, TypeVar
import numpy as np
import PIL.Image
import roma
import torch
import torchvision.transforms.v2
import transformers
import yaml
from .common_spear import (
Configurable,
FlowInput,
Normalization,
ResizeMode,
RoboticsControlPlan,
RoboticsFlowInput,
RoboticsInput,
RoboticsOutput,
RoboticsTarget,
RotationFormat,
expand_dims,
is_quaternion,
is_rotmat,
is_rotmat_3x3,
is_rotmat_9,
quaternion_half_cover,
rotmat_as_3x3,
rotmat_as_9,
)
from .configuration_spear import (
ControlDataIOConfig,
ImageSizeConfig,
PaliGemmaProcessorConfig,
)
class VLMProcessor(Configurable):
@abstractmethod
def preprocess_inputs(
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ...
@property
@abstractmethod
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
pass
@property
@abstractmethod
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
pass
class EmptyTokenizer(Configurable):
"""
Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the
desired result. Includes the hidden states for the image tokens.
"""
def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None:
super().__init__(config)
self.tokenizer = tokenizer
def __call__(self, *_) -> str:
return ""
def np_unique(
data: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute unique elements in data and corresponding indices.
np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply
run np.unique on unsorted data, the indices you will get will be invalid.
"""
(_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True)
(_, indices_of_first_occurence, inverse_indices, counts) = np.unique(
indices[inverse], return_index=True, return_inverse=True, return_counts=True
)
unique_ids = data[indices_of_first_occurence]
return unique_ids, indices_of_first_occurence, inverse_indices, counts
def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor:
"""
Args:
angles: Euler angles in radians in the format 'xyz', shape [..., 3]
Returns:
torch.Tensor of shape [..., 3, 3] containing rotation matrices
"""
return roma.euler_to_rotmat(convention="xyz", angles=angles, degrees=False)
def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor:
"""
Args:
angles: Euler angles in radians in the format 'xyz', shape [..., 3]
Returns:
torch.Tensor of shape [..., 4] containing unit quaternions
"""
return roma.euler_to_unitquat(convention="xyz", angles=angles, degrees=False, normalize=True)
def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor:
"""
Args:
quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4]
eps: Small constant to prevent division by zero
Returns:
torch.Tensor of shape [..., 4] of unit quaternions
"""
return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps)
def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor:
"""
Args:
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
Returns:
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
"""
unit_quat = normalize_quaternion(quaternion)
rotmat = roma.unitquat_to_euler(convention="xyz", quat=unit_quat, as_tuple=False, degrees=False)
return rotmat
def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor:
"""
Args:
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized
Returns:
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3)
"""
unit_quat = normalize_quaternion(quaternion)
rotmat = roma.unitquat_to_rotmat(unit_quat)
return rotmat
def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor:
"""
Args:
rotmat: Batch of rotation matrices, shape [..., 3, 3]
Returns:
Batch of unit quaternions, shape [..., 4]
"""
rotmat = rotmat_as_3x3(rotmat)
return roma.rotmat_to_unitquat(rotmat)
def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor:
"""
Args:
rotmat: Batch of rotation matrices, shape [..., 3, 3]
Returns:
Batch of Euler angles in radiant, shape [..., 3]
"""
rotmat = rotmat_as_3x3(rotmat)
return roma.rotmat_to_euler(convention="xyz", rotmat=rotmat, as_tuple=False, degrees=False)
def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor:
"""
Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
- Let SVD(M) = U \Sigma V^T
- Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T
- det(UV^T) ensures that det(SVD+(M)) = 1
- The return value is a rotation matrix (ortonormal) with the least-squares distance to M
Args:
x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3]
Returns:
torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3)
"""
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="In CPU autocast, but the target dtype is not supported. Disabling autocast.",
)
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
matrices = x.view(-1, 3, 3)
matrices = matrices.to(dtype=torch.float32)
(u, s, v) = torch.svd(matrices)
vt = torch.transpose(v, 1, 2)
det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1)
diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1)
result = torch.matmul(u, diag_vt)
result = result.view(*x.shape)
result = result.to(dtype=x.dtype)
return result
def is_rotmat_orthonormal(
rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = "none"
) -> torch.Tensor | bool:
"""
Check if a rotation matrix is orthonormal or not.
Args:
rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9]
epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally,
anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not
reduction:
'none' - returns torch.Tensor of bools with the same batch shape
'all' - returns a bool, True is ALL matrices in the batch are orthonormal
Returns:
torch.Tensor with the same batch shape or bool
"""
assert is_rotmat(rotmat)
rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32))
is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon)
if reduction == "none":
return is_orthonormal
if reduction == "all":
return bool(torch.all(is_orthonormal).item())
raise ValueError(f"Unknown reduction mode {reduction}")
def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool:
"""
Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3,
also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles
when accidentally `rotmat.shape[-2:] == [3, 3]`
"""
return (
is_rotmat_9(rotmat)
or is_rotmat_3x3(rotmat)
and is_rotmat_orthonormal(rotmat, epsilon=0.01, reduction="all")
)
def is_euler(euler: torch.Tensor) -> bool:
return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler)
def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor:
if is_quaternion(rotation):
return normalize_quaternion(rotation)
if is_euler(rotation):
return rotation
if is_rotmat(rotation):
is_flat = is_rotmat_9(rotation)
rotation = rotmat_as_3x3(rotation) if is_flat else rotation
rotmat = roma.special_gramschmidt(rotation)
rotmat = rotmat_as_9(rotmat) if is_flat else rotmat
return rotmat
raise ValueError(f"Unknown rotation format: {rotation.shape}")
def rotation_format_from_tensor(rotation) -> RotationFormat:
if is_quaternion(rotation):
return RotationFormat.QUATERNION
if is_orthonormal_rotmat(rotation):
return RotationFormat.ROTMAT
if is_euler(rotation):
return RotationFormat.EULER
raise ValueError(f"Tensor shape {rotation.shape} is not a valid rotation format")
def is_unit_quaternion(
quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = "none"
) -> torch.Tensor | bool:
"""
Check if a quternion is normalized or not.
Args:
quaternion: torch.Tensor of shape [..., 4]
tolerance: Tolerance for numerical comparisons
reduction:
'none' - returns torch.Tensor of bools with the same batch shape
'all' - returns a bool, True if ALL quaternions in the batch are normalized
Returns:
torch.Tensor with the same batch shape or bool
"""
assert is_quaternion(quaternion)
is_norm = torch.isclose(
quaternion.norm(dim=-1, keepdim=True),
torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device),
atol=epsilon,
)
if reduction == "none":
return is_norm
if reduction == "all":
return bool(torch.all(is_norm).item())
raise ValueError(f"Unknown reduction mode {reduction}")
def convert_rotation(
rotation: torch.Tensor | np.ndarray,
output_format: RotationFormat,
autonorm: bool = True,
half_cover: bool = True,
) -> torch.Tensor | np.ndarray:
is_np = isinstance(rotation, np.ndarray)
if is_np:
rotation = torch.from_numpy(rotation)
if is_quaternion(rotation):
if autonorm and not is_unit_quaternion(rotation, reduction="all"):
rotation = normalize_quaternion(rotation)
if output_format == RotationFormat.QUATERNION:
output = rotation
elif output_format == RotationFormat.ROTMAT:
output = rotmat_as_9(quaternion_to_rotmat(rotation))
elif output_format == RotationFormat.EULER:
output = quaternion_to_euler(rotation)
else:
raise NotImplementedError(f"Unsupported rotation format: {output_format}")
elif is_orthonormal_rotmat(rotation):
if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction="all"):
rotation = symmetric_orthogonalization(rotation)
if output_format == RotationFormat.QUATERNION:
output = rotmat_to_unit_quaternion(rotation)
elif output_format == RotationFormat.ROTMAT:
output = rotmat_as_9(rotation)
elif output_format == RotationFormat.EULER:
output = rotmat_to_euler(rotation)
else:
raise NotImplementedError(f"Unsupported rotation format: {output_format}")
elif is_euler(rotation):
if output_format == RotationFormat.QUATERNION:
output = euler_to_unit_quaternion(rotation)
elif output_format == RotationFormat.ROTMAT:
output = rotmat_as_9(euler_to_rotmat(rotation))
elif output_format == RotationFormat.EULER:
output = rotation
else:
raise NotImplementedError(f"Unsupported rotation format: {output_format}")
else:
raise ValueError(f"Unknown rotation encoding with shape {rotation.shape}")
if output_format == RotationFormat.QUATERNION and half_cover:
output = quaternion_half_cover(output)
if is_np:
output = output.numpy()
return output
def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor:
"""
Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the
sequence to the 0-th element preceding the sequence
Ex:
`rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame,
implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame
Output: R_01, R_02, R_03, R_04
Args:
rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing
either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions
Returns:
torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations
(R_01, R_02, R_03, R_04, ...)
TODO: Can you make it work without for loop
"""
assert rotation_sequence.ndim >= 3, rotation_sequence.shape
rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence)
rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION)
batch_dims = np.arange(rotation_sequence.ndim - 2)
delta_rotations = torch.cat(
[rotation_sequence[..., :1, :]]
+ [
roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2))
for i in range(2, rotation_sequence.shape[-2] + 1)
],
dim=-2,
)
delta_rotations = convert_rotation(delta_rotations, rotation_format)
return delta_rotations
def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray:
"""Make sure image is of type np.ndarray and HWC format"""
if isinstance(image, PIL.Image.Image):
image = np.asarray(image)
assert isinstance(image, np.ndarray), type(image)
assert image.ndim in [2, 3], image.shape
if image.ndim == 3:
assert image.shape[-1] <= 4, image.shape
return image
def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]:
if isinstance(image, np.ndarray):
(height, width) = image.shape[:2]
else:
(width, height) = image.size
return height, width
def pad_image(
image: PIL.Image.Image | np.ndarray,
target_size: dict[str, int],
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
"""Pad image adding a symmetric border around the height/width."""
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image)
(height, width) = hw_from_image(image)
(target_width, target_height) = (target_size["width"], target_size["height"])
if width == target_width and height == target_height:
return image
assert target_width >= width, f"Can't pad image of width {width} to {target_width}"
assert target_height >= height, f"Can't pad image of height {height} to {target_height}"
(horizontal_pad, vertical_pad) = (
int((target_width - width) / 2),
int((target_height - height) / 2),
)
if isinstance(image, np.ndarray):
padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * (
image.ndim - 2
)
image = np.pad(image, padding, mode="constant", constant_values=pad_value)
else:
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
image = torchvision.transforms.v2.functional.pad(
image, padding=padding, fill=pad_value, padding_mode="constant"
)
return image
def pad_image_to_ratio(
image: PIL.Image.Image | np.ndarray,
target_wh_ratio: float,
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
"""Pad image to a target aspect ratio."""
(height, width) = hw_from_image(image)
wh_ratio = width / height
if target_wh_ratio >= wh_ratio:
pad_size = {"width": round(height * target_wh_ratio), "height": height}
else:
pad_size = {"width": width, "height": round(width / target_wh_ratio)}
image = pad_image(image, target_size=pad_size, pad_value=pad_value)
return image
def crop_image(
image: np.ndarray | PIL.Image.Image,
start_height: int,
start_width: int,
target_height: int,
target_width: int,
) -> np.ndarray | PIL.Image.Image:
np_image = assert_np_hwc_or_hw_image(image)
(height, width) = hw_from_image(image)
assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
(start_height, start_width) = (round(start_height), round(start_width))
(target_height, target_width) = (round(target_height), round(target_width))
np_image = np_image[
start_height : start_height + target_height,
start_width : start_width + target_width,
...,
]
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
return image
def crop_image_center(
image: np.ndarray | PIL.Image.Image, target_size: dict[str, int]
) -> np.ndarray | PIL.Image.Image:
np_image = assert_np_hwc_or_hw_image(image)
(height, width) = np_image.shape[:2]
(target_height, target_width) = (target_size["height"], target_size["width"])
assert target_width <= width, f"Can't crop image of width {width} to {target_width}"
assert target_height <= height, f"Can't crop image of width {height} to {target_height}"
top = (height - target_height) // 2
left = (width - target_width) // 2
np_image = crop_image(np_image, top, left, target_height, target_width)
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image
return image
def crop_image_to_ratio(
image: PIL.Image.Image | np.ndarray, target_wh_ratio: float
) -> PIL.Image.Image | np.ndarray:
"""Pad image to a target aspect ratio."""
(height, width) = hw_from_image(image)
wh_ratio = width / height
if target_wh_ratio >= wh_ratio:
crop_size = {"width": width, "height": round(width / target_wh_ratio)}
else:
crop_size = {"width": round(height * target_wh_ratio), "height": height}
image = crop_image_center(image, target_size=crop_size)
return image
def crop_and_pad_image_to_ratio(
image: PIL.Image.Image | np.ndarray,
target_wh_ratio: float,
mode: ResizeMode | str,
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
"""
Crop and pad an image to a target size depending on the mode.
It's expected that the source image and target size have different aspect ratios.
Args:
image: The image to crop and pad.
target_size: The target size to crop and pad the image to.
mode: The mode to use for cropping and padding.
"""
(height, width) = hw_from_image(image)
wh_ratio = width / height
if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001):
return image
if mode == ResizeMode.SMART:
aspect_ratio = max(width, height) / min(width, height)
target_ratio = max(target_wh_ratio, 1 / target_wh_ratio)
if aspect_ratio == 1:
if target_ratio >= 4 / 3 - 0.01:
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
image = crop_image_to_ratio(image, crop_wh_ratio)
else:
pass
elif aspect_ratio <= 4 / 3 + 0.01:
if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
image = crop_image_to_ratio(image, 1.0)
elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0):
image = crop_image_to_ratio(image, 1.0)
elif target_ratio >= 4 / 3 + 0.01:
pass
else:
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4
image = crop_image_to_ratio(image, crop_wh_ratio)
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
elif mode == ResizeMode.PAD:
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value)
elif mode == ResizeMode.CROP:
image = crop_image_to_ratio(image, target_wh_ratio)
else:
raise ValueError(f"Mode {mode} not supported")
return image
def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool:
if isinstance(image, PIL.Image.Image):
return image.mode in [
"1",
"L",
"LA",
"La",
"P",
"PA",
"F",
"I",
"I;16",
"I;16L",
"I;16B",
"I;16N",
]
if isinstance(image, np.ndarray):
return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1
raise ValueError(f"Unsupported image type: {type(image)}")
def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool:
image = np.asarray(image)
return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1
def resize_image(
image: PIL.Image.Image | np.ndarray,
target_size: dict[str, int],
mode: ResizeMode | str,
resample: PIL.Image.Resampling | str = "auto",
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0,
) -> PIL.Image.Image | np.ndarray:
(target_width, target_height) = (target_size["width"], target_size["height"])
(height, width) = hw_from_image(image)
if height == target_height and width == target_width:
return image
if resample == "auto":
if is_single_channel_image(image):
resample = PIL.Image.Resampling.BILINEAR
else:
resample = PIL.Image.Resampling.LANCZOS
else:
assert isinstance(resample, PIL.Image.Resampling), resample
if is_single_channel_image(image) and resample not in [
PIL.Image.Resampling.BILINEAR,
PIL.Image.Resampling.BICUBIC,
]:
raise ValueError(
f"Single channel images must be resized with bilinear or bicubic, but got {resample}"
)
if is_bin_mask := is_binary_mask(image):
image = np.asarray(image).astype(np.uint8) * 255
if mode == ResizeMode.SMART:
image = crop_and_pad_image_to_ratio(
image,
target_wh_ratio=target_width / target_height,
mode=mode,
pad_value=pad_value,
)
pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image
if mode in [ResizeMode.NAIVE, ResizeMode.SMART]:
pil_image = pil_image.resize((target_width, target_height), resample=resample)
else:
raise NotImplementedError(f"Mode {mode} not supported")
image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image
if is_bin_mask:
image = image.astype(np.uint8) > 127
return image
def is_global_norm(
norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
) -> bool:
"""Return true if norm is NONE or global for all datasets"""
return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping)
def is_mean_norm(
norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list],
) -> bool:
"""Return true if norm is based on mean and std"""
return (
norm == Normalization.MEAN
or isinstance(norm, collections.abc.Mapping)
and set(norm.keys()) == {"mean", "std"}
)
def _broadcast_shapes(
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Broadcast shapes for normalization:
Args:
value: torch.Tensor of shape [..., num_components]. The entire shape might be:
- [num_components]: `value` has no batch dimension
- [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds
contained in `low` and `high`
- [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds
contained in `low` and `high`
- [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high`
must be for a single dataset, i.e. `num_datasets = 1`
low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low`
contains normalization bounds for a single dataset
high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high`
contains normalization bounds for a single dataset
Returns:
Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value`
"""
assert low.ndim == high.ndim == 2, f"{low.shape} != {high.shape} or ndim != 2"
assert value.shape[-1] == low.shape[-1] == high.shape[-1], f"{value.shape} != {low.shape} / {high.shape}"
if value.ndim == low.ndim == high.ndim:
return low, high
if value.ndim < low.ndim:
assert low.ndim == high.ndim == 2, f"{low.shape}, {high.shape}"
assert low.shape[0] == high.shape[0] == 1, f"{low.shape}, {high.shape}"
(low, high) = (low.view(-1), high.view(-1))
return low, high
if low.shape[0] == high.shape[0] == 1:
low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1])
high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1])
else:
assert value.shape[0] == low.shape[0] == high.shape[0], f"{value.shape} != {low.shape} / {high.shape}"
low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1])
high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1])
return low, high
def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
(mean, std) = _broadcast_shapes(value, mean, std)
(mean, std) = (mean.to(device=value.device), std.to(device=value.device))
return value * (std + 1e-08) + mean
def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
(low, high) = _broadcast_shapes(value, low, high)
(low, high) = (low.to(device=value.device), high.to(device=value.device))
return 0.5 * (value + 1) * (high - low) + low
def normalize_gripper_by_bounds(
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True
) -> torch.Tensor:
"""
If binary, normalize to [0, 1], otherwise normalize to [-1, 1]
"""
(low, high) = _broadcast_shapes(value, low, high)
(low, high) = (low.to(device=value.device), high.to(device=value.device))
if binary:
return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0)
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
(mean, std) = _broadcast_shapes(value, mean, std)
(mean, std) = (mean.to(device=value.device), std.to(device=value.device))
return (value - mean) / (std + 1e-08)
def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
(low, high) = _broadcast_shapes(value, low, high)
(low, high) = (low.to(device=value.device), high.to(device=value.device))
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0)
def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray:
if low < 0.0:
return np.clip(-gripper, low, high)
return high - np.clip(gripper, low, high)
GRIPPER_BOUNDS = {
"bridge": (0.0, 1.0),
"bridge_orig": (0.0, 1.0),
"droid": (0.0, 1.0),
"roboset": (0.0, 1.0),
}
def preprocess_gripper_observation(
gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True
) -> np.ndarray:
"""
Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset
or from the robot and output is normalized continuous value.
- if `binary`, output is in [0, 1], with 0 = closed and 1 = open.
- otherwise, output is in [-1, 1], with -1 = closed and 1 = open.
Dataset-specific gripper observations:
bridge: continuous; ~[0=closed; 1=open]
bridge_orig: continuous; ~[0=closed; 1=open]
droid: continuous; [0=open, 1=closed]
roboset: continuous; [0=open, 1=closed]
"""
if isinstance(dataset_name, np.ndarray):
assert np.unique(dataset_name).size == 1, dataset_name
dataset_name = str(dataset_name[0])
if dataset_name in [
"droid",
"roboset",
]:
(low, high) = GRIPPER_BOUNDS[dataset_name]
gripper = normalize_gripper_by_bounds(
torch.from_numpy(invert_gripper(gripper, low=low, high=high)),
low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32),
high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32),
binary=binary,
).numpy()
elif dataset_name in [
"bridge",
"bridge_orig",
]:
(low, high) = GRIPPER_BOUNDS[dataset_name]
gripper = normalize_gripper_by_bounds(
torch.from_numpy(gripper),
low=torch.full(gripper.shape, low, dtype=torch.float32),
high=torch.full(gripper.shape, high, dtype=torch.float32),
binary=binary,
).numpy()
else:
raise NotImplementedError(f"Unknown dataset: {dataset_name}")
return gripper
def rotation_norm_bounds(
rotation_norm: Normalization,
rotation_format: RotationFormat,
stats: Dict[str, Dict[str, Dict[str, List[float]]]],
dataset_names: List[str],
) -> Dict[str, Dict[str, torch.Tensor]]:
if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE:
if rotation_norm == Normalization.BOUNDS:
results = {
dataset_name: {
"low": torch.tensor(dataset_stats["euler"]["min"]),
"high": torch.tensor(dataset_stats["euler"]["max"]),
}
for (dataset_name, dataset_stats) in stats.items()
}
elif rotation_norm == Normalization.BOUNDS_Q99:
results = {
dataset_name: {
"low": torch.tensor(dataset_stats["euler"]["q01"]),
"high": torch.tensor(dataset_stats["euler"]["q99"]),
}
for (dataset_name, dataset_stats) in stats.items()
}
else:
raise NotImplementedError(f"Normalization type {rotation_norm} not yet implemented")
else:
assert rotation_norm == Normalization.NONE, rotation_norm
if rotation_format == RotationFormat.EULER:
rotation_size = 3
elif rotation_format == RotationFormat.QUATERNION:
rotation_size = 4
else:
rotation_size = 9
results = {
dataset_name: {
"low": -1 * torch.ones(rotation_size, dtype=torch.float32),
"high": 1 * torch.ones(rotation_size, dtype=torch.float32),
}
for dataset_name in dataset_names
}
return results
def translation_norm_bounds(
translation_norm: Normalization | tuple,
stats: Dict[str, Dict[str, Dict[str, List[float]]]],
dataset_names: List[str],
) -> Dict[str, Dict[str, torch.Tensor]]:
if isinstance(translation_norm, (Normalization, str)) and translation_norm != Normalization.NONE:
if translation_norm == Normalization.BOUNDS:
results = {
dataset_name: {
"low": torch.tensor(dataset_stats["translation"]["min"]),
"high": torch.tensor(dataset_stats["translation"]["max"]),
}
for (dataset_name, dataset_stats) in stats.items()
}
elif translation_norm == Normalization.BOUNDS_Q99:
results = {
dataset_name: {
"low": torch.tensor(dataset_stats["translation"]["q01"]),
"high": torch.tensor(dataset_stats["translation"]["q99"]),
}
for (dataset_name, dataset_stats) in stats.items()
}
elif translation_norm == Normalization.MEAN:
results = {
dataset_name: {
"mean": torch.tensor(dataset_stats["translation"]["mean"]),
"std": torch.tensor(dataset_stats["translation"]["std"]),
}
for (dataset_name, dataset_stats) in stats.items()
}
else:
raise NotImplementedError(f"Normalization type {translation_norm} not yet implemented")
elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE:
results = {
dataset_name: {
"low": -1 * torch.ones(3, dtype=torch.float32),
"high": 1 * torch.ones(3, dtype=torch.float32),
}
for dataset_name in dataset_names
}
else:
assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm)
assert all((len(value) == 3 for value in translation_norm.values())), translation_norm
assert set(translation_norm.keys()) in (
{"low", "high"},
{"mean", "std"},
), translation_norm
results = {
dataset_name: {
key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items()
}
for dataset_name in dataset_names
}
return results
VLAMProcessorConfigT = TypeVar("VLAMProcessorConfigT")
class VLAMProcessor(Configurable):
def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor):
super().__init__(config)
self.vlm_processor = vlm_processor
self.control_tokenizer = EmptyTokenizer(
config=self.config.control_tokenizer_config, tokenizer=self.tokenizer
)
self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {
"obs_translation": self.obs_translation_norm_bounds,
"obs_rotation": self.obs_rotation_norm_bounds,
"translation": self.translation_norm_bounds,
"rotation": self.rotation_norm_bounds,
"joints": self.joints_norm_bounds,
}
@property
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return self.vlm_processor.tokenizer
@property
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
return self.vlm_processor.image_sizes
@property
def camera_names(self) -> List[str]:
return list(self.vlm_processor.image_sizes.keys())
@property
def control_io_config(self) -> ControlDataIOConfig:
return self.config.control_io_config
@cached_property
def rotation_components(self) -> int:
if self.config.rotation_format == RotationFormat.EULER:
return 3
if self.config.rotation_format == RotationFormat.QUATERNION:
return 4
if self.config.rotation_format == RotationFormat.ROTMAT:
return 9
raise NotImplementedError(self.config.rotation_format)
@abstractmethod
def policy_control_plan_from_model_target(
self, target: RoboticsTarget, dataset_name: np.ndarray
) -> RoboticsControlPlan:
pass
@abstractmethod
def policy_control_plan_from_model_output(
self,
model_output: RoboticsOutput,
dataset_name: np.ndarray,
valid_mask: torch.Tensor,
) -> RoboticsControlPlan:
pass
def resize_image(
self, camera_name: str, image: PIL.Image.Image | np.ndarray
) -> PIL.Image.Image | np.ndarray:
return resize_image(
image,
target_size={
"width": self.image_sizes[camera_name].width,
"height": self.image_sizes[camera_name].height,
},
mode=self.config.image_resize,
resample=PIL.Image.Resampling.LANCZOS,
)
def preprocess_inputs(
self,
chat: List[str],
images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]],
ee_pose_translation: np.ndarray,
ee_pose_rotation: np.ndarray,
gripper: np.ndarray,
joints: np.ndarray,
dataset_name: np.ndarray,
inference_mode: bool,
control_target: Optional[RoboticsTarget] = None,
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
"""
Preprocess the inputs for a single example
Args:
instruction: Language instruction
images: History of input images with increasing timestamps
ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3]
ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9]
joints: np.ndarray, shape [..., num_past_scalars, <= 7]
dataset_name: 1D np.ndarray
inference_mode: If True, prepare the input for inference (e.g. don't include target
any tokens in the input if relevant). If control_target is available, it should
still be preprocessed for test dataset comparison
control_target: RoboticsTarget, each component of shape
[..., num_control_steps, num_control_components]. Provided only when available, usually
during training and dataset test
Returns:
Dict containing torch.Tensor with inputs
"""
del control_target
del inference_mode
inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images)
images: Dict[str, torch.Tensor] = inputs["images"]
input_ids: torch.Tensor = inputs["input_ids"][..., : self.tokenizer.model_max_length]
target_text_tokens_ids: torch.Tensor = inputs["target_ids"][..., : self.tokenizer.model_max_length]
attn_mask = torch.ones(input_ids.shape, dtype=torch.bool)
ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32)
ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32)
ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True)
gripper = preprocess_gripper_observation(gripper, dataset_name)
gripper = torch.tensor(gripper, dtype=torch.float32)
ee_pose_translation = self.normalize(
ee_pose_translation, dataset_name=dataset_name, key="obs_translation"
)
ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key="obs_rotation")
joints = torch.tensor(joints, dtype=torch.float32)
if joints.shape[-1] < 7:
missing_size = 7 - joints.shape[-1]
joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1)
joints = self.normalize(joints, dataset_name=dataset_name, key="joints")
outputs = {
"images": images,
"input_ids": input_ids,
"target_text_tokens_ids": target_text_tokens_ids,
"attn_mask": attn_mask,
"ee_pose_translation": ee_pose_translation,
"ee_pose_rotation": ee_pose_rotation,
"gripper": gripper,
"joints": joints,
"control_tokens_ids": None,
"target_control_tokens_ids": None,
}
return outputs
def create_input(
self,
chat: List[str],
images: Dict[str, List[PIL.Image.Image]],
ee_pose_translation: np.ndarray,
ee_pose_rotation: np.ndarray,
gripper: np.ndarray,
joints: np.ndarray,
dataset_name: np.ndarray,
inference_mode: bool,
control_target: Optional[RoboticsTarget] = None,
) -> RoboticsInput:
inputs = self.preprocess_inputs(
chat=chat,
images=images,
ee_pose_translation=ee_pose_translation,
ee_pose_rotation=ee_pose_rotation,
gripper=gripper,
joints=joints,
dataset_name=dataset_name,
inference_mode=inference_mode,
control_target=control_target,
)
inputs.pop("target_text_tokens_ids")
inputs.pop("target_control_tokens_ids")
return RoboticsInput(**inputs)
def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
if is_mean_norm(getattr(self.config, f"{key}_norm")):
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = normalize_by_moments(value, mean=mean, std=std)
else:
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = normalize_by_bounds(value, low=low, high=high)
return output
def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor:
if is_mean_norm(getattr(self.config, f"{key}_norm")):
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = unnormalize_by_moments(value, mean=mean, std=std)
else:
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key)
output = unnormalize_by_bounds(value, low=low, high=high)
return output
def _norm_bounds_from_dataset_name(
self, dataset_name: np.ndarray, component_key: str
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create an array of normalization bounds corresponding to dataset names
Args:
dataset_name: Array of shape [B] of dataset names for which to fetch the low and high
normalization bounds. Note the values can be repeating
component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to
compute the normalization bounds
Returns:
Tuple of low and high bounds or norm and std, each of shape [B, -1]
"""
norm = getattr(self.config, f"{component_key}_norm")
if is_mean_norm(norm):
(stats_key_1, stats_key_2) = ("mean", "std")
else:
(stats_key_1, stats_key_2) = ("low", "high")
if component_key == "joints":
if not isinstance(norm, collections.abc.Mapping):
raise NotImplementedError()
stats = {
key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1]))
for (key, value) in self.joints_norm_bounds["ANY"].items()
}
return tuple(stats.values())
component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1]
if self.dataset_names == ["ANY"]:
stats_1 = self.norm_bounds[component_key]["ANY"][stats_key_1]
stats_2 = self.norm_bounds[component_key]["ANY"][stats_key_2]
stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0)
stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0)
else:
(unique_names, _, inverse_indices, _) = np_unique(dataset_name)
stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32)
stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32)
for i, ds_name in enumerate(unique_names):
stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1].numpy()
stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2].numpy()
stats_1 = stats_1[inverse_indices]
stats_2 = stats_2[inverse_indices]
return torch.from_numpy(stats_1), torch.from_numpy(stats_2)
@cached_property
def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return rotation_norm_bounds(
rotation_norm=self.config.obs_rotation_norm,
rotation_format=self.config.rotation_format,
stats=self._observation_stats,
dataset_names=self.dataset_names,
)
@cached_property
def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return translation_norm_bounds(
translation_norm=self.config.obs_translation_norm,
stats=self._observation_stats,
dataset_names=self.dataset_names,
)
@cached_property
def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return rotation_norm_bounds(
rotation_norm=self.config.rotation_norm,
rotation_format=self.config.rotation_format,
stats=self._control_stats,
dataset_names=self.dataset_names,
)
@cached_property
def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
return translation_norm_bounds(
translation_norm=self.config.translation_norm,
stats=self._control_stats,
dataset_names=self.dataset_names,
)
@cached_property
def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
NOTE:
- Joint values across all joints and all datasets vary in the range [-2pi; 2pi]
- The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi]
- It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint
"""
low = torch.tensor(self.config.joints_norm["low"], dtype=torch.float32)
high = torch.tensor(self.config.joints_norm["high"], dtype=torch.float32)
results = {"ANY": {"low": low, "high": high}}
return results
@cached_property
def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
return {
"bridge": {
"euler": {
"max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
"mean": [
-0.25754162314671525,
-0.12370228389510128,
0.1620053749182691,
],
"min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
"q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
"q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
"std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
},
"gripper": {
"max": [1.0370277166366577],
"min": [0.04637829214334488],
"q01": [0.05192930996417999],
"q99": [1.0118417739868164],
},
"joints": {
"max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
"translation": {
"max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
"mean": [
0.309032678604126,
0.03403777256608009,
0.061277542263269424,
],
"min": [
-0.04167502000927925,
-0.2889411449432373,
-0.13934996724128723,
],
"q01": [
0.1711955964565277,
-0.15639324486255646,
-0.048255354166030884,
],
"q99": [
0.4604376256465912,
0.24112474918365479,
0.18886254727840424,
],
"std": [
0.0635896623134613,
0.09153717756271362,
0.049334850162267685,
],
},
},
"bridge_orig": {
"euler": {
"max": [3.141592653589793, 1.570796251296997, 3.141204357147217],
"mean": [
-0.25754162314671525,
-0.12370228389510128,
0.1620053749182691,
],
"min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938],
"q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896],
"q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236],
"std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315],
},
"gripper": {
"max": [1.0370277166366577],
"min": [0.04637829214334488],
"q01": [0.05192930996417999],
"q99": [1.0118417739868164],
},
"joints": {
"max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
},
"translation": {
"max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043],
"mean": [
0.309032678604126,
0.03403777256608009,
0.061277542263269424,
],
"min": [
-0.04167502000927925,
-0.2889411449432373,
-0.13934996724128723,
],
"q01": [
0.1711955964565277,
-0.15639324486255646,
-0.048255354166030884,
],
"q99": [
0.4604376256465912,
0.24112474918365479,
0.18886254727840424,
],
"std": [
0.0635896623134613,
0.09153717756271362,
0.049334850162267685,
],
},
},
"droid": {
"euler": {
"max": [3.141592502593994, 1.5705928802490234, 3.1415867805480957],
"mean": [
0.3140628098409554,
-0.09296274023036387,
-0.07227215454779846,
],
"min": [
-3.141592502593994,
-1.5691150426864624,
-3.1415374279022217,
],
"q01": [
-3.1378602981567383,
-1.2125312042236327,
-2.1614069032669065,
],
"q99": [3.137854380607605, 0.9200375998020163, 1.9367506909370364],
"std": [2.926265757944871, 0.363273475703332, 0.7576065217938824],
},
"gripper": {
"max": [1.0],
"min": [0.0],
"q01": [0.0],
"q99": [0.9911894202232361],
},
"joints": {
"max": [
2.668445110321045,
1.5691218376159668,
2.666306734085083,
-0.3114914000034332,
2.6624162197113037,
4.28157901763916,
2.752457857131958,
],
"mean": [
0.023137084334640106,
0.2704989977282293,
-0.01451389357228282,
-2.018709403792315,
-0.042720520800030394,
2.350281188152209,
0.12424663946659845,
],
"min": [
-2.6536705493927,
-1.547789216041565,
-2.6781487464904785,
-2.9409868717193604,
-2.6705946922302246,
0.24893812835216522,
-2.7615714073181152,
],
"q01": [
-0.9026106441020965,
-0.8547340619564057,
-0.9028875434398651,
-2.7698556280136106,
-1.6851656341552732,
1.2335169839859008,
-1.9587260699272155,
],
"q99": [
0.9569852340221403,
1.4148830294609054,
0.7693877756595566,
-0.4545914208889008,
1.5623322343826267,
3.475611729621887,
2.263479118347167,
],
"std": [
0.31695080251469465,
0.49522214687158767,
0.27993538230553827,
0.478161574676113,
0.4969961591445458,
0.45101008525403846,
0.7287264344068457,
],
},
"translation": {
"max": [0.8575563430786133, 0.799155592918396, 1.0043904781341553],
"mean": [
0.5283099395864883,
0.005363794653877434,
0.3120132207021294,
],
"min": [
-0.15604186058044434,
-0.827903687953949,
-0.2347021996974945,
],
"q01": [
0.26669957995414734,
-0.43774398624897004,
-0.048167889714241026,
],
"q99": [0.7774086785316463, 0.428325751423835, 0.776091011762619],
"std": [
0.1148424841779685,
0.17489566608140428,
0.16541062032731538,
],
},
},
"roboset": {
"euler": {
"max": [3.1415449294818236, 1.5705575529715636, 3.141527342124582],
"mean": [
-0.0398455755412464,
1.0518070390619125,
-0.015345692503002759,
],
"min": [
-3.1415813300509536,
-1.5222832468962035,
-3.141575300866071,
],
"q01": [
-2.9414386317311187,
-0.24976770655101155,
-2.985256521212579,
],
"q99": [2.9380437893235993, 1.5403010739503078, 2.9746912523985025],
"std": [1.7866587696177456, 0.40620530263065, 1.7288511340250616],
},
"gripper": {
"max": [0.83056640625],
"min": [0.0001499652862548828],
"q01": [0.0001499652862548828],
"q99": [0.82666015625],
},
"joints": {
"max": [
0.96240234375,
1.1162109375,
1.1064453125,
-0.98095703125,
2.30859375,
1.576171875,
1.7412109375,
],
"mean": [
0.005913593806326389,
0.1877261847257614,
0.04653879255056381,
-2.0529513359069824,
-0.011298442259430885,
0.6185526251792908,
-0.01701134257018566,
],
"min": [
-0.8330078125,
-0.74658203125,
-0.8642578125,
-2.892578125,
-1.390625,
-0.24658203125,
-2.953125,
],
"q01": [
-0.41015625,
-0.5302734375,
-0.6455078125,
-2.57421875,
-0.76416015625,
-0.0386962890625,
-1.435546875,
],
"q99": [
0.66455078125,
0.9501953125,
0.7529296875,
-1.251953125,
0.75244140625,
1.2314453125,
1.384765625,
],
"std": [
0.17915399372577667,
0.32234326004981995,
0.26069700717926025,
0.31767210364341736,
0.205329030752182,
0.33385637402534485,
0.6263682842254639,
],
},
"translation": {
"max": [0.5747738480567932, 0.3972920775413513, 0.7443570494651794],
"mean": [
0.3331542909145355,
0.019357483834028244,
0.37330344319343567,
],
"min": [
0.09978063404560089,
-0.29593944549560547,
0.10065606236457825,
],
"q01": [
0.18437016010284424,
-0.25699371099472046,
0.15134164690971375,
],
"q99": [0.543661892414093, 0.29646238684654236, 0.6682320833206177],
"std": [
0.07849054038524628,
0.12241040915250778,
0.1460595279932022,
],
},
},
}
@cached_property
def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]:
if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm):
return {}
with open(self.config.control_stats_path, "r") as file:
stats = yaml.safe_load(file)
if self.config.delta_controls:
if self.control_io_config.future_controls_sequence_stride_sec is None:
horizon = 0.0
else:
horizon = self.control_io_config.future_controls_sequence_stride_sec
elif self.control_io_config.future_controls_sequence_stride_sec is None:
if self.control_io_config.future_controls_sequence_length == 1:
horizon = 0.0
else:
raise NotImplementedError()
else:
horizon = (
self.control_io_config.future_controls_sequence_length
* self.control_io_config.future_controls_sequence_stride_sec
)
key = f"horizon_{round(horizon, 2)}s"
if key in stats:
stats = stats[key]
else:
raise ValueError(
f"Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]"
)
return stats
@cached_property
def dataset_names(self) -> List[str]:
if (
is_global_norm(self.config.rotation_norm)
and is_global_norm(self.config.obs_rotation_norm)
and is_global_norm(self.config.translation_norm)
and is_global_norm(self.config.obs_translation_norm)
):
return ["ANY"]
return list(set(self._control_stats.keys()) | set(self._observation_stats.keys()))
def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor:
"""
Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding
w.r.t. the 0-th element preceding the sequence
Ex:
Sequence of points: T1, T2, T3, T4
`translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame,
implicitly encoded in T0T1
Output: T0T1, T0T2, T0T3, T0T4
Args:
translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S
corresponds to the sequence dimension
Returns:
torch.Tensor of the same shape as translation_sequence, containing delta translations
"""
assert translation_sequence.ndim >= 3, translation_sequence.shape
delta_translations = torch.cumsum(translation_sequence, dim=-2)
return delta_translations
class RegressionProcessor(VLAMProcessor):
def policy_control_plan_from_model_target(
self, target: RoboticsTarget, dataset_name: np.ndarray
) -> RoboticsControlPlan:
translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key="translation")
rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key="rotation")
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT)
gripper_prob = target.gripper
if self.config.delta_controls:
translation_m = delta_to_relative_translations(translation_m)
rotmat = delta_to_relative_rotations(rotmat)
return RoboticsControlPlan(
translation_m=translation_m,
rotmat=rotmat,
gripper_prob=gripper_prob,
valid_mask=target.valid_mask,
)
def policy_control_plan_from_model_output(
self,
model_output: RoboticsOutput,
dataset_name: np.ndarray,
valid_mask: torch.Tensor,
) -> RoboticsControlPlan:
"""Called during inference to create control plan from model output"""
translation_m = self.unnormalize(
model_output.translation, dataset_name=dataset_name, key="translation"
)
rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key="rotation")
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True)
gripper_prob = torch.sigmoid(model_output.gripper)
if self.config.delta_controls:
translation_m = delta_to_relative_translations(translation_m)
rotmat = delta_to_relative_rotations(rotmat)
return RoboticsControlPlan(
translation_m=translation_m,
rotmat=rotmat,
gripper_prob=gripper_prob,
valid_mask=valid_mask,
)
class PiZeroFlowMatchingProcessor(RegressionProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.generator: torch.Generator = torch.Generator()
@cached_property
def beta_distribution(self) -> torch.distributions.Beta:
return torch.distributions.Beta(
self.config.distribution_hyperparams.get("alpha", 1.5),
self.config.distribution_hyperparams.get("beta", 1.0),
)
def create_input(self, *args, **kwargs) -> RoboticsFlowInput:
"""In practice used only during inference"""
inputs = super().create_input(*args, **kwargs)
flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device("cpu"))
inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...])
return inputs
def sample_timestep(self, batch_size: int) -> torch.Tensor:
if self.config.timestep_distribution.lower() == "uniform":
eps = 1e-05
sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % (
1 - eps
)
elif self.config.timestep_distribution.lower() == "beta":
sample = self.beta_distribution.sample([batch_size, 1, 1])
sample = (1 - self.config.sig_min) * (1 - sample)
else:
raise NotImplementedError(self.config.timestep_distribution)
sample = sample.view(batch_size, 1, 1)
return sample
def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1
def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
return x_1 - (1 - self.config.sig_min) * x_0
def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput:
if self.config.r0_distribution == "normal":
controls_t0 = torch.randn(
[
batch_size,
self.config.control_io_config.future_controls_sequence_length,
3 + self.rotation_components + 1,
],
generator=self.generator,
).to(device=device)
(translation_t0, rotation_t0, gripper_t0) = torch.split(
controls_t0, [3, self.rotation_components, 1], dim=-1
)
rotation_t0 = normalize_rotation(rotation_t0)
elif self.config.r0_distribution == "uniform":
controls_t0 = torch.randn(
[
batch_size,
self.config.control_io_config.future_controls_sequence_length,
4,
],
generator=self.generator,
).to(device=device)
(translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1)
rotation_t0 = convert_rotation(
roma.random_unitquat(
(
batch_size,
self.config.control_io_config.future_controls_sequence_length,
),
device=device,
),
self.config.rotation_format,
)
else:
raise NotImplementedError(self.config.r0_distribution)
if self.config.rotation_format == RotationFormat.QUATERNION:
rotation_t0 = quaternion_half_cover(rotation_t0)
timestep = torch.zeros([batch_size, 1, 1], device=device)
return FlowInput(
timestep=timestep,
translation_t0=translation_t0,
rotation_t0=rotation_t0,
gripper_t0=gripper_t0,
translation_t=None,
rotation_t=None,
gripper_t=None,
)
def policy_control_plan_from_model_output(
self,
model_output: RoboticsOutput,
dataset_name: np.ndarray,
valid_mask: torch.Tensor,
) -> RoboticsControlPlan:
if self.config.translation_norm == Normalization.NONE or is_mean_norm(self.config.translation_norm):
model_output = model_output.replace(translation=torch.clamp(model_output.translation, -1, 1))
if self.config.rotation_norm == Normalization.NONE or is_mean_norm(self.config.rotation_norm):
model_output = model_output.replace(rotation=torch.clamp(model_output.rotation, -1, 1))
control_plan = super().policy_control_plan_from_model_output(model_output, dataset_name, valid_mask)
control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1))
return control_plan
def make_causal_mask(shape: Sequence[int]) -> torch.Tensor:
"""
Create a causal attention mask of shape `shape`
Args:
shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len]
Returns:
torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend
to the corresponding column (i.e. key)
Example:
shape = (3, 5) -> Mask the upper triangular part
[
[ 1, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0],
[ 1, 1, 1, 0, 0]
]
"""
return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0)
def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor:
"""
Enable full bi-directional attention in `attn_mask` inside specific blocks
Args:
attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool
where False values indicate disabled attention
full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate
positions where full bi-directional attention should be enabled
Example:
1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0,
1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
"""
assert full_attn.dtype == torch.bool, full_attn.dtype
assert full_attn.ndim == 1, full_attn.shape
assert full_attn.shape[0] == attn_mask.shape[-2], f"{full_attn.shape[0]}, {attn_mask.shape}"
if attn_mask.shape[-1] != attn_mask.shape[-2]:
raise NotImplementedError("Only self-attention supported right now.")
x = full_attn.view(-1, 1) & full_attn.view(1, -1)
x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]])
x = torch.cumprod(x, dim=1).to(dtype=torch.bool)
x = x & x.permute(1, 0)
mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn
mask_indices = torch.where(mask_positions)[0]
x[mask_indices, mask_indices] = 0
attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1])
return attn_mask
IGNORE_INDEX = -100
class PaliGemmaProcessor(VLMProcessor):
def __init__(
self,
config: PaliGemmaProcessorConfig,
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
**kwargs,
):
del kwargs
super().__init__(config)
self.hf_processor = hf_processor
self.hf_processor.image_processor.size = dict(self.config.image_sizes["main"].as_json())
self.hf_processor.image_seq_length = self.config.num_image_tokens["main"]
self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens["main"]
self.bos_id: int = self.tokenizer.bos_token_id
self.eos_id: int = self.tokenizer.eos_token_id
self.sep_token = "\n"
self.sep_id: int = self.tokenizer(
self.sep_token,
padding=False,
add_special_tokens=False,
return_attention_mask=False,
)["input_ids"][0]
self.image_token_id: int = self.tokenizer(
self.config.image_token,
padding=False,
add_special_tokens=False,
return_attention_mask=False,
)["input_ids"][0]
self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values())
self.bbox_pattern = re.compile(
"\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]"
)
def preprocess_inputs(
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
"""
Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at
https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs
Chat must be always made of separate messages from user and model, always starting with user
<image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos>
Args:
chat: List[str] of even size where each entry corresponds to a different turn in the conversation
images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys
in the Dict and the List corresponds to history of images
"""
for key, value in images.items():
if not isinstance(value, list):
raise TypeError(f"Camera {key} contains values of type {type(value)} instead of list")
(input_ids, target_ids) = ([], [])
for i, text in enumerate(chat):
text = text.replace(self.sep_token, " ").replace("<image>", "")
text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text)
turn_input_ids: List[int] = self.tokenizer(
text,
padding=False,
add_special_tokens=False,
return_attention_mask=False,
)["input_ids"]
if i % 2 == 0:
turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids)
else:
turn_target_ids = turn_input_ids
if i != len(chat) - 1:
turn_input_ids = turn_input_ids + [self.sep_id]
turn_target_ids = turn_target_ids + [IGNORE_INDEX]
input_ids = input_ids + turn_input_ids
target_ids = target_ids + turn_target_ids
input_ids = [self.bos_id] + input_ids + [self.eos_id]
target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id]
image_tokens = self.image_tokens
if self.config.max_language_tokens > 0:
input_ids = input_ids[: self.config.max_language_tokens]
target_ids = target_ids[: self.config.max_language_tokens]
input_ids = image_tokens + input_ids
target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids
input_ids = torch.tensor(input_ids, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
image_tensors: Dict[str, torch.Tensor] = {
f"{camera_name}.siglip": self.hf_processor.image_processor(
camera_images,
size=self.config.image_sizes[camera_name].as_json(),
return_tensors="pt",
)["pixel_values"]
for (camera_name, camera_images) in images.items()
}
attn_mask = make_causal_mask([len(input_ids), len(input_ids)])
attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX)
return {
"input_ids": input_ids,
"target_ids": target_ids,
"images": image_tensors,
"attn_mask": attn_mask,
}
@property
def tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return self.hf_processor.tokenizer
@staticmethod
def _bbox_to_loc_tokens(match: str) -> str:
"""
https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
"""
floats = list(map(float, match.groups()))
transformed = [f"<loc{np.clip(round(num * 1024), 0, 1023):04d}>" for num in floats]
return f"[{', '.join(transformed)}]"
@property
def image_sizes(self) -> Dict[str, ImageSizeConfig]:
return self.config.image_sizes
class PaliGemmaDepthProcessor(PaliGemmaProcessor):
def __init__(
self,
config: PaliGemmaProcessorConfig,
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor,
depth_tokens: int,
):
super().__init__(config, hf_processor)
vocab_size = len(self.tokenizer)
self.depth_token_ids = np.arange(vocab_size - depth_tokens, vocab_size)
self.depth_input_transforms = {
camera_name: torchvision.transforms.v2.Compose(
[
torchvision.transforms.v2.Resize(
size=(camera_image_size.height, camera_image_size.width),
interpolation=torchvision.transforms.v2.InterpolationMode.BICUBIC,
max_size=None,
antialias=True,
),
torchvision.transforms.v2.ToTensor(),
torchvision.transforms.v2.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
for (camera_name, camera_image_size) in self.config.image_sizes.items()
}
def preprocess_inputs(
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]]
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]:
inputs = super().preprocess_inputs(chat=chat, images=images)
depth_images: Dict[str, torch.Tensor] = {
f"{camera_name}.depth": torch.stack(
self.depth_input_transforms[camera_name](camera_images), dim=0
)
for (camera_name, camera_images) in images.items()
}
inputs["images"] = {**inputs["images"], **depth_images}
return inputs
@property
def num_depth_tokens(self) -> int:
return len(self.depth_token_ids)