| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch, os |
| | import torch.nn as nn |
| | from torch.utils.checkpoint import checkpoint |
| |
|
| | from .siglip_vision_tower import SiglipVisionTower |
| |
|
| | import torch.nn.functional as F |
| | from torch.nn.init import trunc_normal_ |
| | from copy import deepcopy |
| | import random |
| | import math |
| |
|
| | class MultiBackboneChannelConcatenationVisionTower(nn.Module): |
| | def __init__(self, |
| | vision_tower, |
| | args, |
| | grid_size=32, |
| | convnext_img_size=1024, |
| | normalize_type=None, raw_config=None): |
| | |
| | super().__init__() |
| |
|
| | self.is_loaded = False |
| | self.grid_size = grid_size |
| | self.num_tokens = self.grid_size ** 2 |
| | self.normalize_type = args.normalize_type |
| | self.moe_version_type = args.moe_version_type |
| | self.raw_config = raw_config |
| | print("moe_version_type: ", self.moe_version_type) |
| | assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}" |
| | |
| | vision_tower_name_list = vision_tower.split(";") |
| | self.input_image_size = 1024 |
| | self.convnext_img_size = convnext_img_size |
| | self.load_vision_towers(vision_tower_name_list, args) |
| |
|
| | |
| | def load_vision_towers(self, vision_tower_name_list, args): |
| | self.vision_towers = nn.ModuleList() |
| |
|
| | freeze_backbone_list = args.freeze_backbones |
| | if freeze_backbone_list is not None and len(freeze_backbone_list) > 0: |
| | print("The frozen backbones: ", freeze_backbone_list) |
| | else: |
| | |
| | freeze_backbone_list = "" |
| |
|
| | for name in vision_tower_name_list: |
| | |
| | |
| | if name == 'convnext-1024': |
| | convnext_args = deepcopy(args) |
| |
|
| | convnext_args.freeze_vision = False |
| | if 'convnext-1024' in freeze_backbone_list: |
| | convnext_args.freeze_vision = True |
| |
|
| | from .convnext_encoder import ConvNextVisionTower |
| | convnext_args.input_image_size = self.convnext_img_size |
| | convnext_vision_tower = args.vision_tower_convnext_path |
| | convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, |
| | convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type) |
| | convnext_vision_tower.load_model() |
| | self.vision_towers.append(convnext_vision_tower) |
| |
|
| | |
| | elif name == 'palisiglip': |
| | palisiglip_args = deepcopy(args) |
| | palisiglip_args.input_image_size = 448 |
| |
|
| | palisiglip_args.freeze_vision = False |
| | if 'palisiglip' in freeze_backbone_list: |
| | palisiglip_args.freeze_vision = True |
| |
|
| | palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config) |
| | |
| | palisiglip_vision_tower.load_model() |
| | self.vision_towers.append(palisiglip_vision_tower) |
| |
|
| | |
| | self.image_processor = None |
| | self.is_loaded = True |
| |
|
| | def load_model(self): |
| | assert self.is_loaded, "All the vision encoders should be loaded during initialization!" |
| |
|
| | def forward(self, x): |
| | |
| | |
| | if self.moe_version_type in [None, 'all_tiling']: |
| | |
| | features = [] |
| | image_input_size = x.shape[2] |
| | assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" |
| | for vision_tower in self.vision_towers: |
| | |
| | if vision_tower.input_image_size != image_input_size: |
| | resized_x = F.interpolate(x.float(), |
| | size=(vision_tower.input_image_size, vision_tower.input_image_size), |
| | mode='bilinear', |
| | align_corners=True).to(dtype=x.dtype) |
| | else: |
| | resized_x = x |
| | |
| | feature = vision_tower(resized_x) |
| | |
| | if len(feature.shape) == 3: |
| | b, n, c = feature.shape |
| | if n == self.num_tokens: |
| | features.append(feature) |
| | continue |
| | w = h = int(n**0.5) |
| | feature = feature.transpose(1,2).reshape(b, c, h, w) |
| | else: |
| | b, c, h, w = feature.shape |
| |
|
| | if w != self.grid_size: |
| | feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) |
| | features.append(feature.flatten(2,3).transpose(1,2)) |
| | |
| | features = torch.cat(features, dim=-1) |
| | elif self.moe_version_type == 'convnext_512_siglip_448': |
| | features = {} |
| | image_input_size = x.shape[2] |
| | assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" |
| | for vision_tower in self.vision_towers: |
| | |
| | if vision_tower.input_image_size != image_input_size: |
| | resized_x = F.interpolate(x.float(), |
| | size=(vision_tower.input_image_size, vision_tower.input_image_size), |
| | mode='bilinear', |
| | align_corners=True).to(dtype=x.dtype) |
| | else: |
| | resized_x = x |
| | |
| | feature = vision_tower(resized_x) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | features[vision_tower.name] = feature |
| |
|
| | else: |
| | assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x)) |
| | pixel_values = x['pixel_values'] |
| | num_patches = x['num_patches'] |
| |
|
| | |
| | if self.moe_version_type == 'seq_concat': |
| | image_in_num_patches = [i-1 for i in num_patches] |
| | else: |
| | image_in_num_patches = [i for i in num_patches] |
| |
|
| |
|
| | assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0)) |
| |
|
| | |
| | thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1 |
| | image_no_tiling = pixel_values[thumbnail_image_id] |
| |
|
| | |
| | features = [] |
| | for layer_id, vision_tower in enumerate(self.vision_towers): |
| | if layer_id == 0: |
| | x = pixel_values |
| | else: |
| | x = image_no_tiling |
| |
|
| | if vision_tower.input_image_size != self.input_image_size: |
| | resized_x = F.interpolate(x.float(), |
| | size=(vision_tower.input_image_size, vision_tower.input_image_size), |
| | mode='bilinear', |
| | align_corners=True).to(dtype=x.dtype) |
| | else: |
| | resized_x = x |
| | |
| | feature = vision_tower(resized_x) |
| | if len(feature.shape) == 3: |
| | b, n, c = feature.shape |
| | if n == self.num_tokens: |
| | features.append(feature) |
| | continue |
| |
|
| | w = h = int(n**0.5) |
| | feature = feature.transpose(1,2).reshape(b, c, h, w) |
| | else: |
| | b, c, h, w = feature.shape |
| |
|
| | if w != self.grid_size: |
| | feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) |
| | features.append(feature.flatten(2,3).transpose(1,2)) |
| |
|
| | clip_embeds = features[0] |
| | if len(features) <= 1: |
| | no_tiling_embeds = None |
| | else: |
| | no_tiling_embeds = torch.cat(features[1:], dim=-1) |
| |
|
| | if self.moe_version_type == 'feat_concat': |
| | |
| | clip_thumbnail_embeds = clip_embeds[thumbnail_image_id] |
| | if no_tiling_embeds is not None: |
| | no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1) |
| | else: |
| | no_tiling_embeds = clip_thumbnail_embeds |
| |
|
| | |
| | clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id) |
| | clip_embeds = clip_embeds[clip_embeds_mask] |
| | |
| |
|
| | features = { |
| | 'clip_embeds': clip_embeds, |
| | 'no_tiling_embeds': no_tiling_embeds, |
| | 'num_patches': num_patches |
| | } |
| |
|
| | |
| |
|
| | return features |
| | |
| | @property |
| | def dummy_feature(self): |
| | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
| |
|
| | @property |
| | def dtype(self): |
| | return next(self.clip_vision_tower.parameters()).dtype |
| |
|
| | @property |
| | def device(self): |
| | return next(self.clip_vision_tower.parameters()).device |
| |
|
| | @property |
| | def config(self): |
| | assert NotImplementedError |
| | pass |
| |
|
| | @property |
| | def hidden_size(self): |
| | if self.moe_version_type == 'convnext_512_siglip_448': |
| | res = {} |
| | for vision_tower in self.vision_towers: |
| | res[vision_tower.name] = vision_tower.hidden_size |
| | return res |
| | else: |
| | return sum([_.hidden_size for _ in self.vision_towers]) |
| |
|
| | @property |
| | def num_patches(self): |
| | return self.num_tokens |
| |
|
| |
|
| |
|
| |
|