| """ |
| reference |
| - https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py |
| |
| Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os |
| from .common import FrozenBatchNorm2d |
| from ..core import register |
| import logging |
|
|
| |
| kaiming_normal_ = nn.init.kaiming_normal_ |
| zeros_ = nn.init.zeros_ |
| ones_ = nn.init.ones_ |
|
|
| __all__ = ['HGNetv2'] |
|
|
|
|
| class LearnableAffineBlock(nn.Module): |
| def __init__( |
| self, |
| scale_value=1.0, |
| bias_value=0.0 |
| ): |
| super().__init__() |
| self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) |
| self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) |
|
|
| def forward(self, x): |
| return self.scale * x + self.bias |
|
|
|
|
| class ConvBNAct(nn.Module): |
| def __init__( |
| self, |
| in_chs, |
| out_chs, |
| kernel_size, |
| stride=1, |
| groups=1, |
| padding='', |
| use_act=True, |
| use_lab=False |
| ): |
| super().__init__() |
| self.use_act = use_act |
| self.use_lab = use_lab |
| if padding == 'same': |
| self.conv = nn.Sequential( |
| nn.ZeroPad2d([0, 1, 0, 1]), |
| nn.Conv2d( |
| in_chs, |
| out_chs, |
| kernel_size, |
| stride, |
| groups=groups, |
| bias=False |
| ) |
| ) |
| else: |
| self.conv = nn.Conv2d( |
| in_chs, |
| out_chs, |
| kernel_size, |
| stride, |
| padding=(kernel_size - 1) // 2, |
| groups=groups, |
| bias=False |
| ) |
| self.bn = nn.BatchNorm2d(out_chs) |
| if self.use_act: |
| self.act = nn.ReLU() |
| else: |
| self.act = nn.Identity() |
| if self.use_act and self.use_lab: |
| self.lab = LearnableAffineBlock() |
| else: |
| self.lab = nn.Identity() |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| x = self.act(x) |
| x = self.lab(x) |
| return x |
|
|
|
|
| class LightConvBNAct(nn.Module): |
| def __init__( |
| self, |
| in_chs, |
| out_chs, |
| kernel_size, |
| groups=1, |
| use_lab=False, |
| ): |
| super().__init__() |
| self.conv1 = ConvBNAct( |
| in_chs, |
| out_chs, |
| kernel_size=1, |
| use_act=False, |
| use_lab=use_lab, |
| ) |
| self.conv2 = ConvBNAct( |
| out_chs, |
| out_chs, |
| kernel_size=kernel_size, |
| groups=out_chs, |
| use_act=True, |
| use_lab=use_lab, |
| ) |
|
|
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
|
|
|
|
| class StemBlock(nn.Module): |
| |
| def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): |
| super().__init__() |
| self.stem1 = ConvBNAct( |
| in_chs, |
| mid_chs, |
| kernel_size=3, |
| stride=2, |
| use_lab=use_lab, |
| ) |
| self.stem2a = ConvBNAct( |
| mid_chs, |
| mid_chs // 2, |
| kernel_size=2, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| self.stem2b = ConvBNAct( |
| mid_chs // 2, |
| mid_chs, |
| kernel_size=2, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| self.stem3 = ConvBNAct( |
| mid_chs * 2, |
| mid_chs, |
| kernel_size=3, |
| stride=2, |
| use_lab=use_lab, |
| ) |
| self.stem4 = ConvBNAct( |
| mid_chs, |
| out_chs, |
| kernel_size=1, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) |
|
|
| def forward(self, x): |
| x = self.stem1(x) |
| x = F.pad(x, (0, 1, 0, 1)) |
| x2 = self.stem2a(x) |
| x2 = F.pad(x2, (0, 1, 0, 1)) |
| x2 = self.stem2b(x2) |
| x1 = self.pool(x) |
| x = torch.cat([x1, x2], dim=1) |
| x = self.stem3(x) |
| x = self.stem4(x) |
| return x |
|
|
|
|
| class EseModule(nn.Module): |
| def __init__(self, chs): |
| super().__init__() |
| self.conv = nn.Conv2d( |
| chs, |
| chs, |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| ) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| identity = x |
| x = x.mean((2, 3), keepdim=True) |
| x = self.conv(x) |
| x = self.sigmoid(x) |
| return torch.mul(identity, x) |
|
|
|
|
| class HG_Block(nn.Module): |
| def __init__( |
| self, |
| in_chs, |
| mid_chs, |
| out_chs, |
| layer_num, |
| kernel_size=3, |
| residual=False, |
| light_block=False, |
| use_lab=False, |
| agg='ese', |
| drop_path=0., |
| ): |
| super().__init__() |
| self.residual = residual |
|
|
| self.layers = nn.ModuleList() |
| for i in range(layer_num): |
| if light_block: |
| self.layers.append( |
| LightConvBNAct( |
| in_chs if i == 0 else mid_chs, |
| mid_chs, |
| kernel_size=kernel_size, |
| use_lab=use_lab, |
| ) |
| ) |
| else: |
| self.layers.append( |
| ConvBNAct( |
| in_chs if i == 0 else mid_chs, |
| mid_chs, |
| kernel_size=kernel_size, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| ) |
|
|
| |
| total_chs = in_chs + layer_num * mid_chs |
| if agg == 'se': |
| aggregation_squeeze_conv = ConvBNAct( |
| total_chs, |
| out_chs // 2, |
| kernel_size=1, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| aggregation_excitation_conv = ConvBNAct( |
| out_chs // 2, |
| out_chs, |
| kernel_size=1, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| self.aggregation = nn.Sequential( |
| aggregation_squeeze_conv, |
| aggregation_excitation_conv, |
| ) |
| else: |
| aggregation_conv = ConvBNAct( |
| total_chs, |
| out_chs, |
| kernel_size=1, |
| stride=1, |
| use_lab=use_lab, |
| ) |
| att = EseModule(out_chs) |
| self.aggregation = nn.Sequential( |
| aggregation_conv, |
| att, |
| ) |
|
|
| self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity() |
|
|
| def forward(self, x): |
| identity = x |
| output = [x] |
| for layer in self.layers: |
| x = layer(x) |
| output.append(x) |
| x = torch.cat(output, dim=1) |
| x = self.aggregation(x) |
| if self.residual: |
| x = self.drop_path(x) + identity |
| return x |
|
|
|
|
| class HG_Stage(nn.Module): |
| def __init__( |
| self, |
| in_chs, |
| mid_chs, |
| out_chs, |
| block_num, |
| layer_num, |
| downsample=True, |
| light_block=False, |
| kernel_size=3, |
| use_lab=False, |
| agg='se', |
| drop_path=0., |
| ): |
| super().__init__() |
| self.downsample = downsample |
| if downsample: |
| self.downsample = ConvBNAct( |
| in_chs, |
| in_chs, |
| kernel_size=3, |
| stride=2, |
| groups=in_chs, |
| use_act=False, |
| use_lab=use_lab, |
| ) |
| else: |
| self.downsample = nn.Identity() |
|
|
| blocks_list = [] |
| for i in range(block_num): |
| blocks_list.append( |
| HG_Block( |
| in_chs if i == 0 else out_chs, |
| mid_chs, |
| out_chs, |
| layer_num, |
| residual=False if i == 0 else True, |
| kernel_size=kernel_size, |
| light_block=light_block, |
| use_lab=use_lab, |
| agg=agg, |
| drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, |
| ) |
| ) |
| self.blocks = nn.Sequential(*blocks_list) |
|
|
| def forward(self, x): |
| x = self.downsample(x) |
| x = self.blocks(x) |
| return x |
|
|
|
|
|
|
| @register() |
| class HGNetv2(nn.Module): |
| """ |
| HGNetV2 |
| Args: |
| stem_channels: list. Number of channels for the stem block. |
| stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc. |
| use_lab: boolean. Whether to use LearnableAffineBlock in network. |
| lr_mult_list: list. Control the learning rate of different stages. |
| Returns: |
| model: nn.Layer. Specific HGNetV2 model depends on args. |
| """ |
|
|
| arch_configs = { |
| 'B0': { |
| 'stem_channels': [3, 16, 16], |
| 'stage_config': { |
| |
| "stage1": [16, 16, 64, 1, False, False, 3, 3], |
| "stage2": [64, 32, 256, 1, True, False, 3, 3], |
| "stage3": [256, 64, 512, 2, True, True, 5, 3], |
| "stage4": [512, 128, 1024, 1, True, True, 5, 3], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth' |
| }, |
| 'B1': { |
| 'stem_channels': [3, 24, 32], |
| 'stage_config': { |
| |
| "stage1": [32, 32, 64, 1, False, False, 3, 3], |
| "stage2": [64, 48, 256, 1, True, False, 3, 3], |
| "stage3": [256, 96, 512, 2, True, True, 5, 3], |
| "stage4": [512, 192, 1024, 1, True, True, 5, 3], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth' |
| }, |
| 'B2': { |
| 'stem_channels': [3, 24, 32], |
| 'stage_config': { |
| |
| "stage1": [32, 32, 96, 1, False, False, 3, 4], |
| "stage2": [96, 64, 384, 1, True, False, 3, 4], |
| "stage3": [384, 128, 768, 3, True, True, 5, 4], |
| "stage4": [768, 256, 1536, 1, True, True, 5, 4], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth' |
| }, |
| 'B3': { |
| 'stem_channels': [3, 24, 32], |
| 'stage_config': { |
| |
| "stage1": [32, 32, 128, 1, False, False, 3, 5], |
| "stage2": [128, 64, 512, 1, True, False, 3, 5], |
| "stage3": [512, 128, 1024, 3, True, True, 5, 5], |
| "stage4": [1024, 256, 2048, 1, True, True, 5, 5], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth' |
| }, |
| 'B4': { |
| 'stem_channels': [3, 32, 48], |
| 'stage_config': { |
| |
| "stage1": [48, 48, 128, 1, False, False, 3, 6], |
| "stage2": [128, 96, 512, 1, True, False, 3, 6], |
| "stage3": [512, 192, 1024, 3, True, True, 5, 6], |
| "stage4": [1024, 384, 2048, 1, True, True, 5, 6], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth' |
| }, |
| 'B5': { |
| 'stem_channels': [3, 32, 64], |
| 'stage_config': { |
| |
| "stage1": [64, 64, 128, 1, False, False, 3, 6], |
| "stage2": [128, 128, 512, 2, True, False, 3, 6], |
| "stage3": [512, 256, 1024, 5, True, True, 5, 6], |
| "stage4": [1024, 512, 2048, 2, True, True, 5, 6], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth' |
| }, |
| 'B6': { |
| 'stem_channels': [3, 48, 96], |
| 'stage_config': { |
| |
| "stage1": [96, 96, 192, 2, False, False, 3, 6], |
| "stage2": [192, 192, 512, 3, True, False, 3, 6], |
| "stage3": [512, 384, 1024, 6, True, True, 5, 6], |
| "stage4": [1024, 768, 2048, 3, True, True, 5, 6], |
| }, |
| 'url': 'https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth' |
| }, |
| } |
|
|
| def __init__(self, |
| name, |
| use_lab=False, |
| return_idx=[1, 2, 3], |
| freeze_stem_only=True, |
| freeze_at=0, |
| freeze_norm=True, |
| pretrained=True, |
| local_model_dir='weight/hgnetv2/'): |
| super().__init__() |
| self.use_lab = use_lab |
| self.return_idx = return_idx |
|
|
| stem_channels = self.arch_configs[name]['stem_channels'] |
| stage_config = self.arch_configs[name]['stage_config'] |
| download_url = self.arch_configs[name]['url'] |
|
|
| self._out_strides = [4, 8, 16, 32] |
| self._out_channels = [stage_config[k][2] for k in stage_config] |
|
|
| |
| self.stem = StemBlock( |
| in_chs=stem_channels[0], |
| mid_chs=stem_channels[1], |
| out_chs=stem_channels[2], |
| use_lab=use_lab) |
|
|
| |
| self.stages = nn.ModuleList() |
| for i, k in enumerate(stage_config): |
| in_channels, mid_channels, out_channels, block_num, downsample, light_block, kernel_size, layer_num = stage_config[ |
| k] |
| self.stages.append( |
| HG_Stage( |
| in_channels, |
| mid_channels, |
| out_channels, |
| block_num, |
| layer_num, |
| downsample, |
| light_block, |
| kernel_size, |
| use_lab)) |
|
|
| if freeze_at >= 0: |
| self._freeze_parameters(self.stem) |
| if not freeze_stem_only: |
| for i in range(min(freeze_at + 1, len(self.stages))): |
| self._freeze_parameters(self.stages[i]) |
|
|
| if freeze_norm: |
| self._freeze_norm(self) |
|
|
| if pretrained: |
| RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m" |
| try: |
| model_path = local_model_dir + 'PPHGNetV2_' + name + '_stage1.pth' |
| if os.path.exists(model_path): |
| state = torch.load(model_path, map_location='cpu') |
| print(f"Loaded stage1 {name} HGNetV2 from local file.") |
| else: |
| |
| if torch.distributed.get_rank() == 0: |
| print(GREEN + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." + RESET) |
| print(GREEN + "Please check your network connection. Or download the model manually from " + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET) |
| state = torch.hub.load_state_dict_from_url(download_url, map_location='cpu', model_dir=local_model_dir) |
| torch.distributed.barrier() |
| else: |
| torch.distributed.barrier() |
| state = torch.load(local_model_dir) |
|
|
| print(f"Loaded stage1 {name} HGNetV2 from URL.") |
|
|
| self.load_state_dict(state) |
|
|
| except (Exception, KeyboardInterrupt) as e: |
| if torch.distributed.get_rank() == 0: |
| print(f"{str(e)}") |
| logging.error(RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET) |
| logging.error(GREEN + "Please check your network connection. Or download the model manually from " \ |
| + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET) |
| exit() |
|
|
|
|
|
|
|
|
| def _freeze_norm(self, m: nn.Module): |
| if isinstance(m, nn.BatchNorm2d): |
| m = FrozenBatchNorm2d(m.num_features) |
| else: |
| for name, child in m.named_children(): |
| _child = self._freeze_norm(child) |
| if _child is not child: |
| setattr(m, name, _child) |
| return m |
|
|
| def _freeze_parameters(self, m: nn.Module): |
| for p in m.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, x): |
| x = self.stem(x) |
| outs = [] |
| for idx, stage in enumerate(self.stages): |
| x = stage(x) |
| if idx in self.return_idx: |
| outs.append(x) |
| return outs |
|
|