| import torch |
| import torch.nn as nn |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel |
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel |
| from transformers import PretrainedConfig |
|
|
| siglip_config = PretrainedConfig.from_dict( |
| { |
| "attention_dropout": 0.0, |
| "hidden_act": "gelu_pytorch_tanh", |
| "hidden_size": 1152, |
| "image_size": 384, |
| "intermediate_size": 4304, |
| "layer_norm_eps": 1e-06, |
| "model_type": "siglip_vision_model", |
| "num_attention_heads": 16, |
| "num_channels": 3, |
| "num_hidden_layers": 27, |
| "patch_size": 14, |
| } |
| ) |
|
|
| qwen2vl_vit_config = PretrainedConfig.from_dict( |
| { |
| "depth": 32, |
| "embed_dim": 1280, |
| "hidden_act": "quick_gelu", |
| "hidden_size": 3584, |
| "in_channels": 3, |
| "in_chans": 3, |
| "mlp_ratio": 4, |
| "model_type": "qwen2_vl", |
| "num_heads": 16, |
| "patch_size": 14, |
| "spatial_merge_size": 2, |
| "spatial_patch_size": 14, |
| "temporal_patch_size": 2, |
| "_attn_implementation": "flash_attention_2", |
| "_attn_implementation_internal": "flash_attention_2" |
| } |
| ) |
|
|
| qwen2_5vl_vit_config = PretrainedConfig.from_dict( |
| { |
| "depth": 32, |
| "hidden_act": "silu", |
| "hidden_size": 1280, |
| "intermediate_size": 3420, |
| "num_heads": 16, |
| "in_chans": 3, |
| "out_hidden_size": 3584, |
| "patch_size": 14, |
| "spatial_merge_size": 2, |
| "spatial_patch_size": 14, |
| "window_size": 112, |
| "fullatt_block_indexes": [ |
| 7, |
| 15, |
| 23, |
| 31 |
| ], |
| "tokens_per_second": 2, |
| "temporal_patch_size": 2 |
| } |
| ) |
|
|
| aimv2_config = PretrainedConfig.from_dict( |
| { |
| "hidden_size": 1024, |
| "image_size": 448, |
| "intermediate_size": 2816, |
| "model_type": "aimv2", |
| "num_attention_heads": 8, |
| "num_channels": 3, |
| "num_hidden_layers": 24, |
| "patch_size": 14, |
| "projection_dropout": 0.0, |
| "qkv_bias": False, |
| "rms_norm_eps": 1e-05, |
| "torch_dtype": "float32", |
| "transformers_version": "4.46.3", |
| "auto_map": { |
| "AutoConfig": "configuration_aimv2.AIMv2Config", |
| "AutoModel": "modeling_aimv2.AIMv2Model", |
| }, |
| } |
| ) |
|
|
| def wrapped_qwen2vl_vision_tower(vision_tower_cfg, qwen2vl_vision_tower): |
| if getattr(vision_tower_cfg, "only_navit", False) and \ |
| getattr(vision_tower_cfg, "navit_use_mm_projector", False): |
| qwen2vl_vision_tower.merger = torch.nn.Identity() |
| print("navit_use_mm_projector is NOT None, so we need to initialize a new merger...") |
| |
| else: |
| old_linear = qwen2vl_vision_tower.merger.mlp[-1] |
| navit_merger_hidden_dim = getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) |
|
|
| rule1 = old_linear.out_features != vision_tower_cfg.hidden_size |
| rule2 = navit_merger_hidden_dim is not None and navit_merger_hidden_dim != old_linear.in_features |
|
|
| if rule1 or rule2: |
| del qwen2vl_vision_tower.merger |
| qwen2vl_vision_tower.merger = CustomPatchMerger( |
| dim=vision_tower_cfg.hidden_size, |
| context_dim=1280, |
| hidden_dim=navit_merger_hidden_dim if navit_merger_hidden_dim is not None else old_linear.in_features |
| ) |
| print("output_dim of original merger is not match or navit_merger_hidden_dim is not match, we need to initialize a new merger...") |
|
|
| return qwen2vl_vision_tower |
|
|
| def build_vision_tower(vision_tower_cfg, **kwargs): |
| vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) |
| if "siglip-so400m-patch14-384" in vision_tower or "Oryx-ViT" in vision_tower or "navit" in vision_tower.lower(): |
| |
| if "navit" in vision_tower.lower(): |
| assert getattr(vision_tower_cfg, "only_navit", False) and \ |
| getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None |
|
|
| if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None: |
| if "Qwen2.5-VL" in vision_tower_cfg.eagle_vision_tower: |
| if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None: |
| qwen2_5vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation |
| qwen2_5vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation |
| qwen2vl_vision_tower = Qwen2_5_VisionTransformerPretrainedModel._from_config(qwen2_5vl_vit_config) |
| elif "Qwen2-VL" in vision_tower_cfg.eagle_vision_tower: |
| if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None: |
| qwen2vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation |
| qwen2vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation |
| qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config) |
| else: |
| raise ValueError(f"Unknown vision tower: {vision_tower_cfg.eagle_vision_tower}") |
|
|
| qwen2vl_vision_tower = wrapped_qwen2vl_vision_tower(vision_tower_cfg, qwen2vl_vision_tower) |
| qwen2vl_vision_tower.requires_grad_(False) |
| if getattr(vision_tower_cfg, "only_navit", False): |
| return None, qwen2vl_vision_tower |
| else: |
| siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| return siglip_vision_tower, qwen2vl_vision_tower |
|
|
| |
| else: |
| return SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| elif "aimv2-huge-patch14-448" in vision_tower or "Ovis2-8B-visual" in vision_tower: |
| return AIMv2VisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| elif "aimv2-large-patch14-448" in vision_tower or "Ovis2-2B-visual" in vision_tower: |
| return AIMv2VisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| else: |
| raise ValueError(f"Unknown vision tower: {vision_tower}") |
|
|
| class SigLipVisionTower(nn.Module): |
| def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"): |
| super().__init__() |
| self.is_loaded = False |
| self.image_tower_name = vision_tower |
| self.select_layer = args.mm_vision_select_layer |
| self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
| self.cache_dir = cache_dir |
|
|
| if not delay_load: |
| self.load_model() |
| else: |
| from transformers import SiglipVisionModel |
| self.cfg_only = siglip_config |
| self.vision_tower = SiglipVisionModel._from_config(siglip_config) |
|
|
| def load_model(self): |
| from transformers import SiglipVisionModel |
| self.vision_tower = SiglipVisionModel._from_config(siglip_config) |
| self.vision_tower.requires_grad_(False) |
| self.is_loaded = True |
|
|
| def feature_select(self, image_forward_outs): |
| assert self.select_feature == "cls_patch" |
| image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1) |
| return image_features |
|
|
| def forward(self, images): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_forward_out = self.vision_tower( |
| image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype) |
| image_features.append(image_feature) |
| else: |
| image_forward_outs = self.vision_tower( |
| images.to(device=self.device, dtype=self.dtype), |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype) |
|
|
| return image_features |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device |
|
|
| @property |
| def config(self): |
| if self.is_loaded: |
| return self.vision_tower.config |
| else: |
| return self.cfg_only |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
| class CustomPatchMerger(nn.Module): |
| def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None: |
| super().__init__() |
| self.input_dim = context_dim * (spatial_merge_size**2) |
| self.ln_q = nn.LayerNorm(context_dim, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(self.input_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.mlp(self.ln_q(x).view(-1, self.input_dim)) |
| return x |
|
|
|
|
|
|
| |
| class AIMv2VisionTower(nn.Module): |
| def __init__(self, vision_tower, args, delay_load=False, cache_dir='./cache_dir'): |
| super().__init__() |
|
|
| self.is_loaded = False |
|
|
| self.image_tower_name = vision_tower |
| self.select_layer = args.mm_vision_select_layer |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
|
|
| self.cache_dir = cache_dir |
| if not delay_load: |
| self.load_model() |
| else: |
| from transformers import AutoConfig, AutoModel |
| |
| |
| self.cfg_only = aimv2_config |
| self.vision_tower = AutoModel._from_config(aimv2_config) |
|
|
|
|
| def load_model(self): |
| from transformers import AutoConfig, AutoModel, AutoProcessor |
| self.image_processor = AutoProcessor.from_pretrained(self.image_tower_name, trust_remote_code=True) |
| self.vision_tower = AutoModel.from_pretrained(self.image_tower_name, trust_remote_code=True) |
| self.vision_tower.requires_grad_(False) |
| |
| self.image_processor.crop_size = self.image_processor.size["shortest_edge"] |
|
|
| self.is_loaded = True |
|
|
| def feature_select(self, image_forward_outs): |
| assert self.select_feature == 'cls_patch' |
| image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1) |
| return image_features |
|
|
| def forward(self, images): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True |
| ,return_dict=True,) |
| image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype) |
| image_features.append(image_feature) |
| else: |
| image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True |
| ,return_dict=True,) |
| image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype) |
|
|
| return image_features |
|
|
| |
| |
| |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device |
|
|
| @property |
| def config(self): |
| if self.is_loaded: |
| return self.vision_tower.config |
| else: |
| return self.cfg_only |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (self.config.image_size // self.config.patch_size) ** 2 |
|
|