| from typing import Optional, List |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmdet.registry import MODELS |
|
|
| from mmengine.model import BaseModule |
| from mmengine.dist import get_dist_info |
| from mmengine.logging import MMLogger |
| from mmengine.runner.checkpoint import CheckpointLoader |
| from timm.layers import resample_abs_pos_embed |
|
|
| from . import open_clip |
| class Data: |
| hidden_size = 1024 |
|
|
| class Output: |
| def __init__(self, hidden_states): |
| self.hidden_states = hidden_states |
|
|
| def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): |
| """Load partial pretrained model with specific prefix. |
| |
| Args: |
| prefix (str): The prefix of sub-module. |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
| details. |
| map_location (str | None): Same as :func:`torch.load`. |
| Defaults to None. |
| logger: logger |
| |
| Returns: |
| dict or OrderedDict: The loaded checkpoint. |
| """ |
|
|
| checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) |
|
|
| if 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
| if not prefix: |
| return state_dict |
| if not prefix.endswith('.'): |
| prefix += '.' |
| prefix_len = len(prefix) |
|
|
| state_dict = { |
| k[prefix_len:]: v |
| for k, v in state_dict.items() if k.startswith(prefix) |
| } |
|
|
| assert state_dict, f'{prefix} is not in the pretrained model' |
| return state_dict |
|
|
| def flatten_permute(x): |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
| return x |
|
|
| class OpenCLIPBackbone(BaseModule): |
| """OpenCLIPBackbone, |
| Please refer to: |
| https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface |
| for the supported models and checkpoints. |
| """ |
| STAGES = 4 |
|
|
| def __init__( |
| self, |
| img_size: int = 1024, |
| model_name: str = '', |
| fix: bool = True, |
| fix_layers: Optional[List] = None, |
| init_cfg=None, |
| dtype=torch.float16, |
| **kwargs, |
| ): |
| assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ |
| f"{init_cfg['type']} is not supported." |
| pretrained = init_cfg['checkpoint'] |
| super().__init__(init_cfg=None) |
| self.init_cfg = init_cfg |
| self.logger = MMLogger.get_current_instance() |
| rank, world_size = get_dist_info() |
|
|
| if world_size > 1: |
| if rank == 0: |
| if init_cfg['type'] == 'clip_pretrain': |
| _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, |
| return_transform=False, logger=self.logger) |
| elif init_cfg['type'] == 'image_pretrain': |
| _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) |
|
|
| else: |
| pass |
| dist.barrier() |
|
|
| |
| if init_cfg['type'] == 'clip_pretrain': |
| clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, |
| return_transform=False, logger=self.logger) |
| elif init_cfg['type'] == 'image_pretrain': |
| clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) |
| elif init_cfg['type'] == 'Pretrained': |
| clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) |
| else: |
| raise NotImplementedError |
|
|
| self.out_indices = (0, 1, 2, 3) |
| model_name_lower = model_name.lower() |
| if 'convnext_' in model_name_lower: |
| model_type = 'convnext' |
| if '_base' in model_name_lower: |
| output_channels = [128, 256, 512, 1024] |
| feat_size = 0 |
| elif '_large' in model_name_lower: |
| output_channels = [192, 384, 768, 1536] |
| feat_size = 0 |
| elif '_xxlarge' in model_name_lower: |
| output_channels = [384, 768, 1536, 3072] |
| feat_size = 0 |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
| elif 'rn' in model_name_lower: |
| model_type = 'resnet' |
| if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: |
| output_channels = [256, 512, 1024, 2048] |
| feat_size = 7 |
| elif model_name_lower == 'rn50x4': |
| output_channels = [320, 640, 1280, 2560] |
| feat_size = 9 |
| elif model_name_lower == 'rn50x16': |
| output_channels = [384, 768, 1536, 3072] |
| feat_size = 12 |
| elif model_name_lower == 'rn50x64': |
| output_channels = [512, 1024, 2048, 4096] |
| feat_size = 14 |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
| elif "vit" in model_name_lower: |
| model_type = 'vit' |
| if model_name_lower == 'vit-l-14': |
| output_channels = [1024, 1024, 1024, 1024] |
| feat_size = 0 |
| assert not clip_model.visual.input_patchnorm |
| assert clip_model.visual.attn_pool is None |
| elif model_name_lower == 'vit-b-32': |
| output_channels = [768, 768, 768, 768] |
| feat_size = 0 |
| assert not clip_model.visual.input_patchnorm |
| assert clip_model.visual.attn_pool is None |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
|
|
| self.model_name = model_name |
| self.fix = fix |
| self.model_type = model_type |
| self.output_channels = output_channels |
| self.feat_size = feat_size |
|
|
| self.config = Data |
| |
| self.config.hidden_size = output_channels[-2] + output_channels[-1] |
|
|
| |
| if self.model_type == 'resnet': |
| self.stem = nn.Sequential(*[ |
| clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, |
| clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, |
| clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, |
| ]) |
| elif self.model_type == 'convnext': |
| self.stem = clip_model.visual.trunk.stem |
| elif self.model_type == 'vit': |
| self.stem = clip_model.visual.conv1 |
| else: |
| raise ValueError |
|
|
| if self.model_type == 'resnet': |
| self.avgpool = clip_model.visual.avgpool |
| elif self.model_type == 'convnext': |
| self.avgpool = nn.Identity() |
| elif self.model_type == 'vit': |
| self.avgpool = flatten_permute |
| else: |
| raise ValueError |
|
|
| self.res_layers = [] |
| if self.model_type in ['vit']: |
| self.t_class_embedding = clip_model.visual.class_embedding |
| self.t_positional_embedding = clip_model.visual.positional_embedding |
| self.t_ln_pre_trans = clip_model.visual.ln_pre |
| self.t_transformer = clip_model.visual.transformer |
| else: |
| for i in range(self.STAGES): |
| if self.model_type == 'resnet': |
| layer_name = f'layer{i + 1}' |
| layer = getattr(clip_model.visual, layer_name) |
| elif self.model_type == 'convnext': |
| layer_name = f'layer{i + 1}' |
| layer = clip_model.visual.trunk.stages[i] |
| else: |
| raise ValueError |
| self.add_module(layer_name, layer) |
| self.res_layers.append(layer_name) |
|
|
| if self.model_type == 'resnet': |
| self.norm_pre = nn.Identity() |
| elif self.model_type == 'convnext': |
| self.norm_pre = clip_model.visual.trunk.norm_pre |
| elif self.model_type == 'vit': |
| self.norm_pre = nn.Identity() |
|
|
| if self.model_type == 'resnet': |
| self.head = clip_model.visual.attnpool |
| elif self.model_type == 'convnext': |
| self.head = nn.Sequential(*[ |
| clip_model.visual.trunk.head, |
| clip_model.visual.head, |
| ]) |
| elif self.model_type == 'vit': |
| self.head = clip_model.visual.ln_post |
| self.proj = clip_model.visual.proj |
|
|
| if self.init_cfg['type'] == 'Pretrained': |
| checkpoint_path = pretrained |
| state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) |
| self.load_state_dict(state_dict, strict=True) |
|
|
| self.fix_layers = fix_layers |
|
|
| if not self.fix: |
| self.train() |
| for name, param in self.norm_pre.named_parameters(): |
| param.requires_grad = False |
| for name, param in self.head.named_parameters(): |
| param.requires_grad = False |
| if self.fix_layers is not None: |
| for i, layer_name in enumerate(self.res_layers): |
| if i in self.fix_layers: |
| res_layer = getattr(self, layer_name) |
| for name, param in res_layer.named_parameters(): |
| param.requires_grad = False |
| if i == 0: |
| for name, param in self.stem.named_parameters(): |
| param.requires_grad = False |
|
|
| if self.fix: |
| self.train(mode=False) |
| for name, param in self.named_parameters(): |
| param.requires_grad = False |
|
|
| self.dtype = dtype |
| self.backbone_type = None |
|
|
| self.enable_output_gradient = False |
|
|
| def init_weights(self): |
| self.logger.info(f"Init Config for {self.model_name}") |
| self.logger.info(self.init_cfg) |
|
|
| def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: |
| if not isinstance(mode, bool): |
| raise ValueError("training mode is expected to be boolean") |
| if self.fix: |
| super().train(mode=False) |
| else: |
| super().train(mode=mode) |
| if self.fix_layers is not None: |
| for i, layer_name in enumerate(self.res_layers): |
| if i in self.fix_layers: |
| res_layer = getattr(self, layer_name) |
| res_layer.train(mode=False) |
| if i == 0: |
| self.stem.train(mode=False) |
| return self |
|
|
| def forward_func(self, x): |
| x = self.stem(x) |
| h, w = x.shape[-2:] |
| x = self.avgpool(x) |
| outs = [] |
| if self.model_type == 'vit': |
| x = torch.cat( |
| [self.t_class_embedding.to(x.dtype) + |
| torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
| x], dim=1 |
| ) |
| new_pos_embed = resample_abs_pos_embed( |
| self.t_positional_embedding[None], |
| [h, w], |
| num_prefix_tokens=1 |
| ) |
| x = x + new_pos_embed.to(x.dtype) |
| x = self.t_ln_pre_trans(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.t_transformer(x) |
| x = x.permute(1, 0, 2) |
| x = x[:, 1:] |
| x = x.permute(0, 2, 1).unflatten(2, (h, w)) |
| for i in range(self.STAGES): |
| outs.append( |
| F.interpolate( |
| x, |
| scale_factor=2 ** (2 - i), |
| mode='bilinear', |
| align_corners=False |
| ) |
| ) |
| else: |
| for i, layer_name in enumerate(self.res_layers): |
| res_layer = getattr(self, layer_name) |
| x = res_layer(x).contiguous() |
| if i in self.out_indices: |
| outs.append(x) |
| return tuple(outs) |
|
|
| def get_clip_feature(self, backbone_feat): |
| if self.model_type == 'resnet': |
| return backbone_feat |
| elif self.model_type == 'convnext': |
| return self.norm_pre(backbone_feat) |
| elif self.model_type == 'vit': |
| return backbone_feat |
| raise NotImplementedError |
|
|
| def forward_feat(self, features): |
| if self.model_type == 'convnext': |
| batch, num_query, channel = features.shape |
| features = features.reshape(batch * num_query, channel, 1, 1) |
| features = self.head(features) |
| return features.view(batch, num_query, features.shape[-1]) |
| elif self.model_type == 'resnet': |
| num_query, channel, seven, seven = features.shape |
| features = self.head(features) |
| return features |
| elif self.model_type == 'vit': |
| return (self.head(features) @ self.proj)[:, 0] |
|
|
| def forward(self, x, output_hidden_states=True): |
| if self.backbone_type is None: |
| self.backbone_type = [p.dtype for p in self.parameters()][0] |
| x = x.to(self.backbone_type) |
| if self.fix: |
| with torch.no_grad(): |
| outs = self.forward_func(x) |
| else: |
| outs = self.forward_func(x) |
|
|
| |
| |
|
|
| |
| second_outs = outs[-2] |
| second_shape = second_outs.shape[2:] |
| last_outs = outs[-1] |
| last_outs = F.interpolate( |
| last_outs, |
| size=second_shape, |
| mode='bilinear', |
| align_corners=False |
| ) |
| outs = torch.cat([second_outs, last_outs], dim=1).flatten(2).permute(0, 2, 1) |
|
|
| outs = self.set_output_gradient(outs) |
| images_feat = torch.cat([outs[:, :1, :], outs], dim=1) |
| hidden_states = [images_feat, images_feat] |
| output = Output(hidden_states=hidden_states) |
| return output |
|
|
| def enable_input_require_grads(self): |
| self.enable_output_gradient = True |
| return |
|
|
| def set_output_gradient(self, output): |
| output.requires_grad_(self.enable_output_gradient) |
| return output |
|
|
| def requires_grad_(self, state): |
| if state: |
| print("Not Frozen the Visual Encoder !") |
| else: |
| print("Frozen the Visual Encoder !") |
| for p in self.parameters(): |
| p.requires_grad_(state) |
| return |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def get_text_model(self): |
| return OpenCLIPBackboneText( |
| self.model_name, |
| init_cfg=self.init_cfg |
| ) |
|
|
| class OpenCLIPBackbone_omgseg(BaseModule): |
| """OpenCLIPBackbone, |
| Please refer to: |
| https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface |
| for the supported models and checkpoints. |
| """ |
| STAGES = 4 |
|
|
| def __init__( |
| self, |
| img_size: int = 1024, |
| model_name: str = '', |
| fix: bool = True, |
| fix_layers: Optional[List] = None, |
| init_cfg=None, |
| dtype=torch.float16, |
| **kwargs, |
| ): |
| assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ |
| f"{init_cfg['type']} is not supported." |
| pretrained = init_cfg['checkpoint'] |
| super().__init__(init_cfg=None) |
| self.init_cfg = init_cfg |
| self.logger = MMLogger.get_current_instance() |
| rank, world_size = get_dist_info() |
|
|
| if world_size > 1: |
| if rank == 0: |
| if init_cfg['type'] == 'clip_pretrain': |
| _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, |
| return_transform=False, logger=self.logger) |
| elif init_cfg['type'] == 'image_pretrain': |
| _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) |
|
|
| else: |
| pass |
| dist.barrier() |
|
|
| |
| if init_cfg['type'] == 'clip_pretrain': |
| clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, |
| return_transform=False, logger=self.logger) |
| elif init_cfg['type'] == 'image_pretrain': |
| clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) |
| elif init_cfg['type'] == 'Pretrained': |
| clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) |
| else: |
| raise NotImplementedError |
|
|
| self.out_indices = (0, 1, 2, 3) |
| model_name_lower = model_name.lower() |
| if 'convnext_' in model_name_lower: |
| model_type = 'convnext' |
| if '_base' in model_name_lower: |
| output_channels = [128, 256, 512, 1024] |
| feat_size = 0 |
| elif '_large' in model_name_lower: |
| output_channels = [192, 384, 768, 1536] |
| feat_size = 0 |
| elif '_xxlarge' in model_name_lower: |
| output_channels = [384, 768, 1536, 3072] |
| feat_size = 0 |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
| elif 'rn' in model_name_lower: |
| model_type = 'resnet' |
| if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: |
| output_channels = [256, 512, 1024, 2048] |
| feat_size = 7 |
| elif model_name_lower == 'rn50x4': |
| output_channels = [320, 640, 1280, 2560] |
| feat_size = 9 |
| elif model_name_lower == 'rn50x16': |
| output_channels = [384, 768, 1536, 3072] |
| feat_size = 12 |
| elif model_name_lower == 'rn50x64': |
| output_channels = [512, 1024, 2048, 4096] |
| feat_size = 14 |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
| elif "vit" in model_name_lower: |
| model_type = 'vit' |
| if model_name_lower == 'vit-l-14': |
| output_channels = [1024, 1024, 1024, 1024] |
| feat_size = 0 |
| assert not clip_model.visual.input_patchnorm |
| assert clip_model.visual.attn_pool is None |
| elif model_name_lower == 'vit-b-32': |
| output_channels = [768, 768, 768, 768] |
| feat_size = 0 |
| assert not clip_model.visual.input_patchnorm |
| assert clip_model.visual.attn_pool is None |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
| else: |
| raise NotImplementedError(f"{model_name} not supported yet.") |
|
|
| self.model_name = model_name |
| self.fix = fix |
| self.model_type = model_type |
| self.output_channels = output_channels |
| self.feat_size = feat_size |
|
|
| self.config = Data |
| |
| self.config.hidden_size = output_channels[-1] + output_channels[-2] |
|
|
| |
| if self.model_type == 'resnet': |
| self.stem = nn.Sequential(*[ |
| clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, |
| clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, |
| clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, |
| ]) |
| elif self.model_type == 'convnext': |
| self.stem = clip_model.visual.trunk.stem |
| elif self.model_type == 'vit': |
| self.stem = clip_model.visual.conv1 |
| else: |
| raise ValueError |
|
|
| if self.model_type == 'resnet': |
| self.avgpool = clip_model.visual.avgpool |
| elif self.model_type == 'convnext': |
| self.avgpool = nn.Identity() |
| elif self.model_type == 'vit': |
| self.avgpool = flatten_permute |
| else: |
| raise ValueError |
|
|
| self.res_layers = [] |
| if self.model_type in ['vit']: |
| self.t_class_embedding = clip_model.visual.class_embedding |
| self.t_positional_embedding = clip_model.visual.positional_embedding |
| self.t_ln_pre_trans = clip_model.visual.ln_pre |
| self.t_transformer = clip_model.visual.transformer |
| else: |
| for i in range(self.STAGES): |
| if self.model_type == 'resnet': |
| layer_name = f'layer{i + 1}' |
| layer = getattr(clip_model.visual, layer_name) |
| elif self.model_type == 'convnext': |
| layer_name = f'layer{i + 1}' |
| layer = clip_model.visual.trunk.stages[i] |
| else: |
| raise ValueError |
| self.add_module(layer_name, layer) |
| self.res_layers.append(layer_name) |
|
|
| if self.model_type == 'resnet': |
| self.norm_pre = nn.Identity() |
| elif self.model_type == 'convnext': |
| self.norm_pre = clip_model.visual.trunk.norm_pre |
| elif self.model_type == 'vit': |
| self.norm_pre = nn.Identity() |
|
|
| if self.model_type == 'resnet': |
| self.head = clip_model.visual.attnpool |
| elif self.model_type == 'convnext': |
| self.head = nn.Sequential(*[ |
| clip_model.visual.trunk.head, |
| clip_model.visual.head, |
| ]) |
| elif self.model_type == 'vit': |
| self.head = clip_model.visual.ln_post |
| self.proj = clip_model.visual.proj |
|
|
| if self.init_cfg['type'] == 'Pretrained': |
| checkpoint_path = pretrained |
| state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) |
| self.load_state_dict(state_dict, strict=True) |
|
|
| self.fix_layers = fix_layers |
|
|
| if not self.fix: |
| self.train() |
| for name, param in self.norm_pre.named_parameters(): |
| param.requires_grad = False |
| for name, param in self.head.named_parameters(): |
| param.requires_grad = False |
| if self.fix_layers is not None: |
| for i, layer_name in enumerate(self.res_layers): |
| if i in self.fix_layers: |
| res_layer = getattr(self, layer_name) |
| for name, param in res_layer.named_parameters(): |
| param.requires_grad = False |
| if i == 0: |
| for name, param in self.stem.named_parameters(): |
| param.requires_grad = False |
|
|
| if self.fix: |
| self.train(mode=False) |
| for name, param in self.named_parameters(): |
| param.requires_grad = False |
|
|
| self.dtype = dtype |
| self.backbone_type = None |
|
|
| self.enable_output_gradient = False |
|
|
| def init_weights(self): |
| self.logger.info(f"Init Config for {self.model_name}") |
| self.logger.info(self.init_cfg) |
|
|
| def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: |
| if not isinstance(mode, bool): |
| raise ValueError("training mode is expected to be boolean") |
| if self.fix: |
| super().train(mode=False) |
| else: |
| super().train(mode=mode) |
| if self.fix_layers is not None: |
| for i, layer_name in enumerate(self.res_layers): |
| if i in self.fix_layers: |
| res_layer = getattr(self, layer_name) |
| res_layer.train(mode=False) |
| if i == 0: |
| self.stem.train(mode=False) |
| return self |
|
|
| def forward_func(self, x): |
| x = self.stem(x) |
| h, w = x.shape[-2:] |
| x = self.avgpool(x) |
| outs = [] |
| if self.model_type == 'vit': |
| x = torch.cat( |
| [self.t_class_embedding.to(x.dtype) + |
| torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
| x], dim=1 |
| ) |
| new_pos_embed = resample_abs_pos_embed( |
| self.t_positional_embedding[None], |
| [h, w], |
| num_prefix_tokens=1 |
| ) |
| x = x + new_pos_embed.to(x.dtype) |
| x = self.t_ln_pre_trans(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.t_transformer(x) |
| x = x.permute(1, 0, 2) |
| x = x[:, 1:] |
| x = x.permute(0, 2, 1).unflatten(2, (h, w)) |
| for i in range(self.STAGES): |
| outs.append( |
| F.interpolate( |
| x, |
| scale_factor=2 ** (2 - i), |
| mode='bilinear', |
| align_corners=False |
| ) |
| ) |
| else: |
| for i, layer_name in enumerate(self.res_layers): |
| res_layer = getattr(self, layer_name) |
| x = res_layer(x).contiguous() |
| if i in self.out_indices: |
| outs.append(x) |
| return tuple(outs) |
|
|
| def get_clip_feature(self, backbone_feat): |
| if self.model_type == 'resnet': |
| return backbone_feat |
| elif self.model_type == 'convnext': |
| return self.norm_pre(backbone_feat) |
| elif self.model_type == 'vit': |
| return backbone_feat |
| raise NotImplementedError |
|
|
| def forward_feat(self, features): |
| if self.model_type == 'convnext': |
| batch, num_query, channel = features.shape |
| features = features.reshape(batch * num_query, channel, 1, 1) |
| features = self.head(features) |
| return features.view(batch, num_query, features.shape[-1]) |
| elif self.model_type == 'resnet': |
| num_query, channel, seven, seven = features.shape |
| features = self.head(features) |
| return features |
| elif self.model_type == 'vit': |
| return (self.head(features) @ self.proj)[:, 0] |
|
|
| def forward(self, x, output_hidden_states=True): |
| if self.backbone_type is None: |
| self.backbone_type = [p.dtype for p in self.parameters()][0] |
| x = x.to(self.backbone_type) |
| if self.fix: |
| with torch.no_grad(): |
| outs = self.forward_func(x) |
| else: |
| outs = self.forward_func(x) |
|
|
| return outs |
|
|
| def get_text_model(self): |
| return OpenCLIPBackboneText( |
| self.model_name, |
| init_cfg=self.init_cfg |
| ) |
|
|
| class OpenCLIPBackboneText(BaseModule): |
| def __init__( |
| self, |
| model_name: str = '', |
| init_cfg=None, |
| ): |
| assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported." |
| pretrained = init_cfg['checkpoint'] |
| super().__init__(init_cfg=None) |
| self.init_cfg = init_cfg |
| self.logger = MMLogger.get_current_instance() |
| rank, world_size = get_dist_info() |
|
|
| if world_size > 1: |
| if rank == 0: |
| _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, |
| logger=self.logger) |
| else: |
| pass |
| dist.barrier() |
|
|
| |
| clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, |
| logger=self.logger) |
|
|
| |
| self.text_tokenizer = open_clip.get_tokenizer(model_name) |
| self.text_transformer = clip_model.transformer |
| self.text_token_embedding = clip_model.token_embedding |
| self.text_pe = clip_model.positional_embedding |
| self.text_ln_final = clip_model.ln_final |
| self.text_proj = clip_model.text_projection |
|
|
| self.register_buffer('text_attn_mask', clip_model.attn_mask) |
|
|
| self.param_dtype = torch.float32 |
| self.model_name = model_name |
|
|
| def init_weights(self): |
| self.logger.info(f"Init Config for {self.model_name}") |
| self.logger.info(self.init_cfg) |
|
|
| |
| |
| @torch.no_grad() |
| def forward(self, text): |
| text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device) |
| x = self.text_token_embedding(text_tokens).to(self.param_dtype) |
| x = x + self.text_pe.to(self.param_dtype) |
| x = x.permute(1, 0, 2) |
| x = self.text_transformer(x, attn_mask=self.text_attn_mask) |
| x = x.permute(1, 0, 2) |
| x = self.text_ln_final(x) |
| |
| x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj |
| return x |
|
|