| | |
| | from typing import List, Sequence |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| | from mmengine.model import BaseModule |
| |
|
| | from mmdet.registry import MODELS |
| | from mmdet.utils import ConfigType, OptMultiConfig |
| | from ..layers import ResLayer |
| | from .resnet import BasicBlock |
| |
|
| |
|
| | class HourglassModule(BaseModule): |
| | """Hourglass Module for HourglassNet backbone. |
| | |
| | Generate module recursively and use BasicBlock as the base unit. |
| | |
| | Args: |
| | depth (int): Depth of current HourglassModule. |
| | stage_channels (list[int]): Feature channels of sub-modules in current |
| | and follow-up HourglassModule. |
| | stage_blocks (list[int]): Number of sub-modules stacked in current and |
| | follow-up HourglassModule. |
| | norm_cfg (ConfigType): Dictionary to construct and config norm layer. |
| | Defaults to `dict(type='BN', requires_grad=True)` |
| | upsample_cfg (ConfigType): Config dict for interpolate layer. |
| | Defaults to `dict(mode='nearest')` |
| | init_cfg (dict or ConfigDict, optional): the config to control the |
| | initialization. |
| | """ |
| |
|
| | def __init__(self, |
| | depth: int, |
| | stage_channels: List[int], |
| | stage_blocks: List[int], |
| | norm_cfg: ConfigType = dict(type='BN', requires_grad=True), |
| | upsample_cfg: ConfigType = dict(mode='nearest'), |
| | init_cfg: OptMultiConfig = None) -> None: |
| | super().__init__(init_cfg) |
| |
|
| | self.depth = depth |
| |
|
| | cur_block = stage_blocks[0] |
| | next_block = stage_blocks[1] |
| |
|
| | cur_channel = stage_channels[0] |
| | next_channel = stage_channels[1] |
| |
|
| | self.up1 = ResLayer( |
| | BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg) |
| |
|
| | self.low1 = ResLayer( |
| | BasicBlock, |
| | cur_channel, |
| | next_channel, |
| | cur_block, |
| | stride=2, |
| | norm_cfg=norm_cfg) |
| |
|
| | if self.depth > 1: |
| | self.low2 = HourglassModule(depth - 1, stage_channels[1:], |
| | stage_blocks[1:]) |
| | else: |
| | self.low2 = ResLayer( |
| | BasicBlock, |
| | next_channel, |
| | next_channel, |
| | next_block, |
| | norm_cfg=norm_cfg) |
| |
|
| | self.low3 = ResLayer( |
| | BasicBlock, |
| | next_channel, |
| | cur_channel, |
| | cur_block, |
| | norm_cfg=norm_cfg, |
| | downsample_first=False) |
| |
|
| | self.up2 = F.interpolate |
| | self.upsample_cfg = upsample_cfg |
| |
|
| | def forward(self, x: torch.Tensor) -> nn.Module: |
| | """Forward function.""" |
| | up1 = self.up1(x) |
| | low1 = self.low1(x) |
| | low2 = self.low2(low1) |
| | low3 = self.low3(low2) |
| | |
| | |
| | if 'scale_factor' in self.upsample_cfg: |
| | up2 = self.up2(low3, **self.upsample_cfg) |
| | else: |
| | shape = up1.shape[2:] |
| | up2 = self.up2(low3, size=shape, **self.upsample_cfg) |
| | return up1 + up2 |
| |
|
| |
|
| | @MODELS.register_module() |
| | class HourglassNet(BaseModule): |
| | """HourglassNet backbone. |
| | |
| | Stacked Hourglass Networks for Human Pose Estimation. |
| | More details can be found in the `paper |
| | <https://arxiv.org/abs/1603.06937>`_ . |
| | |
| | Args: |
| | downsample_times (int): Downsample times in a HourglassModule. |
| | num_stacks (int): Number of HourglassModule modules stacked, |
| | 1 for Hourglass-52, 2 for Hourglass-104. |
| | stage_channels (Sequence[int]): Feature channel of each sub-module in a |
| | HourglassModule. |
| | stage_blocks (Sequence[int]): Number of sub-modules stacked in a |
| | HourglassModule. |
| | feat_channel (int): Feature channel of conv after a HourglassModule. |
| | norm_cfg (norm_cfg): Dictionary to construct and config norm layer. |
| | init_cfg (dict or ConfigDict, optional): the config to control the |
| | initialization. |
| | |
| | Example: |
| | >>> from mmdet.models import HourglassNet |
| | >>> import torch |
| | >>> self = HourglassNet() |
| | >>> self.eval() |
| | >>> inputs = torch.rand(1, 3, 511, 511) |
| | >>> level_outputs = self.forward(inputs) |
| | >>> for level_output in level_outputs: |
| | ... print(tuple(level_output.shape)) |
| | (1, 256, 128, 128) |
| | (1, 256, 128, 128) |
| | """ |
| |
|
| | def __init__(self, |
| | downsample_times: int = 5, |
| | num_stacks: int = 2, |
| | stage_channels: Sequence = (256, 256, 384, 384, 384, 512), |
| | stage_blocks: Sequence = (2, 2, 2, 2, 2, 4), |
| | feat_channel: int = 256, |
| | norm_cfg: ConfigType = dict(type='BN', requires_grad=True), |
| | init_cfg: OptMultiConfig = None) -> None: |
| | assert init_cfg is None, 'To prevent abnormal initialization ' \ |
| | 'behavior, init_cfg is not allowed to be set' |
| | super().__init__(init_cfg) |
| |
|
| | self.num_stacks = num_stacks |
| | assert self.num_stacks >= 1 |
| | assert len(stage_channels) == len(stage_blocks) |
| | assert len(stage_channels) > downsample_times |
| |
|
| | cur_channel = stage_channels[0] |
| |
|
| | self.stem = nn.Sequential( |
| | ConvModule( |
| | 3, cur_channel // 2, 7, padding=3, stride=2, |
| | norm_cfg=norm_cfg), |
| | ResLayer( |
| | BasicBlock, |
| | cur_channel // 2, |
| | cur_channel, |
| | 1, |
| | stride=2, |
| | norm_cfg=norm_cfg)) |
| |
|
| | self.hourglass_modules = nn.ModuleList([ |
| | HourglassModule(downsample_times, stage_channels, stage_blocks) |
| | for _ in range(num_stacks) |
| | ]) |
| |
|
| | self.inters = ResLayer( |
| | BasicBlock, |
| | cur_channel, |
| | cur_channel, |
| | num_stacks - 1, |
| | norm_cfg=norm_cfg) |
| |
|
| | self.conv1x1s = nn.ModuleList([ |
| | ConvModule( |
| | cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) |
| | for _ in range(num_stacks - 1) |
| | ]) |
| |
|
| | self.out_convs = nn.ModuleList([ |
| | ConvModule( |
| | cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) |
| | for _ in range(num_stacks) |
| | ]) |
| |
|
| | self.remap_convs = nn.ModuleList([ |
| | ConvModule( |
| | feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) |
| | for _ in range(num_stacks - 1) |
| | ]) |
| |
|
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | def init_weights(self) -> None: |
| | """Init module weights.""" |
| | |
| | super().init_weights() |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | m.reset_parameters() |
| |
|
| | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
| | """Forward function.""" |
| | inter_feat = self.stem(x) |
| | out_feats = [] |
| |
|
| | for ind in range(self.num_stacks): |
| | single_hourglass = self.hourglass_modules[ind] |
| | out_conv = self.out_convs[ind] |
| |
|
| | hourglass_feat = single_hourglass(inter_feat) |
| | out_feat = out_conv(hourglass_feat) |
| | out_feats.append(out_feat) |
| |
|
| | if ind < self.num_stacks - 1: |
| | inter_feat = self.conv1x1s[ind]( |
| | inter_feat) + self.remap_convs[ind]( |
| | out_feat) |
| | inter_feat = self.inters[ind](self.relu(inter_feat)) |
| |
|
| | return out_feats |
| |
|