File size: 14,202 Bytes
56ef371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import torch
import torch.nn as nn
import torch.nn.functional as F

from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor
from torchvision.transforms import ToPILImage

class VisionFeaturesGather:
    """
    Collects and manages intermediate features for multi-level visual representation extraction
    (used for region feature/ROIAlign task). Each forward pass (per image) builds up a list of features.
    """
    def __init__(self) -> None:
        self.features_list = []
        self.grid_thw = None
        self.window_index = None
        self.merge_size = None
    
    def reset(self):
        """Clear all states before starting a new feature-gathering process."""
        self.features_list.clear()
        self.grid_thw = None
        self.window_index = None
        self.merge_size = None
    
    def set_params(self, grid_thw, window_index, merge_size):
        """Store spatial and merge information for the current image or batch."""
        self.grid_thw = grid_thw
        self.window_index = window_index
        self.merge_size = merge_size

    def append(self, element):
        """Append a set of features (typically per layer in encoder)."""
        self.features_list.append(element)
    
    def extract_multi_level_features(self):
        """
        Assemble all gathered multi-level features into canonical tensor forms.

        The goal: for each visual sample, produce a list of region-aligned feature maps
        (e.g., multiple stage outputs for downstream region patching/ROIAlign).

        Returns:
            List of features, where each element is a list [stage1, stage2, ...] for one image.
        """
        # Concatenate all feature tensors along hidden dimension: [seq_len, hidden_size * k]
        concat_features = torch.cat(self.features_list, dim=1)
        merge_unit = self.merge_size * self.merge_size
        seq_len = concat_features.shape[0]

        # Rearrange into [windows, merge_unit, hidden_dim*layers]
        concat_features = concat_features.reshape(seq_len // merge_unit, merge_unit, -1)
        reverse_indices = torch.argsort(self.window_index)
        concat_features = concat_features[reverse_indices, :, :]
        concat_features = concat_features.reshape(seq_len, -1)
        
        # Split features for each image/video by product of grid h and w (per sample)
        split_size = (self.grid_thw[:, 1] * self.grid_thw[:, 2]).tolist()
        split_features = list(torch.split(concat_features, split_size, dim=0))
        assert len(split_features) == self.grid_thw.shape[0]
        for i in range(len(split_features)):
            # Recover original grid shape and merge windowing into stages, then split
            _, grid_h, grid_w = self.grid_thw[i]
            merge_h = grid_h // self.merge_size
            merge_w = grid_w // self.merge_size
            split_features[i] = split_features[i].reshape(merge_h, merge_w, merge_unit, -1)
            split_features[i] = split_features[i].reshape(merge_h, merge_w, self.merge_size, self.merge_size, -1)
            split_features[i] = split_features[i].permute(0, 2, 1, 3, 4)
            split_features[i] = split_features[i].flatten(start_dim=0, end_dim=-2)
            # Split [h, w, dim] into k tensors [1, dim/k, h, w] (for compatibility with multi-stage vision encoding)
            hidden_dim = split_features[i].shape[-1]
            split_dim = hidden_dim // len(self.features_list)
            split_features[i] = split_features[i].reshape(grid_h, grid_w, -1)
            split_features[i] = [
                split_features[i][..., j*split_dim:(j+1)*split_dim].permute(2, 0, 1).unsqueeze(0)
                for j in range(len(self.features_list))
            ]

        return split_features

# Global gather object to pass into Qwen2_5_VisionTransformer for monkey-patched feature gathering
GATHER = VisionFeaturesGather()

# --------------------------------- Monkey Patch ---------------------------------------
def custom_forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
    """
    Custom forward used with monkey patch to support multi-level feature extraction.
    Applies patch embedding, window partition, position embedding, and passes through all blocks.
    Optionally collects features at each 'fullatt' block for multi-region support.

    Args:
        hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
            The final hidden states of the model.
        grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
            Temporal, height, width of each feature sequence.

    Returns:
        `torch.Tensor`: Final hidden states after MLP head (merger).
    """
    hidden_states = self.patch_embed(hidden_states)
    rotary_pos_emb = self.rot_pos_emb(grid_thw)
    window_index, cu_window_seqlens = self.get_window_index(grid_thw)
    cu_window_seqlens = torch.tensor(
        cu_window_seqlens,
        device=hidden_states.device,
        dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
    )
    cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

    seq_len, _ = hidden_states.size()
    hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
    hidden_states = hidden_states[window_index, :, :]
    hidden_states = hidden_states.reshape(seq_len, -1)
    rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
    rotary_pos_emb = rotary_pos_emb[window_index, :, :]
    rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
    emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
    position_embeddings = (emb.cos(), emb.sin())

    cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
        dim=0,
        # FA2 requires that cu_seqlens_q must have dtype int32
        # torch.onnx.export requires that cu_seqlens_q must match grid_thw dtype
        # See https://github.com/huggingface/transformers/pull/34852 for more info
        dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
    )
    cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

    # If monkey-patched feature gather enabled, prepare to collect intermediate features
    if hasattr(self, 'vision_features_gather'):
        self.vision_features_gather.reset()
        self.vision_features_gather.set_params(grid_thw, window_index, self.spatial_merge_size)

    # Forward pass through all transformer blocks; collect intermediate features if needed
    for layer_num, blk in enumerate(self.blocks):
        if layer_num in self.fullatt_block_indexes:
            cu_seqlens_now = cu_seqlens
        else:
            cu_seqlens_now = cu_window_seqlens
        if self.gradient_checkpointing and self.training:
            hidden_states = self._gradient_checkpointing_func(
                blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings, use_reentrant=False
            )
        else:
            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
        
        if hasattr(self, 'vision_features_gather'):
            # Capture hidden states at all 'full attention' blocks as multi-level features
            if layer_num in self.fullatt_block_indexes:
                # This property is set by monkey patching
                self.vision_features_gather.append(hidden_states.clone())
    
    hidden_states = self.merger(hidden_states)
    reverse_indices = torch.argsort(window_index)
    hidden_states = hidden_states[reverse_indices, :]

    return hidden_states

def init_vision_features_gather(self, vision_features_gather):
    """
    Helper method for monkey patch to inject a VisionFeaturesGather instance into model.
    """
    self.vision_features_gather = vision_features_gather

def replace_qwen_vit_forward():
    """
    Monkey-patch Qwen2_5_VisionTransformer to use custom forward with multi-level feature support.
    """
    Qwen2_5_VisionTransformerPretrainedModel.forward = custom_forward
    Qwen2_5_VisionTransformerPretrainedModel.init_vision_features_gather = init_vision_features_gather


class Qwen2_5_VlVisionTower(nn.Module):
    """
    Vision backbone wrapper for Qwen2.5-VL (Vision Transformer).
    Handles both standard and region-level (multi-level) encoding with optional monkey patch logic.
    """
    def __init__(self, image_tower, args, delay_load=False, min_pixels=56*56, max_pixels=2048*2048):
        super().__init__()

        self.is_loaded = False

        self.image_tower_name = image_tower
        
        # Determine if multi-level region feature is to be enabled (monkey patch required)
        self.use_vision_tower_region_feature = getattr(args, 'mm_use_vision_tower_region_feature', False)
        if self.use_vision_tower_region_feature:
            replace_qwen_vit_forward()    # Monkey patch: add multi-level feature extraction logic
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.delay_load = delay_load
        print (f"Qwen2_5_VlVisionTower loading_info: delay_load: {delay_load} min_pixels: {min_pixels} max_pixels: {max_pixels}")

        # if not delay_load:
        #     self.load_model()
        # else:
        #     # Defer actual model loading to support (e.g.) model parallel or delayed download scenarios
        #     self.cfg_only = args.vision_config
        self.cfg_only = args.vision_config
        self.load_model(model_path=args.name_or_path)

    def load_model(self, model_path=None, image_size=336, is_train=True):
        """
        Actually load Qwen2.5 Vision Tower backbone and processor.
        Sets up the image tower and patch feed pipeline.
        """
        self.image_tower = Qwen2_5_VisionTransformerPretrainedModel._from_config(self.cfg_only, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
        # print(f'Qwen2_5_VlVisionTower loading_info: {loading_info}')

        if model_path is not None:
            self.image_processor = Qwen2VLImageProcessor.from_pretrained(model_path, min_pixels=self.min_pixels, max_pixels=self.max_pixels)
        else:
            self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.image_tower_name, min_pixels=self.min_pixels, max_pixels=self.max_pixels)

        if self.use_vision_tower_region_feature:
            # Setup gather instance for monkey-patched feature extraction
            self.image_tower.init_vision_features_gather(GATHER)
        self.is_loaded = True
    
    def convert_image_format(self, image):
        """
        Convert raw image tensor to pre-processed model input tensor and grid shape, using appropriate processor.
        Handles PIL conversion and applies preprocessor for Qwen2.5-VL.
        """
        pil_image = ToPILImage()(image)
        inputs = self.image_processor(images=pil_image, videos=None, return_tensors="pt")
        return inputs['pixel_values'], inputs['image_grid_thw']

    def forward(self, images, image_grid_thws=[]):
        """
        Forward pass for a batch (list) of images.
        Returns image features, gridTHWs, and optional multi-level features for each input image.
        """
        if type(images) is list:
            image_features = []
            multi_level_features_list = []
            output_image_grid_thws = []

            for i, image in enumerate(images):
                # If no grid provided, convert and infer via processor
                if image_grid_thws is None or len(image_grid_thws) == 0:
                    image, image_grid_thw = self.convert_image_format(image=image)
                else:
                    image_grid_thw = image_grid_thws[i]
                image_forward_out = self.image_tower(image.to(device=self.device, dtype=self.dtype), grid_thw=image_grid_thw.to(device=self.device))
                image_feature = image_forward_out.unsqueeze(0).to(self.dtype)

                image_features.append(image_feature)
                output_image_grid_thws.append(image_grid_thw)

                # If region feature mode enabled, collect multi-level features for this image
                if self.use_vision_tower_region_feature:
                    multi_level_features_list.append(self.get_multi_level_features()[0])
                
        else:
            raise NotImplementedError("Qwen2_5_VlVisionTower only supports list-of-image input")

        return image_features, output_image_grid_thws, multi_level_features_list
    
    def get_multi_level_features(self):
        """
        Get the current (last-processed) multi-level region features from the VisionFeaturesGather helper.
        Used in region-feature/ROIAlign branches.
        """
        multi_level_features = self.image_tower.vision_features_gather.extract_multi_level_features()
        return multi_level_features

    @property
    def dummy_feature(self):
        """Returns a zero-vector feature, for use as fallback/null visual token."""
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        """Report vision tower's expected/active tensor dtype (inferred from real weights)."""
        return self.image_tower.dtype

    @property
    def device(self):
        """Report vision tower's tensor device (cuda/cpu) for autoflow/compatibility."""
        return self.image_tower.device

    @property
    def config(self):
        """Yield config, for both loaded-and-ready and 'config only' modes (delay load etc)."""
        if self.is_loaded:
            return self.image_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        """Return backbone output hidden size (for proj or post-processing modules)."""
        return self.config.out_hidden_size

    @property
    def num_patches(self):
        """Return number of vision tokens (patches) in processed image."""
        return (self.config.image_size // self.config.patch_size) ** 2