# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by [Jinhui YE / HKUST University] in [2025]. """ Qwen-Fast Framework A lightweight implementation for autoregressive discrete action prediction conditioned on multi-view images + instruction. fast tokenizer is copyright from physical-intelligence/fast Key Points: - Qwen2.5 vision-language backbone - Unified action learning via next-token prediction (fast tokenizer) - Autoregressive action tokens derived from discretized / symbolized continuous actions Note: How to add special tokens to Qwen2.5: download our model checkpoint with special tokens added: https://huggingface.co/StarVLA/Qwen2.5-VL-3B-Instruct-Action """ from typing import List from tqdm import tqdm from typing import List, Optional, Tuple, Any import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image from qwen_vl_utils import process_vision_info from starVLA.training.trainer_utils import initialize_overwatch from starVLA.model.tools import FRAMEWORK_REGISTRY from deployment.model_server.tools.image_tools import to_pil_preserve logger = initialize_overwatch(__name__) # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 from starVLA.model.framework.base_framework import baseframework from starVLA.model.modules.vlm import get_vlm_model from starVLA.model.modules.action_model.fast_ActionHeader import get_action_model @FRAMEWORK_REGISTRY.register("QwenFast") class Qwenvl_Fast(baseframework): """ Multimodal vision-language-action model. Components: - Qwen2.5 VL interface for fused language/vision token embeddings - Layer-wise QFormer for multi-layer feature aggregation - DINO encoder for dense multi-view spatial tokens - DiT diffusion head for future action sequence modeling Focus: Predict future continuous actions conditioned on images + instruction. """ def __init__( self, config: Optional[dict] = None, **kwargs, ) -> None: """ Construct all submodules and cache key configuration values. Args: config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. **kwargs: Reserved for future overrides (unused). """ super().__init__() self.config = config self.qwen_vl_interface = get_vlm_model(config=self.config) self.action_model = get_action_model(config=self.config) self.future_action_window_size = config.framework.action_model.future_action_window_size self.past_action_window_size = config.framework.action_model.past_action_window_size self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size # self.hidden_dim = config.framework.action_model.action_hidden_dim self.action_model.fast_tokenizer.time_horizon = self.future_action_window_size + 1 self.action_model.fast_tokenizer.action_dim = self.config.framework.action_model.action_dim def forward( self, examples: List[dict] = None, **kwargs, ) -> Tuple: """ Training forward: directly predict future actions via next-token prediction (no diffusion). Flow: 1. Build QwenVL inputs (images + instruction tokens) 2. Extract hidden states from configured layer range 7. Predict action and compute L1 loss Args: examples: List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim] **kwargs: Reserved. Returns: dict: action_loss (torch.Tensor): Scalar diffusion noise prediction loss. """ batch_images = [example["image"] for example in examples] # [B, [PIL]] instructions = [example["lang"] for example in examples] # [B, str] actions = [example["action"] for example in examples] # label [B, len, 7] # step 0: map_raw_action_to_vlm_action batch_fast_tokens = self.action_model.encoder_action2fastoken(actions) # List[str] # batch_fast_tokens = [self.fast_tokenizer(raw_action)[0] for raw_action in raw_actions] vlm_action_tokens = [self.map_fast_token_to_vlm_action(fast_tokens) for fast_tokens in batch_fast_tokens] # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions, solutions=vlm_action_tokens) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=False, return_dict=True, ) vlm_action_loss = qwenvl_outputs.loss if vlm_action_loss is None or torch.isnan(vlm_action_loss): vlm_action_loss = torch.tensor(0.0, device=self.qwen_vl_interface.model.device) return {"action_loss": vlm_action_loss} @torch.inference_mode() def predict_action( self, examples: List[dict] = None, **kwargs: str, ) -> np.ndarray: """ Inference: single forward pass to obtain future actions (no diffusion sampling). # can be batch forward Steps: 1. Resize images to training resolution (if specified) 2. Encode with QwenVL (hidden states retained) 6. Return normalized action trajectory Returns: dict: normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. """ if type(examples) is not list: examples = [examples] batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]] instructions = [example["lang"] for example in examples] # [B, str] # train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) # if train_obs_image_size: # batch_images = resize_images(batch_images, target_size=train_obs_image_size) instructions = [instruction for instruction in instructions] # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) with torch.autocast("cuda", dtype=torch.bfloat16): generated_ids = self.qwen_vl_interface.model.generate( **qwen_inputs, max_length=2048, ) # --- Extract and decoder vlm_action to continue actions --- # --- extrace token (index based on VLM) --- batch_vlm_action_token_ids = self._extract_action_token_ids(generated_ids) # --- map index to fast tokenizer index space --- batch_fast_action_token_idx = self._decode_action_tokens(batch_vlm_action_token_ids) # --- decode fast tokenizer index to action semantic --- normalized_actions = self.action_model.fast_tokenizer.decode(batch_fast_action_token_idx) return {"normalized_actions": normalized_actions} def _extract_action_token_ids( self, generated_ids: torch.LongTensor, ) -> List[List[int]]: """ Extract action tokens (with offset) from the generated token sequence and return a 2D list: ret[b] = [vlm_action_token_id_0, vlm_action_token_id_1, ...] Rule: keep all tokens falling within [_ACTION_TOKEN_MIN, _ACTION_TOKEN_MAX] in order of appearance. You may change it to "take only the first occurrence followed by continuous segment" as needed. """ act_min = self.qwen_vl_interface._ACTION_TOKEN_MIN act_max = self.qwen_vl_interface._ACTION_TOKEN_MAX mask = (generated_ids >= act_min) & (generated_ids <= act_max) # [B, L] results = [] for b in range(generated_ids.size(0)): idx = mask[b].nonzero(as_tuple=False).flatten() if idx.numel() == 0: results.append([]) continue # all action tokens tokens = generated_ids[b, idx].tolist() results.append(tokens) return results def _decode_action_tokens(self, batch_vlm_tokens: List[List[int]]) -> List[Any]: """ Decode the offset VLM action token list back to fast tokenizer semantics. fast_tokenizer.decode expects the original fast token id sequence (without offset). """ act_min = self.qwen_vl_interface._ACTION_TOKEN_MIN batch_fast_token_ids = [] for seq in batch_vlm_tokens: if not seq: batch_fast_token_ids.append(None) continue fast_ids = [t - act_min for t in seq] batch_fast_token_ids.append(fast_ids) return batch_fast_token_ids def map_fast_token_to_vlm_action(self, tokens) -> str: """Maps fast action tokens to the VLM action format. Action token 0 is mapped to the string ... and so on """ return ''.join([f"" for token in tokens]) # you should add to VLM as special tokens, if __name__ == "__main__": from omegaconf import OmegaConf import debugpy import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() debugpy.listen(("0.0.0.0", 10092)) print("🔍 Rank 0 waiting for debugger attach on port 10092...") debugpy.wait_for_client() args.config_yaml = "./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml" cfg = OmegaConf.load(args.config_yaml) # cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Qwen3-VL-4B-Instruct-Action" # try get model model = Qwenvl_Fast(cfg) print(model) # fake sample image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) # Create a sample sample = { "action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), # action_chunk, action_dim "image": [image, image], # two views "lang": "This is a fake instruction for testing.", # "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim } sample2 = { "action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), # action_chunk, action_dim "image": [image, image], # two views "lang": "The fake instruction for testing.", # "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim } batch = [sample, sample2] # batch size 2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) forward_output = model(batch) action_loss = forward_output['action_loss'] print(f"Action Loss: {action_loss.item()}") # test predict action. for new model, it didn't learn to predict action token, so you would meet empty action predict_output = model.predict_action([sample]) normalized_actions = predict_output['normalized_actions'] print(f"Unnormalized Action: {normalized_actions}") # # test with dataloader # # can be fake sample, but here get from dataloader for simpler from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn vla_dataset_cfg = cfg.datasets.vla_data vla_dataset_cfg.video_backend = "torchvision_av" dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) from torch.utils.data import DataLoader train_dataloader = DataLoader( dataset, batch_size=2, num_workers=1, # For Debug collate_fn=collate_fn, ) for batch in tqdm(train_dataloader, desc="Processing Batches"): batch break device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model(batch) pass action = model.predict_action(batch[0])