Lakesenberg's picture
Upload act_ricky_neuracore model checkpoint
47e954c verified
"""ACT: Action Chunking with Transformers implementation.
This module implements the ACT (Action Chunking with Transformers) model
from "Learning fine-grained bimanual manipulation with low-cost hardware"
(Zhao et al., 2023). ACT uses a transformer architecture with latent variable
modeling to predict action sequences for robot manipulation tasks.
Reference: Zhao, Tony Z., et al. "Learning fine-grained bimanual manipulation
with low-cost hardware." arXiv preprint arXiv:2304.13705 (2023).
"""
import logging
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from neuracore_types import (
BatchedJointData,
BatchedNCData,
BatchedParallelGripperOpenAmountData,
BatchedRGBData,
CameraDataStats,
DataItemStats,
DataType,
JointDataStats,
ModelInitDescription,
ParallelGripperOpenAmountDataStats,
)
from neuracore.ml import (
BatchedInferenceInputs,
BatchedTrainingOutputs,
BatchedTrainingSamples,
NeuracoreModel,
)
from neuracore.ml.algorithm_utils.normalizer import MeanStdNormalizer
from .modules import (
ACTImageEncoder,
PositionalEncoding,
TransformerDecoder,
TransformerEncoder,
)
logger = logging.getLogger(__name__)
PROPRIO_NORMALIZER = MeanStdNormalizer # or MinMaxNormalizer
ACTION_NORMALIZER = MeanStdNormalizer # or MinMaxNormalizer
RESNET_MEAN = [0.485, 0.456, 0.406]
RESNET_STD = [0.229, 0.224, 0.225]
class ACT(NeuracoreModel):
"""Implementation of ACT (Action Chunking Transformer) model.
ACT is a transformer-based architecture that learns to predict sequences
of robot actions by encoding visual observations and proprioceptive state
into a latent representation, then decoding action chunks autoregressively.
The model uses a variational autoencoder framework with separate encoders
for visual features and action sequences, combined with a transformer
decoder for action generation.
"""
def __init__(
self,
model_init_description: ModelInitDescription,
hidden_dim: int = 512,
num_encoder_layers: int = 4,
num_decoder_layers: int = 1,
nheads: int = 8,
dim_feedforward: int = 3200,
dropout: float = 0.1,
use_resnet_stats: bool = True,
lr: float = 1e-4,
freeze_backbone: bool = False,
lr_backbone: float = 1e-5,
weight_decay: float = 1e-4,
kl_weight: float = 10.0,
latent_dim: int = 512,
):
"""Initialize the ACT model.
Args:
model_init_description: Model initialization parameters
hidden_dim: Hidden dimension for transformer layers
num_encoder_layers: Number of transformer encoder layers
num_decoder_layers: Number of transformer decoder layers
nheads: Number of attention heads
dim_feedforward: Feedforward network dimension
dropout: Dropout probability
use_resnet_stats: Whether to use ResNet normalization statistics
lr: Learning rate for main parameters
freeze_backbone: Whether to freeze image encoder backbone
lr_backbone: Learning rate for image encoder backbone
weight_decay: Weight decay for optimizer
kl_weight: Weight for KL divergence loss
latent_dim: Dimension of latent variable space
"""
super().__init__(model_init_description)
self.hidden_dim = hidden_dim
self.use_resnet_stats = use_resnet_stats
self.freeze_backbone = freeze_backbone
self.lr = lr
self.lr_backbone = lr_backbone
self.weight_decay = weight_decay
self.kl_weight = kl_weight
self.latent_dim = latent_dim
data_stats: dict[DataType, DataItemStats] = {}
# Setup proprioceptive data
self.proprio_dims: dict[DataType, tuple[int, int]] = {}
proprio_stats = []
current_dim = 0
for data_type in [
DataType.JOINT_POSITIONS,
DataType.JOINT_VELOCITIES,
DataType.JOINT_TORQUES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
]:
if data_type in self.data_types:
if data_type == DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS:
stats = cast(
list[ParallelGripperOpenAmountDataStats],
self.dataset_statistics[data_type],
)
combined_stats = DataItemStats()
for stat in stats:
combined_stats = combined_stats.concatenate(stat.open_amount)
data_stats[data_type] = combined_stats
else:
stats = cast(
list[JointDataStats], self.dataset_statistics[data_type]
)
combined_stats = DataItemStats()
for stat in stats:
combined_stats = combined_stats.concatenate(stat.value)
data_stats[data_type] = combined_stats
if data_type in self.input_data_types:
proprio_stats.append(combined_stats)
dim = len(combined_stats.mean)
self.proprio_dims[data_type] = (current_dim, current_dim + dim)
current_dim += dim
# State embedding
state_input_dim = current_dim
self.state_embed = None
if state_input_dim > 0:
self.state_embed = nn.Linear(state_input_dim, hidden_dim)
# Setup output data
self.max_output_size = 0
output_stats = []
for data_type in self.output_data_types:
if data_type in [DataType.JOINT_TARGET_POSITIONS, DataType.JOINT_POSITIONS]:
stats = cast(
list[JointDataStats],
self.dataset_statistics[data_type],
)
combined_stats = DataItemStats()
for stat in stats:
combined_stats = combined_stats.concatenate(stat.value)
elif data_type in [
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
]:
stats = cast(
list[ParallelGripperOpenAmountDataStats],
self.dataset_statistics[data_type],
)
combined_stats = DataItemStats()
for stat in stats:
combined_stats = combined_stats.concatenate(stat.open_amount)
else:
raise ValueError(f"Unsupported output data type: {data_type}")
data_stats[data_type] = combined_stats
output_stats.append(combined_stats)
self.max_output_size += len(combined_stats.mean)
# Action embedding
self.action_embed = nn.Linear(self.max_output_size, hidden_dim)
# Setup normalizers
# Only create proprio_normalizer if there are proprioception stats
# This allows the algorithm to work without proprioception (visual-only)
self.proprio_normalizer = (
PROPRIO_NORMALIZER(name="proprioception", statistics=proprio_stats)
if proprio_stats
else None
)
self.action_normalizer = ACTION_NORMALIZER(
name="actions", statistics=output_stats
)
# Vision components
if DataType.RGB_IMAGES in self.input_data_types:
stats = cast(
list[CameraDataStats], self.dataset_statistics[DataType.RGB_IMAGES]
)
max_cameras = len(stats)
self.image_encoders = nn.ModuleList()
for i in range(max_cameras):
if use_resnet_stats:
mean, std = RESNET_MEAN, RESNET_STD
else:
mean_c_h_w, std_c_h_w = stats[i].frame.mean, stats[i].frame.std
mean = mean_c_h_w.mean(axis=(1, 2)).tolist()
std = std_c_h_w.mean(axis=(1, 2)).tolist()
encoder = nn.ModuleDict({
"transform": torch.nn.Sequential(
T.Resize((224, 224)),
T.Normalize(mean=mean, std=std),
),
"encoder": ACTImageEncoder(output_dim=hidden_dim),
})
self.image_encoders.append(encoder)
# CLS token embedding for latent encoder
self.cls_embed = nn.Parameter(torch.randn(1, 1, hidden_dim))
# Main transformer for vision and action generation
self.transformer = nn.ModuleDict({
"encoder": TransformerEncoder(
d_model=hidden_dim,
nhead=nheads,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
),
"decoder": TransformerDecoder(
d_model=hidden_dim,
nhead=nheads,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
),
})
# Separate encoder for latent space
self.latent_encoder = TransformerEncoder(
d_model=hidden_dim,
nhead=nheads,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
# Positional encoding
self.pos_encoder = PositionalEncoding(hidden_dim, dropout)
# Output heads
self.action_head = nn.Linear(hidden_dim, self.max_output_size)
# Latent projections
self.latent_mu = nn.Linear(hidden_dim, latent_dim)
self.latent_logvar = nn.Linear(hidden_dim, latent_dim)
self.latent_out_proj = nn.Linear(latent_dim, hidden_dim)
# Query embedding for decoding
self.query_embed = nn.Parameter(
torch.randn(self.output_prediction_horizon, 1, hidden_dim)
)
# Additional position embeddings for proprio and latent
self.additional_pos_embed = nn.Parameter(torch.randn(2, 1, hidden_dim))
# Setup parameter groups
self._setup_optimizer_param_groups()
def _setup_optimizer_param_groups(self) -> None:
"""Setup parameter groups for optimizer."""
backbone_params, other_params = [], []
for name, param in self.named_parameters():
if any(backbone in name for backbone in ["image_encoders"]):
backbone_params.append(param)
else:
other_params.append(param)
if self.freeze_backbone:
for param in backbone_params:
param.requires_grad = False
self.param_groups = [{"params": other_params, "lr": self.lr}]
else:
self.param_groups = [
{"params": backbone_params, "lr": self.lr_backbone},
{"params": other_params, "lr": self.lr},
]
def _reparametrize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""Sample from latent distribution using reparametrization trick.
During training, samples from the distribution N(mu, exp(logvar)).
During inference, returns the mean mu.
Args:
mu: Mean of latent distribution
logvar: Log variance of latent distribution
Returns:
torch.Tensor: Sampled latent variable
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
return mu
def _combine_proprio(self, batch: BatchedInferenceInputs) -> torch.FloatTensor:
"""Combine different types of joint state data.
Concatenates joint positions, velocities, and torques into a single
feature vector, applying masks and normalization.
Args:
batch: Input batch containing joint state data
Returns:
torch.FloatTensor: Combined and normalized joint state features
"""
if self.state_embed is None:
return None
proprio_list = []
for data_type in [
DataType.JOINT_POSITIONS,
DataType.JOINT_VELOCITIES,
DataType.JOINT_TORQUES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
]:
if data_type not in batch.inputs:
continue
batched_nc_data = batch.inputs[data_type]
mask = batch.inputs_mask[data_type]
if data_type == DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS:
batched_gripper_data = cast(
list[BatchedParallelGripperOpenAmountData], batched_nc_data
)
proprio_data = torch.cat(
[bgd.open_amount for bgd in batched_gripper_data], dim=-1
)
else:
batched_joint_data = cast(list[BatchedJointData], batched_nc_data)
proprio_data = torch.cat(
[bjd.value for bjd in batched_joint_data], dim=-1
)
last_proprio = proprio_data[:, -1, :] # (B, num_features)
masked_proprio = last_proprio * mask
proprio_list.append(masked_proprio)
# If no proprioception data is available, return None
# This allows the algorithm to work with visual-only inputs
if not proprio_list:
return None
# Concatenate all proprio together: (B, total_proprio_dim)
all_proprio = torch.cat(proprio_list, dim=-1)
# Normalize once on all proprio
# Check if normalizer exists (it should if we have proprio data)
if self.proprio_normalizer is None:
raise ValueError(
"Proprioception inputs were provided but no normalizer was available."
)
normalized_proprio = self.proprio_normalizer.normalize(all_proprio)
return normalized_proprio
def _encode_latent(
self,
state: torch.FloatTensor,
actions: torch.FloatTensor,
actions_mask: torch.FloatTensor,
actions_sequence_mask: torch.FloatTensor,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
"""Encode actions to latent space during training.
Uses a separate transformer encoder to encode the action sequence
along with proprioceptive state into latent distribution parameters.
Args:
state: Proprioceptive state features
actions: Target action sequence
actions_mask: Mask for valid action dimensions
actions_sequence_mask: Mask for valid sequence positions
Returns:
tuple[torch.FloatTensor, torch.FloatTensor]: Latent mean and log variance
"""
batch_size = state.shape[0]
# Project joint positions and actions
state_embed = (
self.state_embed(state) if self.state_embed is not None else None
) # [B, H]
action_embed = self.action_embed(
actions * actions_mask.unsqueeze(1)
) # [B, T, H]
# Reshape to sequence first
state_embed = (
state_embed.unsqueeze(0) if state_embed is not None else None
) # [1, B, H]
action_embed = action_embed.transpose(0, 1) # [T, B, H]
# Concatenate [CLS, state_emb, action_embed]
cls_token = self.cls_embed.expand(-1, batch_size, -1) # [1, B, H]
encoder_input = torch.cat([cls_token, state_embed, action_embed], dim=0)
# # Update padding mask
if actions_sequence_mask is not None:
cls_joint_pad = torch.zeros(
batch_size, 2, dtype=torch.bool, device=self.device
)
actions_sequence_mask = torch.cat(
[cls_joint_pad, actions_sequence_mask], dim=1
)
# Add positional encoding
encoder_input = self.pos_encoder(encoder_input)
# Encode sequence
memory = self.latent_encoder(
encoder_input, src_key_padding_mask=actions_sequence_mask
)
# Get latent parameters from CLS token
mu = self.latent_mu(memory[0]) # Take CLS token output
logvar = self.latent_logvar(memory[0])
return mu, logvar
def _encode_visual(
self,
states: torch.FloatTensor,
batched_nc_data: list[BatchedNCData],
camera_images_mask: torch.FloatTensor,
latent: torch.FloatTensor,
) -> torch.FloatTensor:
"""Encode visual inputs with latent and proprioceptive features.
Processes RGB images through vision encoders and combines them with
proprioceptive state and latent features using a transformer encoder.
Args:
states: Proprioceptive state features
batched_nc_data: List of BatchedRGBData
camera_images_mask: Mask for valid camera inputs
latent: Latent features from action encoding
Returns:
torch.FloatTensor: Encoded visual and proprioceptive memory
"""
batched_rgb_data = cast(list[BatchedRGBData], batched_nc_data)
batch_size = states.shape[0]
# Process images
image_features = []
image_pos = []
for cam_id, (encoder_dict, input_rgb) in enumerate(
zip(self.image_encoders, batched_rgb_data)
):
last_frame = input_rgb.frame[:, -1, :, :, :] # (B, 3, H, W)
transformed = encoder_dict["transform"](last_frame)
features, pos = encoder_dict["encoder"](
transformed
) # Vision backbone provides features and pos
features *= camera_images_mask[:, cam_id].view(batch_size, 1, 1, 1)
image_features.append(features)
image_pos.append(pos)
# Combine image features and positions
combined_features = torch.cat(image_features, dim=3) # [B, C, H, W]
combined_pos = torch.cat(image_pos, dim=3) # [B, C, H, W]
# Convert to sequence [H*W, B, C]
src = combined_features.flatten(2).permute(2, 0, 1)
pos = combined_pos.flatten(2).permute(2, 0, 1)
# Process joint positions and latent
# If no proprioception, create zero tensor with same shape as latent
if states is None:
state_features = torch.zeros_like(latent) # [B, H]
else:
state_features = (
self.state_embed(states)
if self.state_embed is not None
else torch.zeros_like(latent)
) # [B, H]
# Stack latent and proprio features
additional_features = torch.stack([latent, state_features], dim=0) # [2, B, H]
# Add position embeddings from additional_pos_embed
additional_pos = self.additional_pos_embed.expand(
-1, batch_size, -1
) # [2, B, H]
# Concatenate everything
src = torch.cat([additional_features, src], dim=0)
pos = torch.cat([additional_pos, pos], dim=0)
# Fuse positional embeddings with source
src = src + pos
# Encode
memory = self.transformer["encoder"](src)
return memory
def _decode(
self,
latent: torch.FloatTensor,
memory: torch.FloatTensor,
) -> torch.Tensor:
"""Decode latent and visual features to action sequence.
Uses a transformer decoder with learned query embeddings to generate
a sequence of action predictions conditioned on visual and latent features.
Args:
latent: Latent features
memory: Encoded visual and proprioceptive memory
Returns:
torch.Tensor: Predicted action sequence [B, T, action_dim]
"""
batch_size = latent.shape[0]
# Convert to sequence first and expand
query_embed = self.query_embed.expand(-1, batch_size, -1) # [T, B, H]
latent = latent.unsqueeze(0).expand_as(query_embed) # [T, B, H]
# Add latent to query embedding
query_embed = query_embed + latent
# Initialize target with zeros
tgt = torch.zeros_like(query_embed)
# Decode sequence
hs = self.transformer["decoder"](tgt, memory, query_pos=query_embed)
# Project to action space (keeping sequence first)
actions = self.action_head(hs) # [T, B, A]
# Convert back to batch first
actions = actions.transpose(0, 1) # [B, T, A]
return actions
def _predict_action(
self,
mu: torch.FloatTensor,
logvar: torch.FloatTensor,
batch: BatchedInferenceInputs,
) -> torch.FloatTensor:
"""Predict action sequence from latent distribution and observations.
Args:
mu: Mean of latent distribution
logvar: Log variance of latent distribution
batch: Input observations
Returns:
torch.FloatTensor: Predicted action sequence
"""
# Sample latent
latent_sample = self._reparametrize(mu, logvar)
# Project latent
latent = self.latent_out_proj(latent_sample) # [B, H]
if DataType.RGB_IMAGES not in batch.inputs:
raise ValueError("No RGB images in batch")
# Encode visual features
proprio_state = self._combine_proprio(batch)
memory = self._encode_visual(
proprio_state,
batch.inputs[DataType.RGB_IMAGES],
batch.inputs_mask[DataType.RGB_IMAGES],
latent,
)
# Decode actions
action_preds = self._decode(latent, memory)
return action_preds
def forward(
self, batch: BatchedInferenceInputs
) -> dict[DataType, list[BatchedNCData]]:
"""Perform inference to predict action sequence.
Args:
batch: Input batch with observations
Returns:
dict[DataType, list[BatchedNCData]]: Model predictions with action sequences
"""
batch_size = len(batch)
mu = torch.zeros(batch_size, self.latent_dim, device=self.device)
logvar = torch.zeros(batch_size, self.latent_dim, device=self.device)
action_preds = self._predict_action(mu, logvar, batch)
# (B, T, action_dim)
predictions = self.action_normalizer.unnormalize(action_preds)
output_tensors: dict[DataType, list[BatchedNCData]] = {}
start_slice_idx = 0
for data_type in self.output_data_types:
end_slice_idx = start_slice_idx + len(self.dataset_statistics[data_type])
dt_preds = predictions[
:, :, start_slice_idx:end_slice_idx
] # (B, T, dt_size)
if data_type in [DataType.JOINT_TARGET_POSITIONS, DataType.JOINT_POSITIONS]:
batched_outputs = []
for i in range(len(self.dataset_statistics[data_type])):
joint_preds = dt_preds[:, :, i : i + 1] # (B, T, 1)
batched_outputs.append(BatchedJointData(value=joint_preds))
output_tensors[data_type] = batched_outputs
elif data_type in [
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
]:
batched_outputs = []
for i in range(len(self.dataset_statistics[data_type])):
gripper_preds = dt_preds[:, :, i : i + 1] # (B, T, 1)
batched_outputs.append(
BatchedParallelGripperOpenAmountData(open_amount=gripper_preds)
)
output_tensors[data_type] = batched_outputs
else:
raise ValueError(f"Unsupported output data type: {data_type}")
start_slice_idx = end_slice_idx
return output_tensors
def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs:
"""Perform a single training step.
Encodes action sequences to latent space, predicts actions, and computes
L1 reconstruction loss plus KL divergence regularization.
Args:
batch: Training batch with inputs and targets
Returns:
BatchedTrainingOutputs: Training outputs with losses and metrics
"""
inference_sample = BatchedInferenceInputs(
inputs=batch.inputs,
inputs_mask=batch.inputs_mask,
batch_size=batch.batch_size,
)
# Extract target actions
action_targets = []
for data_type in self.output_data_types:
if data_type in [DataType.JOINT_TARGET_POSITIONS, DataType.JOINT_POSITIONS]:
batched_joints = cast(list[BatchedJointData], batch.outputs[data_type])
action_targets.extend([bjd.value for bjd in batched_joints])
elif data_type in [
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
]:
grippers = cast(
list[BatchedParallelGripperOpenAmountData], batch.outputs[data_type]
)
action_targets.extend([gripper.open_amount for gripper in grippers])
else:
raise ValueError(f"Unsupported output data type: {data_type}")
action_data = torch.cat(action_targets, dim=-1) # (B, T, action_dim)
# Get masks
pred_sequence_mask = torch.ones_like(
action_data[:, :, 0]
) # All time steps valid
max_action_mask = torch.ones(
batch.batch_size, self.max_output_size, device=self.device
) # All actions valid
proprio_state = self._combine_proprio(inference_sample)
# Normalize actions for encoding
normalized_actions = self.action_normalizer.normalize(action_data)
mu, logvar = self._encode_latent(
proprio_state,
normalized_actions,
max_action_mask,
pred_sequence_mask,
)
action_preds = self._predict_action(mu, logvar, inference_sample)
target_actions = self.action_normalizer.normalize(action_data)
l1_loss_all = F.l1_loss(action_preds, target_actions, reduction="none")
l1_loss = (l1_loss_all * pred_sequence_mask.unsqueeze(-1)).mean()
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0]
loss = l1_loss + self.kl_weight * kl_loss
losses = {
"l1_and_kl_loss": loss,
}
metrics = {
"l1_loss": l1_loss,
"kl_loss": kl_loss,
}
return BatchedTrainingOutputs(
losses=losses,
metrics=metrics,
)
def configure_optimizers(
self,
) -> list[torch.optim.Optimizer]:
"""Configure optimizer with different learning rates.
Uses separate learning rates for image encoder backbone and other
model parameters.
Returns:
list[torch.optim.Optimizer]: List of optimizers for model parameters
"""
return [torch.optim.AdamW(self.param_groups, weight_decay=self.weight_decay)]
@staticmethod
def get_supported_input_data_types() -> set[DataType]:
"""Get the input data types supported by this model.
Returns:
set[DataType]: Set of supported input data types
"""
return {
DataType.JOINT_POSITIONS,
DataType.JOINT_VELOCITIES,
DataType.JOINT_TORQUES,
DataType.RGB_IMAGES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
}
@staticmethod
def get_supported_output_data_types() -> set[DataType]:
"""Get the output data types supported by this model.
Returns:
set[DataType]: Set of supported output data types
"""
return {
DataType.JOINT_POSITIONS,
DataType.JOINT_TARGET_POSITIONS,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
}