""" BEAST: B-Spline Encoded Action Sequences Tokenizer A tokenizer for encoding/decoding robot trajectories using B-splines. Converts continuous trajectories to discrete tokens and vice versa. """ import torch import numpy as np import matplotlib.pyplot as plt import einops from typing import Optional, ClassVar from functools import wraps from transformers.processing_utils import ProcessorMixin def autocast_float32(fn): """Decorator to ensure computation runs in float32 precision.""" @wraps(fn) def wrapped(*args, **kwargs): if hasattr(torch.amp, 'autocast'): with torch.amp.autocast('cuda', dtype=torch.float32): return fn(*args, **kwargs) else: with torch.cuda.amp.autocast(dtype=torch.float32): return fn(*args, **kwargs) return wrapped # ============================================================================= # Utility Functions # ============================================================================= def continuous_to_discrete(tensor: torch.Tensor, min_val: torch.Tensor = None, max_val: torch.Tensor = None, num_bins: int = 256) -> torch.Tensor: """ Convert continuous tensor values to discrete tokens. Args: tensor: Input tensor with continuous values min_val: Minimum value for normalization (uses tensor.min() if None) max_val: Maximum value for normalization (uses tensor.max() if None) num_bins: Number of discrete bins (default 256 for 0-255 range) Returns: Discretized tensor with integer values in [0, num_bins-1] """ if min_val is None: min_val = tensor.min() if max_val is None: max_val = tensor.max() assert torch.all(tensor >= min_val - 1e-3), "Input tensor has values below min_val" assert torch.all(tensor <= max_val + 1e-3), "Input tensor has values above max_val" normalized = (tensor - min_val) / (max_val - min_val) normalized = torch.clamp(normalized, 0, 1) discrete = torch.round(normalized * (num_bins - 1)).to(torch.long) return discrete def discrete_to_continuous(discrete_tensor: torch.Tensor, min_val: torch.Tensor = 0, max_val: torch.Tensor = 1, num_bins: int = 256) -> torch.Tensor: """ Convert discrete tokens back to continuous values. Args: discrete_tensor: Input tensor with discrete values in [0, num_bins-1] min_val: Minimum value of target continuous range max_val: Maximum value of target continuous range num_bins: Number of discrete bins Returns: Continuous tensor with values in [min_val, max_val] """ normalized = discrete_tensor.float() / (num_bins - 1) continuous = normalized * (max_val - min_val) + min_val return torch.clamp(continuous, min_val, max_val) def normalize_tensor(tensor: torch.Tensor, w_min: torch.Tensor, w_max: torch.Tensor, norm_min: float = -1.0, norm_max: float = 1.0) -> torch.Tensor: """ Normalize tensor from [w_min, w_max] to [norm_min, norm_max]. Args: tensor: Input tensor to normalize w_min: Minimum bound of original range w_max: Maximum bound of original range norm_min: Target minimum (default -1.0) norm_max: Target maximum (default 1.0) Returns: Normalized tensor in [norm_min, norm_max] """ clipped = torch.clamp(tensor, w_min, w_max) normalized = (clipped - w_min) / (w_max - w_min) return normalized * (norm_max - norm_min) + norm_min def denormalize_tensor(normalized_tensor: torch.Tensor, w_min: torch.Tensor, w_max: torch.Tensor, norm_min: float = -1.0, norm_max: float = 1.0) -> torch.Tensor: """ Denormalize tensor from [norm_min, norm_max] back to [w_min, w_max]. Args: normalized_tensor: Normalized input tensor w_min: Target minimum bound w_max: Target maximum bound norm_min: Source minimum (default -1.0) norm_max: Source maximum (default 1.0) Returns: Denormalized tensor in [w_min, w_max] """ clipped = torch.clamp(normalized_tensor, norm_min, norm_max) denormalized = (clipped - norm_min) / (norm_max - norm_min) return denormalized * (w_max - w_min) + w_min # ============================================================================= # BSpline Class (Merged from UniBSplineBasis + UniformBSpline) # ============================================================================= class BSpline(torch.nn.Module): """ Uniform B-Spline for trajectory representation and fitting. Combines B-spline basis function computation with trajectory fitting and reconstruction. Supports position, velocity, and acceleration computation. Args: num_basis: Number of B-spline basis functions (control points for free params) degree: B-spline degree (3=cubic, 4=quartic, 0=piecewise constant) num_dof: Degrees of freedom (e.g., 7 for robot arm) tau: Time duration of trajectory (default 1.0, normalized time) init_cond_order: Order of initial conditions (0=none, 1=pos, 2=pos+vel) end_cond_order: Order of end conditions (0=none, 1=pos, 2=pos+vel) dtype: Torch data type (default float32) device: Torch device ('cuda' or 'cpu') Example: >>> bspline = BSpline(num_basis=10, degree=4, num_dof=7, device='cuda') >>> result = bspline.learn_mp_params_from_trajs(times, trajectories) >>> reconstructed = bspline.get_traj_pos(times, result['params']) """ def __init__(self, num_basis: int = 10, degree: int = 3, num_dof: int = 1, tau: float = 1.0, init_cond_order: int = 0, end_cond_order: int = 0, dtype: torch.dtype = torch.float32, device: str = 'cpu'): super().__init__() self.num_basis = num_basis self.degree = degree self.num_dof = num_dof self.init_cond_order = init_cond_order self.end_cond_order = end_cond_order self._dtype = dtype self._device = device # Number of control points = basis + boundary conditions self.num_ctrlp = num_basis + init_cond_order + end_cond_order # Create uniform knot vector num_knots = self.degree + 1 + self.num_ctrlp num_internal = num_knots - 2 * self.degree knots = torch.linspace(0, 1, num_internal, dtype=dtype, device=device) knots = torch.cat([ torch.zeros(self.degree, dtype=dtype, device=device), knots, torch.ones(self.degree, dtype=dtype, device=device) ]) self.register_buffer("knots", knots, persistent=False) self.register_buffer("tau", torch.tensor(tau, dtype=dtype, device=device), persistent=False) # Runtime state self.times = None self.params = None self.init_pos = None self.init_vel = None self.end_pos = None self.end_vel = None self.params_init = None self.params_end = None self._pos_cache = None self._vel_cache = None self.add_dim = [] @property def device(self): return self.knots.device @property def dtype(self): return self.knots.dtype @property def num_params(self) -> int: """Total number of learnable parameters.""" return self.num_basis * self.num_dof def _clear_cache(self): """Clear cached computation results.""" self._pos_cache = None self._vel_cache = None def _time_to_phase(self, times: torch.Tensor) -> torch.Tensor: """Convert times to normalized phase [0, 1].""" tau = times.reshape(-1)[-1] self.tau.copy_(tau) return torch.clip(times / self.tau[..., None], 0, 1) def _basis_function(self, i: int, k: int, knots: torch.Tensor, u: torch.Tensor, num_ctrlp: int = None) -> torch.Tensor: """ Compute B-spline basis using de Boor's recursive algorithm. Args: i: Basis function index k: Current degree level knots: Knot vector u: Evaluation points (phase values) num_ctrlp: Number of control points (for boundary handling) Returns: Basis function values at evaluation points """ if num_ctrlp is None: num_ctrlp = self.num_ctrlp if k == 0: # Base case: piecewise constant if i == num_ctrlp - 1: # Handle right endpoint (closed interval) return torch.where((u >= knots[i]) & (u <= knots[i + 1]), 1.0, 0.0).to(dtype=self.dtype, device=self.device) else: return torch.where((u >= knots[i]) & (u < knots[i + 1]), 1.0, 0.0).to(dtype=self.dtype, device=self.device) else: # Recursive case denom1 = knots[i + k] - knots[i] term1 = 0.0 if denom1 == 0 else (u - knots[i]) / denom1 * \ self._basis_function(i, k - 1, knots, u, num_ctrlp) denom2 = knots[i + k + 1] - knots[i + 1] term2 = 0.0 if denom2 == 0 else (knots[i + k + 1] - u) / denom2 * \ self._basis_function(i + 1, k - 1, knots, u, num_ctrlp) return term1 + term2 def basis(self, times: torch.Tensor) -> torch.Tensor: """ Compute B-spline basis values at given time points. Args: times: Time points tensor of shape [*batch, num_times] Returns: Basis values of shape [*batch, num_times, num_ctrlp] """ phase = self._time_to_phase(times) basis = [self._basis_function(i, self.degree, self.knots, phase) for i in range(self.num_ctrlp)] return torch.stack(basis, dim=-1) def vel_basis(self, times: torch.Tensor) -> torch.Tensor: """ Compute velocity B-spline basis (derivative of position basis). Args: times: Time points tensor Returns: Velocity basis values of shape [*batch, num_times, num_ctrlp-1] """ phase = self._time_to_phase(times) vel_knots = self.knots[1:-1] basis = [self._basis_function(i, self.degree - 1, vel_knots, phase, num_ctrlp=self.num_ctrlp - 1) for i in range(self.num_ctrlp - 1)] return torch.stack(basis, dim=-1) def acc_basis(self, times: torch.Tensor) -> torch.Tensor: """ Compute acceleration B-spline basis (second derivative). Args: times: Time points tensor Returns: Acceleration basis values of shape [*batch, num_times, num_ctrlp-2] """ phase = self._time_to_phase(times) acc_knots = self.knots[2:-2] basis = [self._basis_function(i, self.degree - 2, acc_knots, phase, num_ctrlp=self.num_ctrlp - 2) for i in range(self.num_ctrlp - 2)] return torch.stack(basis, dim=-1) def velocity_control_points(self, ctrl_pts: torch.Tensor) -> torch.Tensor: """ Compute velocity control points from position control points. Args: ctrl_pts: Position control points [*batch, num_dof, num_ctrlp] Returns: Velocity control points [*batch, num_dof, num_ctrlp-1] """ diff = ctrl_pts[..., 1:] - ctrl_pts[..., :-1] delta = self.knots[1 + self.degree:self.num_ctrlp + self.degree] - \ self.knots[1:self.num_ctrlp] return diff * (self.degree / delta) def _compute_init_params(self, init_pos: torch.Tensor, init_vel: torch.Tensor = None) -> Optional[torch.Tensor]: """Compute initial boundary condition control points.""" if self.init_cond_order == 0: return None params = init_pos[..., None] if self.init_cond_order == 2 and init_vel is not None: p1 = init_vel * self.tau * (self.knots[1 + self.degree] - self.knots[1]) / self.degree + init_pos params = torch.cat([params, p1[..., None]], dim=-1) return params def _compute_end_params(self, end_pos: torch.Tensor, end_vel: torch.Tensor = None) -> Optional[torch.Tensor]: """Compute end boundary condition control points.""" if self.end_cond_order == 0: return None params = end_pos[..., None] if self.end_cond_order == 2 and end_vel is not None: pn = end_pos - end_vel * self.tau * \ (self.knots[self.num_ctrlp - 1 + self.degree] - self.knots[self.num_ctrlp - 1]) * self.degree params = torch.cat([pn[..., None], params], dim=-1) return params def set_times(self, times: torch.Tensor): """ Set evaluation time points. Args: times: Time points [*batch, num_times] """ self.times = torch.as_tensor(times, dtype=self.dtype, device=self.device) tau = times.reshape(-1)[-1] self.tau.copy_(tau) self._clear_cache() def set_params(self, params: torch.Tensor) -> torch.Tensor: """ Set B-spline parameters (control point weights). Args: params: Parameters [*batch, num_params] Returns: Any unused parameters (for chaining) """ params = torch.as_tensor(params, dtype=self.dtype, device=self.device) assert params.shape[-1] == self.num_params self.add_dim = list(params.shape[:-1]) self.params = params[..., :self.num_params] self._clear_cache() return params[..., self.num_params:] def set_duration(self, duration: float, dt: float): """ Set trajectory duration and generate time grid. Args: duration: Total trajectory duration dt: Time step (control frequency) """ times = torch.linspace(0, duration, int(round(duration / dt)) + 1, dtype=self.dtype, device=self.device) # Expand for batch dimensions for _ in self.add_dim: times = times.unsqueeze(0) times = times.expand(*self.add_dim, -1) self.set_times(times) def set_initial_conditions(self, init_pos: torch.Tensor, init_vel: torch.Tensor = None): """ Set initial position and velocity conditions. Args: init_pos: Initial position [*batch, num_dof] init_vel: Initial velocity [*batch, num_dof] (optional) """ self.init_pos = torch.as_tensor(init_pos, dtype=self.dtype, device=self.device) self.init_vel = torch.as_tensor(init_vel, dtype=self.dtype, device=self.device) if init_vel is not None else None self.params_init = self._compute_init_params(self.init_pos, self.init_vel) self._clear_cache() def set_end_conditions(self, end_pos: torch.Tensor, end_vel: torch.Tensor = None): """ Set end position and velocity conditions. Args: end_pos: End position [*batch, num_dof] end_vel: End velocity [*batch, num_dof] (optional) """ self.end_pos = torch.as_tensor(end_pos, dtype=self.dtype, device=self.device) if end_pos is not None else None self.end_vel = torch.as_tensor(end_vel, dtype=self.dtype, device=self.device) if end_vel is not None else None self.params_end = self._compute_end_params(self.end_pos, self.end_vel) self._clear_cache() def update_inputs(self, times: torch.Tensor = None, params: torch.Tensor = None, init_pos: torch.Tensor = None, init_vel: torch.Tensor = None, **kwargs): """ Update multiple inputs at once. Args: times: Time points params: B-spline parameters init_pos: Initial position init_vel: Initial velocity **kwargs: Additional args (end_pos, end_vel) """ if params is not None: self.set_params(params) if times is not None: self.set_times(times) if init_pos is not None: self.set_initial_conditions(init_pos, init_vel) if kwargs.get('end_pos') is not None or kwargs.get('end_vel') is not None: self.set_end_conditions(kwargs.get('end_pos'), kwargs.get('end_vel')) def _get_full_params(self) -> torch.Tensor: """Get full control points including boundary conditions.""" params = self.params.reshape(*self.add_dim, self.num_dof, -1) if self.params_init is not None: params = torch.cat([self.params_init, params], dim=-1) if self.params_end is not None: params = torch.cat([params, self.params_end], dim=-1) return params def get_traj_pos(self, times: torch.Tensor = None, params: torch.Tensor = None, init_pos: torch.Tensor = None, init_vel: torch.Tensor = None, flat_shape: bool = False, **kwargs) -> torch.Tensor: """ Compute trajectory positions from B-spline parameters. Args: times: Time points [*batch, num_times] params: B-spline parameters [*batch, num_params] init_pos: Initial position (optional) init_vel: Initial velocity (optional) flat_shape: If True, return flattened [*batch, num_dof*num_times] Returns: Position trajectory [*batch, num_times, num_dof] or flattened """ self.update_inputs(times, params, init_pos, init_vel, **kwargs) if self._pos_cache is not None: pos = self._pos_cache else: assert self.params is not None full_params = self._get_full_params() basis = self.basis(self.times) # Einsum: [*batch, num_times, num_ctrlp] @ [*batch, num_dof, num_ctrlp] pos = torch.einsum('...ik,...jk->...ij', basis, full_params) self._pos_cache = pos if flat_shape: pos = torch.einsum('...ji->...ij', pos).reshape(*self.add_dim, -1) return pos def get_traj_vel(self, times: torch.Tensor = None, params: torch.Tensor = None, init_pos: torch.Tensor = None, init_vel: torch.Tensor = None, flat_shape: bool = False, **kwargs) -> torch.Tensor: """ Compute trajectory velocities from B-spline parameters. Args: times: Time points [*batch, num_times] params: B-spline parameters [*batch, num_params] init_pos: Initial position (optional) init_vel: Initial velocity (optional) flat_shape: If True, return flattened [*batch, num_dof*num_times] Returns: Velocity trajectory [*batch, num_times, num_dof] or flattened """ self.update_inputs(times, params, init_pos, init_vel, **kwargs) if self._vel_cache is not None: vel = self._vel_cache else: assert self.params is not None full_params = self._get_full_params() vel_ctrlp = self.velocity_control_points(full_params) / self.tau vel_basis = self.vel_basis(self.times) vel = torch.einsum('...ik,...jk->...ij', vel_basis, vel_ctrlp) self._vel_cache = vel if flat_shape: vel = torch.einsum('...ji->...ij', vel).reshape(*self.add_dim, -1) return vel def _basis_multi_dofs(self, times: torch.Tensor) -> torch.Tensor: """ Compute multi-DOF basis matrix for least squares fitting. Args: times: Time points [*batch, num_times] Returns: Block-diagonal basis [*batch, num_dof*num_times, num_dof*num_basis] """ add_dim = list(times.shape[:-1]) num_times = times.shape[-1] basis_single = self.basis(times)[..., self.init_cond_order:self.num_ctrlp - self.end_cond_order] basis_multi = torch.zeros(*add_dim, self.num_dof * num_times, self.num_dof * self.num_basis, dtype=self.dtype, device=self.device) for i in range(self.num_dof): row_slice = slice(i * num_times, (i + 1) * num_times) col_slice = slice(i * self.num_basis, (i + 1) * self.num_basis) basis_multi[..., row_slice, col_slice] = basis_single return basis_multi def learn_mp_params_from_trajs(self, times: torch.Tensor, trajs: torch.Tensor, reg: float = 1e-4, **kwargs) -> dict: """ Learn B-spline parameters from trajectory data via least squares. Args: times: Time points [*batch, num_times] trajs: Trajectory data [*batch, num_times, num_dof] reg: Regularization coefficient (default 1e-4) **kwargs: Optional init_pos, init_vel, end_pos, end_vel Returns: Dict with 'params' and boundary conditions """ assert trajs.shape[:-1] == times.shape assert trajs.shape[-1] == self.num_dof times = torch.as_tensor(times, dtype=self.dtype, device=self.device) trajs = torch.as_tensor(trajs, dtype=self.dtype, device=self.device) self.add_dim = list(trajs.shape[:-2]) self.set_times(times) # Initialize dummy params for boundary condition contribution dummy_params = torch.zeros(*self.add_dim, self.num_dof, self.num_basis, device=self.device, dtype=self.dtype) # Handle initial conditions if self.init_cond_order != 0: init_pos = kwargs.get("init_pos", trajs[..., 0, :]) dt = times[..., 1] - times[..., 0] init_vel = kwargs.get("init_vel", torch.diff(trajs, dim=-2)[..., 0, :] / dt[..., None]) self.set_initial_conditions(init_pos, init_vel) if self.params_init is not None: dummy_params = torch.cat([self.params_init, dummy_params], dim=-1) # Handle end conditions if self.end_cond_order != 0: end_pos = kwargs.get("end_pos", trajs[..., -1, :]) dt = times[..., 1] - times[..., 0] end_vel = kwargs.get("end_vel", torch.diff(trajs, dim=-2)[..., -1, :] / dt[..., None]) self.set_end_conditions(end_pos, end_vel) if self.params_end is not None: dummy_params = torch.cat([dummy_params, self.params_end], dim=-1) # Compute position from boundary conditions only basis_single = self.basis(times) pos_boundary = torch.einsum('...ik,...jk->...ij', basis_single, dummy_params) pos_boundary = torch.einsum('...ij->...ji', pos_boundary).reshape(*self.add_dim, -1) # Build least squares system: A @ w = B basis_multi = self._basis_multi_dofs(self.times) A = torch.einsum('...ki,...kj->...ij', basis_multi, basis_multi) A += torch.eye(self.num_params, dtype=self.dtype, device=self.device) * reg # Flatten trajectories and subtract boundary contribution trajs_flat = torch.einsum("...ij->...ji", trajs).reshape(*self.add_dim, -1) pos_residual = trajs_flat - pos_boundary B = torch.einsum('...ki,...k->...i', basis_multi, pos_residual) # Solve for parameters params = torch.linalg.solve(A, B) self.set_params(params) return { "params": params, "init_pos": self.init_pos, "init_vel": self.init_vel, "end_pos": self.end_pos, "end_vel": self.end_vel, } # ============================================================================= # BeastTokenizer Class # ============================================================================= class BeastTokenizer(torch.nn.Module, ProcessorMixin): """ B-spline based tokenizer for trajectory encoding/decoding. Converts continuous robot trajectories to discrete tokens and vice versa using B-splines. Supports separate handling for continuous actions (joints) and discrete states (e.g., binary gripper). Args: num_dof: Total degrees of freedom (joints + gripper) num_basis: Number of B-spline basis functions seq_len: Trajectory sequence length vocab_size: Discrete token vocabulary size (default 256) degree_p: B-spline degree (default 4 = quartic) gripper_zero_order: Use zero-order splines for gripper (piecewise constant) gripper_dof: Number of gripper DOFs (only used if gripper_zero_order=True) init_cond_order: Initial condition order (0=none, 1=pos, 2=pos+vel) end_cond_order: End condition order enforce_init_pos: Enforce initial position constraint in decoding device: Torch device ('cuda' or 'cpu') Example: >>> tokenizer = BeastTokenizer(num_dof=7, num_basis=10, seq_len=50) >>> tokens = tokenizer.encode_discrete(trajectories) >>> reconstructed = tokenizer.decode_discrete(tokens) """ DEFAULT_DT = 0.01 # 100 Hz sampling rate attributes: ClassVar[list[str]] = [] def __init__(self, num_dof: int = 1, num_basis: int = 10, seq_len: int = 50, vocab_size: int = 256, degree_p: int = 4, gripper_zero_order: bool = False, gripper_dof: int = 1, init_cond_order: int = 0, end_cond_order: int = 0, enforce_init_pos: bool = False, device: str = "cuda"): torch.nn.Module.__init__(self) ProcessorMixin.__init__(self) self.device = device self.seq_length = seq_len self.vocab_size = vocab_size self.num_basis = num_basis self.enforce_init_pos = enforce_init_pos self.init_cond_order = init_cond_order self.end_cond_order = end_cond_order self.dt = self.DEFAULT_DT self.init_pos = None # DOF distribution self.gripper_dof = gripper_dof if gripper_zero_order else 0 self.joint_dof = num_dof - self.gripper_dof self.num_dof = self.joint_dof + self.gripper_dof # Create B-spline components self.bsp = BSpline( num_basis=num_basis, degree=degree_p, num_dof=self.joint_dof, init_cond_order=init_cond_order, end_cond_order=end_cond_order, device=device ) self.gripper_bsp = BSpline( num_basis=num_basis, degree=0, num_dof=self.gripper_dof, device=device ) if gripper_zero_order else None # Time grid (normalized [0, 1]) self.times = torch.linspace(0, 1.0, seq_len, device=device) self._initialize_weight_bounds() self.to(self.device) def _initialize_weight_bounds(self): """Initialize weight bounds for normalization.""" total_params = self.num_dof * self.num_basis self.register_buffer("w_min", -0.02 * torch.ones(total_params)) self.register_buffer("w_max", 0.02 * torch.ones(total_params)) def _get_repeated_times(self, batch_size: int) -> torch.Tensor: """Repeat time grid for batch processing.""" return einops.repeat(self.times, 't -> b t', b=batch_size) @autocast_float32 def _learn_trajectory_params(self, times: torch.Tensor, trajs: torch.Tensor) -> dict: """Learn B-spline parameters from trajectories.""" joint_params = self.bsp.learn_mp_params_from_trajs(times, trajs[..., :self.joint_dof]) if self.gripper_bsp is not None: gripper_params = self.gripper_bsp.learn_mp_params_from_trajs( times, trajs[..., -self.gripper_dof:] ) joint_params['params'] = torch.cat( [joint_params['params'], gripper_params['params']], dim=-1 ) return joint_params @autocast_float32 def _reconstruct_trajectory(self, params: torch.Tensor, times: torch.Tensor) -> torch.Tensor: """Reconstruct trajectory from B-spline parameters.""" joint_params = params[..., :self.joint_dof * self.num_basis] self.bsp.update_inputs(times=times, params=joint_params) position = self.bsp.get_traj_pos() if self.gripper_bsp is not None: gripper_params = params[..., -self.gripper_dof * self.num_basis:] self.gripper_bsp.update_inputs(times=times, params=gripper_params) position = torch.cat([position, self.gripper_bsp.get_traj_pos()], dim=-1) return position def _apply_initial_position_constraint(self, params: torch.Tensor, init_pos: torch.Tensor) -> torch.Tensor: """Apply initial position constraint to parameters.""" if not self.init_pos or init_pos is None: return params reshaped = einops.rearrange(params, "b (d t) -> b t d", t=self.num_basis, d=self.num_dof) reshaped[:, 0, :self.joint_dof] = init_pos[:, :self.joint_dof] return einops.rearrange(reshaped, "b t d -> b (d t)") @autocast_float32 def compute_weights(self, demos: torch.Tensor) -> torch.Tensor: """ Compute B-spline weights from demonstration trajectories. Args: demos: Demonstration trajectories [batch, seq_len, num_dof] Returns: B-spline weights [batch, num_params] """ times = self._get_repeated_times(demos.shape[0]) return self.bsp.learn_mp_params_from_trajs(times, demos)['params'] def update_weights_bounds_per_batch(self, weights: torch.Tensor): """ Update weight bounds based on batch statistics. Args: weights: Weights to analyze for bounds update """ weights = weights.reshape(-1, self.num_dof * self.num_basis) batch_min = weights.min(dim=0)[0] batch_max = weights.max(dim=0)[0] tolerance = 1e-4 smaller_mask = batch_min < (self.w_min - tolerance) larger_mask = batch_max > (self.w_max + tolerance) if torch.any(smaller_mask): self.w_min[smaller_mask] = batch_min[smaller_mask] if torch.any(larger_mask): self.w_max[larger_mask] = batch_max[larger_mask] def update_times(self, times: torch.Tensor): """Update the time grid.""" self.times = times @torch.no_grad() @autocast_float32 def encode_discrete(self, trajs: torch.Tensor, update_bounds: bool = True) -> torch.Tensor: """ Encode trajectories to discrete tokens. Args: trajs: Input trajectories [batch, seq_len, num_dof] update_bounds: Update weight bounds from this batch Returns: Discrete tokens [batch, num_basis * num_dof] in range [0, vocab_size-1] """ times = self._get_repeated_times(trajs.shape[0]) params_dict = self._learn_trajectory_params(times, trajs) if update_bounds: self.update_weights_bounds_per_batch(params_dict['params']) params = torch.clamp(params_dict['params'], min=self.w_min, max=self.w_max) tokens = continuous_to_discrete(params, self.w_min, self.w_max, self.vocab_size) return einops.rearrange(tokens, 'b (d t) -> b (t d)', t=self.num_basis, d=self.num_dof) @torch.no_grad() @autocast_float32 def decode_discrete(self, tokens: torch.Tensor, times: torch.Tensor = None, init_pos: torch.Tensor = None) -> torch.Tensor: """ Decode discrete tokens to trajectories. Args: tokens: Discrete tokens [batch, num_basis * num_dof] times: Custom time points (optional) init_pos: Initial position constraint (optional) Returns: Reconstructed trajectories [batch, seq_len, num_dof] """ tokens = einops.rearrange(tokens, 'b (t d) -> b (d t)', t=self.num_basis, d=self.num_dof) params = discrete_to_continuous(tokens, self.w_min, self.w_max, self.vocab_size) if times is None: times = self._get_repeated_times(params.shape[0]) params = self._apply_initial_position_constraint(params, init_pos) return self._reconstruct_trajectory(params, times) @torch.no_grad() @autocast_float32 def encode_continuous(self, trajs: torch.Tensor, update_bounds: bool = True) -> torch.Tensor: """ Encode trajectories to continuous normalized parameters. Args: trajs: Input trajectories [batch, seq_len, num_dof] update_bounds: Update weight bounds from this batch Returns: Normalized parameters [batch, num_params] in range [-1, 1] """ times = self._get_repeated_times(trajs.shape[0]) params_dict = self._learn_trajectory_params(times, trajs) if update_bounds: self.update_weights_bounds_per_batch(params_dict['params']) return normalize_tensor(params_dict['params'], self.w_min, self.w_max) @torch.no_grad() @autocast_float32 def decode_continuous(self, params: torch.Tensor, times: torch.Tensor = None, init_pos: torch.Tensor = None) -> torch.Tensor: """ Decode continuous normalized parameters to trajectories. Args: params: Normalized parameters [batch, num_params] in range [-1, 1] times: Custom time points (optional) init_pos: Initial position constraint (optional) Returns: Reconstructed trajectories [batch, seq_len, num_dof] """ params = denormalize_tensor(params, self.w_min, self.w_max) if times is None: times = self._get_repeated_times(params.shape[0]) params = self._apply_initial_position_constraint(params, init_pos) return self._reconstruct_trajectory(params, times) @autocast_float32 def compute_reconstruction_error(self, raw_traj: torch.Tensor) -> torch.Tensor: """ Compute mean squared reconstruction error. Args: raw_traj: Original trajectory Returns: MSE between original and reconstructed trajectory """ if len(raw_traj.shape) == 2: raw_traj = raw_traj.unsqueeze(-1) tokens = self.encode_discrete(raw_traj) reconstructed = self.decode_discrete(tokens) return torch.mean((raw_traj - reconstructed) ** 2) def _plot_trajectory_comparison(self, original: torch.Tensor, reconstructed: torch.Tensor, title_prefix: str = ""): """Plot comparison between original and reconstructed trajectories.""" original = original.detach().cpu().numpy() reconstructed = reconstructed.detach().cpu().numpy() x_vals = np.linspace(0, 1.0, original.shape[1]) batch_size, _, dof = original.shape for sample_idx in range(batch_size): _, axes = plt.subplots(dof, 1, figsize=(8, 2 * dof), sharex=True) if dof == 1: axes = [axes] for i in range(dof): axes[i].plot(x_vals, reconstructed[sample_idx, :, i], marker='o', label='Reconstructed', linestyle='-', color='b') axes[i].plot(x_vals, original[sample_idx, :, i], marker='*', label='Ground Truth', linestyle='--', color='r') axes[i].set_ylabel(f"DOF {i + 1}") axes[i].grid(True) axes[i].legend(loc="best") axes[-1].set_xlabel("Time (s)") plt.suptitle(f"{title_prefix}Trajectory Comparison - Sample {sample_idx}") plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.show() def visualize_reconstruction_error_discrete(self, raw_traj: torch.Tensor): """Visualize discrete encoding reconstruction error.""" tokens = self.encode_discrete(raw_traj, update_bounds=True) reconstructed = self.decode_discrete(tokens) self._plot_trajectory_comparison(raw_traj, reconstructed, "Discrete ") def visualize_reconstruction_error_continuous(self, raw_traj: torch.Tensor): """Visualize continuous encoding reconstruction error.""" raw_traj = raw_traj.to(torch.float32) if len(raw_traj.shape) == 2: raw_traj = raw_traj.unsqueeze(0) continuous_tokens = self.encode_continuous(raw_traj, update_bounds=True) reconstructed = self.decode_continuous(continuous_tokens) self._plot_trajectory_comparison(raw_traj, reconstructed, "Continuous ") # ============================================================================= # Main # ============================================================================= if __name__ == "__main__": processor = BeastTokenizer(num_dof=7, vocab_size=256) processor.push_to_hub("zhouhongyi/beast", use_auth_token=True) print("Processor pushed to the hub.")