| |
| from typing import List, Sequence |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule |
| from mmengine.model import BaseModule |
|
|
| from mmdet.registry import MODELS |
| from mmdet.utils import ConfigType, OptMultiConfig |
| from ..layers import ResLayer |
| from .resnet import BasicBlock |
|
|
|
|
| class HourglassModule(BaseModule): |
| """Hourglass Module for HourglassNet backbone. |
| |
| Generate module recursively and use BasicBlock as the base unit. |
| |
| Args: |
| depth (int): Depth of current HourglassModule. |
| stage_channels (list[int]): Feature channels of sub-modules in current |
| and follow-up HourglassModule. |
| stage_blocks (list[int]): Number of sub-modules stacked in current and |
| follow-up HourglassModule. |
| norm_cfg (ConfigType): Dictionary to construct and config norm layer. |
| Defaults to `dict(type='BN', requires_grad=True)` |
| upsample_cfg (ConfigType): Config dict for interpolate layer. |
| Defaults to `dict(mode='nearest')` |
| init_cfg (dict or ConfigDict, optional): the config to control the |
| initialization. |
| """ |
|
|
| def __init__(self, |
| depth: int, |
| stage_channels: List[int], |
| stage_blocks: List[int], |
| norm_cfg: ConfigType = dict(type='BN', requires_grad=True), |
| upsample_cfg: ConfigType = dict(mode='nearest'), |
| init_cfg: OptMultiConfig = None) -> None: |
| super().__init__(init_cfg) |
|
|
| self.depth = depth |
|
|
| cur_block = stage_blocks[0] |
| next_block = stage_blocks[1] |
|
|
| cur_channel = stage_channels[0] |
| next_channel = stage_channels[1] |
|
|
| self.up1 = ResLayer( |
| BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg) |
|
|
| self.low1 = ResLayer( |
| BasicBlock, |
| cur_channel, |
| next_channel, |
| cur_block, |
| stride=2, |
| norm_cfg=norm_cfg) |
|
|
| if self.depth > 1: |
| self.low2 = HourglassModule(depth - 1, stage_channels[1:], |
| stage_blocks[1:]) |
| else: |
| self.low2 = ResLayer( |
| BasicBlock, |
| next_channel, |
| next_channel, |
| next_block, |
| norm_cfg=norm_cfg) |
|
|
| self.low3 = ResLayer( |
| BasicBlock, |
| next_channel, |
| cur_channel, |
| cur_block, |
| norm_cfg=norm_cfg, |
| downsample_first=False) |
|
|
| self.up2 = F.interpolate |
| self.upsample_cfg = upsample_cfg |
|
|
| def forward(self, x: torch.Tensor) -> nn.Module: |
| """Forward function.""" |
| up1 = self.up1(x) |
| low1 = self.low1(x) |
| low2 = self.low2(low1) |
| low3 = self.low3(low2) |
| |
| |
| if 'scale_factor' in self.upsample_cfg: |
| up2 = self.up2(low3, **self.upsample_cfg) |
| else: |
| shape = up1.shape[2:] |
| up2 = self.up2(low3, size=shape, **self.upsample_cfg) |
| return up1 + up2 |
|
|
|
|
| @MODELS.register_module() |
| class HourglassNet(BaseModule): |
| """HourglassNet backbone. |
| |
| Stacked Hourglass Networks for Human Pose Estimation. |
| More details can be found in the `paper |
| <https://arxiv.org/abs/1603.06937>`_ . |
| |
| Args: |
| downsample_times (int): Downsample times in a HourglassModule. |
| num_stacks (int): Number of HourglassModule modules stacked, |
| 1 for Hourglass-52, 2 for Hourglass-104. |
| stage_channels (Sequence[int]): Feature channel of each sub-module in a |
| HourglassModule. |
| stage_blocks (Sequence[int]): Number of sub-modules stacked in a |
| HourglassModule. |
| feat_channel (int): Feature channel of conv after a HourglassModule. |
| norm_cfg (norm_cfg): Dictionary to construct and config norm layer. |
| init_cfg (dict or ConfigDict, optional): the config to control the |
| initialization. |
| |
| Example: |
| >>> from mmdet.models import HourglassNet |
| >>> import torch |
| >>> self = HourglassNet() |
| >>> self.eval() |
| >>> inputs = torch.rand(1, 3, 511, 511) |
| >>> level_outputs = self.forward(inputs) |
| >>> for level_output in level_outputs: |
| ... print(tuple(level_output.shape)) |
| (1, 256, 128, 128) |
| (1, 256, 128, 128) |
| """ |
|
|
| def __init__(self, |
| downsample_times: int = 5, |
| num_stacks: int = 2, |
| stage_channels: Sequence = (256, 256, 384, 384, 384, 512), |
| stage_blocks: Sequence = (2, 2, 2, 2, 2, 4), |
| feat_channel: int = 256, |
| norm_cfg: ConfigType = dict(type='BN', requires_grad=True), |
| init_cfg: OptMultiConfig = None) -> None: |
| assert init_cfg is None, 'To prevent abnormal initialization ' \ |
| 'behavior, init_cfg is not allowed to be set' |
| super().__init__(init_cfg) |
|
|
| self.num_stacks = num_stacks |
| assert self.num_stacks >= 1 |
| assert len(stage_channels) == len(stage_blocks) |
| assert len(stage_channels) > downsample_times |
|
|
| cur_channel = stage_channels[0] |
|
|
| self.stem = nn.Sequential( |
| ConvModule( |
| 3, cur_channel // 2, 7, padding=3, stride=2, |
| norm_cfg=norm_cfg), |
| ResLayer( |
| BasicBlock, |
| cur_channel // 2, |
| cur_channel, |
| 1, |
| stride=2, |
| norm_cfg=norm_cfg)) |
|
|
| self.hourglass_modules = nn.ModuleList([ |
| HourglassModule(downsample_times, stage_channels, stage_blocks) |
| for _ in range(num_stacks) |
| ]) |
|
|
| self.inters = ResLayer( |
| BasicBlock, |
| cur_channel, |
| cur_channel, |
| num_stacks - 1, |
| norm_cfg=norm_cfg) |
|
|
| self.conv1x1s = nn.ModuleList([ |
| ConvModule( |
| cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) |
| for _ in range(num_stacks - 1) |
| ]) |
|
|
| self.out_convs = nn.ModuleList([ |
| ConvModule( |
| cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) |
| for _ in range(num_stacks) |
| ]) |
|
|
| self.remap_convs = nn.ModuleList([ |
| ConvModule( |
| feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) |
| for _ in range(num_stacks - 1) |
| ]) |
|
|
| self.relu = nn.ReLU(inplace=True) |
|
|
| def init_weights(self) -> None: |
| """Init module weights.""" |
| |
| super().init_weights() |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| m.reset_parameters() |
|
|
| def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
| """Forward function.""" |
| inter_feat = self.stem(x) |
| out_feats = [] |
|
|
| for ind in range(self.num_stacks): |
| single_hourglass = self.hourglass_modules[ind] |
| out_conv = self.out_convs[ind] |
|
|
| hourglass_feat = single_hourglass(inter_feat) |
| out_feat = out_conv(hourglass_feat) |
| out_feats.append(out_feat) |
|
|
| if ind < self.num_stacks - 1: |
| inter_feat = self.conv1x1s[ind]( |
| inter_feat) + self.remap_convs[ind]( |
| out_feat) |
| inter_feat = self.inters[ind](self.relu(inter_feat)) |
|
|
| return out_feats |
|
|