Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel | |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor | |
| from torchvision.transforms import ToPILImage | |
| class VisionFeaturesGather: | |
| """ | |
| Collects and manages intermediate features for multi-level visual representation extraction | |
| (used for region feature/ROIAlign task). Each forward pass (per image) builds up a list of features. | |
| """ | |
| def __init__(self) -> None: | |
| self.features_list = [] | |
| self.grid_thw = None | |
| self.window_index = None | |
| self.merge_size = None | |
| def reset(self): | |
| """Clear all states before starting a new feature-gathering process.""" | |
| self.features_list.clear() | |
| self.grid_thw = None | |
| self.window_index = None | |
| self.merge_size = None | |
| def set_params(self, grid_thw, window_index, merge_size): | |
| """Store spatial and merge information for the current image or batch.""" | |
| self.grid_thw = grid_thw | |
| self.window_index = window_index | |
| self.merge_size = merge_size | |
| def append(self, element): | |
| """Append a set of features (typically per layer in encoder).""" | |
| self.features_list.append(element) | |
| def extract_multi_level_features(self): | |
| """ | |
| Assemble all gathered multi-level features into canonical tensor forms. | |
| The goal: for each visual sample, produce a list of region-aligned feature maps | |
| (e.g., multiple stage outputs for downstream region patching/ROIAlign). | |
| Returns: | |
| List of features, where each element is a list [stage1, stage2, ...] for one image. | |
| """ | |
| # Concatenate all feature tensors along hidden dimension: [seq_len, hidden_size * k] | |
| concat_features = torch.cat(self.features_list, dim=1) | |
| merge_unit = self.merge_size * self.merge_size | |
| seq_len = concat_features.shape[0] | |
| # Rearrange into [windows, merge_unit, hidden_dim*layers] | |
| concat_features = concat_features.reshape(seq_len // merge_unit, merge_unit, -1) | |
| reverse_indices = torch.argsort(self.window_index) | |
| concat_features = concat_features[reverse_indices, :, :] | |
| concat_features = concat_features.reshape(seq_len, -1) | |
| # Split features for each image/video by product of grid h and w (per sample) | |
| split_size = (self.grid_thw[:, 1] * self.grid_thw[:, 2]).tolist() | |
| split_features = list(torch.split(concat_features, split_size, dim=0)) | |
| assert len(split_features) == self.grid_thw.shape[0] | |
| for i in range(len(split_features)): | |
| # Recover original grid shape and merge windowing into stages, then split | |
| _, grid_h, grid_w = self.grid_thw[i] | |
| merge_h = grid_h // self.merge_size | |
| merge_w = grid_w // self.merge_size | |
| split_features[i] = split_features[i].reshape(merge_h, merge_w, merge_unit, -1) | |
| split_features[i] = split_features[i].reshape(merge_h, merge_w, self.merge_size, self.merge_size, -1) | |
| split_features[i] = split_features[i].permute(0, 2, 1, 3, 4) | |
| split_features[i] = split_features[i].flatten(start_dim=0, end_dim=-2) | |
| # Split [h, w, dim] into k tensors [1, dim/k, h, w] (for compatibility with multi-stage vision encoding) | |
| hidden_dim = split_features[i].shape[-1] | |
| split_dim = hidden_dim // len(self.features_list) | |
| split_features[i] = split_features[i].reshape(grid_h, grid_w, -1) | |
| split_features[i] = [ | |
| split_features[i][..., j*split_dim:(j+1)*split_dim].permute(2, 0, 1).unsqueeze(0) | |
| for j in range(len(self.features_list)) | |
| ] | |
| return split_features | |
| # Global gather object to pass into Qwen2_5_VisionTransformer for monkey-patched feature gathering | |
| GATHER = VisionFeaturesGather() | |
| # --------------------------------- Monkey Patch --------------------------------------- | |
| def custom_forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Custom forward used with monkey patch to support multi-level feature extraction. | |
| Applies patch embedding, window partition, position embedding, and passes through all blocks. | |
| Optionally collects features at each 'fullatt' block for multi-region support. | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): | |
| The final hidden states of the model. | |
| grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): | |
| Temporal, height, width of each feature sequence. | |
| Returns: | |
| `torch.Tensor`: Final hidden states after MLP head (merger). | |
| """ | |
| hidden_states = self.patch_embed(hidden_states) | |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) | |
| window_index, cu_window_seqlens = self.get_window_index(grid_thw) | |
| cu_window_seqlens = torch.tensor( | |
| cu_window_seqlens, | |
| device=hidden_states.device, | |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, | |
| ) | |
| cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) | |
| seq_len, _ = hidden_states.size() | |
| hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) | |
| hidden_states = hidden_states[window_index, :, :] | |
| hidden_states = hidden_states.reshape(seq_len, -1) | |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) | |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] | |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) | |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) | |
| position_embeddings = (emb.cos(), emb.sin()) | |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( | |
| dim=0, | |
| # FA2 requires that cu_seqlens_q must have dtype int32 | |
| # torch.onnx.export requires that cu_seqlens_q must match grid_thw dtype | |
| # See https://github.com/huggingface/transformers/pull/34852 for more info | |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, | |
| ) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| # If monkey-patched feature gather enabled, prepare to collect intermediate features | |
| if hasattr(self, 'vision_features_gather'): | |
| self.vision_features_gather.reset() | |
| self.vision_features_gather.set_params(grid_thw, window_index, self.spatial_merge_size) | |
| # Forward pass through all transformer blocks; collect intermediate features if needed | |
| for layer_num, blk in enumerate(self.blocks): | |
| if layer_num in self.fullatt_block_indexes: | |
| cu_seqlens_now = cu_seqlens | |
| else: | |
| cu_seqlens_now = cu_window_seqlens | |
| if self.gradient_checkpointing and self.training: | |
| hidden_states = self._gradient_checkpointing_func( | |
| blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings, use_reentrant=False | |
| ) | |
| else: | |
| hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) | |
| if hasattr(self, 'vision_features_gather'): | |
| # Capture hidden states at all 'full attention' blocks as multi-level features | |
| if layer_num in self.fullatt_block_indexes: | |
| # This property is set by monkey patching | |
| self.vision_features_gather.append(hidden_states.clone()) | |
| hidden_states = self.merger(hidden_states) | |
| reverse_indices = torch.argsort(window_index) | |
| hidden_states = hidden_states[reverse_indices, :] | |
| return hidden_states | |
| def init_vision_features_gather(self, vision_features_gather): | |
| """ | |
| Helper method for monkey patch to inject a VisionFeaturesGather instance into model. | |
| """ | |
| self.vision_features_gather = vision_features_gather | |
| def replace_qwen_vit_forward(): | |
| """ | |
| Monkey-patch Qwen2_5_VisionTransformer to use custom forward with multi-level feature support. | |
| """ | |
| Qwen2_5_VisionTransformerPretrainedModel.forward = custom_forward | |
| Qwen2_5_VisionTransformerPretrainedModel.init_vision_features_gather = init_vision_features_gather | |
| class Qwen2_5_VlVisionTower(nn.Module): | |
| """ | |
| Vision backbone wrapper for Qwen2.5-VL (Vision Transformer). | |
| Handles both standard and region-level (multi-level) encoding with optional monkey patch logic. | |
| """ | |
| def __init__(self, image_tower, args, delay_load=False, min_pixels=56*56, max_pixels=2048*2048): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.image_tower_name = image_tower | |
| # Determine if multi-level region feature is to be enabled (monkey patch required) | |
| self.use_vision_tower_region_feature = getattr(args, 'mm_use_vision_tower_region_feature', False) | |
| if self.use_vision_tower_region_feature: | |
| replace_qwen_vit_forward() # Monkey patch: add multi-level feature extraction logic | |
| self.min_pixels = min_pixels | |
| self.max_pixels = max_pixels | |
| self.delay_load = delay_load | |
| print (f"Qwen2_5_VlVisionTower loading_info: delay_load: {delay_load} min_pixels: {min_pixels} max_pixels: {max_pixels}") | |
| # if not delay_load: | |
| # self.load_model() | |
| # else: | |
| # # Defer actual model loading to support (e.g.) model parallel or delayed download scenarios | |
| # self.cfg_only = args.vision_config | |
| self.cfg_only = args.vision_config | |
| self.load_model(model_path=args.name_or_path) | |
| def load_model(self, model_path=None, image_size=336, is_train=True): | |
| """ | |
| Actually load Qwen2.5 Vision Tower backbone and processor. | |
| Sets up the image tower and patch feed pipeline. | |
| """ | |
| self.image_tower = Qwen2_5_VisionTransformerPretrainedModel._from_config(self.cfg_only, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16) | |
| # print(f'Qwen2_5_VlVisionTower loading_info: {loading_info}') | |
| if model_path is not None: | |
| self.image_processor = Qwen2VLImageProcessor.from_pretrained(model_path, min_pixels=self.min_pixels, max_pixels=self.max_pixels) | |
| else: | |
| self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.image_tower_name, min_pixels=self.min_pixels, max_pixels=self.max_pixels) | |
| if self.use_vision_tower_region_feature: | |
| # Setup gather instance for monkey-patched feature extraction | |
| self.image_tower.init_vision_features_gather(GATHER) | |
| self.is_loaded = True | |
| def convert_image_format(self, image): | |
| """ | |
| Convert raw image tensor to pre-processed model input tensor and grid shape, using appropriate processor. | |
| Handles PIL conversion and applies preprocessor for Qwen2.5-VL. | |
| """ | |
| pil_image = ToPILImage()(image) | |
| inputs = self.image_processor(images=pil_image, videos=None, return_tensors="pt") | |
| return inputs['pixel_values'], inputs['image_grid_thw'] | |
| def forward(self, images, image_grid_thws=[]): | |
| """ | |
| Forward pass for a batch (list) of images. | |
| Returns image features, gridTHWs, and optional multi-level features for each input image. | |
| """ | |
| if type(images) is list: | |
| image_features = [] | |
| multi_level_features_list = [] | |
| output_image_grid_thws = [] | |
| for i, image in enumerate(images): | |
| # If no grid provided, convert and infer via processor | |
| if image_grid_thws is None or len(image_grid_thws) == 0: | |
| image, image_grid_thw = self.convert_image_format(image=image) | |
| else: | |
| image_grid_thw = image_grid_thws[i] | |
| image_forward_out = self.image_tower(image.to(device=self.device, dtype=self.dtype), grid_thw=image_grid_thw.to(device=self.device)) | |
| image_feature = image_forward_out.unsqueeze(0).to(self.dtype) | |
| image_features.append(image_feature) | |
| output_image_grid_thws.append(image_grid_thw) | |
| # If region feature mode enabled, collect multi-level features for this image | |
| if self.use_vision_tower_region_feature: | |
| multi_level_features_list.append(self.get_multi_level_features()[0]) | |
| else: | |
| raise NotImplementedError("Qwen2_5_VlVisionTower only supports list-of-image input") | |
| return image_features, output_image_grid_thws, multi_level_features_list | |
| def get_multi_level_features(self): | |
| """ | |
| Get the current (last-processed) multi-level region features from the VisionFeaturesGather helper. | |
| Used in region-feature/ROIAlign branches. | |
| """ | |
| multi_level_features = self.image_tower.vision_features_gather.extract_multi_level_features() | |
| return multi_level_features | |
| def dummy_feature(self): | |
| """Returns a zero-vector feature, for use as fallback/null visual token.""" | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| """Report vision tower's expected/active tensor dtype (inferred from real weights).""" | |
| return self.image_tower.dtype | |
| def device(self): | |
| """Report vision tower's tensor device (cuda/cpu) for autoflow/compatibility.""" | |
| return self.image_tower.device | |
| def config(self): | |
| """Yield config, for both loaded-and-ready and 'config only' modes (delay load etc).""" | |
| if self.is_loaded: | |
| return self.image_tower.config | |
| else: | |
| return self.cfg_only | |
| def hidden_size(self): | |
| """Return backbone output hidden size (for proj or post-processing modules).""" | |
| return self.config.out_hidden_size | |
| def num_patches(self): | |
| """Return number of vision tokens (patches) in processed image.""" | |
| return (self.config.image_size // self.config.patch_size) ** 2 | |