cross13tasks / code /model /framework /QwenFast.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
"""
Qwen-Fast Framework
A lightweight implementation for autoregressive discrete action prediction conditioned on multi-view images + instruction.
fast tokenizer is copyright from physical-intelligence/fast
Key Points:
- Qwen2.5 vision-language backbone
- Unified action learning via next-token prediction (fast tokenizer)
- Autoregressive action tokens derived from discretized / symbolized continuous actions
Note: How to add special tokens to Qwen2.5:
download our model checkpoint with special tokens added: https://huggingface.co/StarVLA/Qwen2.5-VL-3B-Instruct-Action
"""
from typing import List
from tqdm import tqdm
from typing import List, Optional, Tuple, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from qwen_vl_utils import process_vision_info
from starVLA.training.trainer_utils import initialize_overwatch
from starVLA.model.tools import FRAMEWORK_REGISTRY
from deployment.model_server.tools.image_tools import to_pil_preserve
logger = initialize_overwatch(__name__)
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.modules.vlm import get_vlm_model
from starVLA.model.modules.action_model.fast_ActionHeader import get_action_model
@FRAMEWORK_REGISTRY.register("QwenFast")
class Qwenvl_Fast(baseframework):
"""
Multimodal vision-language-action model.
Components:
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise QFormer for multi-layer feature aggregation
- DINO encoder for dense multi-view spatial tokens
- DiT diffusion head for future action sequence modeling
Focus: Predict future continuous actions conditioned on images + instruction.
"""
def __init__(
self,
config: Optional[dict] = None,
**kwargs,
) -> None:
"""
Construct all submodules and cache key configuration values.
Args:
config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.
**kwargs: Reserved for future overrides (unused).
"""
super().__init__()
self.config = config
self.qwen_vl_interface = get_vlm_model(config=self.config)
self.action_model = get_action_model(config=self.config)
self.future_action_window_size = config.framework.action_model.future_action_window_size
self.past_action_window_size = config.framework.action_model.past_action_window_size
self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size
# self.hidden_dim = config.framework.action_model.action_hidden_dim
self.action_model.fast_tokenizer.time_horizon = self.future_action_window_size + 1
self.action_model.fast_tokenizer.action_dim = self.config.framework.action_model.action_dim
def forward(
self,
examples: List[dict] = None,
**kwargs,
) -> Tuple:
"""
Training forward: directly predict future actions via next-token prediction (no diffusion).
Flow:
1. Build QwenVL inputs (images + instruction tokens)
2. Extract hidden states from configured layer range
7. Predict action and compute L1 loss
Args:
examples: List[dict], each dict requires:
- image: List[PIL.Image] (multi-view)
- lang: str instruction
- action: np.ndarray or list shaped [T, action_dim]
**kwargs: Reserved.
Returns:
dict:
action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
"""
batch_images = [example["image"] for example in examples] # [B, [PIL]]
instructions = [example["lang"] for example in examples] # [B, str]
actions = [example["action"] for example in examples] # label [B, len, 7]
# step 0: map_raw_action_to_vlm_action
batch_fast_tokens = self.action_model.encoder_action2fastoken(actions) # List[str]
# batch_fast_tokens = [self.fast_tokenizer(raw_action)[0] for raw_action in raw_actions]
vlm_action_tokens = [self.map_fast_token_to_vlm_action(fast_tokens) for fast_tokens in batch_fast_tokens]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions, solutions=vlm_action_tokens)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
vlm_action_loss = qwenvl_outputs.loss
if vlm_action_loss is None or torch.isnan(vlm_action_loss):
vlm_action_loss = torch.tensor(0.0, device=self.qwen_vl_interface.model.device)
return {"action_loss": vlm_action_loss}
@torch.inference_mode()
def predict_action(
self,
examples: List[dict] = None,
**kwargs: str,
) -> np.ndarray:
"""
Inference: single forward pass to obtain future actions (no diffusion sampling).
# can be batch forward
Steps:
1. Resize images to training resolution (if specified)
2. Encode with QwenVL (hidden states retained)
6. Return normalized action trajectory
Returns:
dict:
normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
"""
if type(examples) is not list:
examples = [examples]
batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
# train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
# if train_obs_image_size:
# batch_images = resize_images(batch_images, target_size=train_obs_image_size)
instructions = [instruction for instruction in instructions]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
with torch.autocast("cuda", dtype=torch.bfloat16):
generated_ids = self.qwen_vl_interface.model.generate(
**qwen_inputs,
max_length=2048,
)
# --- Extract and decoder vlm_action to continue actions ---
# --- extrace token (index based on VLM) ---
batch_vlm_action_token_ids = self._extract_action_token_ids(generated_ids)
# --- map index to fast tokenizer index space ---
batch_fast_action_token_idx = self._decode_action_tokens(batch_vlm_action_token_ids)
# --- decode fast tokenizer index to action semantic ---
normalized_actions = self.action_model.fast_tokenizer.decode(batch_fast_action_token_idx)
return {"normalized_actions": normalized_actions}
def _extract_action_token_ids(
self,
generated_ids: torch.LongTensor,
) -> List[List[int]]:
"""
Extract action tokens (with offset) from the generated token sequence and return a 2D list:
ret[b] = [vlm_action_token_id_0, vlm_action_token_id_1, ...]
Rule: keep all tokens falling within [_ACTION_TOKEN_MIN, _ACTION_TOKEN_MAX] in order of appearance.
You may change it to "take only the first occurrence followed by continuous segment" as needed.
"""
act_min = self.qwen_vl_interface._ACTION_TOKEN_MIN
act_max = self.qwen_vl_interface._ACTION_TOKEN_MAX
mask = (generated_ids >= act_min) & (generated_ids <= act_max) # [B, L]
results = []
for b in range(generated_ids.size(0)):
idx = mask[b].nonzero(as_tuple=False).flatten()
if idx.numel() == 0:
results.append([])
continue
# all action tokens
tokens = generated_ids[b, idx].tolist()
results.append(tokens)
return results
def _decode_action_tokens(self, batch_vlm_tokens: List[List[int]]) -> List[Any]:
"""
Decode the offset VLM action token list back to fast tokenizer semantics.
fast_tokenizer.decode expects the original fast token id sequence (without offset).
"""
act_min = self.qwen_vl_interface._ACTION_TOKEN_MIN
batch_fast_token_ids = []
for seq in batch_vlm_tokens:
if not seq:
batch_fast_token_ids.append(None)
continue
fast_ids = [t - act_min for t in seq]
batch_fast_token_ids.append(fast_ids)
return batch_fast_token_ids
def map_fast_token_to_vlm_action(self, tokens) -> str:
"""Maps fast action tokens to the VLM action format.
Action token 0 is mapped to the string <robot_action_0> ... and so on
"""
return ''.join([f"<robot_action_{token}>" for token in tokens]) # you should add <robot_action_{token}> to VLM as special tokens,
if __name__ == "__main__":
from omegaconf import OmegaConf
import debugpy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
args.config_yaml = "./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml"
cfg = OmegaConf.load(args.config_yaml)
# cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Qwen3-VL-4B-Instruct-Action"
# try get model
model = Qwenvl_Fast(cfg)
print(model)
# fake sample
image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
# Create a sample
sample = {
"action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), # action_chunk, action_dim
"image": [image, image], # two views
"lang": "This is a fake instruction for testing.",
# "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim
}
sample2 = {
"action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), # action_chunk, action_dim
"image": [image, image], # two views
"lang": "The fake instruction for testing.",
# "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim
}
batch = [sample, sample2] # batch size 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
forward_output = model(batch)
action_loss = forward_output['action_loss']
print(f"Action Loss: {action_loss.item()}")
# test predict action. for new model, it didn't learn to predict action token, so you would meet empty action
predict_output = model.predict_action([sample])
normalized_actions = predict_output['normalized_actions']
print(f"Unnormalized Action: {normalized_actions}")
# # test with dataloader
# # can be fake sample, but here get from dataloader for simpler
from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn
vla_dataset_cfg = cfg.datasets.vla_data
vla_dataset_cfg.video_backend = "torchvision_av"
dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=1, # For Debug
collate_fn=collate_fn,
)
for batch in tqdm(train_dataloader, desc="Processing Batches"):
batch
break
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model(batch)
pass
action = model.predict_action(batch[0])