cross13tasks / code /model /modules /action_model /VLA_AdapterHeader.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
"""
action_heads.py
Implementations of various action heads, which serve as alternatives to VLM sequential token prediction.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class VLA_Adapter_L1RegressionActionHead(nn.Module):
"""Simple MLP-based action head that generates continuous actions via L1 regression."""
def __init__(
self,
full_config,
):
super().__init__()
self.config = full_config
input_dim = full_config.framework.qwenvl.vl_hidden_dim
hidden_dim = full_config.framework.action_model.hidden_dim
action_dim = full_config.framework.action_model.action_dim
self.action_query_num = full_config.framework.action_model.get("action_query_num", 64)
use_pro_version = full_config.framework.action_model.use_pro_version
self.action_dim = action_dim
self.hidden_dim = hidden_dim
self.num_actions_chunk = self.config.framework.action_model.get("num_actions_chunk", None)
if self.num_actions_chunk is None:
raise ValueError("num_actions_chunk must be specified in action_model config.")
# Learnable action chunk embeddings (like positional embeddings)
# Applied during both training and inference
self.action_chunk_embeddings = nn.Parameter(
torch.zeros(self.num_actions_chunk, action_dim * hidden_dim)
)
nn.init.normal_(self.action_chunk_embeddings, mean=0.0, std=0.02)
self.model = MLPResNet(
num_blocks=24,
input_dim=input_dim*action_dim,
hidden_dim=hidden_dim,
output_dim=action_dim,
use_pro_version=use_pro_version
)
def predict_action(
self,
actions_hidden_states,
vision_hidden_len: int,
state_projected=None,
phase="Inference"
):
"""
Args:
actions_hidden_states: [B, Layers, Total_Len, D]
根据 Qwen_Adapter 的逻辑,Total_Len = (Vision_Len + Action_Query_Num)。
Language Tokens 已经在 Adapter 阶段被过滤掉了,所以这里不需要额外处理 Language。
"""
batch_size = actions_hidden_states.shape[0]
device = actions_hidden_states.device
# 1. Proprioception Processing
if state_projected is not None:
proprio_features = state_projected.unsqueeze(dim=1) # (bsz, 1, llm_dim)
else:
proprio_features = None
# Action Query Tokens (h_a)
action_query_states = actions_hidden_states[:, :, -self.action_query_num:, :]
task_hidden_states = actions_hidden_states[:, :, :-self.action_query_num, :]
assert vision_hidden_len == task_hidden_states.shape[2], "Vision hidden length mismatch"
# 3. Action Chunk Queries Init
cond_actions_hidden_states = torch.zeros(
(batch_size, self.action_dim * self.num_actions_chunk, self.hidden_dim),
device=device, dtype=actions_hidden_states.dtype
).detach()
rearranged_actions_hidden_states = cond_actions_hidden_states.reshape(
batch_size, self.num_actions_chunk, -1
)
# Add learnable action chunk embeddings (applied during both training and inference)
embeddings = self.action_chunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
rearranged_actions_hidden_states = rearranged_actions_hidden_states + embeddings
# 4. MLP Forward
action = self.model(
rearranged_actions_hidden_states,
h_a=action_query_states, # [B, Layers, query_num, D]
p=proprio_features, # [B, 1, D]
h_t=task_hidden_states # [B, Layers, vis_len, D]
)
# Assert shape
assert action.shape == (batch_size, self.num_actions_chunk, self.action_dim), "Action shape mismatch"
return action
class MLPResNet(nn.Module):
"""MLP with residual connection blocks."""
def __init__(
self,
num_blocks,
input_dim,
hidden_dim,
output_dim,
use_pro_version=False
):
super().__init__()
self.layer_norm1 = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.mlp_resnet_blocks = nn.ModuleList()
for _ in range(num_blocks):
if use_pro_version:
self.mlp_resnet_blocks.append(MLPResNetBlock_Pro(dim=hidden_dim))
else:
self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
self.layer_norm2 = nn.LayerNorm(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x, h_a=None, h_t=None, p=None):
# x: (batch_size, input_dim)
x = self.layer_norm1(x)
x = self.fc1(x)
x = self.relu(x)
for i, block in enumerate(self.mlp_resnet_blocks):
idx = i + 1
cur_h_t = None
if h_t is not None and h_t.shape[1] > idx:
cur_h_t = h_t[:, idx, :]
cur_h_a = None
if h_a is not None and h_a.shape[1] > idx:
cur_h_a = h_a[:, idx, :]
x = block(x, h_t=cur_h_t, h_a=cur_h_a, p=p)
x = self.layer_norm2(x)
x = self.fc2(x)
return x
def apply_rope(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
def rotate_half(x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
assert dim % 2 == 0, "RoPE head_dim must be an even number"
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seq_len, device, dtype):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos().to(dtype), emb.sin().to(dtype)
class MLPResNetBlock(nn.Module):
"""
Standard MLP ResNet Block.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.ffn = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.ReLU(),
)
self.num_heads = 8
self.head_dim = dim // self.num_heads
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
self.gating_factor = nn.Parameter(torch.zeros(1))
def forward(self, x, h_t=None, h_a=None, p=None):
g = self.gating_factor
ratio_g = torch.tanh(g)
conditions = []
if h_a is not None:
if h_a.dim() == 2: h_a = h_a.unsqueeze(1)
conditions.append(h_a)
if p is not None:
if p.dim() == 2: p = p.unsqueeze(1)
conditions.append(p)
h_cond = torch.cat(conditions, dim=1) if len(conditions) > 0 else None
if h_t is not None:
if h_t.dim() == 2: h_t = h_t.unsqueeze(1)
B, T, C = x.shape
K_cond = h_cond.size(1) if h_cond is not None else 0
K_task = h_t.size(1) if h_t is not None else 0
# Self Attention Projection
q_1 = self.q_proj(x)
k_tokens = self.k_proj(x)
v_tokens = self.v_proj(x)
# Reshape Self
q_1 = q_1.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k_tokens = k_tokens.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v_tokens = v_tokens.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores_list = []
# Score: Self
attn_scores_list.append(torch.matmul(q_1, k_tokens.transpose(-2, -1)))
# Process Task (Vision)
v_task_reshaped = None
if h_t is not None:
k_task = self.k_proj(h_t)
v_task = self.v_proj(h_t)
k_task = k_task.view(B, K_task, self.num_heads, self.head_dim).transpose(1, 2)
v_task_reshaped = v_task.view(B, K_task, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores_list.append(torch.matmul(q_1, k_task.transpose(-2, -1)))
# Process Adapter (Action/Proprio)
v_cond_reshaped = None
if h_cond is not None:
k_cond = self.k_proj(h_cond)
v_cond = self.v_proj(h_cond)
k_cond = k_cond.view(B, K_cond, self.num_heads, self.head_dim).transpose(1, 2)
v_cond_reshaped = v_cond.view(B, K_cond, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores_list.append(torch.matmul(q_1, k_cond.transpose(-2, -1)) * ratio_g)
# Softmax
attn_scores = torch.cat(attn_scores_list, dim=-1)
attn_scores = attn_scores / math.sqrt(self.head_dim)
attn_weights = torch.softmax(attn_scores, dim=-1)
# Combine Values
v_combined_list = [v_tokens]
if v_task_reshaped is not None: v_combined_list.append(v_task_reshaped)
if v_cond_reshaped is not None: v_combined_list.append(v_cond_reshaped)
v_combined = torch.cat(v_combined_list, dim=2)
# Output Projection
output = torch.matmul(attn_weights, v_combined)
output = output.transpose(1, 2).contiguous().view(B, T, C)
output = self.o_proj(output)
x = self.ffn(output + x)
return x
class MLPResNetBlock_Pro(nn.Module):
"""
MLP ResNet Block Pro with RoPE and dimension checks.
"""
def __init__(self, dim, num_heads=8):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.ffn = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.ReLU(),
)
self.q_proj = nn.Linear(dim, dim)
self.k_self = nn.Linear(dim, dim)
self.v_self = nn.Linear(dim, dim)
self.k_adapter = nn.Linear(dim, dim)
self.v_adapter = nn.Linear(dim, dim)
self.k_task = nn.Linear(dim, dim)
self.v_task = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
self.gating_factor = nn.Parameter(torch.zeros(1))
self.rope = RotaryPositionEmbedding(self.head_dim)
def forward(self, x, h_a=None, h_t=None, p=None):
g = self.gating_factor
ratio_g = torch.tanh(g)
# 1. Prepare Conditions
cond_list = []
if h_a is not None:
if h_a.dim() == 2: h_a = h_a.unsqueeze(1)
cond_list.append(h_a)
if p is not None:
if p.dim() == 2: p = p.unsqueeze(1)
cond_list.append(p)
h_adapter = torch.cat(cond_list, dim=1) if cond_list else None
if h_t is not None:
if h_t.dim() == 2: h_t = h_t.unsqueeze(1)
B, T, C = x.shape
K_a = h_adapter.size(1) if h_adapter is not None else 0
K_t = h_t.size(1) if h_t is not None else 0
def to_heads(t, L):
return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
# Self Attention
q_1 = self.q_proj(x)
k_self = self.k_self(x)
v_self = self.v_self(x)
q_1 = to_heads(q_1, T)
k_self = to_heads(k_self, T)
v_self = to_heads(v_self, T)
# RoPE: Self
cos_main, sin_main = self.rope(seq_len=T, device=x.device, dtype=x.dtype)
q_1, k_self = apply_rope(q_1, k_self, cos_main, sin_main)
attn_scores = [torch.matmul(q_1, k_self.transpose(-2, -1))]
v_list = [v_self]
# Adapter Attention (Action/Proprio) - With RoPE
if h_adapter is not None:
k_adp = self.k_adapter(h_adapter)
v_adp = self.v_adapter(h_adapter)
k_adp, v_adp = to_heads(k_adp, K_a), to_heads(v_adp, K_a)
cos_a, sin_a = self.rope(seq_len=K_a, device=x.device, dtype=x.dtype)
_, k_adp = apply_rope(k_adp, k_adp, cos_a, sin_a)
attn_scores.append(torch.matmul(q_1, k_adp.transpose(-2, -1)))
v_list.append(v_adp)
# Task Attention (Vision) - With RoPE & Gating
if h_t is not None:
k_tsk = self.k_task(h_t)
v_tsk = self.v_task(h_t)
k_tsk, v_tsk = to_heads(k_tsk, K_t), to_heads(v_tsk, K_t)
cos_t, sin_t = self.rope(seq_len=K_t, device=x.device, dtype=x.dtype)
_, k_tsk = apply_rope(k_tsk, k_tsk, cos_t, sin_t)
attn_scores.append(torch.matmul(q_1, k_tsk.transpose(-2, -1)) * ratio_g)
v_list.append(v_tsk)
# Merge & Output
attn_scores = torch.cat(attn_scores, dim=-1) / math.sqrt(self.head_dim)
attn_weights = torch.softmax(attn_scores, dim=-1)
v_combined = torch.cat(v_list, dim=2)
output = torch.matmul(attn_weights, v_combined)
output = output.transpose(1, 2).contiguous().view(B, T, C)
output = self.o_proj(output)
x = self.ffn(output + x)
return x
def get_action_model(config=None):
return VLA_Adapter_L1RegressionActionHead(
full_config=config
)