| import torch.nn as nn | |
| import torch | |
| from .model_parts import CombinationModule | |
| class DecNet(nn.Module): | |
| def __init__(self, heads, final_kernel, head_conv, channel): | |
| super(DecNet, self).__init__() | |
| self.dec_c2 = CombinationModule(128, 64, batch_norm=True) | |
| self.dec_c3 = CombinationModule(256, 128, batch_norm=True) | |
| self.dec_c4 = CombinationModule(512, 256, batch_norm=True) | |
| self.heads = heads | |
| for head in self.heads: | |
| classes = self.heads[head] | |
| if head == 'wh': | |
| fc = nn.Sequential(nn.Conv2d(channel, head_conv, kernel_size=7, padding=7//2, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(head_conv, classes, kernel_size=7, padding=7 // 2, bias=True)) | |
| else: | |
| fc = nn.Sequential(nn.Conv2d(channel, head_conv, kernel_size=3, padding=1, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, | |
| padding=final_kernel // 2, bias=True)) | |
| if 'hm' in head: | |
| fc[-1].bias.data.fill_(-2.19) | |
| else: | |
| self.fill_fc_weights(fc) | |
| self.__setattr__(head, fc) | |
| def fill_fc_weights(self, layers): | |
| for m in layers.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| c4_combine = self.dec_c4(x[-1], x[-2]) | |
| c3_combine = self.dec_c3(c4_combine, x[-3]) | |
| c2_combine = self.dec_c2(c3_combine, x[-4]) | |
| dec_dict = {} | |
| for head in self.heads: | |
| dec_dict[head] = self.__getattr__(head)(c2_combine) | |
| if 'hm' in head: | |
| dec_dict[head] = torch.sigmoid(dec_dict[head]) | |
| return dec_dict |