| | from functools import partial |
| | from typing import Literal |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from mmdet.registry import MODELS |
| |
|
| | from mmengine.model import BaseModule |
| | from mmengine.logging import MMLogger |
| |
|
| | from ext.sam import ImageEncoderViT |
| | from ext.meta.sam_meta import meta_dict, checkpoint_dict |
| | from utils.load_checkpoint import load_checkpoint_with_prefix |
| |
|
| |
|
| | @MODELS.register_module() |
| | class SAMBackbone(BaseModule): |
| |
|
| | def __init__( |
| | self, |
| | model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h', |
| | fix: bool = True, |
| | init_cfg=None, |
| | ): |
| | assert init_cfg is not None and init_cfg['type'] in \ |
| | ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported." |
| | pretrained = init_cfg['checkpoint'] |
| | super().__init__(init_cfg=None) |
| | self.init_cfg = init_cfg |
| | self.logger = MMLogger.get_current_instance() |
| |
|
| | backbone_meta = meta_dict[model_name] |
| |
|
| | backbone = ImageEncoderViT( |
| | depth=backbone_meta['encoder_depth'], |
| | embed_dim=backbone_meta['encoder_embed_dim'], |
| | num_heads=backbone_meta['encoder_num_heads'], |
| | patch_size=backbone_meta['vit_patch_size'], |
| | img_size=backbone_meta['image_size'], |
| | global_attn_indexes=backbone_meta['encoder_global_attn_indexes'], |
| | out_chans=backbone_meta['prompt_embed_dim'], |
| | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), |
| | qkv_bias=True, |
| | use_rel_pos=True, |
| | mlp_ratio=4, |
| | window_size=14, |
| | ) |
| | if self.init_cfg['type'] == 'sam_pretrain': |
| | checkpoint_path = checkpoint_dict[pretrained] |
| | state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='image_encoder') |
| | backbone.load_state_dict(state_dict, strict=True) |
| |
|
| | self.stem = backbone.patch_embed |
| | self.pos_embed = backbone.pos_embed |
| |
|
| | self.res_layers = [] |
| | last_pos = 0 |
| | for idx, cur_pos in enumerate(backbone_meta['encoder_global_attn_indexes']): |
| | blocks = backbone.blocks[last_pos:cur_pos + 1] |
| | layer_name = f'layer{idx + 1}' |
| | self.add_module(layer_name, nn.Sequential(*blocks)) |
| | self.res_layers.append(layer_name) |
| | last_pos = cur_pos + 1 |
| |
|
| | self.out_proj = backbone.neck |
| |
|
| | if self.init_cfg['type'] == 'Pretrained': |
| | checkpoint_path = pretrained |
| | state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) |
| | self.load_state_dict(state_dict, strict=True) |
| |
|
| | self.model_name = model_name |
| | self.fix = fix |
| | self.model_type = 'vit' |
| | self.output_channels = None |
| | self.out_indices = (0, 1, 2, 3) |
| | if self.fix: |
| | self.train(mode=False) |
| | for name, param in self.named_parameters(): |
| | param.requires_grad = False |
| |
|
| | def init_weights(self): |
| | self.logger.info(f"Init Config for {self.model_name}") |
| | self.logger.info(self.init_cfg) |
| |
|
| | def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: |
| | if not isinstance(mode, bool): |
| | raise ValueError("training mode is expected to be boolean") |
| | if self.fix: |
| | super().train(mode=False) |
| | else: |
| | super().train(mode=mode) |
| | return self |
| |
|
| | def forward_func(self, x): |
| | x = self.stem(x) |
| | x = x + self.pos_embed |
| | outs = [] |
| | for i, layer_name in enumerate(self.res_layers): |
| | res_layer = getattr(self, layer_name) |
| | x = res_layer(x) |
| | if i in self.out_indices: |
| | outs.append(x.permute(0, 3, 1, 2).contiguous()) |
| | outs[-1] = self.out_proj(outs[-1]) |
| | return tuple(outs) |
| |
|
| | def forward(self, x): |
| | if self.fix: |
| | with torch.no_grad(): |
| | outs = self.forward_func(x) |
| | else: |
| | outs = self.forward_func(x) |
| | return outs |
| |
|