Spaces:
Runtime error
Runtime error
| import math | |
| import pickle | |
| import torch.nn as nn | |
| import torch | |
| # from xgw.dewarp.fiducial_points.networks.resnet import * | |
| import torch.nn.init as tinit | |
| import torch.nn.functional as F | |
| def conv3x3(in_channels, out_channels, stride=1): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) | |
| def dilation_conv_bn_act(in_channels, out_dim, act_fn, BatchNorm, dilation=4): | |
| model = nn.Sequential( | |
| nn.Conv2d(in_channels, out_dim, kernel_size=3, stride=1, padding=dilation, dilation=dilation), | |
| BatchNorm(out_dim), | |
| # nn.BatchNorm2d(out_dim), | |
| act_fn, | |
| ) | |
| return model | |
| def dilation_conv(in_channels, out_dim, stride=1, dilation=4, groups=1): | |
| model = nn.Sequential( | |
| nn.Conv2d(in_channels, out_dim, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups), | |
| ) | |
| return model | |
| class ResidualBlockWithDilatedV1(nn.Module): | |
| def __init__(self, in_channels, out_channels, BatchNorm, stride=1, downsample=None, is_activation=True, is_top=False, is_dropout=False): | |
| super(ResidualBlockWithDilatedV1, self).__init__() | |
| self.stride = stride | |
| self.is_activation = is_activation | |
| self.downsample = downsample | |
| self.is_top = is_top | |
| if self.stride != 1 or self.is_top: | |
| self.conv1 = conv3x3(in_channels, out_channels, self.stride) | |
| else: | |
| self.conv1 = dilation_conv(in_channels, out_channels, dilation=3) # 3 | |
| self.bn1 = BatchNorm(out_channels) | |
| # self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU(inplace=True) | |
| if self.stride != 1 or self.is_top: | |
| self.conv2 = conv3x3(out_channels, out_channels) | |
| else: | |
| self.conv2 = dilation_conv(out_channels, out_channels, dilation=3) # 1 | |
| self.bn2 = BatchNorm(out_channels) | |
| self.is_dropout = is_dropout | |
| self.drop_out = nn.Dropout2d(p=0.2) | |
| def forward(self, x): | |
| residual = x | |
| out1 = self.relu(self.bn1(self.conv1(x))) | |
| # if self.is_dropout: | |
| # out1 = self.drop_out(out1) | |
| out = self.bn2(self.conv2(out1)) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| class ResNetV2StraightV2(nn.Module): | |
| def __init__(self, num_filter, map_num, BatchNorm, block_nums=[3, 4, 6, 3], block=ResidualBlockWithDilatedV1, stride=[1, 2, 2, 2], dropRate=[0.2, 0.2, 0.2, 0.2], is_sub_dropout=False): | |
| super(ResNetV2StraightV2, self).__init__() | |
| self.in_channels = num_filter * map_num[0] | |
| self.dropRate = dropRate | |
| self.stride = stride | |
| self.is_sub_dropout = is_sub_dropout | |
| # self.is_dropout = is_dropout | |
| self.drop_out = nn.Dropout2d(p=dropRate[0]) | |
| self.drop_out_2 = nn.Dropout2d(p=dropRate[1]) | |
| self.drop_out_3 = nn.Dropout2d(p=dropRate[2]) | |
| self.drop_out_4 = nn.Dropout2d(p=dropRate[3]) # add | |
| self.relu = nn.ReLU(inplace=True) | |
| self.block_nums = block_nums | |
| self.layer1 = self.blocklayer(block, num_filter * map_num[0], self.block_nums[0], BatchNorm, stride=self.stride[0]) | |
| self.layer2 = self.blocklayer(block, num_filter * map_num[1], self.block_nums[1], BatchNorm, stride=self.stride[1]) | |
| self.layer3 = self.blocklayer(block, num_filter * map_num[2], self.block_nums[2], BatchNorm, stride=self.stride[2]) | |
| self.layer4 = self.blocklayer(block, num_filter * map_num[3], self.block_nums[3], BatchNorm, stride=self.stride[3]) | |
| def blocklayer(self, block, out_channels, block_nums, BatchNorm, stride=1): | |
| downsample = None | |
| if (stride != 1) or (self.in_channels != out_channels): | |
| downsample = nn.Sequential( | |
| conv3x3(self.in_channels, out_channels, stride=stride), | |
| BatchNorm(out_channels)) | |
| layers = [] | |
| layers.append(block(self.in_channels, out_channels, BatchNorm, stride, downsample, is_top=True, is_dropout=False)) | |
| self.in_channels = out_channels | |
| for i in range(1, block_nums): | |
| layers.append(block(out_channels, out_channels, BatchNorm, is_activation=True, is_top=False, is_dropout=self.is_sub_dropout)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x, is_skip=False): | |
| out1 = self.layer1(x) | |
| out2 = self.layer2(out1) | |
| out3 = self.layer3(out2) | |
| out4 = self.layer4(out3) | |
| return out4 | |
| class FiducialPoints(nn.Module): | |
| def __init__(self, n_classes, num_filter, architecture, BatchNorm='GN', in_channels=3): | |
| super(FiducialPoints, self).__init__() | |
| self.in_channels = in_channels | |
| self.n_classes = n_classes | |
| self.num_filter = num_filter | |
| if BatchNorm == 'IN': | |
| BatchNorm = nn.InstanceNorm2d | |
| elif BatchNorm == 'BN': | |
| BatchNorm = nn.BatchNorm2d | |
| elif BatchNorm == 'GN': | |
| BatchNorm = nn.GroupNorm | |
| self.dilated_unet = architecture(self.n_classes, self.num_filter, BatchNorm, in_channels=self.in_channels) | |
| def forward(self, x, is_softmax=True): | |
| return self.dilated_unet(x, is_softmax) | |
| ''' Dilated Resnet For Flat By Classify with Rgress simple -2''' | |
| class DilatedResnetForFlatByFiducialPointsS2(nn.Module): | |
| def __init__(self, n_classes, num_filter, BatchNorm, in_channels=3): | |
| super(DilatedResnetForFlatByFiducialPointsS2, self).__init__() | |
| self.in_channels = in_channels | |
| self.n_classes = n_classes | |
| self.num_filter = num_filter | |
| # act_fn = nn.PReLU() | |
| act_fn = nn.ReLU(inplace=True) | |
| # act_fn = nn.LeakyReLU(0.2) | |
| map_num = [1, 2, 4, 8, 16] | |
| print("\n------load DilatedResnetForFlatByFiducialPointsS2------\n") | |
| self.resnet_head = nn.Sequential( | |
| nn.Conv2d(self.in_channels, self.num_filter * map_num[0], kernel_size=3, stride=2, padding=1), | |
| # nn.InstanceNorm2d(self.num_filter * map_num[0]), | |
| # BatchNorm(1, self.num_filter * map_num[0]), | |
| BatchNorm(self.num_filter * map_num[0]), | |
| # nn.BatchNorm2d(self.num_filter * map_num[0]), | |
| act_fn, | |
| # nn.Dropout(p=0.2), | |
| # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), | |
| nn.Conv2d(self.num_filter * map_num[0], self.num_filter * map_num[0], kernel_size=3, stride=2, padding=1), | |
| BatchNorm(self.num_filter * map_num[0]), | |
| act_fn, | |
| # nn.Dropout(p=0.2), | |
| ) | |
| self.resnet_down = ResNetV2StraightV2(num_filter, map_num, BatchNorm, block_nums=[3, 4, 6, 3], block=ResidualBlockWithDilatedV1, dropRate=[0, 0, 0, 0], is_sub_dropout=False) | |
| map_num_i = 3 | |
| self.bridge_1 = nn.Sequential( | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=1), | |
| # conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], act_fn), | |
| ) | |
| self.bridge_2 = nn.Sequential( | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=2), | |
| ) | |
| self.bridge_3 = nn.Sequential( | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=5), | |
| ) | |
| self.bridge_4 = nn.Sequential( | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=8), | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=3), | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=2), | |
| ) | |
| self.bridge_5 = nn.Sequential( | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=12), | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=7), | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=4), | |
| ) | |
| self.bridge_6 = nn.Sequential( | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=18), | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=12), | |
| dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], | |
| act_fn, BatchNorm, dilation=6), | |
| ) | |
| self.bridge_concate = nn.Sequential( | |
| nn.Conv2d(self.num_filter * map_num[map_num_i] * 6, self.num_filter * map_num[2], kernel_size=1, stride=1, padding=0), | |
| # BatchNorm(GN_num, self.num_filter * map_num[4]), | |
| BatchNorm(self.num_filter * map_num[2]), | |
| # nn.BatchNorm2d(self.num_filter * map_num[4]), | |
| act_fn, | |
| ) | |
| self.out_regress = nn.Sequential( | |
| nn.Conv2d(self.num_filter * map_num[2], self.num_filter * map_num[0], kernel_size=3, stride=1, padding=1), | |
| BatchNorm(self.num_filter * map_num[0]), | |
| nn.PReLU(), | |
| nn.Conv2d(self.num_filter * map_num[0], n_classes, kernel_size=3, stride=1, padding=1), | |
| ) | |
| self.segment_regress = nn.Linear(self.num_filter * map_num[2]*31*31, 2) | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| tinit.xavier_normal_(m.weight, gain=0.2) | |
| if isinstance(m, nn.ConvTranspose2d): | |
| assert m.kernel_size[0] == m.kernel_size[1] | |
| tinit.xavier_normal_(m.weight, gain=0.2) | |
| def cat(self, trans, down): | |
| return torch.cat([trans, down], dim=1) | |
| def forward(self, x, is_softmax): | |
| resnet_head = self.resnet_head(x) | |
| resnet_down = self.resnet_down(resnet_head) | |
| bridge_1 = self.bridge_1(resnet_down) | |
| bridge_2 = self.bridge_2(resnet_down) | |
| bridge_3 = self.bridge_3(resnet_down) | |
| bridge_4 = self.bridge_4(resnet_down) | |
| bridge_5 = self.bridge_5(resnet_down) | |
| bridge_6 = self.bridge_6(resnet_down) | |
| bridge_concate = torch.cat([bridge_1, bridge_2, bridge_3, bridge_4, bridge_5, bridge_6], dim=1) | |
| bridge = self.bridge_concate(bridge_concate) | |
| out_regress = self.out_regress(bridge) | |
| segment_regress = self.segment_regress(bridge.view(x.size(0), -1)) | |
| return out_regress, segment_regress | |