|
|
""" |
|
|
BEAST: B-Spline Encoded Action Sequences Tokenizer |
|
|
|
|
|
A tokenizer for encoding/decoding robot trajectories using B-splines. |
|
|
Converts continuous trajectories to discrete tokens and vice versa. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import einops |
|
|
from typing import Optional, ClassVar |
|
|
from functools import wraps |
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
|
|
|
|
|
|
def autocast_float32(fn): |
|
|
"""Decorator to ensure computation runs in float32 precision.""" |
|
|
@wraps(fn) |
|
|
def wrapped(*args, **kwargs): |
|
|
if hasattr(torch.amp, 'autocast'): |
|
|
with torch.amp.autocast('cuda', dtype=torch.float32): |
|
|
return fn(*args, **kwargs) |
|
|
else: |
|
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
|
return fn(*args, **kwargs) |
|
|
return wrapped |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def continuous_to_discrete(tensor: torch.Tensor, min_val: torch.Tensor = None, |
|
|
max_val: torch.Tensor = None, num_bins: int = 256) -> torch.Tensor: |
|
|
""" |
|
|
Convert continuous tensor values to discrete tokens. |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor with continuous values |
|
|
min_val: Minimum value for normalization (uses tensor.min() if None) |
|
|
max_val: Maximum value for normalization (uses tensor.max() if None) |
|
|
num_bins: Number of discrete bins (default 256 for 0-255 range) |
|
|
|
|
|
Returns: |
|
|
Discretized tensor with integer values in [0, num_bins-1] |
|
|
""" |
|
|
if min_val is None: |
|
|
min_val = tensor.min() |
|
|
if max_val is None: |
|
|
max_val = tensor.max() |
|
|
|
|
|
assert torch.all(tensor >= min_val - 1e-3), "Input tensor has values below min_val" |
|
|
assert torch.all(tensor <= max_val + 1e-3), "Input tensor has values above max_val" |
|
|
|
|
|
normalized = (tensor - min_val) / (max_val - min_val) |
|
|
normalized = torch.clamp(normalized, 0, 1) |
|
|
discrete = torch.round(normalized * (num_bins - 1)).to(torch.long) |
|
|
return discrete |
|
|
|
|
|
|
|
|
def discrete_to_continuous(discrete_tensor: torch.Tensor, min_val: torch.Tensor = 0, |
|
|
max_val: torch.Tensor = 1, num_bins: int = 256) -> torch.Tensor: |
|
|
""" |
|
|
Convert discrete tokens back to continuous values. |
|
|
|
|
|
Args: |
|
|
discrete_tensor: Input tensor with discrete values in [0, num_bins-1] |
|
|
min_val: Minimum value of target continuous range |
|
|
max_val: Maximum value of target continuous range |
|
|
num_bins: Number of discrete bins |
|
|
|
|
|
Returns: |
|
|
Continuous tensor with values in [min_val, max_val] |
|
|
""" |
|
|
normalized = discrete_tensor.float() / (num_bins - 1) |
|
|
continuous = normalized * (max_val - min_val) + min_val |
|
|
return torch.clamp(continuous, min_val, max_val) |
|
|
|
|
|
|
|
|
def normalize_tensor(tensor: torch.Tensor, w_min: torch.Tensor, w_max: torch.Tensor, |
|
|
norm_min: float = -1.0, norm_max: float = 1.0) -> torch.Tensor: |
|
|
""" |
|
|
Normalize tensor from [w_min, w_max] to [norm_min, norm_max]. |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor to normalize |
|
|
w_min: Minimum bound of original range |
|
|
w_max: Maximum bound of original range |
|
|
norm_min: Target minimum (default -1.0) |
|
|
norm_max: Target maximum (default 1.0) |
|
|
|
|
|
Returns: |
|
|
Normalized tensor in [norm_min, norm_max] |
|
|
""" |
|
|
clipped = torch.clamp(tensor, w_min, w_max) |
|
|
normalized = (clipped - w_min) / (w_max - w_min) |
|
|
return normalized * (norm_max - norm_min) + norm_min |
|
|
|
|
|
|
|
|
def denormalize_tensor(normalized_tensor: torch.Tensor, w_min: torch.Tensor, w_max: torch.Tensor, |
|
|
norm_min: float = -1.0, norm_max: float = 1.0) -> torch.Tensor: |
|
|
""" |
|
|
Denormalize tensor from [norm_min, norm_max] back to [w_min, w_max]. |
|
|
|
|
|
Args: |
|
|
normalized_tensor: Normalized input tensor |
|
|
w_min: Target minimum bound |
|
|
w_max: Target maximum bound |
|
|
norm_min: Source minimum (default -1.0) |
|
|
norm_max: Source maximum (default 1.0) |
|
|
|
|
|
Returns: |
|
|
Denormalized tensor in [w_min, w_max] |
|
|
""" |
|
|
clipped = torch.clamp(normalized_tensor, norm_min, norm_max) |
|
|
denormalized = (clipped - norm_min) / (norm_max - norm_min) |
|
|
return denormalized * (w_max - w_min) + w_min |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BSpline(torch.nn.Module): |
|
|
""" |
|
|
Uniform B-Spline for trajectory representation and fitting. |
|
|
|
|
|
Combines B-spline basis function computation with trajectory fitting and |
|
|
reconstruction. Supports position, velocity, and acceleration computation. |
|
|
|
|
|
Args: |
|
|
num_basis: Number of B-spline basis functions (control points for free params) |
|
|
degree: B-spline degree (3=cubic, 4=quartic, 0=piecewise constant) |
|
|
num_dof: Degrees of freedom (e.g., 7 for robot arm) |
|
|
tau: Time duration of trajectory (default 1.0, normalized time) |
|
|
init_cond_order: Order of initial conditions (0=none, 1=pos, 2=pos+vel) |
|
|
end_cond_order: Order of end conditions (0=none, 1=pos, 2=pos+vel) |
|
|
dtype: Torch data type (default float32) |
|
|
device: Torch device ('cuda' or 'cpu') |
|
|
|
|
|
Example: |
|
|
>>> bspline = BSpline(num_basis=10, degree=4, num_dof=7, device='cuda') |
|
|
>>> result = bspline.learn_mp_params_from_trajs(times, trajectories) |
|
|
>>> reconstructed = bspline.get_traj_pos(times, result['params']) |
|
|
""" |
|
|
|
|
|
def __init__(self, num_basis: int = 10, degree: int = 3, num_dof: int = 1, |
|
|
tau: float = 1.0, init_cond_order: int = 0, end_cond_order: int = 0, |
|
|
dtype: torch.dtype = torch.float32, device: str = 'cpu'): |
|
|
super().__init__() |
|
|
|
|
|
self.num_basis = num_basis |
|
|
self.degree = degree |
|
|
self.num_dof = num_dof |
|
|
self.init_cond_order = init_cond_order |
|
|
self.end_cond_order = end_cond_order |
|
|
self._dtype = dtype |
|
|
self._device = device |
|
|
|
|
|
|
|
|
self.num_ctrlp = num_basis + init_cond_order + end_cond_order |
|
|
|
|
|
|
|
|
num_knots = self.degree + 1 + self.num_ctrlp |
|
|
num_internal = num_knots - 2 * self.degree |
|
|
knots = torch.linspace(0, 1, num_internal, dtype=dtype, device=device) |
|
|
knots = torch.cat([ |
|
|
torch.zeros(self.degree, dtype=dtype, device=device), |
|
|
knots, |
|
|
torch.ones(self.degree, dtype=dtype, device=device) |
|
|
]) |
|
|
self.register_buffer("knots", knots, persistent=False) |
|
|
self.register_buffer("tau", torch.tensor(tau, dtype=dtype, device=device), persistent=False) |
|
|
|
|
|
|
|
|
self.times = None |
|
|
self.params = None |
|
|
self.init_pos = None |
|
|
self.init_vel = None |
|
|
self.end_pos = None |
|
|
self.end_vel = None |
|
|
self.params_init = None |
|
|
self.params_end = None |
|
|
self._pos_cache = None |
|
|
self._vel_cache = None |
|
|
self.add_dim = [] |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.knots.device |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return self.knots.dtype |
|
|
|
|
|
@property |
|
|
def num_params(self) -> int: |
|
|
"""Total number of learnable parameters.""" |
|
|
return self.num_basis * self.num_dof |
|
|
|
|
|
def _clear_cache(self): |
|
|
"""Clear cached computation results.""" |
|
|
self._pos_cache = None |
|
|
self._vel_cache = None |
|
|
|
|
|
def _time_to_phase(self, times: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert times to normalized phase [0, 1].""" |
|
|
tau = times.reshape(-1)[-1] |
|
|
self.tau.copy_(tau) |
|
|
return torch.clip(times / self.tau[..., None], 0, 1) |
|
|
|
|
|
def _basis_function(self, i: int, k: int, knots: torch.Tensor, |
|
|
u: torch.Tensor, num_ctrlp: int = None) -> torch.Tensor: |
|
|
""" |
|
|
Compute B-spline basis using de Boor's recursive algorithm. |
|
|
|
|
|
Args: |
|
|
i: Basis function index |
|
|
k: Current degree level |
|
|
knots: Knot vector |
|
|
u: Evaluation points (phase values) |
|
|
num_ctrlp: Number of control points (for boundary handling) |
|
|
|
|
|
Returns: |
|
|
Basis function values at evaluation points |
|
|
""" |
|
|
if num_ctrlp is None: |
|
|
num_ctrlp = self.num_ctrlp |
|
|
|
|
|
if k == 0: |
|
|
|
|
|
if i == num_ctrlp - 1: |
|
|
|
|
|
return torch.where((u >= knots[i]) & (u <= knots[i + 1]), |
|
|
1.0, 0.0).to(dtype=self.dtype, device=self.device) |
|
|
else: |
|
|
return torch.where((u >= knots[i]) & (u < knots[i + 1]), |
|
|
1.0, 0.0).to(dtype=self.dtype, device=self.device) |
|
|
else: |
|
|
|
|
|
denom1 = knots[i + k] - knots[i] |
|
|
term1 = 0.0 if denom1 == 0 else (u - knots[i]) / denom1 * \ |
|
|
self._basis_function(i, k - 1, knots, u, num_ctrlp) |
|
|
|
|
|
denom2 = knots[i + k + 1] - knots[i + 1] |
|
|
term2 = 0.0 if denom2 == 0 else (knots[i + k + 1] - u) / denom2 * \ |
|
|
self._basis_function(i + 1, k - 1, knots, u, num_ctrlp) |
|
|
|
|
|
return term1 + term2 |
|
|
|
|
|
def basis(self, times: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute B-spline basis values at given time points. |
|
|
|
|
|
Args: |
|
|
times: Time points tensor of shape [*batch, num_times] |
|
|
|
|
|
Returns: |
|
|
Basis values of shape [*batch, num_times, num_ctrlp] |
|
|
""" |
|
|
phase = self._time_to_phase(times) |
|
|
basis = [self._basis_function(i, self.degree, self.knots, phase) |
|
|
for i in range(self.num_ctrlp)] |
|
|
return torch.stack(basis, dim=-1) |
|
|
|
|
|
def vel_basis(self, times: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute velocity B-spline basis (derivative of position basis). |
|
|
|
|
|
Args: |
|
|
times: Time points tensor |
|
|
|
|
|
Returns: |
|
|
Velocity basis values of shape [*batch, num_times, num_ctrlp-1] |
|
|
""" |
|
|
phase = self._time_to_phase(times) |
|
|
vel_knots = self.knots[1:-1] |
|
|
basis = [self._basis_function(i, self.degree - 1, vel_knots, phase, |
|
|
num_ctrlp=self.num_ctrlp - 1) |
|
|
for i in range(self.num_ctrlp - 1)] |
|
|
return torch.stack(basis, dim=-1) |
|
|
|
|
|
def acc_basis(self, times: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute acceleration B-spline basis (second derivative). |
|
|
|
|
|
Args: |
|
|
times: Time points tensor |
|
|
|
|
|
Returns: |
|
|
Acceleration basis values of shape [*batch, num_times, num_ctrlp-2] |
|
|
""" |
|
|
phase = self._time_to_phase(times) |
|
|
acc_knots = self.knots[2:-2] |
|
|
basis = [self._basis_function(i, self.degree - 2, acc_knots, phase, |
|
|
num_ctrlp=self.num_ctrlp - 2) |
|
|
for i in range(self.num_ctrlp - 2)] |
|
|
return torch.stack(basis, dim=-1) |
|
|
|
|
|
def velocity_control_points(self, ctrl_pts: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute velocity control points from position control points. |
|
|
|
|
|
Args: |
|
|
ctrl_pts: Position control points [*batch, num_dof, num_ctrlp] |
|
|
|
|
|
Returns: |
|
|
Velocity control points [*batch, num_dof, num_ctrlp-1] |
|
|
""" |
|
|
diff = ctrl_pts[..., 1:] - ctrl_pts[..., :-1] |
|
|
delta = self.knots[1 + self.degree:self.num_ctrlp + self.degree] - \ |
|
|
self.knots[1:self.num_ctrlp] |
|
|
return diff * (self.degree / delta) |
|
|
|
|
|
def _compute_init_params(self, init_pos: torch.Tensor, |
|
|
init_vel: torch.Tensor = None) -> Optional[torch.Tensor]: |
|
|
"""Compute initial boundary condition control points.""" |
|
|
if self.init_cond_order == 0: |
|
|
return None |
|
|
|
|
|
params = init_pos[..., None] |
|
|
if self.init_cond_order == 2 and init_vel is not None: |
|
|
p1 = init_vel * self.tau * (self.knots[1 + self.degree] - self.knots[1]) / self.degree + init_pos |
|
|
params = torch.cat([params, p1[..., None]], dim=-1) |
|
|
return params |
|
|
|
|
|
def _compute_end_params(self, end_pos: torch.Tensor, |
|
|
end_vel: torch.Tensor = None) -> Optional[torch.Tensor]: |
|
|
"""Compute end boundary condition control points.""" |
|
|
if self.end_cond_order == 0: |
|
|
return None |
|
|
|
|
|
params = end_pos[..., None] |
|
|
if self.end_cond_order == 2 and end_vel is not None: |
|
|
pn = end_pos - end_vel * self.tau * \ |
|
|
(self.knots[self.num_ctrlp - 1 + self.degree] - self.knots[self.num_ctrlp - 1]) * self.degree |
|
|
params = torch.cat([pn[..., None], params], dim=-1) |
|
|
return params |
|
|
|
|
|
def set_times(self, times: torch.Tensor): |
|
|
""" |
|
|
Set evaluation time points. |
|
|
|
|
|
Args: |
|
|
times: Time points [*batch, num_times] |
|
|
""" |
|
|
self.times = torch.as_tensor(times, dtype=self.dtype, device=self.device) |
|
|
tau = times.reshape(-1)[-1] |
|
|
self.tau.copy_(tau) |
|
|
self._clear_cache() |
|
|
|
|
|
def set_params(self, params: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Set B-spline parameters (control point weights). |
|
|
|
|
|
Args: |
|
|
params: Parameters [*batch, num_params] |
|
|
|
|
|
Returns: |
|
|
Any unused parameters (for chaining) |
|
|
""" |
|
|
params = torch.as_tensor(params, dtype=self.dtype, device=self.device) |
|
|
assert params.shape[-1] == self.num_params |
|
|
self.add_dim = list(params.shape[:-1]) |
|
|
self.params = params[..., :self.num_params] |
|
|
self._clear_cache() |
|
|
return params[..., self.num_params:] |
|
|
|
|
|
def set_duration(self, duration: float, dt: float): |
|
|
""" |
|
|
Set trajectory duration and generate time grid. |
|
|
|
|
|
Args: |
|
|
duration: Total trajectory duration |
|
|
dt: Time step (control frequency) |
|
|
""" |
|
|
times = torch.linspace(0, duration, int(round(duration / dt)) + 1, |
|
|
dtype=self.dtype, device=self.device) |
|
|
|
|
|
for _ in self.add_dim: |
|
|
times = times.unsqueeze(0) |
|
|
times = times.expand(*self.add_dim, -1) |
|
|
self.set_times(times) |
|
|
|
|
|
def set_initial_conditions(self, init_pos: torch.Tensor, init_vel: torch.Tensor = None): |
|
|
""" |
|
|
Set initial position and velocity conditions. |
|
|
|
|
|
Args: |
|
|
init_pos: Initial position [*batch, num_dof] |
|
|
init_vel: Initial velocity [*batch, num_dof] (optional) |
|
|
""" |
|
|
self.init_pos = torch.as_tensor(init_pos, dtype=self.dtype, device=self.device) |
|
|
self.init_vel = torch.as_tensor(init_vel, dtype=self.dtype, device=self.device) if init_vel is not None else None |
|
|
self.params_init = self._compute_init_params(self.init_pos, self.init_vel) |
|
|
self._clear_cache() |
|
|
|
|
|
def set_end_conditions(self, end_pos: torch.Tensor, end_vel: torch.Tensor = None): |
|
|
""" |
|
|
Set end position and velocity conditions. |
|
|
|
|
|
Args: |
|
|
end_pos: End position [*batch, num_dof] |
|
|
end_vel: End velocity [*batch, num_dof] (optional) |
|
|
""" |
|
|
self.end_pos = torch.as_tensor(end_pos, dtype=self.dtype, device=self.device) if end_pos is not None else None |
|
|
self.end_vel = torch.as_tensor(end_vel, dtype=self.dtype, device=self.device) if end_vel is not None else None |
|
|
self.params_end = self._compute_end_params(self.end_pos, self.end_vel) |
|
|
self._clear_cache() |
|
|
|
|
|
def update_inputs(self, times: torch.Tensor = None, params: torch.Tensor = None, |
|
|
init_pos: torch.Tensor = None, init_vel: torch.Tensor = None, **kwargs): |
|
|
""" |
|
|
Update multiple inputs at once. |
|
|
|
|
|
Args: |
|
|
times: Time points |
|
|
params: B-spline parameters |
|
|
init_pos: Initial position |
|
|
init_vel: Initial velocity |
|
|
**kwargs: Additional args (end_pos, end_vel) |
|
|
""" |
|
|
if params is not None: |
|
|
self.set_params(params) |
|
|
if times is not None: |
|
|
self.set_times(times) |
|
|
if init_pos is not None: |
|
|
self.set_initial_conditions(init_pos, init_vel) |
|
|
if kwargs.get('end_pos') is not None or kwargs.get('end_vel') is not None: |
|
|
self.set_end_conditions(kwargs.get('end_pos'), kwargs.get('end_vel')) |
|
|
|
|
|
def _get_full_params(self) -> torch.Tensor: |
|
|
"""Get full control points including boundary conditions.""" |
|
|
params = self.params.reshape(*self.add_dim, self.num_dof, -1) |
|
|
if self.params_init is not None: |
|
|
params = torch.cat([self.params_init, params], dim=-1) |
|
|
if self.params_end is not None: |
|
|
params = torch.cat([params, self.params_end], dim=-1) |
|
|
return params |
|
|
|
|
|
def get_traj_pos(self, times: torch.Tensor = None, params: torch.Tensor = None, |
|
|
init_pos: torch.Tensor = None, init_vel: torch.Tensor = None, |
|
|
flat_shape: bool = False, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Compute trajectory positions from B-spline parameters. |
|
|
|
|
|
Args: |
|
|
times: Time points [*batch, num_times] |
|
|
params: B-spline parameters [*batch, num_params] |
|
|
init_pos: Initial position (optional) |
|
|
init_vel: Initial velocity (optional) |
|
|
flat_shape: If True, return flattened [*batch, num_dof*num_times] |
|
|
|
|
|
Returns: |
|
|
Position trajectory [*batch, num_times, num_dof] or flattened |
|
|
""" |
|
|
self.update_inputs(times, params, init_pos, init_vel, **kwargs) |
|
|
|
|
|
if self._pos_cache is not None: |
|
|
pos = self._pos_cache |
|
|
else: |
|
|
assert self.params is not None |
|
|
full_params = self._get_full_params() |
|
|
basis = self.basis(self.times) |
|
|
|
|
|
pos = torch.einsum('...ik,...jk->...ij', basis, full_params) |
|
|
self._pos_cache = pos |
|
|
|
|
|
if flat_shape: |
|
|
pos = torch.einsum('...ji->...ij', pos).reshape(*self.add_dim, -1) |
|
|
return pos |
|
|
|
|
|
def get_traj_vel(self, times: torch.Tensor = None, params: torch.Tensor = None, |
|
|
init_pos: torch.Tensor = None, init_vel: torch.Tensor = None, |
|
|
flat_shape: bool = False, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Compute trajectory velocities from B-spline parameters. |
|
|
|
|
|
Args: |
|
|
times: Time points [*batch, num_times] |
|
|
params: B-spline parameters [*batch, num_params] |
|
|
init_pos: Initial position (optional) |
|
|
init_vel: Initial velocity (optional) |
|
|
flat_shape: If True, return flattened [*batch, num_dof*num_times] |
|
|
|
|
|
Returns: |
|
|
Velocity trajectory [*batch, num_times, num_dof] or flattened |
|
|
""" |
|
|
self.update_inputs(times, params, init_pos, init_vel, **kwargs) |
|
|
|
|
|
if self._vel_cache is not None: |
|
|
vel = self._vel_cache |
|
|
else: |
|
|
assert self.params is not None |
|
|
full_params = self._get_full_params() |
|
|
vel_ctrlp = self.velocity_control_points(full_params) / self.tau |
|
|
vel_basis = self.vel_basis(self.times) |
|
|
vel = torch.einsum('...ik,...jk->...ij', vel_basis, vel_ctrlp) |
|
|
self._vel_cache = vel |
|
|
|
|
|
if flat_shape: |
|
|
vel = torch.einsum('...ji->...ij', vel).reshape(*self.add_dim, -1) |
|
|
return vel |
|
|
|
|
|
def _basis_multi_dofs(self, times: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute multi-DOF basis matrix for least squares fitting. |
|
|
|
|
|
Args: |
|
|
times: Time points [*batch, num_times] |
|
|
|
|
|
Returns: |
|
|
Block-diagonal basis [*batch, num_dof*num_times, num_dof*num_basis] |
|
|
""" |
|
|
add_dim = list(times.shape[:-1]) |
|
|
num_times = times.shape[-1] |
|
|
basis_single = self.basis(times)[..., self.init_cond_order:self.num_ctrlp - self.end_cond_order] |
|
|
|
|
|
basis_multi = torch.zeros(*add_dim, self.num_dof * num_times, self.num_dof * self.num_basis, |
|
|
dtype=self.dtype, device=self.device) |
|
|
for i in range(self.num_dof): |
|
|
row_slice = slice(i * num_times, (i + 1) * num_times) |
|
|
col_slice = slice(i * self.num_basis, (i + 1) * self.num_basis) |
|
|
basis_multi[..., row_slice, col_slice] = basis_single |
|
|
|
|
|
return basis_multi |
|
|
|
|
|
def learn_mp_params_from_trajs(self, times: torch.Tensor, trajs: torch.Tensor, |
|
|
reg: float = 1e-4, **kwargs) -> dict: |
|
|
""" |
|
|
Learn B-spline parameters from trajectory data via least squares. |
|
|
|
|
|
Args: |
|
|
times: Time points [*batch, num_times] |
|
|
trajs: Trajectory data [*batch, num_times, num_dof] |
|
|
reg: Regularization coefficient (default 1e-4) |
|
|
**kwargs: Optional init_pos, init_vel, end_pos, end_vel |
|
|
|
|
|
Returns: |
|
|
Dict with 'params' and boundary conditions |
|
|
""" |
|
|
assert trajs.shape[:-1] == times.shape |
|
|
assert trajs.shape[-1] == self.num_dof |
|
|
|
|
|
times = torch.as_tensor(times, dtype=self.dtype, device=self.device) |
|
|
trajs = torch.as_tensor(trajs, dtype=self.dtype, device=self.device) |
|
|
|
|
|
self.add_dim = list(trajs.shape[:-2]) |
|
|
self.set_times(times) |
|
|
|
|
|
|
|
|
dummy_params = torch.zeros(*self.add_dim, self.num_dof, self.num_basis, |
|
|
device=self.device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
if self.init_cond_order != 0: |
|
|
init_pos = kwargs.get("init_pos", trajs[..., 0, :]) |
|
|
dt = times[..., 1] - times[..., 0] |
|
|
init_vel = kwargs.get("init_vel", torch.diff(trajs, dim=-2)[..., 0, :] / dt[..., None]) |
|
|
self.set_initial_conditions(init_pos, init_vel) |
|
|
if self.params_init is not None: |
|
|
dummy_params = torch.cat([self.params_init, dummy_params], dim=-1) |
|
|
|
|
|
|
|
|
if self.end_cond_order != 0: |
|
|
end_pos = kwargs.get("end_pos", trajs[..., -1, :]) |
|
|
dt = times[..., 1] - times[..., 0] |
|
|
end_vel = kwargs.get("end_vel", torch.diff(trajs, dim=-2)[..., -1, :] / dt[..., None]) |
|
|
self.set_end_conditions(end_pos, end_vel) |
|
|
if self.params_end is not None: |
|
|
dummy_params = torch.cat([dummy_params, self.params_end], dim=-1) |
|
|
|
|
|
|
|
|
basis_single = self.basis(times) |
|
|
pos_boundary = torch.einsum('...ik,...jk->...ij', basis_single, dummy_params) |
|
|
pos_boundary = torch.einsum('...ij->...ji', pos_boundary).reshape(*self.add_dim, -1) |
|
|
|
|
|
|
|
|
basis_multi = self._basis_multi_dofs(self.times) |
|
|
A = torch.einsum('...ki,...kj->...ij', basis_multi, basis_multi) |
|
|
A += torch.eye(self.num_params, dtype=self.dtype, device=self.device) * reg |
|
|
|
|
|
|
|
|
trajs_flat = torch.einsum("...ij->...ji", trajs).reshape(*self.add_dim, -1) |
|
|
pos_residual = trajs_flat - pos_boundary |
|
|
B = torch.einsum('...ki,...k->...i', basis_multi, pos_residual) |
|
|
|
|
|
|
|
|
params = torch.linalg.solve(A, B) |
|
|
self.set_params(params) |
|
|
|
|
|
return { |
|
|
"params": params, |
|
|
"init_pos": self.init_pos, |
|
|
"init_vel": self.init_vel, |
|
|
"end_pos": self.end_pos, |
|
|
"end_vel": self.end_vel, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeastTokenizer(torch.nn.Module, ProcessorMixin): |
|
|
""" |
|
|
B-spline based tokenizer for trajectory encoding/decoding. |
|
|
|
|
|
Converts continuous robot trajectories to discrete tokens and vice versa |
|
|
using B-splines. Supports separate handling for continuous actions (joints) |
|
|
and discrete states (e.g., binary gripper). |
|
|
|
|
|
Args: |
|
|
num_dof: Total degrees of freedom (joints + gripper) |
|
|
num_basis: Number of B-spline basis functions |
|
|
seq_len: Trajectory sequence length |
|
|
vocab_size: Discrete token vocabulary size (default 256) |
|
|
degree_p: B-spline degree (default 4 = quartic) |
|
|
gripper_zero_order: Use zero-order splines for gripper (piecewise constant) |
|
|
gripper_dof: Number of gripper DOFs (only used if gripper_zero_order=True) |
|
|
init_cond_order: Initial condition order (0=none, 1=pos, 2=pos+vel) |
|
|
end_cond_order: End condition order |
|
|
enforce_init_pos: Enforce initial position constraint in decoding |
|
|
device: Torch device ('cuda' or 'cpu') |
|
|
|
|
|
Example: |
|
|
>>> tokenizer = BeastTokenizer(num_dof=7, num_basis=10, seq_len=50) |
|
|
>>> tokens = tokenizer.encode_discrete(trajectories) |
|
|
>>> reconstructed = tokenizer.decode_discrete(tokens) |
|
|
""" |
|
|
|
|
|
DEFAULT_DT = 0.01 |
|
|
attributes: ClassVar[list[str]] = [] |
|
|
|
|
|
def __init__(self, num_dof: int = 1, num_basis: int = 10, seq_len: int = 50, |
|
|
vocab_size: int = 256, degree_p: int = 4, gripper_zero_order: bool = False, |
|
|
gripper_dof: int = 1, init_cond_order: int = 0, end_cond_order: int = 0, |
|
|
enforce_init_pos: bool = False, device: str = "cuda"): |
|
|
torch.nn.Module.__init__(self) |
|
|
ProcessorMixin.__init__(self) |
|
|
|
|
|
self.device = device |
|
|
self.seq_length = seq_len |
|
|
self.vocab_size = vocab_size |
|
|
self.num_basis = num_basis |
|
|
self.enforce_init_pos = enforce_init_pos |
|
|
self.init_cond_order = init_cond_order |
|
|
self.end_cond_order = end_cond_order |
|
|
self.dt = self.DEFAULT_DT |
|
|
self.init_pos = None |
|
|
|
|
|
|
|
|
self.gripper_dof = gripper_dof if gripper_zero_order else 0 |
|
|
self.joint_dof = num_dof - self.gripper_dof |
|
|
self.num_dof = self.joint_dof + self.gripper_dof |
|
|
|
|
|
|
|
|
self.bsp = BSpline( |
|
|
num_basis=num_basis, degree=degree_p, num_dof=self.joint_dof, |
|
|
init_cond_order=init_cond_order, end_cond_order=end_cond_order, |
|
|
device=device |
|
|
) |
|
|
self.gripper_bsp = BSpline( |
|
|
num_basis=num_basis, degree=0, num_dof=self.gripper_dof, device=device |
|
|
) if gripper_zero_order else None |
|
|
|
|
|
|
|
|
self.times = torch.linspace(0, 1.0, seq_len, device=device) |
|
|
self._initialize_weight_bounds() |
|
|
self.to(self.device) |
|
|
|
|
|
def _initialize_weight_bounds(self): |
|
|
"""Initialize weight bounds for normalization.""" |
|
|
total_params = self.num_dof * self.num_basis |
|
|
self.register_buffer("w_min", -0.02 * torch.ones(total_params)) |
|
|
self.register_buffer("w_max", 0.02 * torch.ones(total_params)) |
|
|
|
|
|
def _get_repeated_times(self, batch_size: int) -> torch.Tensor: |
|
|
"""Repeat time grid for batch processing.""" |
|
|
return einops.repeat(self.times, 't -> b t', b=batch_size) |
|
|
|
|
|
@autocast_float32 |
|
|
def _learn_trajectory_params(self, times: torch.Tensor, trajs: torch.Tensor) -> dict: |
|
|
"""Learn B-spline parameters from trajectories.""" |
|
|
joint_params = self.bsp.learn_mp_params_from_trajs(times, trajs[..., :self.joint_dof]) |
|
|
|
|
|
if self.gripper_bsp is not None: |
|
|
gripper_params = self.gripper_bsp.learn_mp_params_from_trajs( |
|
|
times, trajs[..., -self.gripper_dof:] |
|
|
) |
|
|
joint_params['params'] = torch.cat( |
|
|
[joint_params['params'], gripper_params['params']], dim=-1 |
|
|
) |
|
|
|
|
|
return joint_params |
|
|
|
|
|
@autocast_float32 |
|
|
def _reconstruct_trajectory(self, params: torch.Tensor, times: torch.Tensor) -> torch.Tensor: |
|
|
"""Reconstruct trajectory from B-spline parameters.""" |
|
|
joint_params = params[..., :self.joint_dof * self.num_basis] |
|
|
self.bsp.update_inputs(times=times, params=joint_params) |
|
|
position = self.bsp.get_traj_pos() |
|
|
|
|
|
if self.gripper_bsp is not None: |
|
|
gripper_params = params[..., -self.gripper_dof * self.num_basis:] |
|
|
self.gripper_bsp.update_inputs(times=times, params=gripper_params) |
|
|
position = torch.cat([position, self.gripper_bsp.get_traj_pos()], dim=-1) |
|
|
|
|
|
return position |
|
|
|
|
|
def _apply_initial_position_constraint(self, params: torch.Tensor, |
|
|
init_pos: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply initial position constraint to parameters.""" |
|
|
if not self.init_pos or init_pos is None: |
|
|
return params |
|
|
|
|
|
reshaped = einops.rearrange(params, "b (d t) -> b t d", t=self.num_basis, d=self.num_dof) |
|
|
reshaped[:, 0, :self.joint_dof] = init_pos[:, :self.joint_dof] |
|
|
return einops.rearrange(reshaped, "b t d -> b (d t)") |
|
|
|
|
|
@autocast_float32 |
|
|
def compute_weights(self, demos: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute B-spline weights from demonstration trajectories. |
|
|
|
|
|
Args: |
|
|
demos: Demonstration trajectories [batch, seq_len, num_dof] |
|
|
|
|
|
Returns: |
|
|
B-spline weights [batch, num_params] |
|
|
""" |
|
|
times = self._get_repeated_times(demos.shape[0]) |
|
|
return self.bsp.learn_mp_params_from_trajs(times, demos)['params'] |
|
|
|
|
|
def update_weights_bounds_per_batch(self, weights: torch.Tensor): |
|
|
""" |
|
|
Update weight bounds based on batch statistics. |
|
|
|
|
|
Args: |
|
|
weights: Weights to analyze for bounds update |
|
|
""" |
|
|
weights = weights.reshape(-1, self.num_dof * self.num_basis) |
|
|
batch_min = weights.min(dim=0)[0] |
|
|
batch_max = weights.max(dim=0)[0] |
|
|
|
|
|
tolerance = 1e-4 |
|
|
smaller_mask = batch_min < (self.w_min - tolerance) |
|
|
larger_mask = batch_max > (self.w_max + tolerance) |
|
|
|
|
|
if torch.any(smaller_mask): |
|
|
self.w_min[smaller_mask] = batch_min[smaller_mask] |
|
|
if torch.any(larger_mask): |
|
|
self.w_max[larger_mask] = batch_max[larger_mask] |
|
|
|
|
|
def update_times(self, times: torch.Tensor): |
|
|
"""Update the time grid.""" |
|
|
self.times = times |
|
|
|
|
|
@torch.no_grad() |
|
|
@autocast_float32 |
|
|
def encode_discrete(self, trajs: torch.Tensor, update_bounds: bool = True) -> torch.Tensor: |
|
|
""" |
|
|
Encode trajectories to discrete tokens. |
|
|
|
|
|
Args: |
|
|
trajs: Input trajectories [batch, seq_len, num_dof] |
|
|
update_bounds: Update weight bounds from this batch |
|
|
|
|
|
Returns: |
|
|
Discrete tokens [batch, num_basis * num_dof] in range [0, vocab_size-1] |
|
|
""" |
|
|
times = self._get_repeated_times(trajs.shape[0]) |
|
|
params_dict = self._learn_trajectory_params(times, trajs) |
|
|
|
|
|
if update_bounds: |
|
|
self.update_weights_bounds_per_batch(params_dict['params']) |
|
|
|
|
|
params = torch.clamp(params_dict['params'], min=self.w_min, max=self.w_max) |
|
|
tokens = continuous_to_discrete(params, self.w_min, self.w_max, self.vocab_size) |
|
|
return einops.rearrange(tokens, 'b (d t) -> b (t d)', t=self.num_basis, d=self.num_dof) |
|
|
|
|
|
@torch.no_grad() |
|
|
@autocast_float32 |
|
|
def decode_discrete(self, tokens: torch.Tensor, times: torch.Tensor = None, |
|
|
init_pos: torch.Tensor = None) -> torch.Tensor: |
|
|
""" |
|
|
Decode discrete tokens to trajectories. |
|
|
|
|
|
Args: |
|
|
tokens: Discrete tokens [batch, num_basis * num_dof] |
|
|
times: Custom time points (optional) |
|
|
init_pos: Initial position constraint (optional) |
|
|
|
|
|
Returns: |
|
|
Reconstructed trajectories [batch, seq_len, num_dof] |
|
|
""" |
|
|
tokens = einops.rearrange(tokens, 'b (t d) -> b (d t)', t=self.num_basis, d=self.num_dof) |
|
|
params = discrete_to_continuous(tokens, self.w_min, self.w_max, self.vocab_size) |
|
|
|
|
|
if times is None: |
|
|
times = self._get_repeated_times(params.shape[0]) |
|
|
|
|
|
params = self._apply_initial_position_constraint(params, init_pos) |
|
|
return self._reconstruct_trajectory(params, times) |
|
|
|
|
|
@torch.no_grad() |
|
|
@autocast_float32 |
|
|
def encode_continuous(self, trajs: torch.Tensor, update_bounds: bool = True) -> torch.Tensor: |
|
|
""" |
|
|
Encode trajectories to continuous normalized parameters. |
|
|
|
|
|
Args: |
|
|
trajs: Input trajectories [batch, seq_len, num_dof] |
|
|
update_bounds: Update weight bounds from this batch |
|
|
|
|
|
Returns: |
|
|
Normalized parameters [batch, num_params] in range [-1, 1] |
|
|
""" |
|
|
times = self._get_repeated_times(trajs.shape[0]) |
|
|
params_dict = self._learn_trajectory_params(times, trajs) |
|
|
|
|
|
if update_bounds: |
|
|
self.update_weights_bounds_per_batch(params_dict['params']) |
|
|
|
|
|
return normalize_tensor(params_dict['params'], self.w_min, self.w_max) |
|
|
|
|
|
@torch.no_grad() |
|
|
@autocast_float32 |
|
|
def decode_continuous(self, params: torch.Tensor, times: torch.Tensor = None, |
|
|
init_pos: torch.Tensor = None) -> torch.Tensor: |
|
|
""" |
|
|
Decode continuous normalized parameters to trajectories. |
|
|
|
|
|
Args: |
|
|
params: Normalized parameters [batch, num_params] in range [-1, 1] |
|
|
times: Custom time points (optional) |
|
|
init_pos: Initial position constraint (optional) |
|
|
|
|
|
Returns: |
|
|
Reconstructed trajectories [batch, seq_len, num_dof] |
|
|
""" |
|
|
params = denormalize_tensor(params, self.w_min, self.w_max) |
|
|
|
|
|
if times is None: |
|
|
times = self._get_repeated_times(params.shape[0]) |
|
|
|
|
|
params = self._apply_initial_position_constraint(params, init_pos) |
|
|
return self._reconstruct_trajectory(params, times) |
|
|
|
|
|
@autocast_float32 |
|
|
def compute_reconstruction_error(self, raw_traj: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute mean squared reconstruction error. |
|
|
|
|
|
Args: |
|
|
raw_traj: Original trajectory |
|
|
|
|
|
Returns: |
|
|
MSE between original and reconstructed trajectory |
|
|
""" |
|
|
if len(raw_traj.shape) == 2: |
|
|
raw_traj = raw_traj.unsqueeze(-1) |
|
|
|
|
|
tokens = self.encode_discrete(raw_traj) |
|
|
reconstructed = self.decode_discrete(tokens) |
|
|
return torch.mean((raw_traj - reconstructed) ** 2) |
|
|
|
|
|
def _plot_trajectory_comparison(self, original: torch.Tensor, reconstructed: torch.Tensor, |
|
|
title_prefix: str = ""): |
|
|
"""Plot comparison between original and reconstructed trajectories.""" |
|
|
original = original.detach().cpu().numpy() |
|
|
reconstructed = reconstructed.detach().cpu().numpy() |
|
|
x_vals = np.linspace(0, 1.0, original.shape[1]) |
|
|
|
|
|
batch_size, _, dof = original.shape |
|
|
|
|
|
for sample_idx in range(batch_size): |
|
|
_, axes = plt.subplots(dof, 1, figsize=(8, 2 * dof), sharex=True) |
|
|
if dof == 1: |
|
|
axes = [axes] |
|
|
|
|
|
for i in range(dof): |
|
|
axes[i].plot(x_vals, reconstructed[sample_idx, :, i], |
|
|
marker='o', label='Reconstructed', linestyle='-', color='b') |
|
|
axes[i].plot(x_vals, original[sample_idx, :, i], |
|
|
marker='*', label='Ground Truth', linestyle='--', color='r') |
|
|
axes[i].set_ylabel(f"DOF {i + 1}") |
|
|
axes[i].grid(True) |
|
|
axes[i].legend(loc="best") |
|
|
|
|
|
axes[-1].set_xlabel("Time (s)") |
|
|
plt.suptitle(f"{title_prefix}Trajectory Comparison - Sample {sample_idx}") |
|
|
plt.tight_layout(rect=[0, 0, 1, 0.96]) |
|
|
plt.show() |
|
|
|
|
|
def visualize_reconstruction_error_discrete(self, raw_traj: torch.Tensor): |
|
|
"""Visualize discrete encoding reconstruction error.""" |
|
|
tokens = self.encode_discrete(raw_traj, update_bounds=True) |
|
|
reconstructed = self.decode_discrete(tokens) |
|
|
self._plot_trajectory_comparison(raw_traj, reconstructed, "Discrete ") |
|
|
|
|
|
def visualize_reconstruction_error_continuous(self, raw_traj: torch.Tensor): |
|
|
"""Visualize continuous encoding reconstruction error.""" |
|
|
raw_traj = raw_traj.to(torch.float32) |
|
|
if len(raw_traj.shape) == 2: |
|
|
raw_traj = raw_traj.unsqueeze(0) |
|
|
|
|
|
continuous_tokens = self.encode_continuous(raw_traj, update_bounds=True) |
|
|
reconstructed = self.decode_continuous(continuous_tokens) |
|
|
self._plot_trajectory_comparison(raw_traj, reconstructed, "Continuous ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
processor = BeastTokenizer(num_dof=7, vocab_size=256) |
|
|
processor.push_to_hub("zhouhongyi/beast", use_auth_token=True) |
|
|
print("Processor pushed to the hub.") |
|
|
|