AI_Fraud_Detector / models /networks /pose_efficientNet.py
astrosbd's picture
Upload 78 files
e7824d0 verified
#-*- coding: utf-8 -*-
import math
import sys
import os
if not os.getcwd() in sys.path:
sys.path.append(os.getcwd())
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
from ..builder import MODELS, build_model
from .efficientNet import (
round_filters,
round_repeats,
drop_connect,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
calculate_output_image_size,
url_map_advprop,
url_map
)
from .common import (
InceptionBlock,
conv_block,
BN_MOMENTUM,
SELayer
)
VALID_MODELS = (
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
'efficientnet-b8',
# Support the construction of 'efficientnet-l2' without pretrained weights
'efficientnet-l2'
)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck Block.
Args:
block_args (namedtuple): BlockArgs, defined in utils.py.
global_params (namedtuple): GlobalParam, defined in utils.py.
image_size (tuple or list): [image_height, image_width].
References:
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
"""
def __init__(self, block_args, global_params, image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
# Expansion phase (Inverted Bottleneck)
inp = self._block_args.input_filters # number of input channels
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._depthwise_conv = Conv2d(
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
kernel_size=k, stride=s, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
image_size = calculate_output_image_size(image_size, s)
# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
# Pointwise convolution phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()
def forward(self, inputs, drop_connect_rate=None):
"""MBConvBlock's forward function.
Args:
inputs (tensor): Input tensor.
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
Returns:
Output of this block after processing.
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._expand_conv(inputs)
x = self._bn0(x)
x = self._swish(x)
x = self._depthwise_conv(x)
x = self._bn1(x)
x = self._swish(x)
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_reduce(x_squeezed)
x_squeezed = self._swish(x_squeezed)
x_squeezed = self._se_expand(x_squeezed)
x = torch.sigmoid(x_squeezed) * x
# Pointwise Convolution
x = self._project_conv(x)
x = self._bn2(x)
# Skip connection and drop connect
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
# The combination of skip connection and drop connect brings about stochastic depth.
if drop_connect_rate:
x = drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
@MODELS.register_module()
class EfficientNet(nn.Module):
"""EfficientNet model.
Most easily loaded with the .from_name or .from_pretrained methods.
Args:
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
global_params (namedtuple): A set of GlobalParams shared between blocks.
References:
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
Example:
>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
>>> model.eval()
>>> outputs = model(inputs)
"""
def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), 'blocks_args should be a list'
assert len(blocks_args) > 0, 'block args must be greater than 0'
self._global_params = global_params
self._blocks_args = blocks_args
# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=image_size)
# Stem
in_channels = 3 # rgb
out_channels = round_filters(32, self._global_params) # number of output channels
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
image_size = calculate_output_image_size(image_size, 2)
# Build blocks
self._blocks = nn.ModuleList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(block_args.input_filters, self._global_params),
output_filters=round_filters(block_args.output_filters, self._global_params),
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
)
# The first block needs to take care of stride and filter size increase.
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
image_size = calculate_output_image_size(image_size, block_args.stride)
if block_args.num_repeat > 1: # modify block_args to keep same output size
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
if self._global_params.include_top:
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
# Heatmap Decoder Construction
if self._global_params.include_hm_decoder:
print("Constructing the heatmap Decoder!")
self.efpn = self._global_params.efpn
self.tfpn = self._global_params.tfpn
assert not (self.efpn and self.tfpn), "Only one of E-FPN or FPN is intergrated!"
self.se_layer = self._global_params.se_layer
# self.hm_decoder_filters = [1792, 448, 160, 56] if self.fpn else [1792, 256, 256, 128]
self.hm_decoder_filters = [1792, 448, 160, 56]
num_kernels = [4, 4, 4, 4] if (self.efpn or self.tfpn) else [4, 4, 4]
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._sigmoid = nn.Sigmoid()
self._relu = nn.ReLU(inplace=True)
self._relu1 = nn.ReLU(inplace=False)
self.deconv_with_bias = False
if self._global_params.use_c3:
self.inception_block = InceptionBlock(112, 112, stride=1, pool_size=3)
else:
self.inception_block = InceptionBlock(56, 56, stride=1, pool_size=3)
self.heads = self._global_params.heads
n_deconv = len(self.hm_decoder_filters)
self.fpn_layers = [self._global_params.use_c51, self._global_params.use_c4, self._global_params.use_c3]
if self.efpn or self.tfpn:
for idx in range(n_deconv):
in_decod_filters = self.hm_decoder_filters[idx]
if idx == 0:
out_decod_filters = self.hm_decoder_filters[idx+1]
deconv = nn.Sequential(
conv_block(in_decod_filters, out_decod_filters, (3,3), stride=1, padding=1),
)
else:
in_decod_filters = in_decod_filters*2 if self.fpn_layers[idx-1] else in_decod_filters
kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[idx])
if idx+1 < n_deconv:
out_decod_filters = self.hm_decoder_filters[idx+1]
deconv = nn.Sequential(
conv_block(in_decod_filters, out_decod_filters, (3,3), stride=1, padding=1),
nn.ConvTranspose2d(
in_channels=out_decod_filters,
out_channels=out_decod_filters,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias),
nn.BatchNorm2d(out_decod_filters, momentum=BN_MOMENTUM),
)
else:
out_decod_filters = in_decod_filters
deconv = nn.Sequential(
self.inception_block,
nn.ConvTranspose2d(
in_channels=out_decod_filters,
out_channels=out_decod_filters,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias),
nn.BatchNorm2d(out_decod_filters, momentum=BN_MOMENTUM),
)
# In case of using C2, this conv to apply to C2 features to get the same filters of the last deconv
if self._global_params.use_c2:
self.conv_c2 = conv_block(32, out_decod_filters, (3,3), stride=1, padding=1)
if self.se_layer:
se = SELayer(channel=out_decod_filters*2)
self.__setattr__(f'se_layer_{idx+1}', se)
self.__setattr__(f'deconv_{idx+1}', deconv)
else:
self.deconv_layers = self._make_deconv_layer(
len(num_kernels),
self.hm_decoder_filters,
num_kernels,
)
for head, num_output in self.heads.items():
head_conv = int(self._global_params.head_conv)
num_output = int(num_output)
if self._global_params.use_c2:
assert self._global_params.efpn or self._global_params.tfpn, "FPN Design must be set active!"
assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
in_head_filters = self.hm_decoder_filters[-1]*4
elif self._global_params.use_c3:
in_head_filters = self.hm_decoder_filters[-1]*2
else:
in_head_filters = self.hm_decoder_filters[-1]
if head_conv > 0:
if head != 'cls':
fc = nn.Sequential(
nn.Conv2d(in_head_filters, head_conv,
kernel_size=3, padding=1, bias=True),
nn.BatchNorm2d(head_conv),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, num_output,
kernel_size=1, stride=1, padding=0)
)
else:
fc = nn.Sequential(
nn.Conv2d(in_head_filters, head_conv, kernel_size=3,
padding=1, bias=True),
nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
# nn.Conv2d(head_conv, num_output, kernel_size=1,
# stride=1, padding=0, bias=True),
# nn.BatchNorm2d(num_output),
# nn.ReLU(inplace=True),
# nn.AdaptiveMaxPool2d(head_conv//4),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
# nn.Linear((head_conv//4)**2, head_conv, bias=True),
# nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
# nn.ReLU(inplace=True),
nn.Linear(head_conv, num_output, bias=True),
# nn.Sigmoid(),
# nn.Softmax(dim=-1)
)
else:
fc = nn.Conv2d(
in_channels=in_head_filters,
out_channels=num_output,
kernel_size=1,
stride=1,
padding=0
)
self.__setattr__(head, fc)
# set activation to memory efficient swish by default
self._swish = MemoryEfficientSwish()
def _get_deconv_cfg(self, deconv_kernel):
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
return deconv_kernel, padding, output_padding
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == (len(num_filters) - 1), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
in_planes = num_filters[i]
out_planes = num_filters[i+1]
layers.append(nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias),
nn.BatchNorm2d(out_planes, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))
)
return nn.Sequential(*layers)
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
def extract_endpoints(self, inputs):
"""Use convolution layer to extract features
from reduction levels i in [1, 2, 3, 4, 5].
Args:
inputs (tensor): Input tensor.
Returns:
Dictionary of last intermediate features
with reduction levels i in [1, 2, 3, 4, 5].
Example:
>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
>>> endpoints = model.extract_endpoints(inputs)
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
"""
endpoints = dict()
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
prev_x = x
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
# print('Prev', prev_x.size())
# print('X', x.size())
if prev_x.size(2) > x.size(2):
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
elif idx == len(self._blocks) - 1:
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
prev_x = x
# Head
x = self._swish(self._bn1(self._conv_head(x)))
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
return endpoints
def extract_features(self, inputs):
"""use convolution layer to extract feature .
Args:
inputs (tensor): Input tensor.
Returns:
Output of the final convolution
layer in the efficientnet model.
"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return x
def forward(self, inputs):
"""EfficientNet's forward function.
Calls extract_features to extract features, applies final linear layer, and returns logits.
Args:
inputs (tensor): Input tensor.
Returns:
Output of this model after processing.
"""
# Convolution layers
# x = self.extract_features(inputs)
endpoints = self.extract_endpoints(inputs)
x1 = endpoints['reduction_6']
x2 = endpoints['reduction_5']
x3 = endpoints['reduction_4']
x4 = endpoints['reduction_3']
x5 = endpoints['reduction_2']
x = x1
if self._global_params.include_top:
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)
return x
if self._global_params.include_hm_decoder:
x1 = self._dropout(x1)
x2 = self._dropout(x2)
x3 = self._dropout(x3)
x4 = self._dropout(x4)
if self.efpn:
assert self._global_params.use_c51, "C51 must be utilized for FPN intergration"
x = self.__getattr__('deconv_1')(x1)
if self._global_params.use_c51:
x_weighted = self._sigmoid(x)
x_inv = torch.sub(1, x_weighted, alpha=1)
x2_ = torch.multiply(x_inv, x2)
x = torch.cat([x, x2_], dim=1)
if self.se_layer:
x = self.__getattr__('se_layer_1')(x)
else:
x = self._relu(x)
x = self.__getattr__('deconv_2')(x)
if self._global_params.use_c4:
x_weighted = self._sigmoid(x)
x_inv = torch.sub(1, x_weighted, alpha=1)
x3_ = torch.multiply(x_inv, x3)
x = torch.cat([x, x3_], dim=1)
if self.se_layer:
x = self.__getattr__('se_layer_2')(x)
else:
x = self._relu(x)
x = self.__getattr__('deconv_3')(x)
if self._global_params.use_c3:
assert self._global_params.use_c4, "C4 must be utilized for FPN intergration of C3"
x_weighted = self._sigmoid(x)
x_inv = torch.sub(1, x_weighted, alpha=1)
x4_ = torch.multiply(x_inv, x4)
x = torch.cat([x, x4_], dim=1)
if self.se_layer:
x = self.__getattr__('se_layer_3')(x)
else:
x = self._relu(x)
x = self.__getattr__('deconv_4')(x)
if not self._global_params.use_c2:
x = self._relu(x)
else:
assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
x5 = self._dropout(x5)
x5_ = self.conv_c2(x5)
x_weighted = self._sigmoid(x)
x_inv = torch.sub(1, x_weighted, alpha=1)
x5_ = torch.multiply(x_inv, x5_)
x = torch.cat([x, x5_], dim=1)
if self.se_layer:
x = self.__getattr__('se_layer_4')(x)
elif self.tfpn:
assert self._global_params.use_c51, "C51 must be utilized for FPN intergration"
x = self.__getattr__('deconv_1')(x1)
x = self._relu1(x)
x = torch.cat([x, x2], dim=1)
x = self.__getattr__('deconv_2')(x)
if not self._global_params.use_c4:
x = self._relu1(x)
else:
x = torch.cat([x, x3], dim=1)
x = self.__getattr__('deconv_3')(x)
if not self._global_params.use_c3:
x = self._relu1(x)
else:
assert self._global_params.use_c4, "C4 must be utilized for FPN intergration of C3"
x = torch.cat([x, x4], dim=1)
x = self.__getattr__('deconv_4')(x)
if not self._global_params.use_c2:
x = self._relu(x)
else:
assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
x5 = self._dropout(x5)
x5 = self.conv_c2(x5)
x = self._relu1(x)
x = torch.cat([x, x5], dim=1)
else:
x = self.deconv_layers(x1)
ret = {}
for head in self.heads:
ret[head] = self.__getattr__(head)(x)
return [ret]
@classmethod
def from_name(cls, model_name, in_channels=3, **override_params):
"""Create an efficientnet model according to name.
Args:
model_name (str): Name for efficientnet.
in_channels (int): Input data's channel number.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'num_classes', 'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'
Returns:
An efficientnet model.
"""
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(model_name, override_params)
model = cls(blocks_args, global_params)
model._change_in_channels(in_channels)
return model
@classmethod
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
in_channels=3, num_classes=1000, **override_params):
"""Create an efficientnet model according to name.
Args:
model_name (str): Name for efficientnet.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
advprop (bool):
Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
in_channels (int): Input data's channel number.
num_classes (int):
Number of categories for classification.
It controls the output size for final linear layer.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'
Returns:
A pretrained efficientnet model.
"""
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
load_pretrained_weights(model, model_name, weights_path=weights_path,
load_fc=((num_classes == 1000) and (model._global_params.include_top)), advprop=advprop)
model._change_in_channels(in_channels)
return model
@classmethod
def get_image_size(cls, model_name):
"""Get the input image size for a given efficientnet model.
Args:
model_name (str): Name for efficientnet.
Returns:
Input image size (resolution).
"""
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res
@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name.
Args:
model_name (str): Name for efficientnet.
Returns:
bool: Is a valid name or not.
"""
if model_name not in VALID_MODELS:
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
def _change_in_channels(self, in_channels):
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
Args:
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
@MODELS.register_module()
class PoseEfficientNet(EfficientNet):
def __init__(self, model_name, in_channels=3, **override_params):
self.model_name = model_name
self.in_channels = in_channels
# Initialize Parent Class
super()._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(model_name, override_params)
super().__init__(blocks_args, global_params)
@classmethod
def from_name(cls, model_name, in_channels, **override_params):
return NotImplemented
@classmethod
def from_pretrained(cls, model_name, weights_path, advprop, in_channels, num_classes, **override_params):
return NotImplemented
def _change_in_channels(self, in_channels):
return NotImplemented
def init_weights(self, pretrained=False, advprop=False, verbose=True):
if pretrained:
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[self.model_name])
self.load_state_dict(state_dict, strict=False)
# Initialize weights for Deconvolution Layer
if self._global_params.include_hm_decoder:
if self.efpn or self.tfpn:
deconv_layers = [self.deconv_1, self.deconv_2, self.deconv_3, self.deconv_4]
else:
deconv_layers = self.deconv_layers
for layer in deconv_layers:
for _, m in layer.named_modules():
if isinstance(m, nn.ConvTranspose2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if self.deconv_with_bias:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Init head parameters
for head in self.heads:
final_layer = self.__getattr__(head)
for i, m in enumerate(final_layer.modules()):
if isinstance(m, nn.Conv2d):
if m.weight.shape[0] == self.heads[head]:
if 'hm' in head:
nn.init.constant_(m.bias, -2.19)
else:
# nn.init.normal_(m.weight, std=0.001)
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.constant_(m.bias, 0)
self._change_in_channels(in_channels=self.in_channels)
if verbose:
print('Loaded pretrained weights for {}'.format(self.model_name))
if __name__ == '__main__':
cfg = dict(type='PoseEfficientNet',
model_name='efficientnet-b4',
include_top=False,
include_hm_decoder=True,
head_conv=64,
heads={'hm':1, 'cls':1, 'cstency':256},
use_c2=True)
model = build_model(cfg, MODELS)
model.init_weights(pretrained=True)
model.eval()
inputs = torch.rand((1, 3, 384, 384))
for i, (n, p) in enumerate(model.named_parameters()):
print(i, n)
# To show the whole pose EFN model outputs shape
x = model(inputs)[0]
for head in x.keys():
print(f'{head} shape is --- {x[head].shape}')
# To show the endpoints features shape
# endpoints = model.extract_endpoints(inputs)
# for k in endpoints.keys():
# print(endpoints[k].shape)