| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from accelerate.hooks import add_hook_to_module |
| | from einops import rearrange |
| | from s2wrapper import forward as multiscale_forward |
| | from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| | from transformers.models.siglip import SiglipVisionModel |
| |
|
| |
|
| | class VisionTower(nn.Module): |
| | def __init__(self, vision_tower, args, delay_load=False): |
| | super().__init__() |
| |
|
| | self.is_loaded = False |
| |
|
| | self.vision_tower_name = vision_tower |
| | self.select_layer = getattr(args, "mm_vision_select_layer", -2) |
| | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
| |
|
| | self.cfg_only = None |
| |
|
| | def feature_select(self, image_forward_outs): |
| | image_features = image_forward_outs.hidden_states[self.select_layer] |
| | if self.select_feature == "patch": |
| | image_features = image_features[:, 1:] |
| | elif self.select_feature == "cls_patch": |
| | image_features = image_features |
| | else: |
| | raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| | return image_features |
| |
|
| | def _maybe_resize_pos_embeds( |
| | self, |
| | model: PreTrainedModel, |
| | image_processor: BaseImageProcessor, |
| | resolution: int = -1, |
| | interpolate_mode: str = "linear", |
| | ): |
| | if resolution in [model.config.image_size, -1]: |
| | return |
| | print( |
| | f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." |
| | ) |
| | embeddings = model.vision_model.embeddings |
| | patch_size = embeddings.patch_size |
| | num_new_tokens = int((resolution // patch_size) ** 2) |
| |
|
| | old_embeddings = embeddings.position_embedding |
| | match interpolate_mode: |
| | case "linear": |
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| |
|
| | if is_deepspeed_zero3_enabled(): |
| | try: |
| | import deepspeed |
| | except ImportError: |
| | raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.") |
| | with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): |
| | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
| | else: |
| | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
| | new_embeddings = nn.Embedding( |
| | num_new_tokens, |
| | old_embedding_dim, |
| | dtype=old_embeddings.weight.dtype, |
| | device=old_embeddings.weight.device, |
| | ) |
| | mapped_indices = ( |
| | torch.arange(num_new_tokens).to(old_embeddings.weight.device) |
| | / (num_new_tokens - 1) |
| | * (old_num_tokens - 1) |
| | ) |
| | floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1) |
| | ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1) |
| | if is_deepspeed_zero3_enabled(): |
| | params = [old_embeddings.weight, new_embeddings.weight] |
| | with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
| | interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ |
| | ceil_indices, : |
| | ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] |
| | else: |
| | interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ |
| | ceil_indices, : |
| | ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] |
| | new_embeddings.weight.data = interpolated_embeds |
| | case _: |
| | raise NotImplementedError |
| |
|
| | if hasattr(old_embeddings, "_hf_hook"): |
| | hook = old_embeddings._hf_hook |
| | add_hook_to_module(new_embeddings, hook) |
| | new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) |
| | |
| | model.config.image_size = resolution |
| | if hasattr(image_processor, "crop_size"): |
| | |
| | image_processor.crop_size = resolution |
| | else: |
| | |
| | assert hasattr(image_processor, "size") |
| | image_processor.size = {"height": resolution, "width": resolution} |
| | embeddings.position_embedding = new_embeddings |
| | embeddings.image_size = resolution |
| | embeddings.num_patches = embeddings.num_positions = num_new_tokens |
| | embeddings.position_ids = ( |
| | torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device) |
| | ) |
| |
|
| | def forward(self, images): |
| | if type(images) is list: |
| | image_features = [] |
| | for image in images: |
| | image_forward_out = self.vision_tower( |
| | image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
| | output_hidden_states=True, |
| | ) |
| | image_feature = self.feature_select(image_forward_out).to(image.dtype) |
| | image_features.append(image_feature) |
| | else: |
| | image_forward_outs = self.vision_tower( |
| | images.to(device=self.device, dtype=self.dtype), |
| | output_hidden_states=True, |
| | ) |
| | image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| |
|
| | return image_features |
| |
|
| | @property |
| | def dummy_feature(self): |
| | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
| |
|
| | @property |
| | def dtype(self): |
| | return self.vision_tower.dtype |
| |
|
| | @property |
| | def device(self): |
| | return self.vision_tower.device |
| |
|
| | @property |
| | def config(self): |
| | if self.is_loaded: |
| | return self.vision_tower.config |
| | else: |
| | return self.cfg_only |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.config.hidden_size |
| |
|
| | @property |
| | def num_patches(self): |
| | return (self.config.image_size // self.config.patch_size) ** 2 |
| |
|
| |
|
| | class VisionTowerS2(VisionTower): |
| | def __init__(self, vision_tower, args, delay_load=False): |
| | super().__init__(vision_tower, args, delay_load) |
| |
|
| | self.scales = list(map(int, args.s2_scales.split(","))) |
| | self.scales.sort() |
| | self.max_split_size = args.s2_max_split_size |
| | self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0) |
| |
|
| | def forward_feature(self, images): |
| | image_forward_outs = self.vision_tower( |
| | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True |
| | ) |
| | image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| | return image_features |
| |
|
| | def forward(self, images): |
| | if type(images) is list: |
| | image_features = [] |
| | for image in images: |
| | image_feature = multiscale_forward( |
| | self.forward_feature, |
| | image.unsqueeze(0), |
| | img_sizes=self.scales, |
| | max_split_size=self.max_split_size, |
| | resize_output_to_idx=self.resize_output_to_scale_idx, |
| | ) |
| | image_features.append(image_feature) |
| | else: |
| | image_features = multiscale_forward( |
| | self.forward_feature, |
| | images, |
| | img_sizes=self.scales, |
| | max_split_size=self.max_split_size, |
| | resize_output_to_idx=self.resize_output_to_scale_idx, |
| | ) |
| |
|
| | return image_features |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.config.hidden_size * len(self.scales) |
| |
|
| |
|
| | class VisionTowerDynamicS2(VisionTower): |
| | def __init__(self, vision_tower, args, delay_load=False): |
| | super().__init__(vision_tower, args, delay_load) |
| |
|
| | self.scales = list(map(int, args.s2_scales.split(","))) |
| | self.scales.sort() |
| | self.max_split_size = args.s2_max_split_size |
| | self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0) |
| |
|
| | def forward_feature(self, images): |
| | image_forward_outs = self.vision_tower( |
| | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True |
| | ) |
| | image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| | return image_features |
| |
|
| | def forward(self, images): |
| | assert type(images) is not list |
| | image_features = self.forward_feature(images) |
| |
|
| | return image_features |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.config.hidden_size * len(self.scales) |
| |
|
| |
|
| | class SiglipVisionTower(VisionTower): |
| | def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: |
| | super().__init__(model_name_or_path, config) |
| | self.vision_tower = SiglipVisionModel.from_pretrained( |
| | model_name_or_path, |
| | attn_implementation=config._attn_implementation, |
| | torch_dtype=eval(config.model_dtype), |
| | ) |
| | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| | self.is_loaded = True |
| |
|
| |
|
| | class SiglipVisionTowerS2(VisionTowerS2): |
| | def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: |
| | super().__init__(model_name_or_path, config) |
| | self.vision_tower = SiglipVisionModel.from_pretrained( |
| | model_name_or_path, |
| | attn_implementation=config._attn_implementation, |
| | torch_dtype=eval(config.model_dtype), |
| | ) |
| | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| | |
| | self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1] |
| | self.is_loaded = True |
| |
|
| |
|
| | class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2): |
| | def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: |
| | super().__init__(model_name_or_path, config) |
| | self.vision_tower = SiglipVisionModel.from_pretrained( |
| | model_name_or_path, |
| | attn_implementation=config._attn_implementation, |
| | torch_dtype=eval(config.model_dtype), |
| | ) |
| | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| | |
| | self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0] |
| | self.is_loaded = True |
| |
|