File size: 4,184 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | # Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
"""Implementations of various action heads, which serve as alternatives to VLM sequential token prediction."""
"this file is adap from https://github.com/moojink/openvla-oft/blob/main/prismatic/models/action_heads.py"
import torch.nn as nn
class MLPResNetBlock(nn.Module):
"""One MLP ResNet block with a residual connection."""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.ReLU(),
)
def forward(self, x):
# x: (batch_size, hidden_dim)
# We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as
# described here: https://arxiv.org/pdf/2002.04745.pdf
identity = x
x = self.ffn(x)
x = x + identity
return x
class MLPResNet(nn.Module):
"""MLP with residual connection blocks."""
def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
super().__init__()
self.layer_norm1 = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.mlp_resnet_blocks = nn.ModuleList()
for _ in range(num_blocks):
self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
self.layer_norm2 = nn.LayerNorm(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# x: (batch_size, input_dim)
x = self.layer_norm1(x) # shape: (batch_size, input_dim)
x = self.fc1(x) # shape: (batch_size, hidden_dim)
x = self.relu(x) # shape: (batch_size, hidden_dim)
for block in self.mlp_resnet_blocks:
x = block(x) # shape: (batch_size, hidden_dim)
x = self.layer_norm2(x) # shape: (batch_size, hidden_dim)
x = self.fc2(x) # shape: (batch_size, output_dim)
return x
class L1RegressionActionHead(nn.Module):
"""Simple MLP-based action head that generates continuous actions via L1 regression."""
def __init__(
self,
input_dim=2048,
hidden_dim=4096,
action_dim=7,
NUM_ACTIONS_CHUNK=8,
):
super().__init__()
self.action_dim = action_dim
self.NUM_ACTIONS_CHUNK = NUM_ACTIONS_CHUNK
self.model = MLPResNet(
num_blocks=2, input_dim=input_dim , hidden_dim=hidden_dim, output_dim=action_dim
)
def predict_action(self, actions_hidden_states):
"""
actions_hidden_states: (B, chunk_len, hidden_dim)
返回: (B, chunk_len, action_dim)
"""
batch_size, chunk_len, hidden_dim = actions_hidden_states.shape
x = actions_hidden_states.reshape(batch_size * chunk_len, hidden_dim)
x = self.model(x) # (B * chunk_len, action_dim)
actions = x.view(batch_size, chunk_len, self.action_dim)
return actions
def forward(self, actions_hidden_states):
return self.predict_action(actions_hidden_states)
def get_action_model(config=None):
"""
Factory: build ActionModel from global framework config.
Args:
config: Global config (expects config.framework.action_model namespace).
Returns:
ActionModel: Initialized diffusion action head.
"""
action_model_cfg = config.framework.action_model
model_type = action_model_cfg.action_model_type
action_hidden_dim = action_model_cfg.action_hidden_dim
action_dim = action_model_cfg.action_dim
future_action_window_size = action_model_cfg.future_action_window_size
past_action_window_size = action_model_cfg.past_action_window_size
action_model = L1RegressionActionHead(
input_dim=action_hidden_dim,
hidden_dim=action_hidden_dim*2,
action_dim=action_dim,
NUM_ACTIONS_CHUNK=past_action_window_size+1+future_action_window_size,
)
return action_model
|