File size: 13,782 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 | """
action_heads.py
Implementations of various action heads, which serve as alternatives to VLM sequential token prediction.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class VLA_Adapter_L1RegressionActionHead(nn.Module):
"""Simple MLP-based action head that generates continuous actions via L1 regression."""
def __init__(
self,
full_config,
):
super().__init__()
self.config = full_config
input_dim = full_config.framework.qwenvl.vl_hidden_dim
hidden_dim = full_config.framework.action_model.hidden_dim
action_dim = full_config.framework.action_model.action_dim
self.action_query_num = full_config.framework.action_model.get("action_query_num", 64)
use_pro_version = full_config.framework.action_model.use_pro_version
self.action_dim = action_dim
self.hidden_dim = hidden_dim
self.num_actions_chunk = self.config.framework.action_model.get("num_actions_chunk", None)
if self.num_actions_chunk is None:
raise ValueError("num_actions_chunk must be specified in action_model config.")
# Learnable action chunk embeddings (like positional embeddings)
# Applied during both training and inference
self.action_chunk_embeddings = nn.Parameter(
torch.zeros(self.num_actions_chunk, action_dim * hidden_dim)
)
nn.init.normal_(self.action_chunk_embeddings, mean=0.0, std=0.02)
self.model = MLPResNet(
num_blocks=24,
input_dim=input_dim*action_dim,
hidden_dim=hidden_dim,
output_dim=action_dim,
use_pro_version=use_pro_version
)
def predict_action(
self,
actions_hidden_states,
vision_hidden_len: int,
state_projected=None,
phase="Inference"
):
"""
Args:
actions_hidden_states: [B, Layers, Total_Len, D]
根据 Qwen_Adapter 的逻辑,Total_Len = (Vision_Len + Action_Query_Num)。
Language Tokens 已经在 Adapter 阶段被过滤掉了,所以这里不需要额外处理 Language。
"""
batch_size = actions_hidden_states.shape[0]
device = actions_hidden_states.device
# 1. Proprioception Processing
if state_projected is not None:
proprio_features = state_projected.unsqueeze(dim=1) # (bsz, 1, llm_dim)
else:
proprio_features = None
# Action Query Tokens (h_a)
action_query_states = actions_hidden_states[:, :, -self.action_query_num:, :]
task_hidden_states = actions_hidden_states[:, :, :-self.action_query_num, :]
assert vision_hidden_len == task_hidden_states.shape[2], "Vision hidden length mismatch"
# 3. Action Chunk Queries Init
cond_actions_hidden_states = torch.zeros(
(batch_size, self.action_dim * self.num_actions_chunk, self.hidden_dim),
device=device, dtype=actions_hidden_states.dtype
).detach()
rearranged_actions_hidden_states = cond_actions_hidden_states.reshape(
batch_size, self.num_actions_chunk, -1
)
# Add learnable action chunk embeddings (applied during both training and inference)
embeddings = self.action_chunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
rearranged_actions_hidden_states = rearranged_actions_hidden_states + embeddings
# 4. MLP Forward
action = self.model(
rearranged_actions_hidden_states,
h_a=action_query_states, # [B, Layers, query_num, D]
p=proprio_features, # [B, 1, D]
h_t=task_hidden_states # [B, Layers, vis_len, D]
)
# Assert shape
assert action.shape == (batch_size, self.num_actions_chunk, self.action_dim), "Action shape mismatch"
return action
class MLPResNet(nn.Module):
"""MLP with residual connection blocks."""
def __init__(
self,
num_blocks,
input_dim,
hidden_dim,
output_dim,
use_pro_version=False
):
super().__init__()
self.layer_norm1 = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.mlp_resnet_blocks = nn.ModuleList()
for _ in range(num_blocks):
if use_pro_version:
self.mlp_resnet_blocks.append(MLPResNetBlock_Pro(dim=hidden_dim))
else:
self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
self.layer_norm2 = nn.LayerNorm(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x, h_a=None, h_t=None, p=None):
# x: (batch_size, input_dim)
x = self.layer_norm1(x)
x = self.fc1(x)
x = self.relu(x)
for i, block in enumerate(self.mlp_resnet_blocks):
idx = i + 1
cur_h_t = None
if h_t is not None and h_t.shape[1] > idx:
cur_h_t = h_t[:, idx, :]
cur_h_a = None
if h_a is not None and h_a.shape[1] > idx:
cur_h_a = h_a[:, idx, :]
x = block(x, h_t=cur_h_t, h_a=cur_h_a, p=p)
x = self.layer_norm2(x)
x = self.fc2(x)
return x
def apply_rope(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
def rotate_half(x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
assert dim % 2 == 0, "RoPE head_dim must be an even number"
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seq_len, device, dtype):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos().to(dtype), emb.sin().to(dtype)
class MLPResNetBlock(nn.Module):
"""
Standard MLP ResNet Block.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.ffn = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.ReLU(),
)
self.num_heads = 8
self.head_dim = dim // self.num_heads
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
self.gating_factor = nn.Parameter(torch.zeros(1))
def forward(self, x, h_t=None, h_a=None, p=None):
g = self.gating_factor
ratio_g = torch.tanh(g)
conditions = []
if h_a is not None:
if h_a.dim() == 2: h_a = h_a.unsqueeze(1)
conditions.append(h_a)
if p is not None:
if p.dim() == 2: p = p.unsqueeze(1)
conditions.append(p)
h_cond = torch.cat(conditions, dim=1) if len(conditions) > 0 else None
if h_t is not None:
if h_t.dim() == 2: h_t = h_t.unsqueeze(1)
B, T, C = x.shape
K_cond = h_cond.size(1) if h_cond is not None else 0
K_task = h_t.size(1) if h_t is not None else 0
# Self Attention Projection
q_1 = self.q_proj(x)
k_tokens = self.k_proj(x)
v_tokens = self.v_proj(x)
# Reshape Self
q_1 = q_1.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k_tokens = k_tokens.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v_tokens = v_tokens.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores_list = []
# Score: Self
attn_scores_list.append(torch.matmul(q_1, k_tokens.transpose(-2, -1)))
# Process Task (Vision)
v_task_reshaped = None
if h_t is not None:
k_task = self.k_proj(h_t)
v_task = self.v_proj(h_t)
k_task = k_task.view(B, K_task, self.num_heads, self.head_dim).transpose(1, 2)
v_task_reshaped = v_task.view(B, K_task, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores_list.append(torch.matmul(q_1, k_task.transpose(-2, -1)))
# Process Adapter (Action/Proprio)
v_cond_reshaped = None
if h_cond is not None:
k_cond = self.k_proj(h_cond)
v_cond = self.v_proj(h_cond)
k_cond = k_cond.view(B, K_cond, self.num_heads, self.head_dim).transpose(1, 2)
v_cond_reshaped = v_cond.view(B, K_cond, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores_list.append(torch.matmul(q_1, k_cond.transpose(-2, -1)) * ratio_g)
# Softmax
attn_scores = torch.cat(attn_scores_list, dim=-1)
attn_scores = attn_scores / math.sqrt(self.head_dim)
attn_weights = torch.softmax(attn_scores, dim=-1)
# Combine Values
v_combined_list = [v_tokens]
if v_task_reshaped is not None: v_combined_list.append(v_task_reshaped)
if v_cond_reshaped is not None: v_combined_list.append(v_cond_reshaped)
v_combined = torch.cat(v_combined_list, dim=2)
# Output Projection
output = torch.matmul(attn_weights, v_combined)
output = output.transpose(1, 2).contiguous().view(B, T, C)
output = self.o_proj(output)
x = self.ffn(output + x)
return x
class MLPResNetBlock_Pro(nn.Module):
"""
MLP ResNet Block Pro with RoPE and dimension checks.
"""
def __init__(self, dim, num_heads=8):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.ffn = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.ReLU(),
)
self.q_proj = nn.Linear(dim, dim)
self.k_self = nn.Linear(dim, dim)
self.v_self = nn.Linear(dim, dim)
self.k_adapter = nn.Linear(dim, dim)
self.v_adapter = nn.Linear(dim, dim)
self.k_task = nn.Linear(dim, dim)
self.v_task = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
self.gating_factor = nn.Parameter(torch.zeros(1))
self.rope = RotaryPositionEmbedding(self.head_dim)
def forward(self, x, h_a=None, h_t=None, p=None):
g = self.gating_factor
ratio_g = torch.tanh(g)
# 1. Prepare Conditions
cond_list = []
if h_a is not None:
if h_a.dim() == 2: h_a = h_a.unsqueeze(1)
cond_list.append(h_a)
if p is not None:
if p.dim() == 2: p = p.unsqueeze(1)
cond_list.append(p)
h_adapter = torch.cat(cond_list, dim=1) if cond_list else None
if h_t is not None:
if h_t.dim() == 2: h_t = h_t.unsqueeze(1)
B, T, C = x.shape
K_a = h_adapter.size(1) if h_adapter is not None else 0
K_t = h_t.size(1) if h_t is not None else 0
def to_heads(t, L):
return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
# Self Attention
q_1 = self.q_proj(x)
k_self = self.k_self(x)
v_self = self.v_self(x)
q_1 = to_heads(q_1, T)
k_self = to_heads(k_self, T)
v_self = to_heads(v_self, T)
# RoPE: Self
cos_main, sin_main = self.rope(seq_len=T, device=x.device, dtype=x.dtype)
q_1, k_self = apply_rope(q_1, k_self, cos_main, sin_main)
attn_scores = [torch.matmul(q_1, k_self.transpose(-2, -1))]
v_list = [v_self]
# Adapter Attention (Action/Proprio) - With RoPE
if h_adapter is not None:
k_adp = self.k_adapter(h_adapter)
v_adp = self.v_adapter(h_adapter)
k_adp, v_adp = to_heads(k_adp, K_a), to_heads(v_adp, K_a)
cos_a, sin_a = self.rope(seq_len=K_a, device=x.device, dtype=x.dtype)
_, k_adp = apply_rope(k_adp, k_adp, cos_a, sin_a)
attn_scores.append(torch.matmul(q_1, k_adp.transpose(-2, -1)))
v_list.append(v_adp)
# Task Attention (Vision) - With RoPE & Gating
if h_t is not None:
k_tsk = self.k_task(h_t)
v_tsk = self.v_task(h_t)
k_tsk, v_tsk = to_heads(k_tsk, K_t), to_heads(v_tsk, K_t)
cos_t, sin_t = self.rope(seq_len=K_t, device=x.device, dtype=x.dtype)
_, k_tsk = apply_rope(k_tsk, k_tsk, cos_t, sin_t)
attn_scores.append(torch.matmul(q_1, k_tsk.transpose(-2, -1)) * ratio_g)
v_list.append(v_tsk)
# Merge & Output
attn_scores = torch.cat(attn_scores, dim=-1) / math.sqrt(self.head_dim)
attn_weights = torch.softmax(attn_scores, dim=-1)
v_combined = torch.cat(v_list, dim=2)
output = torch.matmul(attn_weights, v_combined)
output = output.transpose(1, 2).contiguous().view(B, T, C)
output = self.o_proj(output)
x = self.ffn(output + x)
return x
def get_action_model(config=None):
return VLA_Adapter_L1RegressionActionHead(
full_config=config
) |