| import torch |
| import torch.nn as nn |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel |
| from transformers import PretrainedConfig |
|
|
| siglip_config = PretrainedConfig.from_dict( |
| { |
| "attention_dropout": 0.0, |
| "hidden_act": "gelu_pytorch_tanh", |
| "hidden_size": 1152, |
| "image_size": 384, |
| "intermediate_size": 4304, |
| "layer_norm_eps": 1e-06, |
| "model_type": "siglip_vision_model", |
| "num_attention_heads": 16, |
| "num_channels": 3, |
| "num_hidden_layers": 27, |
| "patch_size": 14, |
| } |
| ) |
|
|
| qwen2vl_vit_config = PretrainedConfig.from_dict( |
| { |
| "depth": 32, |
| "embed_dim": 1280, |
| "hidden_act": "quick_gelu", |
| "hidden_size": 3584, |
| "in_channels": 3, |
| "in_chans": 3, |
| "mlp_ratio": 4, |
| "model_type": "qwen2_vl", |
| "num_heads": 16, |
| "patch_size": 14, |
| "spatial_merge_size": 2, |
| "spatial_patch_size": 14, |
| "temporal_patch_size": 2, |
| "_attn_implementation": "flash_attention_2", |
| "_attn_implementation_internal": "flash_attention_2" |
| } |
| ) |
|
|
| def build_vision_tower(vision_tower_cfg, **kwargs): |
| vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) |
| if "siglip-so400m-patch14-384" in vision_tower: |
| |
| if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None: |
| if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None: |
| qwen2vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation |
| qwen2vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation |
| |
| qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config) |
| |
| if getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) is not None: |
| del qwen2vl_vision_tower.merger |
| qwen2vl_vision_tower.merger = CustomPatchMerger( |
| vision_tower_cfg.hidden_size, |
| context_dim=1280, |
| hidden_dim=getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) |
| ) |
| qwen2vl_vision_tower.requires_grad_(False) |
| |
| |
| if getattr(vision_tower_cfg, "only_navit", False): |
| siglip_vision_tower = None |
| else: |
| siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| |
| return siglip_vision_tower, qwen2vl_vision_tower |
| |
| else: |
| siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| return siglip_vision_tower |
| else: |
| raise ValueError(f"Unknown vision tower: {vision_tower}") |
|
|
| class SigLipVisionTower(nn.Module): |
| def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"): |
| super().__init__() |
| self.is_loaded = False |
| self.image_tower_name = vision_tower |
| self.select_layer = args.mm_vision_select_layer |
| self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
| self.cache_dir = cache_dir |
|
|
| if not delay_load: |
| self.load_model() |
| else: |
| from transformers import SiglipVisionModel |
| self.cfg_only = siglip_config |
| self.vision_tower = SiglipVisionModel._from_config(siglip_config) |
|
|
| def load_model(self): |
| from transformers import SiglipVisionModel |
| self.vision_tower = SiglipVisionModel._from_config(siglip_config) |
| self.vision_tower.requires_grad_(False) |
| self.is_loaded = True |
|
|
| def feature_select(self, image_forward_outs): |
| assert self.select_feature == "cls_patch" |
| image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1) |
| return image_features |
|
|
| def forward(self, images): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_forward_out = self.vision_tower( |
| image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype) |
| image_features.append(image_feature) |
| else: |
| image_forward_outs = self.vision_tower( |
| images.to(device=self.device, dtype=self.dtype), |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype) |
|
|
| return image_features |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device |
|
|
| @property |
| def config(self): |
| if self.is_loaded: |
| return self.vision_tower.config |
| else: |
| return self.cfg_only |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
| class CustomPatchMerger(nn.Module): |
| def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None: |
| super().__init__() |
| self.input_dim = context_dim * (spatial_merge_size**2) |
| self.ln_q = nn.LayerNorm(context_dim, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(self.input_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.mlp(self.ln_q(x).view(-1, self.input_dim)) |
| return x |