Spaces:
Runtime error
Runtime error
Eagle-X5-13B-Chat
/
eagle
/model
/multimodal_encoder
/multi_backbone_channel_concatenation_encoder.py
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| from .convnext_encoder import ConvNextVisionTower | |
| from .hr_clip_encoder import HRCLIPVisionTower | |
| from .vision_models.eva_vit import EVAVITVisionTower | |
| from .sam_encoder import SAMVisionTower | |
| from .pix2struct_encoder import Pix2StructLargeVisionTower | |
| import torch.nn.functional as F | |
| from torch.nn.init import trunc_normal_ | |
| from copy import deepcopy | |
| import random | |
| import math | |
| class MultiBackboneChannelConcatenationVisionTower(nn.Module): | |
| def __init__(self, | |
| vision_tower, | |
| args, | |
| grid_size=32): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.grid_size = grid_size | |
| self.num_tokens = self.grid_size ** 2 | |
| vision_tower_name_list = vision_tower.split(";") | |
| self.input_image_size = 1024 # hardcode | |
| self.load_vision_towers(vision_tower_name_list, args) | |
| def load_vision_towers(self, vision_tower_name_list, args): | |
| self.vision_towers = nn.ModuleList() | |
| for name in vision_tower_name_list: | |
| if name == 'det-1024': | |
| det_args = deepcopy(args) | |
| det_args.input_image_size = 1024 | |
| det_args.freeze_vision = False | |
| det_args.vision_tower_pretrained_from = '/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth' | |
| det_vision_tower = EVAVITVisionTower("eva02-l-16", det_args) | |
| det_vision_tower.load_model() | |
| self.vision_towers.append(det_vision_tower) | |
| elif name == 'convnext-1024': | |
| ## ConvNeXt | |
| convnext_args = deepcopy(args) | |
| convnext_args.freeze_vision = False | |
| convnext_args.input_image_size = 1024 | |
| convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode | |
| convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, | |
| convnext_args) | |
| convnext_vision_tower.load_model() | |
| self.vision_towers.append(convnext_vision_tower) | |
| elif name == "sam-1024": | |
| sam_args = deepcopy(args) | |
| sam_args.freeze_vision = False | |
| sam_args.input_image_size = 1024 | |
| sam_args.add_pixel_shuffle = True | |
| sam_vision_tower = SAMVisionTower("SAM-L", sam_args) | |
| sam_vision_tower.load_model() | |
| self.vision_towers.append(sam_vision_tower) | |
| elif name == 'pix2struct-1024': | |
| pix_args = deepcopy(args) | |
| #pix_args.freeze_vision = True | |
| pix_args.input_image_size = 1024 | |
| pix_args.freeze_vision = False | |
| pix_args.do_resize = True | |
| pix_args.de_normalize = True | |
| pix_vision_tower = Pix2StructLargeVisionTower("pix2struct-large", pix_args) | |
| pix_vision_tower.load_model() | |
| self.vision_towers.append(pix_vision_tower) | |
| elif name == 'clip-448': | |
| clip_args = deepcopy(args) | |
| clip_args.input_image_size = 336 # actually 448, will have no effect | |
| clip_args.freeze_vision = False | |
| clip_vision_tower = HRCLIPVisionTower("openai/clip-vit-large-patch14-336", clip_args) | |
| clip_vision_tower.load_model() | |
| self.vision_towers.append(clip_vision_tower) | |
| # a hardcode here, so we always use convnext in the vision encoder mixture | |
| self.image_processor = convnext_vision_tower.image_processor | |
| self.is_loaded = True | |
| def load_model(self): | |
| assert self.is_loaded, "All the vision encoders should be loaded during initialization!" | |
| def forward(self, x): | |
| features = [] | |
| for vision_tower in self.vision_towers: | |
| if vision_tower.input_image_size != self.input_image_size: | |
| resized_x = F.interpolate(x.float(), | |
| size=(vision_tower.input_image_size, vision_tower.input_image_size), | |
| mode='bilinear', | |
| align_corners=True).to(dtype=x.dtype) | |
| else: | |
| resized_x = x | |
| feature = vision_tower(resized_x) | |
| if len(feature.shape) == 3: # b, n, c | |
| b, n, c = feature.shape | |
| if n == self.num_tokens: | |
| features.append(feature) | |
| continue | |
| w = h = int(n**0.5) | |
| feature = feature.transpose(1,2).reshape(b, c, h, w) | |
| else: | |
| b, c, h, w = feature.shape | |
| if w != self.grid_size: | |
| feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) | |
| features.append(feature.flatten(2,3).transpose(1,2)) | |
| features = torch.cat(features, dim=-1) | |
| return features | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return next(self.clip_vision_tower.parameters()).dtype | |
| def device(self): | |
| return next(self.clip_vision_tower.parameters()).device | |
| def config(self): | |
| assert NotImplementedError | |
| pass | |
| def hidden_size(self): | |
| return sum([_.hidden_size for _ in self.vision_towers]) | |
| def num_patches(self): | |
| return self.num_tokens | |