| |
| |
| |
|
|
| """ |
| Qwen-Fast Framework |
| |
| A lightweight implementation for autoregressive discrete action prediction conditioned on multi-view images + instruction. |
| fast tokenizer is copyright from physical-intelligence/fast |
| |
| Key Points: |
| - Qwen2.5 vision-language backbone |
| - Unified action learning via next-token prediction (fast tokenizer) |
| - Autoregressive action tokens derived from discretized / symbolized continuous actions |
| |
| Note: How to add special tokens to Qwen2.5: |
| download our model checkpoint with special tokens added: https://huggingface.co/StarVLA/Qwen2.5-VL-3B-Instruct-Action |
| """ |
|
|
| from typing import List |
| from tqdm import tqdm |
| from typing import List, Optional, Tuple, Any |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| from qwen_vl_utils import process_vision_info |
|
|
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.modules.vlm import get_vlm_model |
| from starVLA.model.modules.action_model.fast_ActionHeader import get_action_model |
|
|
|
|
| @FRAMEWORK_REGISTRY.register("QwenFast") |
| class Qwenvl_Fast(baseframework): |
| """ |
| Multimodal vision-language-action model. |
| |
| Components: |
| - Qwen2.5 VL interface for fused language/vision token embeddings |
| - Layer-wise QFormer for multi-layer feature aggregation |
| - DINO encoder for dense multi-view spatial tokens |
| - DiT diffusion head for future action sequence modeling |
| |
| Focus: Predict future continuous actions conditioned on images + instruction. |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[dict] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| Construct all submodules and cache key configuration values. |
| |
| Args: |
| config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. |
| **kwargs: Reserved for future overrides (unused). |
| """ |
| super().__init__() |
| self.config = config |
| self.qwen_vl_interface = get_vlm_model(config=self.config) |
| self.action_model = get_action_model(config=self.config) |
|
|
| self.future_action_window_size = config.framework.action_model.future_action_window_size |
| self.past_action_window_size = config.framework.action_model.past_action_window_size |
| self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size |
| |
| |
| self.action_model.fast_tokenizer.time_horizon = self.future_action_window_size + 1 |
| self.action_model.fast_tokenizer.action_dim = self.config.framework.action_model.action_dim |
|
|
| |
|
|
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs, |
| ) -> Tuple: |
| """ |
| Training forward: directly predict future actions via next-token prediction (no diffusion). |
| |
| Flow: |
| 1. Build QwenVL inputs (images + instruction tokens) |
| 2. Extract hidden states from configured layer range |
| 7. Predict action and compute L1 loss |
| |
| Args: |
| examples: List[dict], each dict requires: |
| - image: List[PIL.Image] (multi-view) |
| - lang: str instruction |
| - action: np.ndarray or list shaped [T, action_dim] |
| **kwargs: Reserved. |
| |
| Returns: |
| dict: |
| action_loss (torch.Tensor): Scalar diffusion noise prediction loss. |
| """ |
| batch_images = [example["image"] for example in examples] |
| instructions = [example["lang"] for example in examples] |
| actions = [example["action"] for example in examples] |
| |
| |
| batch_fast_tokens = self.action_model.encoder_action2fastoken(actions) |
|
|
| |
| vlm_action_tokens = [self.map_fast_token_to_vlm_action(fast_tokens) for fast_tokens in batch_fast_tokens] |
|
|
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions, solutions=vlm_action_tokens) |
| |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ) |
| |
| vlm_action_loss = qwenvl_outputs.loss |
| if vlm_action_loss is None or torch.isnan(vlm_action_loss): |
| vlm_action_loss = torch.tensor(0.0, device=self.qwen_vl_interface.model.device) |
|
|
| return {"action_loss": vlm_action_loss} |
|
|
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict] = None, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """ |
| Inference: single forward pass to obtain future actions (no diffusion sampling). |
| # can be batch forward |
| Steps: |
| 1. Resize images to training resolution (if specified) |
| 2. Encode with QwenVL (hidden states retained) |
| 6. Return normalized action trajectory |
| |
| Returns: |
| dict: |
| normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. |
| """ |
| if type(examples) is not list: |
| examples = [examples] |
| batch_images = [to_pil_preserve(example["image"]) for example in examples] |
| instructions = [example["lang"] for example in examples] |
| |
| |
| |
| |
| instructions = [instruction for instruction in instructions] |
|
|
|
|
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| generated_ids = self.qwen_vl_interface.model.generate( |
| **qwen_inputs, |
| max_length=2048, |
| ) |
| |
| |
| batch_vlm_action_token_ids = self._extract_action_token_ids(generated_ids) |
| |
| batch_fast_action_token_idx = self._decode_action_tokens(batch_vlm_action_token_ids) |
| |
| normalized_actions = self.action_model.fast_tokenizer.decode(batch_fast_action_token_idx) |
|
|
| return {"normalized_actions": normalized_actions} |
|
|
| def _extract_action_token_ids( |
| self, |
| generated_ids: torch.LongTensor, |
| ) -> List[List[int]]: |
| """ |
| Extract action tokens (with offset) from the generated token sequence and return a 2D list: |
| ret[b] = [vlm_action_token_id_0, vlm_action_token_id_1, ...] |
| Rule: keep all tokens falling within [_ACTION_TOKEN_MIN, _ACTION_TOKEN_MAX] in order of appearance. |
| You may change it to "take only the first occurrence followed by continuous segment" as needed. |
| """ |
| act_min = self.qwen_vl_interface._ACTION_TOKEN_MIN |
| act_max = self.qwen_vl_interface._ACTION_TOKEN_MAX |
| mask = (generated_ids >= act_min) & (generated_ids <= act_max) |
| results = [] |
| for b in range(generated_ids.size(0)): |
| idx = mask[b].nonzero(as_tuple=False).flatten() |
| if idx.numel() == 0: |
| results.append([]) |
| continue |
| |
| tokens = generated_ids[b, idx].tolist() |
| results.append(tokens) |
| return results |
|
|
| def _decode_action_tokens(self, batch_vlm_tokens: List[List[int]]) -> List[Any]: |
| """ |
| Decode the offset VLM action token list back to fast tokenizer semantics. |
| fast_tokenizer.decode expects the original fast token id sequence (without offset). |
| """ |
| act_min = self.qwen_vl_interface._ACTION_TOKEN_MIN |
| batch_fast_token_ids = [] |
| for seq in batch_vlm_tokens: |
| if not seq: |
| batch_fast_token_ids.append(None) |
| continue |
| fast_ids = [t - act_min for t in seq] |
| |
| batch_fast_token_ids.append(fast_ids) |
| |
| return batch_fast_token_ids |
|
|
| def map_fast_token_to_vlm_action(self, tokens) -> str: |
| """Maps fast action tokens to the VLM action format. |
| Action token 0 is mapped to the string <robot_action_0> ... and so on |
| """ |
| return ''.join([f"<robot_action_{token}>" for token in tokens]) |
|
|
|
|
| if __name__ == "__main__": |
| from omegaconf import OmegaConf |
| import debugpy |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
| args.config_yaml = "./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml" |
| cfg = OmegaConf.load(args.config_yaml) |
| |
|
|
| |
| model = Qwenvl_Fast(cfg) |
| print(model) |
|
|
| |
| image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) |
| |
| sample = { |
| "action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), |
| "image": [image, image], |
| "lang": "This is a fake instruction for testing.", |
| |
| } |
|
|
| sample2 = { |
| "action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), |
| "image": [image, image], |
| "lang": "The fake instruction for testing.", |
| |
| } |
|
|
| batch = [sample, sample2] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| forward_output = model(batch) |
| action_loss = forward_output['action_loss'] |
| print(f"Action Loss: {action_loss.item()}") |
|
|
| |
| predict_output = model.predict_action([sample]) |
| normalized_actions = predict_output['normalized_actions'] |
| print(f"Unnormalized Action: {normalized_actions}") |
|
|
|
|
|
|
|
|
| |
| |
| from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn |
|
|
| vla_dataset_cfg = cfg.datasets.vla_data |
| vla_dataset_cfg.video_backend = "torchvision_av" |
| dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) |
|
|
| from torch.utils.data import DataLoader |
|
|
| train_dataloader = DataLoader( |
| dataset, |
| batch_size=2, |
| num_workers=1, |
| collate_fn=collate_fn, |
| ) |
| |
| for batch in tqdm(train_dataloader, desc="Processing Batches"): |
| batch |
| break |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| model(batch) |
| pass |
| action = model.predict_action(batch[0]) |
|
|