| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
| from typing import Iterable, Tuple, Dict, Type |
| import torch |
| import torch.nn as nn |
|
|
| |
| |
| |
| ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {} |
|
|
|
|
| def register_action(name: str): |
| """Decorator for registering a new action space.""" |
| def _wrap(cls): |
| key = name.lower() |
| if key in ACTION_REGISTRY: |
| raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}") |
| ACTION_REGISTRY[key] = cls |
| cls.name = key |
| return cls |
| return _wrap |
|
|
|
|
| def build_action_space(name: str, **kwargs) -> "BaseActionSpace": |
| """Instantiate a registered action space by name.""" |
| key = name.lower() |
| if key not in ACTION_REGISTRY: |
| raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}") |
| return ACTION_REGISTRY[key](**kwargs) |
|
|
|
|
| |
| |
| |
| class BaseActionSpace(nn.Module): |
| """ |
| Abstract base class for all action-space definitions. |
| |
| Each subclass defines: |
| - `dim_action`: dimension of the action vector. |
| - `gripper_idx`: indices of gripper channels. |
| - `compute_loss(pred, target)`: supervised loss for this space. |
| - `preprocess(proprio, action, mode)`: pre-step modifications. |
| - `postprocess(action)`: post-step corrections (e.g. apply sigmoid). |
| """ |
|
|
| name: str = "base" |
| dim_action: int = 0 |
| gripper_idx: Tuple[int, ...] = () |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| |
| |
| |
| def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: |
| raise NotImplementedError |
|
|
| def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """Alias for compute_loss.""" |
| return self.compute_loss(pred, target) |
|
|
| |
| |
| |
| def preprocess( |
| self, |
| proprio: torch.Tensor, |
| action: torch.Tensor, |
| mode: str = "train", |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Default: return unchanged.""" |
| return proprio, action |
|
|
| def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
| """Default: return unchanged.""" |
| return action |
|
|
|
|
| |
| |
| |
| def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None: |
| bad = [i for i in idx if i < 0 or i >= D] |
| if bad: |
| raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}") |
|
|
|
|
| |
| |
| |
| @register_action("ee6d") |
| class EE6DActionSpace(BaseActionSpace): |
| """End-effector layout with xyz, 6D rotation, and gripper channels.""" |
|
|
| dim_action = 20 |
| gripper_idx = (9, 19) |
| GRIPPER_SCALE = 1.0 |
| XYZ_SCALE = 100.0 |
| ROT_SCALE = 10.0 |
|
|
| POS_IDX_1 = (0, 1, 2) |
| POS_IDX_2 = (10, 11, 12) |
| ROT_IDX_1 = (3, 4, 5, 6, 7, 8) |
| ROT_IDX_2 = (13, 14, 15, 16, 17, 18) |
|
|
| def __init__(self): |
| super().__init__() |
| self.mse = nn.MSELoss() |
| self.bce = nn.BCEWithLogitsLoss() |
|
|
| def compute_loss(self, pred, target): |
| assert pred.shape == target.shape, "pred/target shapes must match" |
| B, T, D = pred.shape |
| _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
| |
| g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] |
| gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE |
|
|
| |
| pos_loss = ( |
| self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + |
| self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) |
| ) * self.XYZ_SCALE |
|
|
| |
| rot_loss = ( |
| self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + |
| self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) |
| ) * self.ROT_SCALE |
|
|
| return { |
| "position_loss": pos_loss, |
| "rotate6D_loss": rot_loss, |
| "gripper_loss": gripper_loss, |
| } |
|
|
| def preprocess(self, proprio, action, mode="train"): |
| """Zero-out gripper channels in proprio/action.""" |
| proprio_m = proprio.clone() |
| action_m = action.clone() |
| proprio_m[..., self.gripper_idx] = 0.0 |
| action_m[..., self.gripper_idx] = 0.0 |
| return proprio_m, action_m |
|
|
| def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
| """Apply sigmoid to gripper logits.""" |
| if action.size(-1) > max(self.gripper_idx): |
| action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) |
| return action |
|
|
|
|
| @register_action("joint") |
| class JointActionSpace(BaseActionSpace): |
| """Joint-space layout with joints + gripper only.""" |
|
|
| dim_action = 14 |
| gripper_idx = (6, 13) |
| GRIPPER_SCALE = 0.1 |
| JOINTS_SCALE = 1.0 |
|
|
| def __init__(self): |
| super().__init__() |
| self.mse = nn.MSELoss() |
| self.bce = nn.BCEWithLogitsLoss() |
|
|
| def compute_loss(self, pred, target): |
| assert pred.shape == target.shape |
| B, T, D = pred.shape |
| _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
| g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] |
| gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE |
|
|
| joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx)) |
| joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE |
|
|
| return { |
| "joints_loss": joints_loss, |
| "gripper_loss": gripper_loss, |
| } |
|
|
| def preprocess(self, proprio, action, mode="train"): |
| """Zero-out gripper channels in proprio/action.""" |
| proprio_m = proprio.clone() |
| action_m = action.clone() |
| proprio_m[..., self.gripper_idx] = 0.0 |
| action_m[..., self.gripper_idx] = 0.0 |
| return proprio_m, action_m |
|
|
| def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
| """Apply sigmoid to gripper logits.""" |
| if action.size(-1) > max(self.gripper_idx): |
| action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) |
| return action |
|
|
|
|
| @register_action("agibot_ee6d") |
| class AGIBOTEE6DActionSpace(BaseActionSpace): |
| """AGI-bot variant of EE6DActionSpace using MSE for all components.""" |
|
|
| dim_action = 20 |
| gripper_idx = (9, 19) |
| GRIPPER_SCALE = 10.0 |
| XYZ_SCALE = 100.0 |
| ROT_SCALE = 10.0 |
| POS_IDX_1 = (0, 1, 2) |
| POS_IDX_2 = (10, 11, 12) |
| ROT_IDX_1 = (3, 4, 5, 6, 7, 8) |
| ROT_IDX_2 = (13, 14, 15, 16, 17, 18) |
|
|
| def __init__(self): |
| super().__init__() |
| self.mse = nn.MSELoss() |
|
|
| def compute_loss(self, pred, target): |
| assert pred.shape == target.shape |
| B, T, D = pred.shape |
| _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
| gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE |
| pos_loss = ( |
| self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + |
| self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) |
| ) * self.XYZ_SCALE |
| rot_loss = ( |
| self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + |
| self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) |
| ) * self.ROT_SCALE |
|
|
| return { |
| "position_loss": pos_loss, |
| "rotate6D_loss": rot_loss, |
| "gripper_loss": gripper_loss, |
| } |
|
|
| def preprocess(self, proprio, action, mode="train"): |
| """No preprocessing applied in AGIBOT variant.""" |
| return proprio, action |
|
|
| def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
| """AGIBOT does not postprocess.""" |
| return action |
|
|
|
|
| |
| |
| |
| __all__ = [ |
| "BaseActionSpace", |
| "build_action_space", |
| "register_action", |
| "EE6DActionSpace", |
| "JointActionSpace", |
| "AGIBOTEE6DActionSpace", |
| "ACTION_REGISTRY", |
| ] |
|
|