| |
| |
| |
| import math |
| from os.path import join |
| import numpy as np |
| import copy |
| from functools import partial |
|
|
| import torch |
| from torch import nn |
| import torch.utils.model_zoo as model_zoo |
| import torch.nn.functional as F |
| import fvcore.nn.weight_init as weight_init |
|
|
| from detectron2.modeling.backbone import FPN |
| from detectron2.modeling.backbone.build import BACKBONE_REGISTRY |
| from detectron2.layers.batch_norm import get_norm, FrozenBatchNorm2d |
| from detectron2.modeling.backbone import Backbone |
|
|
| from timm import create_model |
| from timm.models.helpers import build_model_with_cfg |
| from timm.models.registry import register_model |
| from timm.models.resnet import ResNet, Bottleneck |
| from timm.models.resnet import default_cfgs as default_cfgs_resnet |
|
|
|
|
| class CustomResNet(ResNet): |
| def __init__(self, **kwargs): |
| self.out_indices = kwargs.pop('out_indices') |
| super().__init__(**kwargs) |
|
|
|
|
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.act1(x) |
| x = self.maxpool(x) |
| ret = [x] |
| x = self.layer1(x) |
| ret.append(x) |
| x = self.layer2(x) |
| ret.append(x) |
| x = self.layer3(x) |
| ret.append(x) |
| x = self.layer4(x) |
| ret.append(x) |
| return [ret[i] for i in self.out_indices] |
|
|
|
|
| def load_pretrained(self, cached_file): |
| data = torch.load(cached_file, map_location='cpu') |
| if 'state_dict' in data: |
| self.load_state_dict(data['state_dict']) |
| else: |
| self.load_state_dict(data) |
|
|
|
|
| model_params = { |
| 'resnet50': dict(block=Bottleneck, layers=[3, 4, 6, 3]), |
| 'resnet50_in21k': dict(block=Bottleneck, layers=[3, 4, 6, 3]), |
| } |
|
|
|
|
| def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs): |
| params = model_params[variant] |
| default_cfgs_resnet['resnet50_in21k'] = \ |
| copy.deepcopy(default_cfgs_resnet['resnet50']) |
| default_cfgs_resnet['resnet50_in21k']['url'] = \ |
| 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth' |
| default_cfgs_resnet['resnet50_in21k']['num_classes'] = 11221 |
|
|
| return build_model_with_cfg( |
| CustomResNet, variant, pretrained, |
| default_cfg=default_cfgs_resnet[variant], |
| out_indices=out_indices, |
| pretrained_custom_load=True, |
| **params, |
| **kwargs) |
|
|
|
|
| class LastLevelP6P7_P5(nn.Module): |
| """ |
| """ |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.num_levels = 2 |
| self.in_feature = "p5" |
| self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) |
| self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) |
| for module in [self.p6, self.p7]: |
| weight_init.c2_xavier_fill(module) |
|
|
| def forward(self, c5): |
| p6 = self.p6(c5) |
| p7 = self.p7(F.relu(p6)) |
| return [p6, p7] |
|
|
|
|
| def freeze_module(x): |
| """ |
| """ |
| for p in x.parameters(): |
| p.requires_grad = False |
| FrozenBatchNorm2d.convert_frozen_batchnorm(x) |
| return x |
|
|
|
|
| class TIMM(Backbone): |
| def __init__(self, base_name, out_levels, freeze_at=0, norm='FrozenBN'): |
| super().__init__() |
| out_indices = [x - 1 for x in out_levels] |
| if 'resnet' in base_name: |
| self.base = create_timm_resnet( |
| base_name, out_indices=out_indices, |
| pretrained=False) |
| elif 'eff' in base_name: |
| self.base = create_model( |
| base_name, features_only=True, |
| out_indices=out_indices, pretrained=True) |
| else: |
| assert 0, base_name |
| feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) \ |
| for i, f in enumerate(self.base.feature_info)] |
| self._out_features = ['layer{}'.format(x) for x in out_levels] |
| self._out_feature_channels = { |
| 'layer{}'.format(l): feature_info[l - 1]['num_chs'] for l in out_levels} |
| self._out_feature_strides = { |
| 'layer{}'.format(l): feature_info[l - 1]['reduction'] for l in out_levels} |
| self._size_divisibility = max(self._out_feature_strides.values()) |
| if 'resnet' in base_name: |
| self.freeze(freeze_at) |
| if norm == 'FrozenBN': |
| self = FrozenBatchNorm2d.convert_frozen_batchnorm(self) |
|
|
| def freeze(self, freeze_at=0): |
| """ |
| """ |
| if freeze_at >= 1: |
| print('Frezing', self.base.conv1) |
| self.base.conv1 = freeze_module(self.base.conv1) |
| if freeze_at >= 2: |
| print('Frezing', self.base.layer1) |
| self.base.layer1 = freeze_module(self.base.layer1) |
|
|
| def forward(self, x): |
| features = self.base(x) |
| ret = {k: v for k, v in zip(self._out_features, features)} |
| return ret |
| |
| @property |
| def size_divisibility(self): |
| return self._size_divisibility |
|
|
|
|
| @BACKBONE_REGISTRY.register() |
| def build_timm_backbone(cfg, input_shape): |
| model = TIMM( |
| cfg.MODEL.TIMM.BASE_NAME, |
| cfg.MODEL.TIMM.OUT_LEVELS, |
| freeze_at=cfg.MODEL.TIMM.FREEZE_AT, |
| norm=cfg.MODEL.TIMM.NORM, |
| ) |
| return model |
|
|
|
|
| @BACKBONE_REGISTRY.register() |
| def build_p67_timm_fpn_backbone(cfg, input_shape): |
| """ |
| """ |
| bottom_up = build_timm_backbone(cfg, input_shape) |
| in_features = cfg.MODEL.FPN.IN_FEATURES |
| out_channels = cfg.MODEL.FPN.OUT_CHANNELS |
| backbone = FPN( |
| bottom_up=bottom_up, |
| in_features=in_features, |
| out_channels=out_channels, |
| norm=cfg.MODEL.FPN.NORM, |
| top_block=LastLevelP6P7_P5(out_channels, out_channels), |
| fuse_type=cfg.MODEL.FPN.FUSE_TYPE, |
| ) |
| return backbone |
|
|
| @BACKBONE_REGISTRY.register() |
| def build_p35_timm_fpn_backbone(cfg, input_shape): |
| """ |
| """ |
| bottom_up = build_timm_backbone(cfg, input_shape) |
| |
| in_features = cfg.MODEL.FPN.IN_FEATURES |
| out_channels = cfg.MODEL.FPN.OUT_CHANNELS |
| backbone = FPN( |
| bottom_up=bottom_up, |
| in_features=in_features, |
| out_channels=out_channels, |
| norm=cfg.MODEL.FPN.NORM, |
| top_block=None, |
| fuse_type=cfg.MODEL.FPN.FUSE_TYPE, |
| ) |
| return backbone |