File size: 14,391 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | # Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
import torch
import transformers
from typing import Optional, List
import copy
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Dict, Optional, List
from torch.nn.utils.rnn import pad_sequence
from transformers import BatchFeature
from qwen_vl_utils import process_vision_info
from accelerate.logging import get_logger
logger = get_logger(__name__)
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = 151655
VIDEO_TOKEN_INDEX = 151656
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_VIDEO_TOKEN = "<video>"
_ACTION_TOKEN_MIN = 151665 # how can we know this range?
_ACTION_TOKEN_MAX = 153712 # here only for fast_tokenizer, see starVLA/model/modules/vlm/tools/add_qwen_special_tokens/README.md
import torch.nn as nn
class _QWen_VL_Interface(nn.Module):
"""
This exists because of the diversity of VLMs, so we encapsulate the changes here.
Lightweight wrapper around Qwen2.5-VL (Qwen2_5_VLForConditionalGeneration).
Purpose:
- Unify interface with other VLM backends (CausalLM-like usage).
- Centralize preprocessing (tokenization + multimodal packing).
- Provide consistent forward / generate signatures.
Notes:
- Keeps original model behavior; does not modify internal architecture.
- Mixed precision handled via torch.autocast in forward / generate.
- Adaptation layer can be extended for future multi-modal routing if needed.
"""
def __init__(self, config: Optional[dict] = None, **kwargs):
"""
Initialize the Qwen2.5-VL wrapper.
Parameters:
config (dict | Any | None):
Expected to expose a nested attribute/namespace `framework.get("qwenvl", {})`
where:
framework.qwenvl.base_vlm (str): HuggingFace model id or local path.
Optional expected structure (illustrative):
config.framework.get("qwenvl", {}) -> {
"base_vlm": "Qwen/Qwen2.5-VL-3B-Instruct"
}
config.datasets.vla_data.get("CoT_prompt", str) may be used later in build_qwenvl_inputs.
**kwargs:
Ignored currently; placeholder for future extension (e.g., override device_map, dtype).
Side Effects:
- Downloads / loads pretrained Qwen2.5-VL weights (unless cached).
- Instantiates AutoProcessor and enforces left padding (required for some FlashAttention paths).
Attributes Set:
self.model (Qwen2_5_VLForConditionalGeneration)
self.processor (AutoProcessor)
self.config (original config reference)
Notes:
- device_map='cuda' is passed to from_pretrained (single or multi-GPU depending on HF accelerate mapping).
- torch_dtype='auto' lets HF decide best available (prefers bfloat16 on supported hardware).
- tokenizer padding_side forced to 'left' (important for generation + KV caching alignment).
"""
super().__init__()
qwenvl_config = config.framework.get("qwenvl", {})
model_id = qwenvl_config.get("base_vlm", "Qwen/Qwen2.5-VL-3B-Instruct")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "left"
self.model = model
self.processor = processor
self.config = config
self._ACTION_TOKEN_MIN = _ACTION_TOKEN_MIN
self._ACTION_TOKEN_MAX = _ACTION_TOKEN_MAX
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = True,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Forward pass delegating to underlying Qwen2.5-VL backbone.
Args:
input_ids (LongTensor | None): [B, T] token ids (mutually exclusive with inputs_embeds).
attention_mask (Tensor | None): [B, T], 1 = attend, 0 = masked.
pixel_values (FloatTensor | None): Vision batch (model-specific preprocessed shape).
labels (LongTensor | None): [B, T] LM targets; ignored positions = -100 (IGNORE_INDEX).
image_grid_thw (FloatTensor | None): Optional tiling metadata (e.g., [B, 3] for temporal/height/width splits).
inputs_embeds (FloatTensor | None): [B, T, D] alternative embedding input.
past_key_values (List[FloatTensor] | None): Cached KV states for incremental decoding.
use_cache (bool | None): If True, returns updated past_key_values.
output_attentions (bool): Whether to include attention maps.
output_hidden_states (bool): Must be True if downstream modules consume hidden states.
return_dict (bool): Return HF dataclass if True; else tuple.
**kwargs: Extra args forwarded to underlying model.
Returns:
CausalLMOutputWithPast | tuple: HF-standard structure (logits, past_key_values, hidden_states, etc.).
Notes:
- Autocast(bfloat16) used for efficiency.
- padding_side already set to 'left' in tokenizer at init.
- Hidden states required for auxiliary alignment or feature extraction modules.
"""
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
return outputs
def generate(
self,
**kwargs,
):
"""
High-level generation interface (auto-regressive decoding), optionally vision-conditioned.
Args:
**kwargs: fully follow raw model.generate() signature.
Returns:
GenerateOutput | Model-dependent generation return.
"""
with torch.autocast("cuda", dtype=torch.float16):
generation_output = self.model.generate(
**kwargs,
)
return generation_output
def build_qwenvl_inputs(self, images, instructions, solutions=None, **kwargs):
"""
Construct and tokenize multimodal chat-style inputs for Qwen2.5-VL (batched).
Overview:
For each sample i:
- Takes a list of PIL images: images[i] = [img_0, img_1, ...]
- Takes a matching instruction string instructions[i]
- Optionally formats instruction with a chain-of-thought template (CoT_prompt) if present in config.
- Builds a single-turn chat message containing:
[{"role": "user", "content": [
{"type": "image", "image": <PIL.Image>}, ...,
{"type": "text", "text": <final_prompt>}
]}]
- Applies processor.apply_chat_template(..., add_generation_prompt=True)
- Extracts vision inputs via process_vision_info
- Calls processor(...) to produce a BatchFeature with token + vision tensors.
Parameters:
images (List[List[PIL.Image.Image]]):
Length B. Each element is a (possibly empty) list of PIL images associated with that instruction.
Supports multi-image inputs (ordered). For video-as-frames, upstream code should decide packaging.
instructions (List[str]):
Length B textual prompts or task instructions.
**kwargs:
Reserved for future extensions (e.g., system prompts, style controls, additional metadata).
Config Dependencies:
self.config.datasets.vla_data.get("CoT_prompt", str):
If present, each instruction string is injected into the template by replacing "{instruction}".
Returns:
BatchFeature (HF):
Typical keys (moved to self.model.device):
input_ids: LongTensor [B, T]
attention_mask: LongTensor/Bool [B, T]
pixel_values / image_grid / video specifics (model-dependent)
(Possibly) token_type_ids or other processor outputs
The structure aligns with what Qwen2_5_VLForConditionalGeneration.forward expects.
Shapes / Notes:
- Sequence length T varies by number of images (special tokens) + prompt length.
- pixel_values may have internal batching distinct from B if images are flattened; underlying model maps them.
- The association between images and textual placeholders is preserved by processor ordering.
Edge Cases:
- Empty image list per sample is allowed (pure text prompt).
- Mismatched lengths of images and instructions raise AssertionError.
- CoT prompt replacement is naive string replace; ensure template contains "{instruction}" placeholder.
Performance:
- This path aims for faster inference vs. more granular per-turn assembly.
- Minor tokenization differences (e.g., whitespace) can affect highly overfitted benchmarks.
Does Not:
- Perform augmentation.
- Cache processed pixel tensors.
- Handle streaming input.
"""
# Create messages: one message per sample
messages = []
assert len(images) == len(instructions), "Images and instructions must have the same length"
for imgs, instruction in zip(images, instructions):
content = [{"type": "image", "image": img} for img in imgs]
if "CoT_prompt" in self.config.datasets.vla_data: # If using a grounding prompt to task
CoT_prompt = self.config.datasets.vla_data.get("CoT_prompt", "")
prompt = CoT_prompt.replace("{instruction}", instruction)
else:
prompt = instruction
content.append({"type": "text", "text": prompt})
msg = [{"role": "user", "content": content}]
if solutions is not None:
solution = solutions[len(messages)]
msg.append({"role": "assistant", "content": [{"type": "text", "text": solution}]})
messages.append(msg)
# Prepare text prompts using processor
# default process is json --> message --> texts --> input_ids
texts = [self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]
# image_inputs = list of PIL
image_inputs, video_inputs = process_vision_info(messages)
batch_input = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
# if solutions, mask out the non solution tokens in labels --> @JinhuiYE can we mask out system prompt?
if solutions is not None:
action_token_min = _ACTION_TOKEN_MIN # how can we know this range? --> we has other way for this, but is slower see qwenhelix branch
action_token_max = _ACTION_TOKEN_MAX # here only for fast_tokenizer, see starVLA/model/modules/vlm/tools/add_qwen_special_tokens/README.md
labels = batch_input['input_ids'].clone()
# For each sequence in the batch, find the first occurrence of an action token.
for i in range(labels.size(0)):
seq = labels[i]
# Create a mask for tokens within the action token range.
mask_seq = (seq >= action_token_min) & (seq <= action_token_max)
nonzero_indices = torch.nonzero(mask_seq, as_tuple=False)
if nonzero_indices.numel() > 0:
first_action_index = nonzero_indices[0].item()
# Mask out all tokens before the first action token.
seq[:first_action_index] = IGNORE_INDEX
else:
# If no action token is found, mask the entire sequence.
seq[:] = IGNORE_INDEX
RuntimeWarning (f"action token are on in yout tokenizer, plz see starVLA/model/modules/vlm/tools/add_qwen_special_tokens/README.md.")
labels[labels == self.processor.tokenizer.pad_token_id] = -100 ## mask out pad tokens as well
batch_input['labels'] = labels
return batch_input.to(self.model.device)
if __name__ == "__main__":
from omegaconf import OmegaConf
import debugpy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
cfg = OmegaConf.load(args.config_yaml)
model_id = "./playground/Pretrained_models/Qwen2.5-VL-3B-Instruct"
cfg.framework.qwenvl.base_vlm = model_id
model = _QWen_VL_Interface(config=cfg)
pass
|