| |
| |
| |
| |
| """ |
| Qwen-Adapter Framework |
| A lightweight implementation that Qwen-VL + Adapter Action head to directly predict continuous actions |
| Action head is copyright from VLA-Adapter, |
| """ |
| from typing import List |
| from tqdm import tqdm |
| from typing import List, Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
|
|
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.modules.vlm import get_vlm_model |
| from starVLA.model.modules.action_model.VLA_AdapterHeader import get_action_model, VLA_Adapter_L1RegressionActionHead |
| from starVLA.training.trainer_utils.trainer_tools import resize_images |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
| from starVLA.model.modules.vlm.QWen3 import IMAGE_TOKEN_INDEX, VIDEO_TOKEN_INDEX |
|
|
| def get_image_token_counts(batch_inputs): |
| IMAGE_TOKEN_ID = IMAGE_TOKEN_INDEX |
| |
| |
| |
| num_tokens_per_sample = torch.sum(batch_inputs['input_ids'] == IMAGE_TOKEN_ID, dim=1) |
| |
| last_index_per_sample = (batch_inputs['input_ids'] == IMAGE_TOKEN_ID).int().cumsum(dim=1).argmax(dim=1) |
| |
| first_index_per_sample = (batch_inputs['input_ids'] == IMAGE_TOKEN_ID).int().cumsum(dim=1).argmin(dim=1) |
| |
| return num_tokens_per_sample, first_index_per_sample, last_index_per_sample |
|
|
|
|
| class ProprioProjector(nn.Module): |
| """ |
| Projects proprio state inputs into the LLM's embedding space. |
| """ |
| def __init__(self, llm_dim: int, proprio_dim: int) -> None: |
| super().__init__() |
| self.llm_dim = llm_dim |
| self.proprio_dim = proprio_dim |
|
|
| self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True) |
| self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| self.act_fn1 = nn.GELU() |
|
|
| def forward(self, proprio: torch.Tensor = None) -> torch.Tensor: |
| |
| projected_features = self.fc1(proprio) |
| projected_features = self.act_fn1(projected_features) |
| projected_features = self.fc2(projected_features) |
| return projected_features |
|
|
| |
| @FRAMEWORK_REGISTRY.register("QwenAdapter") |
| class Qwen_Adapter(baseframework): |
| """ |
| Multimodal vision-language-action model. |
| |
| Components: |
| - Qwen2.5 VL interface for fused language/vision token embeddings |
| |
| Focus: Predict future continuous actions conditioned on images + instruction. |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[dict] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| Construct all submodules and cache key configuration values. |
| |
| Args: |
| config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. |
| **kwargs: Reserved for future overrides (unused). |
| """ |
| super().__init__() |
| self.config = config |
| self.phase = self.config.framework.action_model.get("phase", "Training") |
| self.qwen_vl_interface = get_vlm_model(config=self.config) |
| self.config.framework.qwenvl.vl_hidden_dim = self.qwen_vl_interface.model.config.hidden_size |
| self.action_query_num = self.config.framework.action_model.get("action_query_num", 64) |
| self.action_model: VLA_Adapter_L1RegressionActionHead = get_action_model(config=self.config) |
| self.action_query = nn.Parameter(torch.randn(self.action_query_num, self.qwen_vl_interface.model.config.hidden_size)) |
| self.dummy_action_token = "🔍" |
| self.dummy_action_token_id = self.qwen_vl_interface.processor.tokenizer("🔍", add_special_tokens=False)["input_ids"][0] |
| self.dummy_action_prompt = self.dummy_action_token * self.action_query_num |
| self.chunk_len = self.config.framework.action_model.get("num_actions_chunk", None) |
| if self.chunk_len is None: |
| raise ValueError("num_actions_chunk must be specified in action_model config.") |
| if self.config.framework.action_model.get("use_proprio", False): |
| self.proprio_projector = ProprioProjector( |
| llm_dim=self.qwen_vl_interface.model.config.hidden_size, |
| proprio_dim=self.config.framework.action_model.get("state_dim", 14), |
| ) |
| else: |
| self.proprio_projector = None |
| nn.init.normal_(self.action_query, mean=0.0, std=0.02) |
|
|
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs, |
| ) -> Tuple: |
| """ |
| |
| """ |
| batch_images = [example["image"] for example in examples] |
| instructions = [example["lang"] for example in examples] |
| gt_actions = [example["action"] for example in examples] |
| |
| |
| |
| |
| |
| state = [example["state"] for example in examples] if "state" in examples[0] else None |
| |
| |
| |
| |
| |
| |
| prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{self.dummy_action_prompt}<action>." |
| instructions = [instruction + prompt_suffix for instruction in instructions] |
| |
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions |
| ) |
| |
| |
| |
| |
| input_ids = qwen_inputs['input_ids'] |
| action_mask = (input_ids == self.dummy_action_token_id) |
| |
| |
| |
| |
| |
| batch_size = qwen_inputs['input_ids'].shape[0] |
| device = qwen_inputs['input_ids'].device |
| action_positions_tensor = torch.full((batch_size, self.action_query_num), 0, dtype=torch.long, device=device) |
| valid_counts = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
| for b in range(batch_size): |
| act_pos = torch.where(action_mask[b])[0] |
| if len(act_pos) == self.action_query_num: |
| action_positions_tensor[b] = act_pos |
| valid_counts[b] = True |
|
|
| def inject_query_hook(module, inputs, output): |
| """Replace action placeholder embeddings with learnable queries (VECTORIZED).""" |
| query_embed = self.action_query.to(dtype=output.dtype, device=output.device) |
| |
| |
| batch_indices = torch.arange(batch_size, device=output.device).unsqueeze(1).expand(-1, self.action_query_num) |
| |
| |
| valid_batch_indices = batch_indices[valid_counts] |
| valid_action_positions = action_positions_tensor[valid_counts] |
| |
| if len(valid_batch_indices) > 0: |
| output[valid_batch_indices, valid_action_positions, :] = query_embed.unsqueeze(0) |
| |
| return output |
| |
| embedding_layer = self.qwen_vl_interface.model.model.get_input_embeddings() |
| hook_handle = embedding_layer.register_forward_hook(inject_query_hook) |
| try: |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| finally: |
| hook_handle.remove() |
| |
| hidden_states = qwenvl_outputs.hidden_states |
| |
| |
| |
| multi_layer_hidden_states = [] |
| num_images, first_index_per_sample, last_index_per_sample = get_image_token_counts(qwen_inputs) |
| |
| max_patch_len = -999 |
| for b in range(batch_size): |
| sample_patch_len = last_index_per_sample[b] - first_index_per_sample[b] + 1 |
| if sample_patch_len > max_patch_len: |
| max_patch_len = sample_patch_len.item() |
| |
| for layer_hidden in hidden_states[0:]: |
| |
| |
| |
| |
| |
| |
| batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_patch_len) |
| seq_indices = torch.arange(max_patch_len, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
| |
| seq_indices = seq_indices + first_index_per_sample.unsqueeze(1) |
|
|
| |
| seq_indices = torch.clamp(seq_indices, max=last_index_per_sample.unsqueeze(1)) |
|
|
| |
| batch_vision_states = layer_hidden[batch_indices, seq_indices, :] |
|
|
| |
| vision_patch_lengths = last_index_per_sample - first_index_per_sample + 1 |
| padding_mask = torch.arange(max_patch_len, device=device).unsqueeze(0) >= vision_patch_lengths.unsqueeze(1) |
| batch_vision_states = batch_vision_states.masked_fill(padding_mask.unsqueeze(-1), 0.0) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| batch_indices_action = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, self.action_query_num) |
| action_query_states = layer_hidden[batch_indices_action, action_positions_tensor, :] |
| |
| |
| |
| |
| all_hidden_states = torch.cat([ |
| batch_vision_states.unsqueeze(1), |
| action_query_states.unsqueeze(1) |
| ], dim=2) |
| |
| multi_layer_hidden_states.append(all_hidden_states) |
|
|
| multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim=1) |
| state_projected = None |
| if state is not None: |
| state = torch.tensor( |
| np.array(state), device=multi_layer_hidden_states.device, dtype=multi_layer_hidden_states.dtype |
| ) |
| if self.proprio_projector is not None: |
| state_projected = self.proprio_projector(proprio=state.squeeze(1)) |
|
|
| |
| self.action_model = self.action_model.to(device=multi_layer_hidden_states.device, dtype=multi_layer_hidden_states.dtype) |
| predicted_actions = self.action_model.predict_action( |
| multi_layer_hidden_states, |
| vision_hidden_len=max_patch_len, |
| state_projected=state_projected, |
| phase=self.phase, |
| ) |
|
|
| gt_actions = torch.tensor(np.stack(gt_actions)).to( |
| device=predicted_actions.device, |
| dtype=predicted_actions.dtype |
| ) |
|
|
| loss = torch.nn.L1Loss()(predicted_actions, gt_actions) |
|
|
| return {"action_loss": loss} |
|
|
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict] = None, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """ |
| Inference: Predict future continuous actions aligned with the Forward logic (Hook + Multi-layer states). |
| |
| Steps: |
| 1. Resize images to training resolution (if specified) |
| 2. Insert action placeholder tokens into instruction |
| 3. Encode with QwenVL (hidden states retained) with hook to inject action queries |
| 4. Extract multi-layer features at action query positions |
| 5. Predict actions via action model |
| 6. Return normalized action trajectory |
| |
| Returns: |
| dict: |
| normalized_actions (np.ndarray): Shape [B, chunk_len, action_dim], predicted normalized actions. |
| """ |
| batch_images = [to_pil_preserve(example["image"]) for example in examples] |
| instructions = [example["lang"] for example in examples] |
| state = [example["state"] for example in examples] if "state" in examples[0] else None |
| |
| train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) |
| if train_obs_image_size: |
| batch_images = resize_images(batch_images, target_size=train_obs_image_size) |
| |
| |
| |
| |
| prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{self.dummy_action_prompt}<action>." |
| instructions = [instruction + prompt_suffix for instruction in instructions] |
| |
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions |
| ) |
| |
| |
| input_ids = qwen_inputs['input_ids'] |
| action_mask = (input_ids == self.dummy_action_token_id) |
| |
| |
| |
| |
| |
| batch_size = qwen_inputs['input_ids'].shape[0] |
| device = qwen_inputs['input_ids'].device |
| action_positions_tensor = torch.full((batch_size, self.action_query_num), 0, dtype=torch.long, device=device) |
| valid_counts = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
| for b in range(batch_size): |
| act_pos = torch.where(action_mask[b])[0] |
| if len(act_pos) == self.action_query_num: |
| action_positions_tensor[b] = act_pos |
| valid_counts[b] = True |
|
|
| def inject_query_hook(module, inputs, output): |
| """Replace action placeholder embeddings with learnable queries (VECTORIZED).""" |
| query_embed = self.action_query.to(dtype=output.dtype, device=output.device) |
| |
| |
| batch_indices = torch.arange(batch_size, device=output.device).unsqueeze(1).expand(-1, self.action_query_num) |
| |
| |
| valid_batch_indices = batch_indices[valid_counts] |
| valid_action_positions = action_positions_tensor[valid_counts] |
| |
| if len(valid_batch_indices) > 0: |
| output[valid_batch_indices, valid_action_positions, :] = query_embed.unsqueeze(0) |
| |
| return output |
| |
| embedding_layer = self.qwen_vl_interface.model.model.get_input_embeddings() |
| hook_handle = embedding_layer.register_forward_hook(inject_query_hook) |
| try: |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| finally: |
| hook_handle.remove() |
| |
| hidden_states = qwenvl_outputs.hidden_states |
| |
| |
| |
| multi_layer_hidden_states = [] |
| num_images, first_index_per_sample, last_index_per_sample = get_image_token_counts(qwen_inputs) |
| |
| max_patch_len = -999 |
| for b in range(batch_size): |
| sample_patch_len = last_index_per_sample[b] - first_index_per_sample[b] + 1 |
| if sample_patch_len > max_patch_len: |
| max_patch_len = sample_patch_len.item() |
| |
| for layer_hidden in hidden_states[0:]: |
| |
| |
| |
| |
| |
| |
| batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_patch_len) |
| seq_indices = torch.arange(max_patch_len, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
| |
| seq_indices = seq_indices + first_index_per_sample.unsqueeze(1) |
|
|
| |
| seq_indices = torch.clamp(seq_indices, max=last_index_per_sample.unsqueeze(1)) |
|
|
| |
| batch_vision_states = layer_hidden[batch_indices, seq_indices, :] |
|
|
| |
| vision_patch_lengths = last_index_per_sample - first_index_per_sample + 1 |
| padding_mask = torch.arange(max_patch_len, device=device).unsqueeze(0) >= vision_patch_lengths.unsqueeze(1) |
| batch_vision_states = batch_vision_states.masked_fill(padding_mask.unsqueeze(-1), 0.0) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| batch_indices_action = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, self.action_query_num) |
| action_query_states = layer_hidden[batch_indices_action, action_positions_tensor, :] |
| |
| |
| |
| |
| all_hidden_states = torch.cat([ |
| batch_vision_states.unsqueeze(1), |
| action_query_states.unsqueeze(1) |
| ], dim=2) |
| |
| multi_layer_hidden_states.append(all_hidden_states) |
| |
| multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim=1) |
| state_projected = None |
| if state is not None: |
| state = torch.tensor( |
| np.array(state), device=multi_layer_hidden_states.device, dtype=multi_layer_hidden_states.dtype |
| ) |
| if self.proprio_projector is not None: |
| state_projected = self.proprio_projector(proprio=state.squeeze(1)) |
| |
| |
| |
| |
| with torch.autocast("cuda", dtype=torch.float32): |
| self.action_model = self.action_model.to( |
| device=multi_layer_hidden_states.device, |
| dtype=multi_layer_hidden_states.dtype |
| ) |
| predicted_actions = self.action_model.predict_action( |
| multi_layer_hidden_states, |
| vision_hidden_len=max_patch_len, |
| state_projected=state_projected, |
| phase=self.phase, |
| ) |
| |
| normalized_actions = predicted_actions.detach().cpu().numpy() |
| return {"normalized_actions": normalized_actions} |
|
|
|
|
| if __name__ == "__main__": |
| from omegaconf import OmegaConf |
| import debugpy |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_train_adapter.yaml", help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
|
|
| cfg = OmegaConf.load(args.config_yaml) |
| |
| cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Qwen2.5-VL-3B-Instruct" |
| |
| model: Qwen_Adapter = Qwen_Adapter(cfg) |
| print(model) |
|
|
|
|
|
|
| |
| image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) |
| |
| sample = { |
| "action": np.random.uniform(-1, 1, size=(16, 14)).astype(np.float16), |
| "image": [image, image], |
| "lang": "This is a fake for testing.", |
| |
| } |
|
|
| batch = [sample, sample] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| forward_output = model(batch) |
| action_loss = forward_output['action_loss'] |
| print(f"Action Loss: {action_loss.item()}") |
|
|
| |
| predict_output = model.predict_action(examples=[batch[0]]) |
| normalized_actions = predict_output['normalized_actions'] |
| print(f"Unnormalized Action: {normalized_actions}") |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|