# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by Jinhui YE / HKUST University] in [2025]. """ Qwen-GROOT Framework A lightweight implementation that Qwen2.5-vl + Flow-matching head to directly predict continuous actions Flow-matching header is copyright from GR00T N1.5, but a sample MoE inspired by PI_0 """ import sys sys.path.append("/mnt/data/fangyu/code/rewardmodel") from typing import List from tqdm import tqdm from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image import copy from starVLA.training.trainer_utils import initialize_overwatch from deployment.model_server.tools.image_tools import to_pil_preserve from transformers import AutoImageProcessor, AutoModel from omegaconf import OmegaConf logger = initialize_overwatch(__name__) # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 from starVLA.model.framework.base_framework import baseframework from starVLA.model.modules.vlm import get_vlm_model from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig from starVLA.dataloader.gr00t_lerobot.datasets import ACTION_REPRESENTATION_SLICES from starVLA.training.trainer_utils.trainer_tools import resize_images from starVLA.model.tools import FRAMEWORK_REGISTRY #################################################### # ⚠️ Warning: This framework has been restructured and is NOT compatible with checkpoints created before 2025-10-20. #################################################### @FRAMEWORK_REGISTRY.register("QwenLatent") class QwenLatent(baseframework): """ Multimodal vision-language-action model. Components: - Qwen2.5 VL interface for fused language/vision token embeddings - Layer-wise cross DiT diffusion head Focus: Predict future continuous actions conditioned on images + instruction. """ @staticmethod def _get_last_nonpad_indices(attention_mask: torch.Tensor) -> torch.Tensor: """ Return the index of the last non-padding token for each sequence. Works for both tokenizer.padding_side == "left" and "right". attention_mask: [B, T] with 1/True for real tokens and 0/False for pads. """ if attention_mask is None: raise ValueError("attention_mask cannot be None") if attention_mask.dim() != 2: raise ValueError(f"attention_mask must be 2D [B,T], got shape {tuple(attention_mask.shape)}") # Find distance-from-end to last 1 by reversing sequence dimension. # Example: # - left pad: [0,0,1,1,1] -> flip -> [1,1,1,0,0] -> argmax = 0 -> last = T-1 # - right pad: [1,1,1,0,0] -> flip -> [0,0,1,1,1] -> argmax = 2 -> last = T-1-2 = 2 mask = attention_mask.to(dtype=torch.long) rev_first_one = torch.flip(mask, dims=[1]).argmax(dim=1) last_nonpad = mask.size(1) - 1 - rev_first_one return last_nonpad # def __init__( self, config: Optional[dict] = None, **kwargs, ) -> None: """ Construct all submodules and cache key configuration values. Args: config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. **kwargs: Reserved for future overrides (unused). """ super().__init__() self.config = config self.qwen_vl_interface = get_vlm_model(config=self.config) # dynamic get llm config num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size self.llm_hidden_size = llm_hidden_size self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size self.config.framework.qwenvl.num_vl_layers = num_vl_layers action_model_cfg = getattr(self.config.framework, "action_model", None) if action_model_cfg is not None: action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True) print(f"{action_model_kwargs=}") self.action_model = ActionModelFM(ActionModelConfig(**action_model_kwargs)) else: self.action_model = ActionModelFM(ActionModelConfig()) ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None) if ckpt_path: self.action_model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=True) print(f"✅ loaded action model from {ckpt_path}") print(f"action model loss mode: {self.action_model.config.loss_mode}") # Dataset soft prompt for QwenVL (conditioning on dataset_id) self.dataset_vocab_size = getattr(self.config.framework.action_model, "dataset_vocab_size", 256) self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 32) self.dataset_embed = nn.Embedding( self.dataset_vocab_size, llm_hidden_size * self.num_data_tokens, ) # Learnable query token appended to VLM inputs (for action embedding) self.query_token = nn.Parameter(torch.randn(1, 1, llm_hidden_size)) # 使用 MLP 投影器,增加表达能力(2048 → 2048 → 1024) action_hidden_size = self.action_model.config.hidden_size self.action_embed_projector = nn.Sequential( nn.Linear(llm_hidden_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, action_hidden_size), ) self.chunk_size = self.config.datasets.vla_data.chunk_size self.num_history_steps = 0 self.use_state = self.action_model.use_state def _maybe_log_align_stats( self, predicted_action_embeddings: torch.Tensor, gt_action_embeddings: torch.Tensor, ) -> None: if getattr(self, "_align_stats_logged", False): return if torch.distributed.is_available() and torch.distributed.is_initialized(): if torch.distributed.get_rank() != 0: return with torch.no_grad(): pred = predicted_action_embeddings.float() gt = gt_action_embeddings.float() pred_norm = pred.norm(dim=-1).mean().item() gt_norm = gt.norm(dim=-1).mean().item() logger.info( "Align stats: pred(mean=%.4f,std=%.4f,avg_norm=%.4f) " "gt(mean=%.4f,std=%.4f,avg_norm=%.4f)", pred.mean().item(), pred.std().item(), pred_norm, gt.mean().item(), gt.std().item(), gt_norm, ) self._align_stats_logged = True def forward( self, examples: List[dict] = None, **kwargs, ): """ Args: examples: List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim] Returns: dict: action_loss (torch.Tensor): Scalar diffusion noise prediction loss. """ batch_images = [example["image"] for example in examples] # [B,[PLT]] instructions = [example["lang"] for example in examples] # [B, str] actions = [example["action"] for example in examples] # label [B, L, action_dim] states = [example["state"] for example in examples] if self.use_state else None # [B, L, state_dim] when state_use_action_chunk dataset_ids = [example.get("dataset_id", 0) for example in examples] # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( images=batch_images, instructions=instructions, chunk_size=self.chunk_size, ) # Prepend dataset soft prompt tokens to VLM inputs if "input_ids" in qwen_inputs: dataset_ids_tensor = torch.tensor( dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long ) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( len(dataset_ids), self.num_data_tokens, self.llm_hidden_size ) token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) query_embeds = self.query_token.expand(len(dataset_ids), -1, -1) qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds, query_embeds), dim=1) qwen_inputs.pop("input_ids") if "attention_mask" in qwen_inputs: prefix_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) query_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], 1), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) qwen_inputs["attention_mask"] = torch.cat( (prefix_mask, qwen_inputs["attention_mask"], query_mask), dim=1 ) if "position_ids" in qwen_inputs: prefix_pos = torch.arange( self.num_data_tokens, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) query_pos = ( torch.full( (qwen_inputs["position_ids"].shape[0], 1), qwen_inputs["position_ids"].shape[1] + self.num_data_tokens, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ) ) qwen_inputs["position_ids"] = torch.cat( (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens, query_pos), dim=1 ) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, ) last_hidden_states = qwenvl_outputs.hidden_states[-1] if "attention_mask" in qwen_inputs: # 找到非 padding 的最后一个 token index(兼容 left/right padding) last_token_indices = self._get_last_nonpad_indices(qwen_inputs["attention_mask"]) batch_indices = torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device) action_token_hidden = last_hidden_states[batch_indices, last_token_indices] else: action_token_hidden = last_hidden_states[:, -1, :] predicted_action_embeddings = self.action_embed_projector(action_token_hidden).float() # [B, Action_Hidden] predicted_action_embeddings = F.normalize(predicted_action_embeddings, p=2, dim=-1) # Step 2: Action Expert Forward and Loss loss_mode = getattr(self.action_model.config, "loss_mode", "full") with torch.autocast("cuda", dtype=torch.float32): actions_target = torch.as_tensor(np.array(actions), device=last_hidden_states.device, dtype=torch.float32) B = actions_target.shape[0] t = self.action_model._sample_fm_time(B, device=actions_target.device, dtype=actions_target.dtype) noise = torch.randn_like(actions_target) if loss_mode == "predict_only": # Only predict_loss: skip align_loss and recon_loss align_loss = None recon_loss = None predict_loss = self.action_model.recon_loss_from_embedding( actions=actions_target, action_embedding=predicted_action_embeddings, t=t, noise=noise, ) else: # Full mode: align + recon + predict # state chunk 与 action chunk 对齐(同长度) states_target = None if self.use_state: states_target = torch.as_tensor(np.array(states), device=last_hidden_states.device, dtype=torch.float32) gt_action_embeddings = self.action_model.encode_actions( actions=actions_target, dataset_ids=dataset_ids, state=states_target, ) self._maybe_log_align_stats(predicted_action_embeddings, gt_action_embeddings) align_loss = F.l1_loss(predicted_action_embeddings, gt_action_embeddings.float().detach()) recon_loss = self.action_model.recon_loss_from_embedding( actions=actions_target, action_embedding=gt_action_embeddings, t=t, noise=noise, ) predict_loss = self.action_model.recon_loss_from_embedding( actions=actions_target, action_embedding=predicted_action_embeddings, t=t, noise=noise, ) return { "align_loss": align_loss, "recon_loss": recon_loss, "predict_loss": predict_loss, } @torch.inference_mode() def predict_action( # TODO align predict_action with forward, make api more flexible self, examples: List[dict] = None, embodiment_tag: Optional[str] = None, **kwargs: str, ) -> np.ndarray: """ 推理:单次前向直接回归未来动作(无扩散采样)。 Steps: 1. Resize images to training resolution (if specified) 2. Encode with QwenVL (hidden states retained) Args: examples: List of example dicts containing image, lang, etc. embodiment_tag: Optional embodiment tag (e.g., "franka", "oxe_rt1", "oxe_bridge"). If provided, will extract valid action dimensions based on ACTION_REPRESENTATION_SLICES. If None, returns full unified action representation. Returns: dict: normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. If embodiment_tag is provided, shape is [B, T, valid_dim] where valid_dim is determined by ACTION_REPRESENTATION_SLICES[embodiment_tag]. """ from deployment.model_server.tools.image_tools import to_pil_preserve batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]] instructions = [example["lang"] for example in examples] # [B, str] dataset_ids = [example.get("dataset_id") for example in examples] train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) if train_obs_image_size: batch_images = resize_images(batch_images, target_size=train_obs_image_size) # Step 1: QWenVL input format qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( images=batch_images, instructions=instructions, ) # Prepend dataset soft prompt tokens to VLM inputs if "input_ids" in qwen_inputs: dataset_ids_tensor = torch.tensor( dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long ) ds_embeds = self.dataset_embed(dataset_ids_tensor).view( len(dataset_ids), self.num_data_tokens, self.llm_hidden_size ) token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) query_embeds = self.query_token.expand(len(dataset_ids), -1, -1) qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds, query_embeds), dim=1) qwen_inputs.pop("input_ids") if "attention_mask" in qwen_inputs: prefix_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) query_mask = torch.ones( (qwen_inputs["attention_mask"].shape[0], 1), device=qwen_inputs["attention_mask"].device, dtype=qwen_inputs["attention_mask"].dtype, ) qwen_inputs["attention_mask"] = torch.cat( (prefix_mask, qwen_inputs["attention_mask"], query_mask), dim=1 ) if "position_ids" in qwen_inputs: prefix_pos = torch.arange( self.num_data_tokens, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) query_pos = ( torch.full( (qwen_inputs["position_ids"].shape[0], 1), qwen_inputs["position_ids"].shape[1] + self.num_data_tokens, device=qwen_inputs["position_ids"].device, dtype=qwen_inputs["position_ids"].dtype, ) ) qwen_inputs["position_ids"] = torch.cat( (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens, query_pos), dim=1 ) with torch.autocast("cuda", dtype=torch.bfloat16): qwenvl_outputs = self.qwen_vl_interface( **qwen_inputs, output_attentions=False, output_hidden_states=True, return_dict=True, ) last_hidden_states = qwenvl_outputs.hidden_states[-1] if "attention_mask" in qwen_inputs: # 找到非 padding 的最后一个 token index(兼容 left/right padding) last_token_indices = self._get_last_nonpad_indices(qwen_inputs["attention_mask"]) batch_indices = torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device) action_token_hidden = last_hidden_states[batch_indices, last_token_indices] else: action_token_hidden = last_hidden_states[:, -1, :] predicted_action_embeddings = self.action_embed_projector(action_token_hidden).float() # [B, Action_Hidden] # L2 normalize before sending to decoder (consistent with training) predicted_action_embeddings = F.normalize(predicted_action_embeddings, p=2, dim=-1) # Step 4: 选择 decoder 进行推理 with torch.autocast("cuda", dtype=torch.float32): pred_actions = self.action_model.decode_actions( predicted_action_embeddings, chunk_size=self.chunk_size ) normalized_actions = pred_actions.detach().cpu().numpy() # 如果提供了 embodiment_tag,根据 tag 提取有效的动作维度 if embodiment_tag is not None: if embodiment_tag not in ACTION_REPRESENTATION_SLICES: raise ValueError( f"Unknown embodiment tag '{embodiment_tag}'. " f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES.keys())}" ) # 获取对应的 slice target_slice = ACTION_REPRESENTATION_SLICES[embodiment_tag] # 从统一表示中提取对应的维度 normalized_actions = normalized_actions[..., target_slice] return {"normalized_actions": normalized_actions} if __name__ == "__main__": from omegaconf import OmegaConf import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="/fsx/home/yfang/projects/LearnLatent/starVLA/config/training/starvla_train_qwenlatent_oxe.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() cfg = OmegaConf.load(args.config_yaml) # try get model model = QwenLatent(cfg) # ckpt="/mnt/petrelfs/yejinhui/Projects/llavavla/results/Checkpoints/1011_qwenpi/checkpoints/need_steps_10000_pytorch_model.pt" # model = Qwen_PI.from_pretrained(ckpt) print(model) # fake sample image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) # Create a sample sample = { "action": np.random.uniform(-1, 1, size=(15, 7)).astype(np.float16), # action_chunk, action_dim "image": [image], # two views "image_past_half": [image], "image_past_one": [image], "image_future": [image], "lang": "put the ball on the table", "state": np.random.uniform(-1, 1, size=(1, 8)).astype(np.float16), # chunk, state_dim } batch = [sample, sample] # batch size 2 device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") model = model.to(device) forward_output = model(batch) align_loss = forward_output['align_loss'] recon_loss = forward_output['recon_loss'] print(f"Align Loss: {align_loss.item()}") print(f"Recon Loss: {recon_loss.item()}") # # test predict action # predict_output = model.predict_action([sample]) # normalized_actions = predict_output['normalized_actions'] # print(f"Unnormalized Action: {normalized_actions}") # # Advance: try forward model with dataloader # # can be fake sample, but here get from dataloader for simpler # from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn # vla_dataset_cfg = cfg.datasets.vla_data # dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) # from torch.utils.data import DataLoader # train_dataloader = DataLoader( # dataset, # batch_size=2, # num_workers=1, # For Debug # collate_fn=collate_fn, # ) # # # for batch in tqdm(train_dataloader, desc="Processing Batches"): # batch # break # # try get model # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = model.to(device) # model(batch) # action = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]]) # # fake state # for ba in batch: # ba["state"] = ba["action"][0][None] # model(batch) # action = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]], state=[batch[0]["state"]])