# ------------------------------------------------------------------------------ # Copyright 2025 2toINF (https://github.com/2toINF) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------------ from __future__ import annotations from typing import Iterable, Tuple, Dict, Type import torch import torch.nn as nn # ============================================================================= # Registry # ============================================================================= ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {} def register_action(name: str): """Decorator for registering a new action space.""" def _wrap(cls): key = name.lower() if key in ACTION_REGISTRY: raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}") ACTION_REGISTRY[key] = cls cls.name = key return cls return _wrap def build_action_space(name: str, **kwargs) -> "BaseActionSpace": """Instantiate a registered action space by name.""" key = name.lower() if key not in ACTION_REGISTRY: raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}") return ACTION_REGISTRY[key](**kwargs) # ============================================================================= # Base class # ============================================================================= class BaseActionSpace(nn.Module): """ Abstract base class for all action-space definitions. Each subclass defines: - `dim_action`: dimension of the action vector. - `gripper_idx`: indices of gripper channels. - `compute_loss(pred, target)`: supervised loss for this space. - `preprocess(proprio, action, mode)`: pre-step modifications. - `postprocess(action)`: post-step corrections (e.g. apply sigmoid). """ name: str = "base" dim_action: int = 0 idx_for_delta: Tuple[int, ...] = () def __init__(self, **kwargs): super().__init__() # --------------------------------------------------------------------- # Core supervised loss # --------------------------------------------------------------------- def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: raise NotImplementedError def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: """Alias for compute_loss.""" return self.compute_loss(pred, target) def prepare_for_training(self, action, proprio): """Prepare action and proprio for training (e.g. delta encoding).""" return action, proprio # --------------------------------------------------------------------- # Space-level hooks # --------------------------------------------------------------------- def preprocess( self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train", ) -> Tuple[torch.Tensor, torch.Tensor]: """Default: return unchanged.""" return proprio, action def postprocess(self, action: torch.Tensor, **kwargs ) -> torch.Tensor: """Default: return unchanged.""" return action # ============================================================================= # Utilities # ============================================================================= def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None: bad = [i for i in idx if i < 0 or i >= D] if bad: raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}") # ============================================================================= # Implementations # ============================================================================= @register_action("ee6d") class EE6DActionSpace(BaseActionSpace): """End-effector layout with xyz, 6D rotation, and gripper channels.""" dim_action = 20 gripper_idx = (9, 19) GRIPPER_SCALE = 1.0 XYZ_SCALE = 500.0 ROT_SCALE = 10.0 POS_IDX_1 = (0, 1, 2) POS_IDX_2 = (10, 11, 12) ROT_IDX_1 = (3, 4, 5, 6, 7, 8) ROT_IDX_2 = (13, 14, 15, 16, 17, 18) def __init__(self, **kwargs): super().__init__(**kwargs) self.mse = nn.MSELoss() self.bce = nn.BCEWithLogitsLoss() def compute_loss(self, pred, target): assert pred.shape == target.shape, "pred/target shapes must match" B, T, D = pred.shape _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") # Gripper BCE g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE # XYZ position pos_loss = ( self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) ) * self.XYZ_SCALE # Rotation 6D rot_loss = ( self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) ) * self.ROT_SCALE return { "position_loss": pos_loss, "rotate6D_loss": rot_loss, "gripper_loss": gripper_loss, } def preprocess(self, proprio, action, mode="train"): """Zero-out gripper channels in proprio/action.""" proprio_m = proprio.clone() action_m = action.clone() proprio_m[..., self.gripper_idx] = 0.0 action_m[..., self.gripper_idx] = 0.0 return proprio_m, action_m def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor: """Apply sigmoid to gripper logits.""" if action.size(-1) > max(self.gripper_idx): action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) return super().postprocess(action, proprio) @register_action("auto") class AutoActionSpace(BaseActionSpace): """ Auto-detecting action space that adapts to any action dimension. - Model outputs max_dim for compatibility with pretrained models - Loss is computed only on the first real_dim dimensions - Postprocess trims output back to real_dim Args: real_dim: The actual action dimension from the dataset/policy feature max_dim: The model's output dimension for pretrained VLA compatibility """ SCALE = 100.0 def __init__(self, real_dim: int, max_dim: int = 20, idx_for_delta: Tuple[int, ...] = (), idx_for_mask_proprio: Tuple[int, ...] = (), **kwargs ): super().__init__() self.real_dim = real_dim self.dim_action = max_dim # Model-facing dimension self.idx_for_delta = idx_for_delta self.idx_for_mask_proprio = idx_for_mask_proprio self.mse = nn.MSELoss() def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor: """Pad real_dim → max_dim (zeros for the dummy channels).""" if x is None: return None if x.size(-1) == self.dim_action: return x if x.size(-1) != self.real_dim: # If dimension doesn't match either, pad/trim to real_dim first if x.size(-1) < self.real_dim: pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)] pad = x.new_zeros(pad_shape) x = torch.cat([x, pad], dim=-1) else: x = x[..., : self.real_dim] pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim] pad = x.new_zeros(pad_shape) return torch.cat([x, pad], dim=-1) def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor: """Trim model output max_dim → real_dim.""" return x[..., : self.real_dim] def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]: """ Compute loss only on the first real_dim dimensions. pred: [B, T, max_dim] from the model target: [B, T, real_dim] or [B, T, max_dim] Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim]) """ pred = self._pad_to_model_dim(pred) target = self._pad_to_model_dim(target) assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}" # only compute loss on the real dimensions loss = ( self.mse( pred[:, :, : self.real_dim], target[:, :, : self.real_dim], ) * self.SCALE ) return {"loss": loss} def prepare_for_training(self, action, proprio): action = action.clone() proprio = proprio.clone() # apply delta encoding if specified if self.idx_for_delta: action[..., self.idx_for_delta] -= proprio[..., self.idx_for_delta] if self.idx_for_mask_proprio: proprio[..., self.idx_for_mask_proprio] = 0.0 return action, proprio def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"): """ Pad action from real_dim to max_dim for the model. """ proprio = self._pad_to_model_dim(proprio) if self.idx_for_mask_proprio: proprio[..., self.idx_for_mask_proprio] = 0.0 return proprio, self._pad_to_model_dim(action) def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor: """ Trim model output from max_dim to real_dim for real robot control. """ if self.idx_for_delta: action = action.clone() action[..., self.idx_for_delta] += proprio[..., self.idx_for_delta] return self._trim_to_real_dim(action) # ============================================================================= # Exports # ============================================================================= __all__ = [ "BaseActionSpace", "build_action_space", "register_action", "EE6DActionSpace", "JointActionSpace", "AGIBOTEE6DActionSpace", "AutoActionSpace", "ACTION_REGISTRY", ]