| |
| |
| |
|
|
| import torch |
| from typing import Optional, List |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers import Qwen3VLForConditionalGeneration, AutoProcessor |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from typing import Dict, Optional, List |
| from torch.nn.utils.rnn import pad_sequence |
| from transformers import BatchFeature |
|
|
| from qwen_vl_utils import process_vision_info |
|
|
|
|
| from accelerate.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = 151655 |
| VIDEO_TOKEN_INDEX = 151656 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
| DEFAULT_VIDEO_TOKEN = "<video>" |
|
|
| _ACTION_TOKEN_MIN = 151669 |
| _ACTION_TOKEN_MAX = 153716 |
|
|
|
|
| import torch.nn as nn |
|
|
|
|
| class _QWen3_VL_Interface(nn.Module): |
| """ |
| This exists because of the diversity of VLMs, so we encapsulate the changes here. |
| Lightweight wrapper around Qwen3-VL (Qwen3VLForConditionalGeneration). |
| |
| Purpose: |
| - Unify interface with other VLM backends (CausalLM-like usage). |
| - Centralize preprocessing (tokenization + multimodal packing). |
| - Provide consistent forward / generate signatures. |
| |
| """ |
|
|
| def __init__(self, config: Optional[dict] = None, **kwargs): |
| """ |
| Initialize the Qwen3-VL wrapper. |
| Following https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct |
| |
| """ |
| super().__init__() |
|
|
| qwenvl_config = config.framework.get("qwenvl", {}) |
| model_id = qwenvl_config.get("base_vlm", "Qwen/Qwen3-VL-4B-Instruct") |
|
|
| model = Qwen3VLForConditionalGeneration.from_pretrained( |
| model_id, |
| attn_implementation="flash_attention_2", |
| dtype=torch.bfloat16, |
| ) |
| processor = AutoProcessor.from_pretrained(model_id) |
| processor.tokenizer.padding_side = "left" |
|
|
| self.model = model |
| self.processor = processor |
| self.config = config |
|
|
| |
| self.model.config.hidden_size = self.model.config.text_config.hidden_size |
|
|
| |
| if "-Action" in model_id: |
| self._ACTION_TOKEN_MIN = _ACTION_TOKEN_MIN |
| self._ACTION_TOKEN_MAX = _ACTION_TOKEN_MAX |
|
|
| def forward( |
| self, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| """ |
| Forward pass delegating to underlying Qwen2.5-VL backbone. |
| """ |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| outputs = self.model( |
| **kwargs, |
| ) |
|
|
| return outputs |
|
|
| def generate( |
| self, |
| **kwargs, |
| ): |
| """ |
| High-level generation interface (auto-regressive decoding), optionally vision-conditioned. |
| |
| Args: |
| **kwargs: fully follow raw model.generate() signature. |
| Returns: |
| GenerateOutput | Model-dependent generation return. |
| """ |
| with torch.autocast("cuda", dtype=torch.float16): |
| generation_output = self.model.generate( |
| **kwargs, |
| ) |
| return generation_output |
|
|
| def build_qwenvl_inputs(self, images, instructions, solutions=None, **kwargs): |
| """ |
| Build model inputs from raw data (images + instructions + optional solutions). |
| Follow Oficial Qwen3-VL Instruct format: https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct |
| """ |
|
|
| |
| messages = [] |
| assert len(images) == len(instructions), "Images and instructions must have the same length" |
| for imgs, instruction in zip(images, instructions): |
| content = [{"type": "image", "image": img} for img in imgs] |
|
|
| if "CoT_prompt" in self.config.datasets.vla_data: |
| CoT_prompt = self.config.datasets.vla_data.get("CoT_prompt", "") |
| prompt = CoT_prompt.replace("{instruction}", instruction) |
| else: |
| prompt = instruction |
|
|
| content.append({"type": "text", "text": prompt}) |
| msg = [{"role": "user", "content": content}] |
|
|
| if solutions is not None: |
| solution = solutions[len(messages)] |
| msg.append({"role": "assistant", "content": [{"type": "text", "text": solution}]}) |
| messages.append(msg) |
|
|
| |
|
|
| batch_inputs = self.processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| padding=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt" |
| ) |
|
|
| |
| if solutions is not None: |
| action_token_min = _ACTION_TOKEN_MIN |
| action_token_max = _ACTION_TOKEN_MAX |
| labels = batch_inputs['input_ids'].clone() |
| |
| for i in range(labels.size(0)): |
| seq = labels[i] |
| |
| mask_seq = (seq >= action_token_min) & (seq <= action_token_max) |
| nonzero_indices = torch.nonzero(mask_seq, as_tuple=False) |
| if nonzero_indices.numel() > 0: |
| first_action_index = nonzero_indices[0].item() |
| |
| seq[:first_action_index] = IGNORE_INDEX |
| else: |
| |
| seq[:] = IGNORE_INDEX |
| RuntimeWarning (f"action token are on in yout tokenizer, plz see starVLA/model/modules/vlm/tools/add_qwen_special_tokens/README.md.") |
| |
| labels[labels == self.processor.tokenizer.pad_token_id] = -100 |
| batch_inputs['labels'] = labels |
|
|
| return batch_inputs.to(self.model.device) |
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| from omegaconf import OmegaConf |
| import debugpy |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
|
|
| cfg = OmegaConf.load(args.config_yaml) |
| |
| cfg.framework.qwenvl.base_vlm = "./playground/Pretrained_models/Qwen3-VL-4B-Instruct" |
| qwen_vl = _QWen3_VL_Interface(cfg) |
| pass |
|
|