cross13tasks / code /model /framework /QwenLatent.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by Jinhui YE / HKUST University] in [2025].
"""
Qwen-GROOT Framework
A lightweight implementation that Qwen2.5-vl + Flow-matching head to directly predict continuous actions
Flow-matching header is copyright from GR00T N1.5, but a sample MoE inspired by PI_0
"""
import sys
sys.path.append("/mnt/data/fangyu/code/rewardmodel")
from typing import List
from tqdm import tqdm
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import copy
from starVLA.training.trainer_utils import initialize_overwatch
from deployment.model_server.tools.image_tools import to_pil_preserve
from transformers import AutoImageProcessor, AutoModel
from omegaconf import OmegaConf
logger = initialize_overwatch(__name__)
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
IGNORE_INDEX = -100
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.modules.vlm import get_vlm_model
from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM
from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig
from starVLA.dataloader.gr00t_lerobot.datasets import ACTION_REPRESENTATION_SLICES
from starVLA.training.trainer_utils.trainer_tools import resize_images
from starVLA.model.tools import FRAMEWORK_REGISTRY
####################################################
# ⚠️ Warning: This framework has been restructured and is NOT compatible with checkpoints created before 2025-10-20.
####################################################
@FRAMEWORK_REGISTRY.register("QwenLatent")
class QwenLatent(baseframework):
"""
Multimodal vision-language-action model.
Components:
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise cross DiT diffusion head
Focus: Predict future continuous actions conditioned on images + instruction.
"""
@staticmethod
def _get_last_nonpad_indices(attention_mask: torch.Tensor) -> torch.Tensor:
"""
Return the index of the last non-padding token for each sequence.
Works for both tokenizer.padding_side == "left" and "right".
attention_mask: [B, T] with 1/True for real tokens and 0/False for pads.
"""
if attention_mask is None:
raise ValueError("attention_mask cannot be None")
if attention_mask.dim() != 2:
raise ValueError(f"attention_mask must be 2D [B,T], got shape {tuple(attention_mask.shape)}")
# Find distance-from-end to last 1 by reversing sequence dimension.
# Example:
# - left pad: [0,0,1,1,1] -> flip -> [1,1,1,0,0] -> argmax = 0 -> last = T-1
# - right pad: [1,1,1,0,0] -> flip -> [0,0,1,1,1] -> argmax = 2 -> last = T-1-2 = 2
mask = attention_mask.to(dtype=torch.long)
rev_first_one = torch.flip(mask, dims=[1]).argmax(dim=1)
last_nonpad = mask.size(1) - 1 - rev_first_one
return last_nonpad
#
def __init__(
self,
config: Optional[dict] = None,
**kwargs,
) -> None:
"""
Construct all submodules and cache key configuration values.
Args:
config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.
**kwargs: Reserved for future overrides (unused).
"""
super().__init__()
self.config = config
self.qwen_vl_interface = get_vlm_model(config=self.config)
# dynamic get llm config
num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size
self.llm_hidden_size = llm_hidden_size
self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size
self.config.framework.qwenvl.num_vl_layers = num_vl_layers
action_model_cfg = getattr(self.config.framework, "action_model", None)
if action_model_cfg is not None:
action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True)
print(f"{action_model_kwargs=}")
self.action_model = ActionModelFM(ActionModelConfig(**action_model_kwargs))
else:
self.action_model = ActionModelFM(ActionModelConfig())
ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None)
if ckpt_path:
self.action_model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=True)
print(f"✅ loaded action model from {ckpt_path}")
print(f"action model loss mode: {self.action_model.config.loss_mode}")
# Dataset soft prompt for QwenVL (conditioning on dataset_id)
self.dataset_vocab_size = getattr(self.config.framework.action_model, "dataset_vocab_size", 256)
self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 32)
self.dataset_embed = nn.Embedding(
self.dataset_vocab_size,
llm_hidden_size * self.num_data_tokens,
)
# Learnable query token appended to VLM inputs (for action embedding)
self.query_token = nn.Parameter(torch.randn(1, 1, llm_hidden_size))
# 使用 MLP 投影器,增加表达能力(2048 → 2048 → 1024)
action_hidden_size = self.action_model.config.hidden_size
self.action_embed_projector = nn.Sequential(
nn.Linear(llm_hidden_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, action_hidden_size),
)
self.chunk_size = self.config.datasets.vla_data.chunk_size
self.num_history_steps = 0
self.use_state = self.action_model.use_state
# Multi-t sampling trick: sample K different t per example for the FM head
# to enlarge effective batch size without re-running the expensive VLM.
self.num_t_samples = getattr(self.config.framework.action_model, "num_t_samples", 1)
print(f"num_t_samples: {self.num_t_samples}")
def _maybe_log_align_stats(
self,
predicted_action_embeddings: torch.Tensor,
gt_action_embeddings: torch.Tensor,
) -> None:
if getattr(self, "_align_stats_logged", False):
return
if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() != 0:
return
with torch.no_grad():
pred = predicted_action_embeddings.float()
gt = gt_action_embeddings.float()
pred_norm = pred.norm(dim=-1).mean().item()
gt_norm = gt.norm(dim=-1).mean().item()
logger.info(
"Align stats: pred(mean=%.4f,std=%.4f,avg_norm=%.4f) "
"gt(mean=%.4f,std=%.4f,avg_norm=%.4f)",
pred.mean().item(),
pred.std().item(),
pred_norm,
gt.mean().item(),
gt.std().item(),
gt_norm,
)
self._align_stats_logged = True
def forward(
self,
examples: List[dict] = None,
**kwargs,
):
"""
Args:
examples: List[dict], each dict requires:
- image: List[PIL.Image] (multi-view)
- lang: str instruction
- action: np.ndarray or list shaped [T, action_dim]
Returns:
dict:
action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
"""
batch_images = [example["image"] for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
actions = [example["action"] for example in examples] # label [B, L, action_dim]
states = [example["state"] for example in examples] if self.use_state else None # [B, L, state_dim] when state_use_action_chunk
dataset_ids = [example.get("dataset_id", 0) for example in examples]
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(
images=batch_images,
instructions=instructions,
chunk_size=self.chunk_size,
)
# Prepend dataset soft prompt tokens to VLM inputs
if "input_ids" in qwen_inputs:
dataset_ids_tensor = torch.tensor(
dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long
)
ds_embeds = self.dataset_embed(dataset_ids_tensor).view(
len(dataset_ids), self.num_data_tokens, self.llm_hidden_size
)
token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"])
query_embeds = self.query_token.expand(len(dataset_ids), -1, -1)
qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds, query_embeds), dim=1)
qwen_inputs.pop("input_ids")
if "attention_mask" in qwen_inputs:
prefix_mask = torch.ones(
(qwen_inputs["attention_mask"].shape[0], self.num_data_tokens),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
query_mask = torch.ones(
(qwen_inputs["attention_mask"].shape[0], 1),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
qwen_inputs["attention_mask"] = torch.cat(
(prefix_mask, qwen_inputs["attention_mask"], query_mask), dim=1
)
if "position_ids" in qwen_inputs:
prefix_pos = torch.arange(
self.num_data_tokens,
device=qwen_inputs["position_ids"].device,
dtype=qwen_inputs["position_ids"].dtype,
).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1)
query_pos = (
torch.full(
(qwen_inputs["position_ids"].shape[0], 1),
qwen_inputs["position_ids"].shape[1] + self.num_data_tokens,
device=qwen_inputs["position_ids"].device,
dtype=qwen_inputs["position_ids"].dtype,
)
)
qwen_inputs["position_ids"] = torch.cat(
(prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens, query_pos), dim=1
)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
last_hidden_states = qwenvl_outputs.hidden_states[-1]
if "attention_mask" in qwen_inputs:
# 找到非 padding 的最后一个 token index(兼容 left/right padding)
last_token_indices = self._get_last_nonpad_indices(qwen_inputs["attention_mask"])
batch_indices = torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device)
action_token_hidden = last_hidden_states[batch_indices, last_token_indices]
else:
action_token_hidden = last_hidden_states[:, -1, :]
predicted_action_embeddings = self.action_embed_projector(action_token_hidden).float() # [B, Action_Hidden]
predicted_action_embeddings = F.normalize(predicted_action_embeddings, p=2, dim=-1)
# Step 2: Action Expert Forward and Loss
loss_mode = getattr(self.action_model.config, "loss_mode", "full")
with torch.autocast("cuda", dtype=torch.float32):
actions_target = torch.as_tensor(np.array(actions), device=last_hidden_states.device, dtype=torch.float32)
# Multi-t sampling trick: expand the FM-head batch K times by sampling K
# independent t values per example. The expensive VLM embedding is computed
# only once and then tiled, so the extra cost is only in the lightweight FM head.
K = self.num_t_samples
def tile_batch(x: torch.Tensor, k: int) -> torch.Tensor:
"""Repeat tensor k times along dim-0, keeping all other dims intact."""
return x.repeat(k, *([1] * (x.dim() - 1)))
if K > 1:
actions_target_fm = tile_batch(actions_target, K) # [K*B, T, D]
predicted_embeddings_fm = tile_batch(predicted_action_embeddings, K)
else:
actions_target_fm = actions_target
predicted_embeddings_fm = predicted_action_embeddings
B_fm = actions_target_fm.shape[0]
t = self.action_model._sample_fm_time(B_fm, device=actions_target.device, dtype=actions_target.dtype)
noise = torch.randn_like(actions_target_fm)
if loss_mode == "predict_only":
# Only predict_loss: skip align_loss and recon_loss
align_loss = None
recon_loss = None
predict_loss = self.action_model.recon_loss_from_embedding(
actions=actions_target_fm,
action_embedding=predicted_embeddings_fm,
t=t,
noise=noise,
)
else:
# Full mode: align + recon + predict
# state chunk 与 action chunk 对齐(同长度)
states_target = None
if self.use_state:
states_target = torch.as_tensor(np.array(states), device=last_hidden_states.device, dtype=torch.float32)
gt_action_embeddings = self.action_model.encode_actions(
actions=actions_target,
dataset_ids=dataset_ids,
state=states_target,
)
self._maybe_log_align_stats(predicted_action_embeddings, gt_action_embeddings)
# align_loss only needs the original (non-expanded) embeddings
align_loss = F.l1_loss(predicted_action_embeddings, gt_action_embeddings.float().detach())
gt_embeddings_fm = tile_batch(gt_action_embeddings, K) if K > 1 else gt_action_embeddings
recon_loss = self.action_model.recon_loss_from_embedding(
actions=actions_target_fm,
action_embedding=gt_embeddings_fm,
t=t,
noise=noise,
)
predict_loss = self.action_model.recon_loss_from_embedding(
actions=actions_target_fm,
action_embedding=predicted_embeddings_fm,
t=t,
noise=noise,
)
return {
"align_loss": align_loss,
"recon_loss": recon_loss,
"predict_loss": predict_loss,
}
@torch.inference_mode()
def predict_action( # TODO align predict_action with forward, make api more flexible
self,
examples: List[dict] = None,
embodiment_tag: Optional[str] = None,
**kwargs: str,
) -> np.ndarray:
"""
推理:单次前向直接回归未来动作(无扩散采样)。
Steps:
1. Resize images to training resolution (if specified)
2. Encode with QwenVL (hidden states retained)
Args:
examples: List of example dicts containing image, lang, etc.
embodiment_tag: Optional embodiment tag (e.g., "franka", "oxe_rt1", "oxe_bridge").
If provided, will extract valid action dimensions based on ACTION_REPRESENTATION_SLICES.
If None, returns full unified action representation.
Returns:
dict:
normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
If embodiment_tag is provided, shape is [B, T, valid_dim] where
valid_dim is determined by ACTION_REPRESENTATION_SLICES[embodiment_tag].
"""
from deployment.model_server.tools.image_tools import to_pil_preserve
batch_images = [to_pil_preserve(example["image"]) for example in examples] # [B,[PLT]]
instructions = [example["lang"] for example in examples] # [B, str]
dataset_ids = [example.get("dataset_id") for example in examples]
train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
if train_obs_image_size:
batch_images = resize_images(batch_images, target_size=train_obs_image_size)
# Step 1: QWenVL input format
qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(
images=batch_images,
instructions=instructions,
)
# Prepend dataset soft prompt tokens to VLM inputs
if "input_ids" in qwen_inputs:
dataset_ids_tensor = torch.tensor(
dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long
)
ds_embeds = self.dataset_embed(dataset_ids_tensor).view(
len(dataset_ids), self.num_data_tokens, self.llm_hidden_size
)
token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"])
query_embeds = self.query_token.expand(len(dataset_ids), -1, -1)
qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds, query_embeds), dim=1)
qwen_inputs.pop("input_ids")
if "attention_mask" in qwen_inputs:
prefix_mask = torch.ones(
(qwen_inputs["attention_mask"].shape[0], self.num_data_tokens),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
query_mask = torch.ones(
(qwen_inputs["attention_mask"].shape[0], 1),
device=qwen_inputs["attention_mask"].device,
dtype=qwen_inputs["attention_mask"].dtype,
)
qwen_inputs["attention_mask"] = torch.cat(
(prefix_mask, qwen_inputs["attention_mask"], query_mask), dim=1
)
if "position_ids" in qwen_inputs:
prefix_pos = torch.arange(
self.num_data_tokens,
device=qwen_inputs["position_ids"].device,
dtype=qwen_inputs["position_ids"].dtype,
).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1)
query_pos = (
torch.full(
(qwen_inputs["position_ids"].shape[0], 1),
qwen_inputs["position_ids"].shape[1] + self.num_data_tokens,
device=qwen_inputs["position_ids"].device,
dtype=qwen_inputs["position_ids"].dtype,
)
)
qwen_inputs["position_ids"] = torch.cat(
(prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens, query_pos), dim=1
)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwenvl_outputs = self.qwen_vl_interface(
**qwen_inputs,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
last_hidden_states = qwenvl_outputs.hidden_states[-1]
if "attention_mask" in qwen_inputs:
# 找到非 padding 的最后一个 token index(兼容 left/right padding)
last_token_indices = self._get_last_nonpad_indices(qwen_inputs["attention_mask"])
batch_indices = torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device)
action_token_hidden = last_hidden_states[batch_indices, last_token_indices]
else:
action_token_hidden = last_hidden_states[:, -1, :]
predicted_action_embeddings = self.action_embed_projector(action_token_hidden).float() # [B, Action_Hidden]
# L2 normalize before sending to decoder (consistent with training)
predicted_action_embeddings = F.normalize(predicted_action_embeddings, p=2, dim=-1)
# Step 4: 选择 decoder 进行推理
with torch.autocast("cuda", dtype=torch.float32):
pred_actions = self.action_model.decode_actions(
predicted_action_embeddings,
chunk_size=self.chunk_size
)
normalized_actions = pred_actions.detach().cpu().numpy()
# 如果提供了 embodiment_tag,根据 tag 提取有效的动作维度
if embodiment_tag is not None:
if embodiment_tag not in ACTION_REPRESENTATION_SLICES:
raise ValueError(
f"Unknown embodiment tag '{embodiment_tag}'. "
f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES.keys())}"
)
# 获取对应的 slice
target_slice = ACTION_REPRESENTATION_SLICES[embodiment_tag]
# 从统一表示中提取对应的维度
normalized_actions = normalized_actions[..., target_slice]
return {"normalized_actions": normalized_actions}
if __name__ == "__main__":
from omegaconf import OmegaConf
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str,
default="/fsx/home/yfang/projects/LearnLatent/starVLA/config/training/starvla_train_qwenlatent_oxe.yaml",
help="Path to YAML config")
args, clipargs = parser.parse_known_args()
cfg = OmegaConf.load(args.config_yaml)
# try get model
model = QwenLatent(cfg)
# ckpt="/mnt/petrelfs/yejinhui/Projects/llavavla/results/Checkpoints/1011_qwenpi/checkpoints/need_steps_10000_pytorch_model.pt"
# model = Qwen_PI.from_pretrained(ckpt)
print(model)
# fake sample
image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
# Create a sample
sample = {
"action": np.random.uniform(-1, 1, size=(15, 7)).astype(np.float16), # action_chunk, action_dim
"image": [image], # two views
"image_past_half": [image],
"image_past_one": [image],
"image_future": [image],
"lang": "put the ball on the table",
"state": np.random.uniform(-1, 1, size=(1, 8)).astype(np.float16), # chunk, state_dim
}
batch = [sample, sample] # batch size 2
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
model = model.to(device)
forward_output = model(batch)
align_loss = forward_output['align_loss']
recon_loss = forward_output['recon_loss']
print(f"Align Loss: {align_loss.item()}")
print(f"Recon Loss: {recon_loss.item()}")
# # test predict action
# predict_output = model.predict_action([sample])
# normalized_actions = predict_output['normalized_actions']
# print(f"Unnormalized Action: {normalized_actions}")
# # Advance: try forward model with dataloader
# # can be fake sample, but here get from dataloader for simpler
# from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn
# vla_dataset_cfg = cfg.datasets.vla_data
# dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
# from torch.utils.data import DataLoader
# train_dataloader = DataLoader(
# dataset,
# batch_size=2,
# num_workers=1, # For Debug
# collate_fn=collate_fn,
# )
# #
# for batch in tqdm(train_dataloader, desc="Processing Batches"):
# batch
# break
# # try get model
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)
# model(batch)
# action = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]])
# # fake state
# for ba in batch:
# ba["state"] = ba["action"][0][None]
# model(batch)
# action = model.predict_action(batch_images=[batch[0]["image"]], instructions=[batch[0]["lang"]], state=[batch[0]["state"]])