| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
| from typing import Iterable, Tuple, Dict, Type |
| import torch |
| import torch.nn as nn |
|
|
| |
| |
| |
| ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {} |
|
|
|
|
| def register_action(name: str): |
| """Decorator for registering a new action space.""" |
| def _wrap(cls): |
| key = name.lower() |
| if key in ACTION_REGISTRY: |
| raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}") |
| ACTION_REGISTRY[key] = cls |
| cls.name = key |
| return cls |
| return _wrap |
|
|
|
|
| def build_action_space(name: str, **kwargs) -> "BaseActionSpace": |
| """Instantiate a registered action space by name.""" |
| key = name.lower() |
| if key not in ACTION_REGISTRY: |
| raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}") |
| return ACTION_REGISTRY[key](**kwargs) |
|
|
|
|
| |
| |
| |
| class BaseActionSpace(nn.Module): |
| """ |
| Abstract base class for all action-space definitions. |
| |
| Each subclass defines: |
| - `dim_action`: dimension of the action vector. |
| - `gripper_idx`: indices of gripper channels. |
| - `compute_loss(pred, target)`: supervised loss for this space. |
| - `preprocess(proprio, action, mode)`: pre-step modifications. |
| - `postprocess(action)`: post-step corrections (e.g. apply sigmoid). |
| """ |
|
|
| name: str = "base" |
| dim_action: int = 0 |
| idx_for_delta: Tuple[int, ...] = () |
|
|
| def __init__(self, **kwargs): |
| super().__init__() |
|
|
| |
| |
| |
| def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: |
| raise NotImplementedError |
|
|
| def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """Alias for compute_loss.""" |
| return self.compute_loss(pred, target) |
|
|
|
|
| def prepare_for_training(self, action, proprio): |
| """Prepare action and proprio for training (e.g. delta encoding).""" |
| return action, proprio |
|
|
| |
| |
| |
| def preprocess( |
| self, |
| proprio: torch.Tensor, |
| action: torch.Tensor, |
| mode: str = "train", |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Default: return unchanged.""" |
| return proprio, action |
|
|
| def postprocess(self, |
| action: torch.Tensor, |
| **kwargs |
| ) -> torch.Tensor: |
| """Default: return unchanged.""" |
| return action |
|
|
| |
| |
| |
| def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None: |
| bad = [i for i in idx if i < 0 or i >= D] |
| if bad: |
| raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}") |
|
|
|
|
| |
| |
| |
| @register_action("ee6d") |
| class EE6DActionSpace(BaseActionSpace): |
| """End-effector layout with xyz, 6D rotation, and gripper channels.""" |
|
|
| dim_action = 20 |
| gripper_idx = (9, 19) |
| GRIPPER_SCALE = 1.0 |
| XYZ_SCALE = 500.0 |
| ROT_SCALE = 10.0 |
|
|
| POS_IDX_1 = (0, 1, 2) |
| POS_IDX_2 = (10, 11, 12) |
| ROT_IDX_1 = (3, 4, 5, 6, 7, 8) |
| ROT_IDX_2 = (13, 14, 15, 16, 17, 18) |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self.mse = nn.MSELoss() |
| self.bce = nn.BCEWithLogitsLoss() |
|
|
| def compute_loss(self, pred, target): |
| assert pred.shape == target.shape, "pred/target shapes must match" |
| B, T, D = pred.shape |
| _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
| |
| g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] |
| gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE |
|
|
| |
| pos_loss = ( |
| self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + |
| self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) |
| ) * self.XYZ_SCALE |
|
|
| |
| rot_loss = ( |
| self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + |
| self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) |
| ) * self.ROT_SCALE |
|
|
| return { |
| "position_loss": pos_loss, |
| "rotate6D_loss": rot_loss, |
| "gripper_loss": gripper_loss, |
| } |
|
|
| def preprocess(self, proprio, action, mode="train"): |
| """Zero-out gripper channels in proprio/action.""" |
| proprio_m = proprio.clone() |
| action_m = action.clone() |
| proprio_m[..., self.gripper_idx] = 0.0 |
| action_m[..., self.gripper_idx] = 0.0 |
| return proprio_m, action_m |
|
|
| def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor: |
| """Apply sigmoid to gripper logits.""" |
| if action.size(-1) > max(self.gripper_idx): |
| action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) |
| return super().postprocess(action, proprio) |
|
|
|
|
|
|
| @register_action("auto") |
| class AutoActionSpace(BaseActionSpace): |
| """ |
| Auto-detecting action space that adapts to any action dimension. |
| |
| - Model outputs max_dim for compatibility with pretrained models |
| - Loss is computed only on the first real_dim dimensions |
| - Postprocess trims output back to real_dim |
| |
| Args: |
| real_dim: The actual action dimension from the dataset/policy feature |
| max_dim: The model's output dimension for pretrained VLA compatibility |
| """ |
|
|
| SCALE = 100.0 |
|
|
| def __init__(self, |
| real_dim: int, |
| max_dim: int = 20, |
| idx_for_delta: Tuple[int, ...] = (), |
| idx_for_mask_proprio: Tuple[int, ...] = (), |
| **kwargs |
| ): |
| super().__init__() |
| self.real_dim = real_dim |
| self.dim_action = max_dim |
| self.idx_for_delta = idx_for_delta |
| self.idx_for_mask_proprio = idx_for_mask_proprio |
| self.mse = nn.MSELoss() |
|
|
| def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor: |
| """Pad real_dim → max_dim (zeros for the dummy channels).""" |
| if x is None: |
| return None |
| if x.size(-1) == self.dim_action: |
| return x |
| if x.size(-1) != self.real_dim: |
| |
| if x.size(-1) < self.real_dim: |
| pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)] |
| pad = x.new_zeros(pad_shape) |
| x = torch.cat([x, pad], dim=-1) |
| else: |
| x = x[..., : self.real_dim] |
|
|
| pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim] |
| pad = x.new_zeros(pad_shape) |
| return torch.cat([x, pad], dim=-1) |
|
|
| def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor: |
| """Trim model output max_dim → real_dim.""" |
| return x[..., : self.real_dim] |
|
|
| def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]: |
| """ |
| Compute loss only on the first real_dim dimensions. |
| |
| pred: [B, T, max_dim] from the model |
| target: [B, T, real_dim] or [B, T, max_dim] |
| |
| Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim]) |
| """ |
| pred = self._pad_to_model_dim(pred) |
| target = self._pad_to_model_dim(target) |
| assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}" |
|
|
| |
| loss = ( |
| self.mse( |
| pred[:, :, : self.real_dim], |
| target[:, :, : self.real_dim], |
| ) |
| * self.SCALE |
| ) |
| return {"loss": loss} |
|
|
| def prepare_for_training(self, action, proprio): |
| action = action.clone() |
| proprio = proprio.clone() |
| |
| if self.idx_for_delta: |
| action[..., self.idx_for_delta] -= proprio[..., self.idx_for_delta] |
| if self.idx_for_mask_proprio: |
| proprio[..., self.idx_for_mask_proprio] = 0.0 |
| return action, proprio |
|
|
| def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"): |
| """ |
| Pad action from real_dim to max_dim for the model. |
| """ |
| proprio = self._pad_to_model_dim(proprio) |
| if self.idx_for_mask_proprio: |
| proprio[..., self.idx_for_mask_proprio] = 0.0 |
| return proprio, self._pad_to_model_dim(action) |
|
|
| def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor: |
| """ |
| Trim model output from max_dim to real_dim for real robot control. |
| """ |
| if self.idx_for_delta: |
| action = action.clone() |
| action[..., self.idx_for_delta] += proprio[..., self.idx_for_delta] |
| return self._trim_to_real_dim(action) |
|
|
| |
| |
| |
| __all__ = [ |
| "BaseActionSpace", |
| "build_action_space", |
| "register_action", |
| "EE6DActionSpace", |
| "JointActionSpace", |
| "AGIBOTEE6DActionSpace", |
| "AutoActionSpace", |
| "ACTION_REGISTRY", |
| ] |
|
|