# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by [Junqiu YU / Fudan University] in [2025]. # Design and Merged by [Jinhui YE / HKUST University] in [2025]. """ Qwen-GR00T Framework A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions Flow-matching header is copyright from GR00T N1.5, """ import sys from pathlib import Path # Add workspace root to Python path if not already there _workspace_root = Path(__file__).parent.parent.parent.parent if str(_workspace_root) not in sys.path: sys.path.insert(0, str(_workspace_root)) from typing import List from tqdm import tqdm from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image from starVLA.training.trainer_utils import initialize_overwatch from deployment.model_server.tools.image_tools import to_pil_preserve logger = initialize_overwatch(__name__) # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 from starVLA.model.framework.base_framework import baseframework from starVLA.model.modules.vlm import get_vlm_model from starVLA.model.modules.action_model.GR00T_ActionHeader import get_action_model, FlowmatchingActionHead from starVLA.training.trainer_utils.trainer_tools import resize_images from starVLA.model.tools import FRAMEWORK_REGISTRY @FRAMEWORK_REGISTRY.register("QwenGR00T") class Qwen_GR00T(baseframework): """ Multimodal vision-language-action model. Components: - Qwen2.5 VL interface for fused language/vision token embeddings - Layer-wise QFormer for multi-layer feature aggregation - DINO encoder for dense multi-view spatial tokens - DiT diffusion head for future action sequence modeling Focus: Predict future continuous actions conditioned on images + instruction. """ def __init__( self, config: Optional[dict] = None, **kwargs, ) -> None: """ Construct all submodules and cache key configuration values. Args: config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. **kwargs: Reserved for future overrides (unused). """ super().__init__() self.config = config self.qwen_vl_interface = get_vlm_model(config=self.config) llm_hidden_size = self.qwen_vl_interface.model.config.hidden_size self.llm_hidden_size = llm_hidden_size # align dims --> we should put them to config or no? self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = llm_hidden_size self.action_model: FlowmatchingActionHead = get_action_model(config=self.config) # 修复后续引用 self.future_action_window_size = config.framework.action_model.future_action_window_size self.past_action_window_size = config.framework.action_model.past_action_window_size self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size # Dataset soft prompt: conditions VLM on dataset identity self.dataset_vocab_size = getattr(self.config.framework.action_model, "dataset_vocab_size", 256) self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 0) if self.num_data_tokens > 0: self.dataset_embed = nn.Embedding( self.dataset_vocab_size, llm_hidden_size * self.num_data_tokens, ) def forward( self, examples: List[dict] = None, **kwargs, ) -> Tuple: """ """ batch_images = [example["image"] for example in examples] # [B,[PLT]] instructions = [example["lang"] for example in examples] # [B, str] actions = [example["action"] for example in examples] # label [B, len, 7] state = [example["state"] for example in examples] if "state" in examples[0] else None # [B, 1, state_dim] dataset_ids = [example.get("dataset_id", 0) for example in examples] # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) # Prepend dataset soft prompt tokens to VLM inputs if self.num_data_tokens > 0 and "input_ids" in qwen_inputs: dataset_ids_tensor = torch.tensor( dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long ) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( len(dataset_ids), self.num_data_tokens, self.llm_hidden_size ) token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds), dim=1) qwen_inputs.pop("input_ids") if "attention_mask" in qwen_inputs: prefix_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) qwen_inputs["attention_mask"] = torch.cat( (prefix_mask, qwen_inputs["attention_mask"]), dim=1 ) if "position_ids" in qwen_inputs: prefix_pos = torch.arange( self.num_data_tokens, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) qwen_inputs["position_ids"] = torch.cat( (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens), dim=1 ) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, ) # last_hidden_state: [B, seq_len, H] last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H] # Step 4: Action Expert Forward and Loss # Extract encoder_attention_mask before VLM forward (qwen_inputs still in scope). # In cross-embodied training, batch sequences have very different lengths due to # varying camera counts (different image token counts per environment). Without # masking, the DiT cross-attention attends to padding tokens, injecting # task-dependent noise that causes unstable performance across environments. encoder_attention_mask = qwen_inputs.get("attention_mask", None) with torch.autocast("cuda", dtype=torch.float32): actions = torch.tensor( np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype ) # [B, T_full, action_dim] actions_target = actions[:, -(self.future_action_window_size+1):, :] # (B, chunk_len, action_dim) repeated_diffusion_steps = ( self.config.trainer.get("repeated_diffusion_steps", 4) if self.config and self.config.trainer else 4 ) actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1) encoder_attention_mask_repeated = ( encoder_attention_mask.repeat(repeated_diffusion_steps, 1) if encoder_attention_mask is not None else None ) state_repeated = None if state is not None: state = torch.tensor( np.array(state), device=last_hidden.device, dtype=last_hidden.dtype ) state_repeated = state.repeat(repeated_diffusion_steps, 1, 1) action_loss = self.action_model( last_hidden_repeated, actions_target_repeated, state_repeated, encoder_attention_mask=encoder_attention_mask_repeated, ) return {"action_loss": action_loss} @torch.inference_mode() def predict_action( self, examples: List[dict], **kwargs: str, ) -> np.ndarray: """ Steps: 1. Resize images to training resolution (if specified) 2. Encode with QwenVL (hidden states retained) 6. Return normalized action trajectory Returns: dict: normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. """ if type(examples) is not list: examples = [examples] batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]] instructions = [example["lang"] for example in examples] # [B, str] state = [example["state"] for example in examples] if "state" in examples[0] else None # [B, 1, state_dim] dataset_ids = [example.get("dataset_id", 0) for example in examples] train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) if train_obs_image_size: batch_images = resize_images(batch_images, target_size=train_obs_image_size) # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) # Prepend dataset soft prompt tokens to VLM inputs if self.num_data_tokens > 0 and "input_ids" in qwen_inputs: dataset_ids_tensor = torch.tensor( dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long ) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( len(dataset_ids), self.num_data_tokens, self.llm_hidden_size ) token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds), dim=1) qwen_inputs.pop("input_ids") if "attention_mask" in qwen_inputs: prefix_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) qwen_inputs["attention_mask"] = torch.cat( (prefix_mask, qwen_inputs["attention_mask"]), dim=1 ) if "position_ids" in qwen_inputs: prefix_pos = torch.arange( self.num_data_tokens, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) qwen_inputs["position_ids"] = torch.cat( (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens), dim=1 ) encoder_attention_mask = qwen_inputs.get("attention_mask", None) # encoder_attention_mask = None with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, ) # last_hidden_state: [B, seq_len, H] last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H] state = torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) if state is not None else None # Step 4: Action Expert Forward with torch.autocast("cuda", dtype=torch.float32): pred_actions = self.action_model.predict_action( last_hidden, state, encoder_attention_mask=encoder_attention_mask ) normalized_actions = pred_actions.detach().cpu().numpy() return {"normalized_actions": normalized_actions} if __name__ == "__main__": from omegaconf import OmegaConf import debugpy import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() debugpy.listen(("0.0.0.0", 10092)) print("🔍 Rank 0 waiting for debugger attach on port 10092...") debugpy.wait_for_client() args.config_yaml = "examples/MultiRobot/train_files/starvla_cotrain_multiRobot.yaml" cfg = OmegaConf.load(args.config_yaml) # try get model # cfg.framework.action_model.action_hidden_dim = 2048 # cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Florence-2-large" model: Qwen_GR00T = Qwen_GR00T(cfg) print(model) # fake sample image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) # Create a sample sample = { "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim "image": [image], # three views "lang": "Put all the toys in the child's room - the three board games (two on the bed and one on the table), the two jigsaw puzzles on the table, and the tennis ball on the table - inside the toy box on the table in the child's room.", # "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim } sample2 = { "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), # action_chunk, action_dim "image": [image], # three views "lang": "Put all the toys in the child's room - the three board games (two on the bed and one on the table), the two jigsaw puzzles on the table, and the tennis ball on the table - inside the toy box on the table in the child's room.", # "state" : np.random.uniform(-1, 1, size=(1, 7)).astype(np.float16), # chunk, state_dim } batch = [sample, sample2] # batch size 2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) forward_output = model(batch) action_loss = forward_output['action_loss'] print(f"Action Loss: {action_loss.item()}") # test predict action predict_output = model.predict_action(examples=[sample]) #, state=[batch[0]["state"]] normalized_actions = predict_output['normalized_actions'] print(f"Unnormalized Action: {normalized_actions}") # # Advance: try forward model with dataloader # # can be fake sample, but here get from dataloader for simpler vla_dataset_cfg = cfg.datasets.vla_data from torch.utils.data import DataLoader from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn cfg.datasets.vla_data.include_state = "False" dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) train_dataloader = DataLoader( dataset, batch_size=2, num_workers=1, # For Debug collate_fn=collate_fn, ) # forward model with dataloader for batch in tqdm(train_dataloader, desc="Processing Batches"): # try get model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model(batch) # break action = model.predict_action(examples=batch) print("Finished")