File size: 14,095 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 | # Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
"""
Qwen-OFT Framework
A lightweight implementation that uses an action special token to parallelly predict continuous actions
conditioned on multi-view images plus a language instruction (shares parameters with the VLM).
Inspired by OpenVLA-OFT
Key Points:
- Qwen2.5 vision-language backbone
- Injects an action special token into the VLM
- Continuous action prediction via L1 regression over the action special token hidden states
Note: How to add special tokens to Qwen2.5:
download our model checkpoint with special tokens added: https://huggingface.co/StarVLA/Qwen2.5-VL-3B-Instruct-Action
or /starVLA/model/modules/vlm/tools/add_qwen_special_tokens/README.md (adpat a little code)
"""
from typing import List
from tqdm import tqdm
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from starVLA.training.trainer_utils import initialize_overwatch
from starVLA.model.tools import FRAMEWORK_REGISTRY
from deployment.model_server.tools.image_tools import to_pil_preserve
logger = initialize_overwatch(__name__)
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.modules.vlm import get_vlm_model
from starVLA.model.modules.action_model.MLP_ActionHeader import get_action_model
from starVLA.training.trainer_utils.trainer_tools import resize_images
@FRAMEWORK_REGISTRY.register("QwenOFT")
class Qwenvl_OFT(baseframework):
"""
Multimodal vision-language-action model.
Components:
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise QFormer for multi-layer feature aggregation
- DINO encoder for dense multi-view spatial tokens
- DiT diffusion head for future action sequence modeling
Focus: Predict future continuous actions conditioned on images + instruction.
"""
def __init__(
self,
config: Optional[dict] = None,
**kwargs,
) -> None:
"""
Construct all submodules and cache key configuration values.
Args:
config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.
**kwargs: Reserved for future overrides (unused).
"""
super().__init__()
self.config = config
self.qwen_vl_interface = get_vlm_model(config=self.config)
# align dims --> we should put them to config or no?
config.framework.action_model.action_hidden_dim = self.qwen_vl_interface.model.config.hidden_size
self.action_model = get_action_model(config=self.config)
self.future_action_window_size = config.framework.action_model.future_action_window_size
self.past_action_window_size = config.framework.action_model.past_action_window_size
self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size
# self.hidden_dim = config.framework.action_model.action_hidden_dim
self.action_token = "🔍" # TODO also can add spacail token to Qwen, but too complex
self.action_token_id = self.qwen_vl_interface.processor.tokenizer("🔍", add_special_tokens=False)["input_ids"][0]
# L1 损失
self.l1_loss = nn.L1Loss()
def forward(
self,
examples: List[dict] = None,
**kwargs,
) -> Tuple:
"""
训练前向:直接回归未来动作(无扩散)。
Flow:
1. Build QwenVL inputs (images + instruction tokens)
2. Extract hidden states from configured layer range
7. Predict action and compute L1 loss
Args:
examples: List[dict], each dict requires:
- image: List[PIL.Image] (multi-view)
- lang: str instruction
- action: np.ndarray or list shaped [T, action_dim]
**kwargs: Reserved.
Returns:
dict:
action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
"""
batch_images = [example["image"] for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
actions = [example["action"] for example in examples] # label [B, len, 7]
# step 0: add special action token to instruction
action_tokens = self.action_token* self.chunk_len #can't add " " between two tokens, otherwise will be tokenized to multiple tokens
prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
instructions = [instruction + prompt_suffix for instruction in instructions]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
# last_hidden_state: [B, seq_len, H]
last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H]
# Step 4: Action Expert Forward and Loss
with torch.autocast("cuda", dtype=torch.float32):
# 提取动作 token embedding 作为动作预测查询
input_ids = qwen_inputs.get("input_ids", None)
action_queries = self._gather_action_token_embeddings(last_hidden, input_ids, action_token_id=self.action_token_id) # [B, chunk_len, H]
pred_actions = self.action_model.predict_action(action_queries) # (B, chunk_len, action_dim)
# 标签对齐:取最后 chunk_len 段
actions = torch.tensor(
np.array(actions), device=pred_actions.device, dtype=pred_actions.dtype
) # [B, T_full, action_dim]
actions_target = actions[:, -(self.future_action_window_size+1):, :] # (B, chunk_len, action_dim)
# 计算 L1 损失
action_loss = self.l1_loss(pred_actions, actions_target)
return {"action_loss": action_loss}
@torch.inference_mode()
def predict_action(
self,
examples: List[dict] = None,
**kwargs: str,
) -> np.ndarray:
"""
推理:单次前向直接回归未来动作(无扩散采样)。
Steps:
1. Resize images to training resolution (if specified)
2. Encode with QwenVL (hidden states retained)
6. Return normalized action trajectory
Returns:
dict:
normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
"""
batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
if train_obs_image_size:
batch_images = resize_images(batch_images, target_size=train_obs_image_size)
# step 0: add special action token to instruction
action_tokens = self.action_token* self.chunk_len #can't add " " between two tokens, otherwise will be tokenized to multiple tokens
prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
instructions = [instruction + prompt_suffix for instruction in instructions]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
# last_hidden_state: [B, seq_len, H]
last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H]
# Step 4: Action Expert Forward and Loss
with torch.autocast("cuda", dtype=torch.float32):
# 提取动作 token embedding 作为动作预测查询
input_ids = qwen_inputs.get("input_ids", None)
action_queries = self._gather_action_token_embeddings(last_hidden, input_ids, action_token_id=self.action_token_id) # [B, chunk_len, H]
pred_actions = self.action_model.predict_action(action_queries) # (B, chunk_len, action_dim)
normalized_actions = pred_actions.detach().cpu().numpy()
return {"normalized_actions": normalized_actions}
def _gather_action_token_embeddings(
self,
last_hidden: torch.Tensor, # [B, L, H]
input_ids: torch.Tensor, # [B, L]
action_token_id=None, # 可为 int 或 List[int]
) -> torch.Tensor:
"""
向量化批量提取动作 token embedding:
- 不再逐样本 for 循环
- 取每个样本里最靠后的 chunk_len 个动作占位 token
Args:
last_hidden: [B, L, H]
input_ids: [B, L]
action_token_id: int 或 List[int]
Returns:
action_queries: [B, chunk_len, H]
"""
if action_token_id is None:
raise ValueError("action_token_id 不能为空")
device = input_ids.device
B, L, H = last_hidden.shape
# 支持多 id(如多个变体)
if isinstance(action_token_id, (list, tuple, set)):
id_list = torch.tensor(list(action_token_id), device=device, dtype=input_ids.dtype)
# torch.isin 需要 PyTorch >=1.10
mask = torch.isin(input_ids, id_list)
else:
mask = (input_ids == action_token_id) # [B, L]
counts = mask.sum(dim=1) # [B]
if (counts < self.chunk_len).any():
insufficient = (counts < self.chunk_len).nonzero(as_tuple=False).flatten().tolist()
raise RuntimeError(
f"以下样本动作 token 数量不足 {self.chunk_len}: {insufficient} | counts={counts.tolist()}"
)
# 位置索引
idx = torch.arange(L, device=device).unsqueeze(0).expand(B, L) # [B, L]
masked_pos = torch.where(mask, idx, torch.full_like(idx, -1)) # 非动作位置置 -1
# 取最后 chunk_len 个(索引大的在序列靠后)
# 注意: 已确保数量足够,不会出现 -1 被错误选中的问题
topk_pos = masked_pos.topk(k=self.chunk_len, dim=-1).values # [B, chunk_len] 未排序
# 时间顺序排序
selected_pos = topk_pos.sort(dim=-1).values # [B, chunk_len]
# Gather
expanded_index = selected_pos.unsqueeze(-1).expand(-1, -1, H) # [B, chunk_len, H]
action_queries = last_hidden.gather(dim=1, index=expanded_index) # [B, chunk_len, H]
return action_queries
if __name__ == "__main__":
from omegaconf import OmegaConf
import debugpy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
cfg = OmegaConf.load(args.config_yaml)
cfg.framework.action_model.action_hidden_dim = 2048
cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Florence-2-large"
# try get model
model = Qwenvl_OFT(cfg)
print(model)
# fake sample
image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
# Create a sample
sample = {
"action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim
"image": [image], # two views
"lang": "This is a fake instruction for testing.",
# "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim
}
sample2 = {
"action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim
"image": [image], # two views
"lang": "For testing.",
# "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim
}
batch = [sample, sample2] # batch size 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
forward_output = model(batch)
action_loss = forward_output['action_loss']
print(f"Action Loss: {action_loss.item()}")
# test predict action
predict_output = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]])
normalized_actions = predict_output['normalized_actions']
print(f"Unnormalized Action: {normalized_actions}")
# try forward model
# can be fake sample, but here get from dataloader for simpler
from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn
vla_dataset_cfg = cfg.datasets.vla_data
dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=1, # For Debug
collate_fn=collate_fn,
)
# zhe
for batch in tqdm(train_dataloader, desc="Processing Batches"):
batch
break
# try get model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model(batch)
pass
action = model.predict_action(batch) |