Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,202 Bytes
56ef371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor
from torchvision.transforms import ToPILImage
class VisionFeaturesGather:
"""
Collects and manages intermediate features for multi-level visual representation extraction
(used for region feature/ROIAlign task). Each forward pass (per image) builds up a list of features.
"""
def __init__(self) -> None:
self.features_list = []
self.grid_thw = None
self.window_index = None
self.merge_size = None
def reset(self):
"""Clear all states before starting a new feature-gathering process."""
self.features_list.clear()
self.grid_thw = None
self.window_index = None
self.merge_size = None
def set_params(self, grid_thw, window_index, merge_size):
"""Store spatial and merge information for the current image or batch."""
self.grid_thw = grid_thw
self.window_index = window_index
self.merge_size = merge_size
def append(self, element):
"""Append a set of features (typically per layer in encoder)."""
self.features_list.append(element)
def extract_multi_level_features(self):
"""
Assemble all gathered multi-level features into canonical tensor forms.
The goal: for each visual sample, produce a list of region-aligned feature maps
(e.g., multiple stage outputs for downstream region patching/ROIAlign).
Returns:
List of features, where each element is a list [stage1, stage2, ...] for one image.
"""
# Concatenate all feature tensors along hidden dimension: [seq_len, hidden_size * k]
concat_features = torch.cat(self.features_list, dim=1)
merge_unit = self.merge_size * self.merge_size
seq_len = concat_features.shape[0]
# Rearrange into [windows, merge_unit, hidden_dim*layers]
concat_features = concat_features.reshape(seq_len // merge_unit, merge_unit, -1)
reverse_indices = torch.argsort(self.window_index)
concat_features = concat_features[reverse_indices, :, :]
concat_features = concat_features.reshape(seq_len, -1)
# Split features for each image/video by product of grid h and w (per sample)
split_size = (self.grid_thw[:, 1] * self.grid_thw[:, 2]).tolist()
split_features = list(torch.split(concat_features, split_size, dim=0))
assert len(split_features) == self.grid_thw.shape[0]
for i in range(len(split_features)):
# Recover original grid shape and merge windowing into stages, then split
_, grid_h, grid_w = self.grid_thw[i]
merge_h = grid_h // self.merge_size
merge_w = grid_w // self.merge_size
split_features[i] = split_features[i].reshape(merge_h, merge_w, merge_unit, -1)
split_features[i] = split_features[i].reshape(merge_h, merge_w, self.merge_size, self.merge_size, -1)
split_features[i] = split_features[i].permute(0, 2, 1, 3, 4)
split_features[i] = split_features[i].flatten(start_dim=0, end_dim=-2)
# Split [h, w, dim] into k tensors [1, dim/k, h, w] (for compatibility with multi-stage vision encoding)
hidden_dim = split_features[i].shape[-1]
split_dim = hidden_dim // len(self.features_list)
split_features[i] = split_features[i].reshape(grid_h, grid_w, -1)
split_features[i] = [
split_features[i][..., j*split_dim:(j+1)*split_dim].permute(2, 0, 1).unsqueeze(0)
for j in range(len(self.features_list))
]
return split_features
# Global gather object to pass into Qwen2_5_VisionTransformer for monkey-patched feature gathering
GATHER = VisionFeaturesGather()
# --------------------------------- Monkey Patch ---------------------------------------
def custom_forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Custom forward used with monkey patch to support multi-level feature extraction.
Applies patch embedding, window partition, position embedding, and passes through all blocks.
Optionally collects features at each 'fullatt' block for multi-region support.
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
Temporal, height, width of each feature sequence.
Returns:
`torch.Tensor`: Final hidden states after MLP head (merger).
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# FA2 requires that cu_seqlens_q must have dtype int32
# torch.onnx.export requires that cu_seqlens_q must match grid_thw dtype
# See https://github.com/huggingface/transformers/pull/34852 for more info
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
# If monkey-patched feature gather enabled, prepare to collect intermediate features
if hasattr(self, 'vision_features_gather'):
self.vision_features_gather.reset()
self.vision_features_gather.set_params(grid_thw, window_index, self.spatial_merge_size)
# Forward pass through all transformer blocks; collect intermediate features if needed
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings, use_reentrant=False
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
if hasattr(self, 'vision_features_gather'):
# Capture hidden states at all 'full attention' blocks as multi-level features
if layer_num in self.fullatt_block_indexes:
# This property is set by monkey patching
self.vision_features_gather.append(hidden_states.clone())
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
def init_vision_features_gather(self, vision_features_gather):
"""
Helper method for monkey patch to inject a VisionFeaturesGather instance into model.
"""
self.vision_features_gather = vision_features_gather
def replace_qwen_vit_forward():
"""
Monkey-patch Qwen2_5_VisionTransformer to use custom forward with multi-level feature support.
"""
Qwen2_5_VisionTransformerPretrainedModel.forward = custom_forward
Qwen2_5_VisionTransformerPretrainedModel.init_vision_features_gather = init_vision_features_gather
class Qwen2_5_VlVisionTower(nn.Module):
"""
Vision backbone wrapper for Qwen2.5-VL (Vision Transformer).
Handles both standard and region-level (multi-level) encoding with optional monkey patch logic.
"""
def __init__(self, image_tower, args, delay_load=False, min_pixels=56*56, max_pixels=2048*2048):
super().__init__()
self.is_loaded = False
self.image_tower_name = image_tower
# Determine if multi-level region feature is to be enabled (monkey patch required)
self.use_vision_tower_region_feature = getattr(args, 'mm_use_vision_tower_region_feature', False)
if self.use_vision_tower_region_feature:
replace_qwen_vit_forward() # Monkey patch: add multi-level feature extraction logic
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.delay_load = delay_load
print (f"Qwen2_5_VlVisionTower loading_info: delay_load: {delay_load} min_pixels: {min_pixels} max_pixels: {max_pixels}")
# if not delay_load:
# self.load_model()
# else:
# # Defer actual model loading to support (e.g.) model parallel or delayed download scenarios
# self.cfg_only = args.vision_config
self.cfg_only = args.vision_config
self.load_model(model_path=args.name_or_path)
def load_model(self, model_path=None, image_size=336, is_train=True):
"""
Actually load Qwen2.5 Vision Tower backbone and processor.
Sets up the image tower and patch feed pipeline.
"""
self.image_tower = Qwen2_5_VisionTransformerPretrainedModel._from_config(self.cfg_only, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
# print(f'Qwen2_5_VlVisionTower loading_info: {loading_info}')
if model_path is not None:
self.image_processor = Qwen2VLImageProcessor.from_pretrained(model_path, min_pixels=self.min_pixels, max_pixels=self.max_pixels)
else:
self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.image_tower_name, min_pixels=self.min_pixels, max_pixels=self.max_pixels)
if self.use_vision_tower_region_feature:
# Setup gather instance for monkey-patched feature extraction
self.image_tower.init_vision_features_gather(GATHER)
self.is_loaded = True
def convert_image_format(self, image):
"""
Convert raw image tensor to pre-processed model input tensor and grid shape, using appropriate processor.
Handles PIL conversion and applies preprocessor for Qwen2.5-VL.
"""
pil_image = ToPILImage()(image)
inputs = self.image_processor(images=pil_image, videos=None, return_tensors="pt")
return inputs['pixel_values'], inputs['image_grid_thw']
def forward(self, images, image_grid_thws=[]):
"""
Forward pass for a batch (list) of images.
Returns image features, gridTHWs, and optional multi-level features for each input image.
"""
if type(images) is list:
image_features = []
multi_level_features_list = []
output_image_grid_thws = []
for i, image in enumerate(images):
# If no grid provided, convert and infer via processor
if image_grid_thws is None or len(image_grid_thws) == 0:
image, image_grid_thw = self.convert_image_format(image=image)
else:
image_grid_thw = image_grid_thws[i]
image_forward_out = self.image_tower(image.to(device=self.device, dtype=self.dtype), grid_thw=image_grid_thw.to(device=self.device))
image_feature = image_forward_out.unsqueeze(0).to(self.dtype)
image_features.append(image_feature)
output_image_grid_thws.append(image_grid_thw)
# If region feature mode enabled, collect multi-level features for this image
if self.use_vision_tower_region_feature:
multi_level_features_list.append(self.get_multi_level_features()[0])
else:
raise NotImplementedError("Qwen2_5_VlVisionTower only supports list-of-image input")
return image_features, output_image_grid_thws, multi_level_features_list
def get_multi_level_features(self):
"""
Get the current (last-processed) multi-level region features from the VisionFeaturesGather helper.
Used in region-feature/ROIAlign branches.
"""
multi_level_features = self.image_tower.vision_features_gather.extract_multi_level_features()
return multi_level_features
@property
def dummy_feature(self):
"""Returns a zero-vector feature, for use as fallback/null visual token."""
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
"""Report vision tower's expected/active tensor dtype (inferred from real weights)."""
return self.image_tower.dtype
@property
def device(self):
"""Report vision tower's tensor device (cuda/cpu) for autoflow/compatibility."""
return self.image_tower.device
@property
def config(self):
"""Yield config, for both loaded-and-ready and 'config only' modes (delay load etc)."""
if self.is_loaded:
return self.image_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
"""Return backbone output hidden size (for proj or post-processing modules)."""
return self.config.out_hidden_size
@property
def num_patches(self):
"""Return number of vision tokens (patches) in processed image."""
return (self.config.image_size // self.config.patch_size) ** 2
|