| |
| |
| |
| |
| """ |
| Qwen-GR00T Framework |
| A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions |
| Flow-matching header is copyright from GR00T N1.5, |
| """ |
| import sys |
| from pathlib import Path |
|
|
| |
| _workspace_root = Path(__file__).parent.parent.parent.parent |
| if str(_workspace_root) not in sys.path: |
| sys.path.insert(0, str(_workspace_root)) |
|
|
| from typing import List |
| from tqdm import tqdm |
| from typing import List, Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
|
|
|
|
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.modules.vlm import get_vlm_model |
| from starVLA.model.modules.action_model.GR00T_ActionHeader import get_action_model, FlowmatchingActionHead |
| from starVLA.training.trainer_utils.trainer_tools import resize_images |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
|
|
|
|
| @FRAMEWORK_REGISTRY.register("QwenGR00T") |
| class Qwen_GR00T(baseframework): |
| """ |
| Multimodal vision-language-action model. |
| |
| Components: |
| - Qwen2.5 VL interface for fused language/vision token embeddings |
| - Layer-wise QFormer for multi-layer feature aggregation |
| - DINO encoder for dense multi-view spatial tokens |
| - DiT diffusion head for future action sequence modeling |
| |
| Focus: Predict future continuous actions conditioned on images + instruction. |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[dict] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| Construct all submodules and cache key configuration values. |
| |
| Args: |
| config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. |
| **kwargs: Reserved for future overrides (unused). |
| """ |
| super().__init__() |
| self.config = config |
| self.qwen_vl_interface = get_vlm_model(config=self.config) |
| llm_hidden_size = self.qwen_vl_interface.model.config.hidden_size |
| self.llm_hidden_size = llm_hidden_size |
| |
| self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = llm_hidden_size |
|
|
| self.action_model: FlowmatchingActionHead = get_action_model(config=self.config) |
|
|
| self.future_action_window_size = config.framework.action_model.future_action_window_size |
| self.past_action_window_size = config.framework.action_model.past_action_window_size |
| self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size |
|
|
| |
| self.dataset_vocab_size = getattr(self.config.framework.action_model, "dataset_vocab_size", 256) |
| self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 0) |
| if self.num_data_tokens > 0: |
| self.dataset_embed = nn.Embedding( |
| self.dataset_vocab_size, |
| llm_hidden_size * self.num_data_tokens, |
| ) |
| |
|
|
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs, |
| ) -> Tuple: |
| """ |
| |
| """ |
| batch_images = [example["image"] for example in examples] |
| instructions = [example["lang"] for example in examples] |
| actions = [example["action"] for example in examples] |
| |
| state = [example["state"] for example in examples] if "state" in examples[0] else None |
| dataset_ids = [example.get("dataset_id", 0) for example in examples] |
|
|
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) |
|
|
| |
| if self.num_data_tokens > 0 and "input_ids" in qwen_inputs: |
| dataset_ids_tensor = torch.tensor( |
| dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long |
| ) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| len(dataset_ids), self.num_data_tokens, self.llm_hidden_size |
| ) |
| token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) |
| qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds), dim=1) |
| qwen_inputs.pop("input_ids") |
| if "attention_mask" in qwen_inputs: |
| prefix_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| qwen_inputs["attention_mask"] = torch.cat( |
| (prefix_mask, qwen_inputs["attention_mask"]), dim=1 |
| ) |
| if "position_ids" in qwen_inputs: |
| prefix_pos = torch.arange( |
| self.num_data_tokens, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) |
| qwen_inputs["position_ids"] = torch.cat( |
| (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens), dim=1 |
| ) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| |
| last_hidden = qwenvl_outputs.hidden_states[-1] |
|
|
| |
| |
| |
| |
| |
| |
| encoder_attention_mask = qwen_inputs.get("attention_mask", None) |
|
|
| with torch.autocast("cuda", dtype=torch.float32): |
| actions = torch.tensor( |
| np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype |
| ) |
| actions_target = actions[:, -(self.future_action_window_size+1):, :] |
|
|
| repeated_diffusion_steps = ( |
| self.config.trainer.get("repeated_diffusion_steps", 4) if self.config and self.config.trainer else 4 |
| ) |
| actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1) |
| last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1) |
| encoder_attention_mask_repeated = ( |
| encoder_attention_mask.repeat(repeated_diffusion_steps, 1) |
| if encoder_attention_mask is not None else None |
| ) |
|
|
| state_repeated = None |
| if state is not None: |
| state = torch.tensor( |
| np.array(state), device=last_hidden.device, dtype=last_hidden.dtype |
| ) |
| state_repeated = state.repeat(repeated_diffusion_steps, 1, 1) |
|
|
| action_loss = self.action_model( |
| last_hidden_repeated, |
| actions_target_repeated, |
| state_repeated, |
| encoder_attention_mask=encoder_attention_mask_repeated, |
| ) |
|
|
| return {"action_loss": action_loss} |
|
|
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict], |
| **kwargs: str, |
| ) -> np.ndarray: |
| """ |
| Steps: |
| 1. Resize images to training resolution (if specified) |
| 2. Encode with QwenVL (hidden states retained) |
| 6. Return normalized action trajectory |
| Returns: |
| dict: |
| normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. |
| """ |
| if type(examples) is not list: |
| examples = [examples] |
| batch_images = [to_pil_preserve(example["image"]) for example in examples] |
| instructions = [example["lang"] for example in examples] |
| |
| state = [example["state"] for example in examples] if "state" in examples[0] else None |
| dataset_ids = [example.get("dataset_id", 0) for example in examples] |
|
|
| train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) |
| if train_obs_image_size: |
| batch_images = resize_images(batch_images, target_size=train_obs_image_size) |
| |
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions) |
|
|
| |
| if self.num_data_tokens > 0 and "input_ids" in qwen_inputs: |
| dataset_ids_tensor = torch.tensor( |
| dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long |
| ) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| len(dataset_ids), self.num_data_tokens, self.llm_hidden_size |
| ) |
| token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) |
| qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds), dim=1) |
| qwen_inputs.pop("input_ids") |
| if "attention_mask" in qwen_inputs: |
| prefix_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| qwen_inputs["attention_mask"] = torch.cat( |
| (prefix_mask, qwen_inputs["attention_mask"]), dim=1 |
| ) |
| if "position_ids" in qwen_inputs: |
| prefix_pos = torch.arange( |
| self.num_data_tokens, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) |
| qwen_inputs["position_ids"] = torch.cat( |
| (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens), dim=1 |
| ) |
|
|
| encoder_attention_mask = qwen_inputs.get("attention_mask", None) |
| |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden = qwenvl_outputs.hidden_states[-1] |
|
|
| state = torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) if state is not None else None |
| |
| |
| with torch.autocast("cuda", dtype=torch.float32): |
| pred_actions = self.action_model.predict_action( |
| last_hidden, state, encoder_attention_mask=encoder_attention_mask |
| ) |
|
|
| normalized_actions = pred_actions.detach().cpu().numpy() |
| return {"normalized_actions": normalized_actions} |
|
|
|
|
| if __name__ == "__main__": |
| from omegaconf import OmegaConf |
| import debugpy |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="./examples/Robotwin/train_files/starvla_cotrain_robotwin.yaml", help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
| args.config_yaml = "examples/MultiRobot/train_files/starvla_cotrain_multiRobot.yaml" |
| cfg = OmegaConf.load(args.config_yaml) |
| |
| |
|
|
| |
| |
|
|
| model: Qwen_GR00T = Qwen_GR00T(cfg) |
| print(model) |
|
|
|
|
|
|
| |
| image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) |
| |
| sample = { |
| "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), |
| "image": [image], |
| "lang": "Put all the toys in the child's room - the three board games (two on the bed and one on the table), the two jigsaw puzzles on the table, and the tennis ball on the table - inside the toy box on the table in the child's room.", |
| |
| } |
| sample2 = { |
| "action": np.random.uniform(-1, 1, size=(16, 7)).astype(np.float16), |
| "image": [image], |
| "lang": "Put all the toys in the child's room - the three board games (two on the bed and one on the table), the two jigsaw puzzles on the table, and the tennis ball on the table - inside the toy box on the table in the child's room.", |
| |
| } |
|
|
| batch = [sample, sample2] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| forward_output = model(batch) |
| action_loss = forward_output['action_loss'] |
| print(f"Action Loss: {action_loss.item()}") |
|
|
| |
| predict_output = model.predict_action(examples=[sample]) |
| normalized_actions = predict_output['normalized_actions'] |
| print(f"Unnormalized Action: {normalized_actions}") |
|
|
| |
| |
| vla_dataset_cfg = cfg.datasets.vla_data |
| from torch.utils.data import DataLoader |
| from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn |
| cfg.datasets.vla_data.include_state = "False" |
| dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) |
|
|
| train_dataloader = DataLoader( |
| dataset, |
| batch_size=2, |
| num_workers=1, |
| collate_fn=collate_fn, |
| ) |
| |
| for batch in tqdm(train_dataloader, desc="Processing Batches"): |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| model(batch) |
| |
|
|
| action = model.predict_action(examples=batch) |
| print("Finished") |
|
|