| import logging |
| import torch |
| import torch.nn as nn |
| from mmcv.runner.checkpoint import load_state_dict |
| from mmdet.models.builder import BACKBONES |
| from .vit import ViT, SimpleFeaturePyramid, partial |
| from .fpn import LastLevelMaxPool |
|
|
|
|
| @BACKBONES.register_module() |
| class EVA02(nn.Module): |
| def __init__( |
| self, |
| |
| img_size=1024, |
| real_img_size=(256, 704), |
| patch_size=16, |
| in_chans=3, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| mlp_ratio=4*2/3, |
| qkv_bias=True, |
| drop_path_rate=0.0, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| use_abs_pos=True, |
| pt_hw_seq_len=16, |
| intp_freq=True, |
| window_size=0, |
| window_block_indexes=(), |
| residual_block_indexes=(), |
| use_act_checkpoint=False, |
| pretrain_img_size=224, |
| pretrain_use_cls_token=True, |
| out_feature="last_feat", |
| xattn=False, |
| frozen_blocks=-1, |
| |
| fpn_in_feature="last_feat", |
| fpn_out_channels=256, |
| fpn_scale_factors=(4.0, 2.0, 1.0, 0.5), |
| fpn_top_block=False, |
| fpn_norm="LN", |
| fpn_square_pad=0, |
| pretrained=None |
| ): |
| super().__init__() |
|
|
| self.backbone = SimpleFeaturePyramid( |
| ViT( |
| img_size=img_size, |
| real_img_size=real_img_size, |
| patch_size=patch_size, |
| in_chans=in_chans, |
| embed_dim=embed_dim, |
| depth=depth, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| drop_path_rate=drop_path_rate, |
| norm_layer=norm_layer, |
| use_abs_pos=use_abs_pos, |
| pt_hw_seq_len=pt_hw_seq_len, |
| intp_freq=intp_freq, |
| window_size=window_size, |
| window_block_indexes=window_block_indexes, |
| residual_block_indexes=residual_block_indexes, |
| use_act_checkpoint=use_act_checkpoint, |
| pretrain_img_size=pretrain_img_size, |
| pretrain_use_cls_token=pretrain_use_cls_token, |
| out_feature=out_feature, |
| xattn=xattn, |
| frozen_blocks=frozen_blocks, |
| ), |
| in_feature=fpn_in_feature, |
| out_channels=fpn_out_channels, |
| scale_factors=fpn_scale_factors, |
| top_block=LastLevelMaxPool() if fpn_top_block else None, |
| norm=fpn_norm, |
| square_pad=fpn_square_pad, |
| ) |
| self.init_weights(pretrained) |
| |
| def init_weights(self, pretrained=None): |
| if pretrained is None: |
| return |
| logging.info('Loading pretrained weights from %s' % pretrained) |
| state_dict = torch.load(pretrained)['model'] |
| load_state_dict(self, state_dict, strict=False) |
|
|
| def forward(self, x): |
| outs = self.backbone(x) |
| return list(outs.values()) |
|
|