cross13tasks / code /model /framework /QwenOFT.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
"""
Qwen-OFT Framework
A lightweight implementation that uses an action special token to parallelly predict continuous actions
conditioned on multi-view images plus a language instruction (shares parameters with the VLM).
Inspired by OpenVLA-OFT
Key Points:
- Qwen2.5 vision-language backbone
- Injects an action special token into the VLM
- Continuous action prediction via L1 regression over the action special token hidden states
Note: How to add special tokens to Qwen2.5:
download our model checkpoint with special tokens added: https://huggingface.co/StarVLA/Qwen2.5-VL-3B-Instruct-Action
or /starVLA/model/modules/vlm/tools/add_qwen_special_tokens/README.md (adpat a little code)
"""
from typing import List
from tqdm import tqdm
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from starVLA.training.trainer_utils import initialize_overwatch
from starVLA.model.tools import FRAMEWORK_REGISTRY
from deployment.model_server.tools.image_tools import to_pil_preserve
logger = initialize_overwatch(__name__)
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.modules.vlm import get_vlm_model
from starVLA.model.modules.action_model.MLP_ActionHeader import get_action_model
from starVLA.training.trainer_utils.trainer_tools import resize_images
@FRAMEWORK_REGISTRY.register("QwenOFT")
class Qwenvl_OFT(baseframework):
"""
Multimodal vision-language-action model.
Components:
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise QFormer for multi-layer feature aggregation
- DINO encoder for dense multi-view spatial tokens
- DiT diffusion head for future action sequence modeling
Focus: Predict future continuous actions conditioned on images + instruction.
"""
def __init__(
self,
config: Optional[dict] = None,
**kwargs,
) -> None:
"""
Construct all submodules and cache key configuration values.
Args:
config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.
**kwargs: Reserved for future overrides (unused).
"""
super().__init__()
self.config = config
self.qwen_vl_interface = get_vlm_model(config=self.config)
# align dims --> we should put them to config or no?
config.framework.action_model.action_hidden_dim = self.qwen_vl_interface.model.config.hidden_size
self.action_model = get_action_model(config=self.config)
self.future_action_window_size = config.framework.action_model.future_action_window_size
self.past_action_window_size = config.framework.action_model.past_action_window_size
self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size
# self.hidden_dim = config.framework.action_model.action_hidden_dim
self.action_token = "🔍" # TODO also can add spacail token to Qwen, but too complex
self.action_token_id = self.qwen_vl_interface.processor.tokenizer("🔍", add_special_tokens=False)["input_ids"][0]
# L1 损失
self.l1_loss = nn.L1Loss()
def forward(
self,
examples: List[dict] = None,
**kwargs,
) -> Tuple:
"""
训练前向:直接回归未来动作(无扩散)。
Flow:
1. Build QwenVL inputs (images + instruction tokens)
2. Extract hidden states from configured layer range
7. Predict action and compute L1 loss
Args:
examples: List[dict], each dict requires:
- image: List[PIL.Image] (multi-view)
- lang: str instruction
- action: np.ndarray or list shaped [T, action_dim]
**kwargs: Reserved.
Returns:
dict:
action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
"""
batch_images = [example["image"] for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
actions = [example["action"] for example in examples] # label [B, len, 7]
# step 0: add special action token to instruction
action_tokens = self.action_token* self.chunk_len #can't add " " between two tokens, otherwise will be tokenized to multiple tokens
prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
instructions = [instruction + prompt_suffix for instruction in instructions]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
# last_hidden_state: [B, seq_len, H]
last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H]
# Step 4: Action Expert Forward and Loss
with torch.autocast("cuda", dtype=torch.float32):
# 提取动作 token embedding 作为动作预测查询
input_ids = qwen_inputs.get("input_ids", None)
action_queries = self._gather_action_token_embeddings(last_hidden, input_ids, action_token_id=self.action_token_id) # [B, chunk_len, H]
pred_actions = self.action_model.predict_action(action_queries) # (B, chunk_len, action_dim)
# 标签对齐:取最后 chunk_len 段
actions = torch.tensor(
np.array(actions), device=pred_actions.device, dtype=pred_actions.dtype
) # [B, T_full, action_dim]
actions_target = actions[:, -(self.future_action_window_size+1):, :] # (B, chunk_len, action_dim)
# 计算 L1 损失
action_loss = self.l1_loss(pred_actions, actions_target)
return {"action_loss": action_loss}
@torch.inference_mode()
def predict_action(
self,
examples: List[dict] = None,
**kwargs: str,
) -> np.ndarray:
"""
推理:单次前向直接回归未来动作(无扩散采样)。
Steps:
1. Resize images to training resolution (if specified)
2. Encode with QwenVL (hidden states retained)
6. Return normalized action trajectory
Returns:
dict:
normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
"""
batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
if train_obs_image_size:
batch_images = resize_images(batch_images, target_size=train_obs_image_size)
# step 0: add special action token to instruction
action_tokens = self.action_token* self.chunk_len #can't add " " between two tokens, otherwise will be tokenized to multiple tokens
prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
instructions = [instruction + prompt_suffix for instruction in instructions]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
# last_hidden_state: [B, seq_len, H]
last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H]
# Step 4: Action Expert Forward and Loss
with torch.autocast("cuda", dtype=torch.float32):
# 提取动作 token embedding 作为动作预测查询
input_ids = qwen_inputs.get("input_ids", None)
action_queries = self._gather_action_token_embeddings(last_hidden, input_ids, action_token_id=self.action_token_id) # [B, chunk_len, H]
pred_actions = self.action_model.predict_action(action_queries) # (B, chunk_len, action_dim)
normalized_actions = pred_actions.detach().cpu().numpy()
return {"normalized_actions": normalized_actions}
def _gather_action_token_embeddings(
self,
last_hidden: torch.Tensor, # [B, L, H]
input_ids: torch.Tensor, # [B, L]
action_token_id=None, # 可为 int 或 List[int]
) -> torch.Tensor:
"""
向量化批量提取动作 token embedding:
- 不再逐样本 for 循环
- 取每个样本里最靠后的 chunk_len 个动作占位 token
Args:
last_hidden: [B, L, H]
input_ids: [B, L]
action_token_id: int 或 List[int]
Returns:
action_queries: [B, chunk_len, H]
"""
if action_token_id is None:
raise ValueError("action_token_id 不能为空")
device = input_ids.device
B, L, H = last_hidden.shape
# 支持多 id(如多个变体)
if isinstance(action_token_id, (list, tuple, set)):
id_list = torch.tensor(list(action_token_id), device=device, dtype=input_ids.dtype)
# torch.isin 需要 PyTorch >=1.10
mask = torch.isin(input_ids, id_list)
else:
mask = (input_ids == action_token_id) # [B, L]
counts = mask.sum(dim=1) # [B]
if (counts < self.chunk_len).any():
insufficient = (counts < self.chunk_len).nonzero(as_tuple=False).flatten().tolist()
raise RuntimeError(
f"以下样本动作 token 数量不足 {self.chunk_len}: {insufficient} | counts={counts.tolist()}"
)
# 位置索引
idx = torch.arange(L, device=device).unsqueeze(0).expand(B, L) # [B, L]
masked_pos = torch.where(mask, idx, torch.full_like(idx, -1)) # 非动作位置置 -1
# 取最后 chunk_len 个(索引大的在序列靠后)
# 注意: 已确保数量足够,不会出现 -1 被错误选中的问题
topk_pos = masked_pos.topk(k=self.chunk_len, dim=-1).values # [B, chunk_len] 未排序
# 时间顺序排序
selected_pos = topk_pos.sort(dim=-1).values # [B, chunk_len]
# Gather
expanded_index = selected_pos.unsqueeze(-1).expand(-1, -1, H) # [B, chunk_len, H]
action_queries = last_hidden.gather(dim=1, index=expanded_index) # [B, chunk_len, H]
return action_queries
if __name__ == "__main__":
from omegaconf import OmegaConf
import debugpy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
cfg = OmegaConf.load(args.config_yaml)
cfg.framework.action_model.action_hidden_dim = 2048
cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Florence-2-large"
# try get model
model = Qwenvl_OFT(cfg)
print(model)
# fake sample
image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
# Create a sample
sample = {
"action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim
"image": [image], # two views
"lang": "This is a fake instruction for testing.",
# "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim
}
sample2 = {
"action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim
"image": [image], # two views
"lang": "For testing.",
# "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim
}
batch = [sample, sample2] # batch size 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
forward_output = model(batch)
action_loss = forward_output['action_loss']
print(f"Action Loss: {action_loss.item()}")
# test predict action
predict_output = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]])
normalized_actions = predict_output['normalized_actions']
print(f"Unnormalized Action: {normalized_actions}")
# try forward model
# can be fake sample, but here get from dataloader for simpler
from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn
vla_dataset_cfg = cfg.datasets.vla_data
dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=1, # For Debug
collate_fn=collate_fn,
)
# zhe
for batch in tqdm(train_dataloader, desc="Processing Batches"):
batch
break
# try get model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model(batch)
pass
action = model.predict_action(batch)