File size: 7,734 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | # Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
import torch
from typing import Optional, List
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Dict, Optional, List
from typing import List, Union, Dict, Optional
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from accelerate.logging import get_logger
logger = get_logger(__name__)
# IGNORE_INDEX = -100
# IMAGE_TOKEN_INDEX = 151655
# VIDEO_TOKEN_INDEX = 151656
# DEFAULT_IMAGE_TOKEN = "<image>"
# DEFAULT_VIDEO_TOKEN = "<video>"
# [151936, 153984]
import torch.nn as nn
def _construct_prompts(text):
return text
class _Florence_Interface(nn.Module):
"""
This exists because of the diversity of VLMs, so we encapsulate the changes here.
Lightweight wrapper around Qwen3-VL (Qwen3VLForConditionalGeneration).
Purpose:
- Unify interface with other VLM backends (CausalLM-like usage).
- Centralize preprocessing (tokenization + multimodal packing).
- Provide consistent forward / generate signatures.
"""
def __init__(self, config: Optional[dict] = None, **kwargs):
"""
Initialize the VLM wrapper.
Following https://huggingface.co/microsoft/Florence-2-large
"""
super().__init__()
qwenvl_config = config.framework.get("qwenvl", {})
model_id = qwenvl_config.get("base_vlm", "microsoft/Florence-2-large")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation="eager" ) # 强制使用 eager 注意力
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
self.processor._construct_prompts = _construct_prompts
self.config = config
# alin with qwen2.5
self.model.config.hidden_size = self.model.config.projection_dim
# del unused moduals to save memory
if hasattr(self.model, "decoder"):
del self.model.decoder
if hasattr(self.model, "lm_head"):
del self.model.lm_head
def forward(
self,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Forward pass delegating to underlying Qwen2.5-VL backbone.
"""
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs = self.forward_vlm(
**kwargs,
)
return outputs
# ============================= Florence2 encoder =============================
def forward_vlm(
self,
input_ids: torch.LongTensor, # [B, L]
pixel_values: torch.FloatTensor, # [B, C, H, W] --> [B, H, W]
**kwargs
):
"""
# copyright from X-VLA https://github.com/2toinf/X-VLA/blob/main/models/modeling_florence2.py
Encode text + multi-view images via Florence2 encoder.
Returns:
enc_out.hidden_states: [B, T_enc, D]
"""
# get image features
param_dtype = next(self.model.parameters()).dtype
pixel_values = pixel_values.to(self.model.device, dtype=param_dtype)
valid_feats = self.model._encode_image(pixel_values) # [B, N, D]
B_multiview, N, D = valid_feats.shape
# get text embeddings
inputs_embeds = self.model.get_input_embeddings()(input_ids) # [B, L, D]
# # olny support single image from florence, your can modify here for multi-image support by merge each image features
# like pixel_values: B*N_view, C, H, W --> B*N_view, N_token, D -> B, N_view*N_token, D -> image_features
B, L, _ = inputs_embeds.shape
image_features = valid_feats.view(B, -1, D) # [B, N_view*N, D]
# merge image features and text embeddings
merged_embeds, attention_mask = self.model._merge_input_ids_with_image_features(
image_features, # first view: [B, N, D]
inputs_embeds, # [B, L, D]
)
# TODO should return text index and image index for later index masking
enc_out = self.model.language_model.model.encoder(
attention_mask=attention_mask,
inputs_embeds=merged_embeds,
)
enc_out.hidden_states = [enc_out.last_hidden_state]
# last_hidden = qwenvl_outputs.hidden_states[-1] # [B, L, H]
return enc_out
def build_qwenvl_inputs(self, images, instructions, **kwargs):
"""
Build model inputs from raw data (images + instructions).
Follow Oficial Florence 2 format: https://huggingface.co/microsoft/Florence-2-large
"""
# Create messages: one message per sample
assert len(images) == len(instructions), "Images and instructions must have the same length"
assert len(images[0]) == 1, "Florence2 only support batch size 1 for now"
# # # olny support single image from florence, your can modify here for multi-image support by merge each image features
flatten_batch_images = []
for exameple_images in images:
flatten_batch_images.extend(exameple_images)
# images = [image[0] for image in images]
task_prompt = "Locate the objects with category name in the image." #"Locate the objects with category name in the image."
for index in range(len(instructions)):
instruction = instructions[index]
instructions[index] = task_prompt + " " + instruction
# olny support single image for a text input from florence, your can modify here for multi-image support by merge each image features
inputs = self.processor(text=instructions, images=flatten_batch_images, return_tensors="pt", padding=True, truncation=True,)
inputs["labels"] = inputs["input_ids"].clone()
return inputs.to(self.model.device)
if __name__ == "__main__":
from omegaconf import OmegaConf
import debugpy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
cfg = OmegaConf.load(args.config_yaml)
# model_id = "microsoft/Florence-2-large"
model_id = "playground/Pretrained_models/Florence-2-large"
cfg.framework.qwenvl.base_vlm = model_id
qwen_vl = _Florence_Interface(cfg)
qwen_vl.model.eval()
import requests
import torch
from PIL import Image
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
prompt = "<OD>"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
inputs = qwen_vl.build_qwenvl_inputs(images=[[image]], instructions=[prompt])
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs = qwen_vl.forward_vlm(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
)
print(f"forward_vlm last_hidden_state shape: {outputs.last_hidden_state.shape}")
print(f"forward_vlm hidden_states length: {len(outputs.hidden_states)}")
|