Spaces:
Runtime error
Runtime error
| """ | |
| Mostly copy-paste from LLaVA-HR | |
| https://github.com/luogen1996/LLaVA-HR | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import List, Optional | |
| def forward_embeddings(self, pixel_values: torch.FloatTensor) -> torch.Tensor: | |
| batch_size = pixel_values.shape[0] | |
| target_dtype = self.patch_embedding.weight.dtype | |
| patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] | |
| patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | |
| class_embeds = self.class_embedding.expand(batch_size, 1, -1) | |
| embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | |
| position_embeddings = self.position_embedding(self.position_ids) | |
| if position_embeddings.shape[1]!=embeddings.shape[1]: | |
| position_embeddings=resample_pos_embed(position_embeddings,embeddings.shape[1]) | |
| embeddings = embeddings + position_embeddings | |
| return embeddings | |
| def resample_pos_embed( | |
| posemb, | |
| new_size: int, | |
| num_prefix_tokens: int = 1, | |
| interpolation: str = 'bicubic', | |
| antialias: bool = True, | |
| verbose: bool = False, | |
| ): | |
| new_size=[int(math.sqrt(new_size-num_prefix_tokens)),int(math.sqrt(new_size-num_prefix_tokens))] | |
| num_pos_tokens = posemb.shape[1] - num_prefix_tokens | |
| old_size = int(math.sqrt(num_pos_tokens)) | |
| bs=posemb.shape[0] | |
| if num_prefix_tokens: | |
| posemb_prefix, posemb = posemb[:,:num_prefix_tokens], posemb[:,num_prefix_tokens:] | |
| else: | |
| posemb_prefix, posemb = None, posemb | |
| # do the interpolation | |
| embed_dim = posemb.shape[-1] | |
| orig_dtype = posemb.dtype | |
| posemb = posemb.float() # interpolate needs float32 | |
| posemb = posemb.reshape(bs, old_size, old_size, -1).permute(0, 3, 1, 2) | |
| posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) | |
| posemb = posemb.permute(0, 2, 3, 1).reshape(bs, -1, embed_dim) | |
| posemb = posemb.to(dtype=orig_dtype) | |
| # add back extra (class, etc) prefix tokens | |
| if posemb_prefix is not None: | |
| posemb = torch.cat([posemb_prefix, posemb],1) | |
| if not torch.jit.is_scripting() and verbose: | |
| print(f'Resized position embedding: {old_size} to {new_size}.') | |
| return posemb | |
| class HRCLIPVisionTower(nn.Module): | |
| def __init__(self, vision_tower, args, delay_load=False): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.freeze_vision=args.freeze_vision | |
| self.input_image_size=args.input_image_size | |
| self.vision_tower_name = vision_tower | |
| self.select_layer = args.mm_vision_select_layer | |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
| if not delay_load: | |
| self.load_model() | |
| else: | |
| self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) | |
| def load_model(self): | |
| self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) | |
| self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) | |
| # checkpointing for clip | |
| self.vision_tower.vision_model.encoder.gradient_checkpointing =True | |
| if self.freeze_vision: | |
| self.vision_tower.requires_grad_(False) | |
| cls_=self.vision_tower.vision_model.embeddings | |
| bound_method = forward_embeddings.__get__(cls_, cls_.__class__) | |
| setattr(cls_, 'forward', bound_method) | |
| if self.input_image_size is not None: | |
| self.image_processor.size=self.input_image_size | |
| self.image_processor.crop_size={ | |
| 'height':self.input_image_size, | |
| 'width': self.input_image_size | |
| } | |
| self.is_loaded = True | |
| def forward(self, x): | |
| # 448 image input | |
| blks = self.vision_tower.vision_model.encoder.layers | |
| x = self.vision_tower.vision_model.embeddings(x) | |
| x = self.vision_tower.vision_model.pre_layrnorm(x[:, 1:]) | |
| # inference of fast branch | |
| for blk in blks: | |
| if self.training: | |
| x=checkpoint( | |
| blk.__call__, | |
| x, | |
| None, | |
| None | |
| )[0] | |
| else: | |
| x = blk(x, None, None)[0] | |
| return x | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return self.vision_tower.dtype | |
| def device(self): | |
| return self.vision_tower.device | |
| def num_attention_heads(self): | |
| return self.config.num_attention_heads | |
| def num_layers(self): | |
| return self.config.num_hidden_layers | |
| def config(self): | |
| if self.is_loaded: | |
| return self.vision_tower.config | |
| else: | |
| return self.cfg_only | |
| def hidden_size(self): | |
| return self.config.hidden_size | |
| def num_patches(self): | |
| return (self.config.image_size // self.config.patch_size) ** 2 | |