Valley2.5 / modeling_vision_tower.py
Hyggge's picture
feat: modify file type of *.py, *.txt, etc. to change storage method
64c250f
raw
history blame
13.1 kB
import torch
import torch.nn as nn
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from transformers import PretrainedConfig
siglip_config = PretrainedConfig.from_dict(
{
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
}
)
qwen2vl_vit_config = PretrainedConfig.from_dict(
{
"depth": 32,
"embed_dim": 1280,
"hidden_act": "quick_gelu",
"hidden_size": 3584,
"in_channels": 3,
"in_chans": 3,
"mlp_ratio": 4,
"model_type": "qwen2_vl",
"num_heads": 16,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"_attn_implementation": "flash_attention_2",
"_attn_implementation_internal": "flash_attention_2"
}
)
qwen2_5vl_vit_config = PretrainedConfig.from_dict(
{
"depth": 32,
"hidden_act": "silu",
"hidden_size": 1280,
"intermediate_size": 3420,
"num_heads": 16,
"in_chans": 3,
"out_hidden_size": 3584,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"window_size": 112,
"fullatt_block_indexes": [
7,
15,
23,
31
],
"tokens_per_second": 2,
"temporal_patch_size": 2
}
)
aimv2_config = PretrainedConfig.from_dict(
{
"hidden_size": 1024,
"image_size": 448,
"intermediate_size": 2816,
"model_type": "aimv2",
"num_attention_heads": 8,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dropout": 0.0,
"qkv_bias": False,
"rms_norm_eps": 1e-05,
"torch_dtype": "float32",
"transformers_version": "4.46.3",
"auto_map": {
"AutoConfig": "configuration_aimv2.AIMv2Config",
"AutoModel": "modeling_aimv2.AIMv2Model",
},
}
)
def wrapped_qwen2vl_vision_tower(vision_tower_cfg, qwen2vl_vision_tower):
if getattr(vision_tower_cfg, "only_navit", False) and \
getattr(vision_tower_cfg, "navit_use_mm_projector", False):
qwen2vl_vision_tower.merger = torch.nn.Identity()
print("navit_use_mm_projector is NOT None, so we need to initialize a new merger...")
else:
old_linear = qwen2vl_vision_tower.merger.mlp[-1] # shape: 5120 * 3584, 3584 is dim of LLM, 5120 is the hidden_dim of merger
navit_merger_hidden_dim = getattr(vision_tower_cfg, "navit_merger_hidden_dim", None)
rule1 = old_linear.out_features != vision_tower_cfg.hidden_size
rule2 = navit_merger_hidden_dim is not None and navit_merger_hidden_dim != old_linear.in_features
if rule1 or rule2:
del qwen2vl_vision_tower.merger
qwen2vl_vision_tower.merger = CustomPatchMerger(
dim=vision_tower_cfg.hidden_size, # output_dim of merger, also the dim of LLM
context_dim=1280, # 1280 is the hidden_dim of qwen2vl_vision_tower, so input_dim of merger is 1280*4=5120 (2*2 pixel shuffle)
hidden_dim=navit_merger_hidden_dim if navit_merger_hidden_dim is not None else old_linear.in_features # hidden_dim of merger
)
print("output_dim of original merger is not match or navit_merger_hidden_dim is not match, we need to initialize a new merger...")
return qwen2vl_vision_tower
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
if "siglip-so400m-patch14-384" in vision_tower or "Oryx-ViT" in vision_tower or "navit" in vision_tower.lower():
# if 'navit' in vision_tower, vision_tower_cfg.eagle_vision_tower is not None and vision_tower_cfg.only_navit is True
if "navit" in vision_tower.lower():
assert getattr(vision_tower_cfg, "only_navit", False) and \
getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None
if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None:
if "Qwen2.5-VL" in vision_tower_cfg.eagle_vision_tower:
if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None:
qwen2_5vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation
qwen2_5vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation
qwen2vl_vision_tower = Qwen2_5_VisionTransformerPretrainedModel._from_config(qwen2_5vl_vit_config)
elif "Qwen2-VL" in vision_tower_cfg.eagle_vision_tower:
if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None:
qwen2vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation
qwen2vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation
qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config)
else:
raise ValueError(f"Unknown vision tower: {vision_tower_cfg.eagle_vision_tower}")
qwen2vl_vision_tower = wrapped_qwen2vl_vision_tower(vision_tower_cfg, qwen2vl_vision_tower)
qwen2vl_vision_tower.requires_grad_(False)
if getattr(vision_tower_cfg, "only_navit", False):
return None, qwen2vl_vision_tower
else:
siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
return siglip_vision_tower, qwen2vl_vision_tower
# only return siglip vision tower if eagle vision tower is None
else:
return SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "aimv2-huge-patch14-448" in vision_tower or "Ovis2-8B-visual" in vision_tower:
return AIMv2VisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "aimv2-large-patch14-448" in vision_tower or "Ovis2-2B-visual" in vision_tower:
return AIMv2VisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
else:
raise ValueError(f"Unknown vision tower: {vision_tower}")
class SigLipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"):
super().__init__()
self.is_loaded = False
self.image_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.cache_dir = cache_dir
if not delay_load:
self.load_model()
else:
from transformers import SiglipVisionModel
self.cfg_only = siglip_config
self.vision_tower = SiglipVisionModel._from_config(siglip_config) # dummy-load
def load_model(self):
from transformers import SiglipVisionModel
self.vision_tower = SiglipVisionModel._from_config(siglip_config)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
assert self.select_feature == "cls_patch"
image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1)
return image_features
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
return_dict=True,
)
image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
return_dict=True,
)
image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class CustomPatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.input_dim = context_dim * (spatial_merge_size**2)
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.input_dim))
return x
class AIMv2VisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False, cache_dir='./cache_dir'):
super().__init__()
self.is_loaded = False
self.image_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.cache_dir = cache_dir
if not delay_load:
self.load_model()
else:
from transformers import AutoConfig, AutoModel
# self.cfg_only = AutoConfig.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir, trust_remote_code=True)
# self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True) # dummy-load
self.cfg_only = aimv2_config
self.vision_tower = AutoModel._from_config(aimv2_config) # dummy-load
def load_model(self):
from transformers import AutoConfig, AutoModel, AutoProcessor
self.image_processor = AutoProcessor.from_pretrained(self.image_tower_name, trust_remote_code=True)
self.vision_tower = AutoModel.from_pretrained(self.image_tower_name, trust_remote_code=True)
self.vision_tower.requires_grad_(False)
# self.image_processor.crop_size = self.image_processor.crop_size['height']
self.image_processor.crop_size = self.image_processor.size["shortest_edge"]
self.is_loaded = True
def feature_select(self, image_forward_outs):
assert self.select_feature == 'cls_patch'
image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1)
return image_features
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
,return_dict=True,)
image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
,return_dict=True,)
image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype)
return image_features
# @property
# def dummy_feature(self):
# return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2