| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from collections import OrderedDict |
|
|
| from .common import get_activation, FrozenBatchNorm2d |
|
|
| from ..core import register |
| import os |
|
|
|
|
| __all__ = ['PResNet'] |
|
|
|
|
| ResNet_cfg = { |
| 18: [2, 2, 2, 2], |
| 34: [3, 4, 6, 3], |
| 50: [3, 4, 6, 3], |
| 101: [3, 4, 23, 3], |
| |
| } |
|
|
|
|
| donwload_url = { |
| 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth', |
| 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth', |
| 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth', |
| 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth', |
| } |
|
|
| local_weights = { |
| 18: "ResNet18_vd_pretrained_from_paddle.pth", |
| 34: "ResNet34_vd_pretrained_from_paddle.pth", |
| 50: "ResNet50_vd_ssld_v2_pretrained_from_paddle.pth", |
| 101: "ResNet101_vd_ssld_pretrained_from_paddle.pth" |
| } |
|
|
| class ConvNormLayer(nn.Module): |
| def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): |
| super().__init__() |
| self.conv = nn.Conv2d( |
| ch_in, |
| ch_out, |
| kernel_size, |
| stride, |
| padding=(kernel_size-1)//2 if padding is None else padding, |
| bias=bias) |
| self.norm = nn.BatchNorm2d(ch_out) |
| self.act = get_activation(act) |
|
|
| def forward(self, x): |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): |
| super().__init__() |
|
|
| self.shortcut = shortcut |
|
|
| if not shortcut: |
| if variant == 'd' and stride == 2: |
| self.short = nn.Sequential(OrderedDict([ |
| ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), |
| ('conv', ConvNormLayer(ch_in, ch_out, 1, 1)) |
| ])) |
| else: |
| self.short = ConvNormLayer(ch_in, ch_out, 1, stride) |
|
|
| self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) |
| self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) |
| self.act = nn.Identity() if act is None else get_activation(act) |
|
|
|
|
| def forward(self, x): |
| out = self.branch2a(x) |
| out = self.branch2b(out) |
| if self.shortcut: |
| short = x |
| else: |
| short = self.short(x) |
|
|
| out = out + short |
| out = self.act(out) |
|
|
| return out |
|
|
|
|
| class BottleNeck(nn.Module): |
| expansion = 4 |
|
|
| def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): |
| super().__init__() |
|
|
| if variant == 'a': |
| stride1, stride2 = stride, 1 |
| else: |
| stride1, stride2 = 1, stride |
|
|
| width = ch_out |
|
|
| self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) |
| self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) |
| self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1) |
|
|
| self.shortcut = shortcut |
| if not shortcut: |
| if variant == 'd' and stride == 2: |
| self.short = nn.Sequential(OrderedDict([ |
| ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), |
| ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)) |
| ])) |
| else: |
| self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) |
|
|
| self.act = nn.Identity() if act is None else get_activation(act) |
|
|
| def forward(self, x): |
| out = self.branch2a(x) |
| out = self.branch2b(out) |
| out = self.branch2c(out) |
|
|
| if self.shortcut: |
| short = x |
| else: |
| short = self.short(x) |
|
|
| out = out + short |
| out = self.act(out) |
|
|
| return out |
|
|
|
|
| class Blocks(nn.Module): |
| def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'): |
| super().__init__() |
|
|
| self.blocks = nn.ModuleList() |
| for i in range(count): |
| self.blocks.append( |
| block( |
| ch_in, |
| ch_out, |
| stride=2 if i == 0 and stage_num != 2 else 1, |
| shortcut=False if i == 0 else True, |
| variant=variant, |
| act=act) |
| ) |
|
|
| if i == 0: |
| ch_in = ch_out * block.expansion |
|
|
| def forward(self, x): |
| out = x |
| for block in self.blocks: |
| out = block(out) |
| return out |
|
|
|
|
| @register() |
| class PResNet(nn.Module): |
| def __init__( |
| self, |
| depth, |
| variant='d', |
| num_stages=4, |
| return_idx=[0, 1, 2, 3], |
| act='relu', |
| freeze_at=-1, |
| freeze_norm=True, |
| pretrained=False, |
| local_model_dir='weights/resnets', |
| ): |
| super().__init__() |
|
|
| block_nums = ResNet_cfg[depth] |
| ch_in = 64 |
| if variant in ['c', 'd']: |
| conv_def = [ |
| [3, ch_in // 2, 3, 2, "conv1_1"], |
| [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], |
| [ch_in // 2, ch_in, 3, 1, "conv1_3"], |
| ] |
| else: |
| conv_def = [[3, ch_in, 7, 2, "conv1_1"]] |
|
|
| self.conv1 = nn.Sequential(OrderedDict([ |
| (name, ConvNormLayer(cin, cout, k, s, act=act)) for cin, cout, k, s, name in conv_def |
| ])) |
|
|
| ch_out_list = [64, 128, 256, 512] |
| block = BottleNeck if depth >= 50 else BasicBlock |
|
|
| _out_channels = [block.expansion * v for v in ch_out_list] |
| _out_strides = [4, 8, 16, 32] |
|
|
| self.res_layers = nn.ModuleList() |
| for i in range(num_stages): |
| stage_num = i + 2 |
| self.res_layers.append( |
| Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant) |
| ) |
| ch_in = _out_channels[i] |
|
|
| self.return_idx = return_idx |
| self.out_channels = [_out_channels[_i] for _i in return_idx] |
| self.out_strides = [_out_strides[_i] for _i in return_idx] |
|
|
| if freeze_at >= 0: |
| self._freeze_parameters(self.conv1) |
| for i in range(min(freeze_at, num_stages)): |
| self._freeze_parameters(self.res_layers[i]) |
|
|
| if freeze_norm: |
| self._freeze_norm(self) |
|
|
| if pretrained: |
| model_path = local_model_dir + local_weights[depth] |
| if os.path.exists(model_path): |
| state = torch.load(model_path, map_location='cpu') |
| print(f"Loaded PResNet{depth} from local file@{model_path}.") |
| else: |
| if isinstance(pretrained, bool) or 'http' in pretrained: |
| state = torch.hub.load_state_dict_from_url(donwload_url[depth], map_location='cpu', model_dir=local_model_dir) |
| else: |
| state = torch.load(pretrained, map_location='cpu') |
| self.load_state_dict(state) |
| print(f'Load PResNet{depth} state_dict') |
|
|
| def _freeze_parameters(self, m: nn.Module): |
| for p in m.parameters(): |
| p.requires_grad = False |
|
|
| def _freeze_norm(self, m: nn.Module): |
| if isinstance(m, nn.BatchNorm2d): |
| m = FrozenBatchNorm2d(m.num_features) |
| else: |
| for name, child in m.named_children(): |
| _child = self._freeze_norm(child) |
| if _child is not child: |
| setattr(m, name, _child) |
| return m |
|
|
| def forward(self, x): |
| conv1 = self.conv1(x) |
| x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1) |
| outs = [] |
| for idx, stage in enumerate(self.res_layers): |
| x = stage(x) |
| if idx in self.return_idx: |
| outs.append(x) |
| return outs |
|
|