EBench-XVLA-Generalist / action_hub.py
2toINF's picture
Upload ckpt-200000 (X-VLA generalist)
dd37dbc verified
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from typing import Iterable, Tuple, Dict, Type
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
return ACTION_REGISTRY[key](**kwargs)
# =============================================================================
# Base class
# =============================================================================
class BaseActionSpace(nn.Module):
"""
Abstract base class for all action-space definitions.
Each subclass defines:
- `dim_action`: dimension of the action vector.
- `gripper_idx`: indices of gripper channels.
- `compute_loss(pred, target)`: supervised loss for this space.
- `preprocess(proprio, action, mode)`: pre-step modifications.
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
"""
name: str = "base"
dim_action: int = 0
idx_for_delta: Tuple[int, ...] = ()
def __init__(self, **kwargs):
super().__init__()
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
def prepare_for_training(self, action, proprio):
"""Prepare action and proprio for training (e.g. delta encoding)."""
return action, proprio
# ---------------------------------------------------------------------
# Space-level hooks
# ---------------------------------------------------------------------
def preprocess(
self,
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
def postprocess(self,
action: torch.Tensor,
**kwargs
) -> torch.Tensor:
"""Default: return unchanged."""
return action
# =============================================================================
# Utilities
# =============================================================================
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= D]
if bad:
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
# =============================================================================
# Implementations
# =============================================================================
@register_action("ee6d")
class EE6DActionSpace(BaseActionSpace):
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 1.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
# Gripper BCE
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return super().postprocess(action, proprio)
@register_action("auto")
class AutoActionSpace(BaseActionSpace):
"""
Auto-detecting action space that adapts to any action dimension.
- Model outputs max_dim for compatibility with pretrained models
- Loss is computed only on the first real_dim dimensions
- Postprocess trims output back to real_dim
Args:
real_dim: The actual action dimension from the dataset/policy feature
max_dim: The model's output dimension for pretrained VLA compatibility
"""
SCALE = 100.0
def __init__(self,
real_dim: int,
max_dim: int = 20,
idx_for_delta: Tuple[int, ...] = (),
idx_for_mask_proprio: Tuple[int, ...] = (),
**kwargs
):
super().__init__()
self.real_dim = real_dim
self.dim_action = max_dim # Model-facing dimension
self.idx_for_delta = idx_for_delta
self.idx_for_mask_proprio = idx_for_mask_proprio
self.mse = nn.MSELoss()
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Pad real_dim → max_dim (zeros for the dummy channels)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.real_dim:
# If dimension doesn't match either, pad/trim to real_dim first
if x.size(-1) < self.real_dim:
pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
pad = x.new_zeros(pad_shape)
x = torch.cat([x, pad], dim=-1)
else:
x = x[..., : self.real_dim]
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Trim model output max_dim → real_dim."""
return x[..., : self.real_dim]
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Compute loss only on the first real_dim dimensions.
pred: [B, T, max_dim] from the model
target: [B, T, real_dim] or [B, T, max_dim]
Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
"""
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
# only compute loss on the real dimensions
loss = (
self.mse(
pred[:, :, : self.real_dim],
target[:, :, : self.real_dim],
)
* self.SCALE
)
return {"loss": loss}
def prepare_for_training(self, action, proprio):
action = action.clone()
proprio = proprio.clone()
# apply delta encoding if specified
if self.idx_for_delta:
action[..., self.idx_for_delta] -= proprio[..., self.idx_for_delta]
if self.idx_for_mask_proprio:
proprio[..., self.idx_for_mask_proprio] = 0.0
return action, proprio
def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
"""
Pad action from real_dim to max_dim for the model.
"""
proprio = self._pad_to_model_dim(proprio)
if self.idx_for_mask_proprio:
proprio[..., self.idx_for_mask_proprio] = 0.0
return proprio, self._pad_to_model_dim(action)
def postprocess(self, action: torch.Tensor, proprio: torch.Tensor) -> torch.Tensor:
"""
Trim model output from max_dim to real_dim for real robot control.
"""
if self.idx_for_delta:
action = action.clone()
action[..., self.idx_for_delta] += proprio[..., self.idx_for_delta]
return self._trim_to_real_dim(action)
# =============================================================================
# Exports
# =============================================================================
__all__ = [
"BaseActionSpace",
"build_action_space",
"register_action",
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"AutoActionSpace",
"ACTION_REGISTRY",
]