Source / network.py
Richard1231's picture
Upload folder using huggingface_hub
470ac18
import math
import pickle
import torch.nn as nn
import torch
# from xgw.dewarp.fiducial_points.networks.resnet import *
import torch.nn.init as tinit
import torch.nn.functional as F
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
def dilation_conv_bn_act(in_channels, out_dim, act_fn, BatchNorm, dilation=4):
model = nn.Sequential(
nn.Conv2d(in_channels, out_dim, kernel_size=3, stride=1, padding=dilation, dilation=dilation),
BatchNorm(out_dim),
# nn.BatchNorm2d(out_dim),
act_fn,
)
return model
def dilation_conv(in_channels, out_dim, stride=1, dilation=4, groups=1):
model = nn.Sequential(
nn.Conv2d(in_channels, out_dim, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups),
)
return model
class ResidualBlockWithDilatedV1(nn.Module):
def __init__(self, in_channels, out_channels, BatchNorm, stride=1, downsample=None, is_activation=True, is_top=False, is_dropout=False):
super(ResidualBlockWithDilatedV1, self).__init__()
self.stride = stride
self.is_activation = is_activation
self.downsample = downsample
self.is_top = is_top
if self.stride != 1 or self.is_top:
self.conv1 = conv3x3(in_channels, out_channels, self.stride)
else:
self.conv1 = dilation_conv(in_channels, out_channels, dilation=3) # 3
self.bn1 = BatchNorm(out_channels)
# self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if self.stride != 1 or self.is_top:
self.conv2 = conv3x3(out_channels, out_channels)
else:
self.conv2 = dilation_conv(out_channels, out_channels, dilation=3) # 1
self.bn2 = BatchNorm(out_channels)
self.is_dropout = is_dropout
self.drop_out = nn.Dropout2d(p=0.2)
def forward(self, x):
residual = x
out1 = self.relu(self.bn1(self.conv1(x)))
# if self.is_dropout:
# out1 = self.drop_out(out1)
out = self.bn2(self.conv2(out1))
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNetV2StraightV2(nn.Module):
def __init__(self, num_filter, map_num, BatchNorm, block_nums=[3, 4, 6, 3], block=ResidualBlockWithDilatedV1, stride=[1, 2, 2, 2], dropRate=[0.2, 0.2, 0.2, 0.2], is_sub_dropout=False):
super(ResNetV2StraightV2, self).__init__()
self.in_channels = num_filter * map_num[0]
self.dropRate = dropRate
self.stride = stride
self.is_sub_dropout = is_sub_dropout
# self.is_dropout = is_dropout
self.drop_out = nn.Dropout2d(p=dropRate[0])
self.drop_out_2 = nn.Dropout2d(p=dropRate[1])
self.drop_out_3 = nn.Dropout2d(p=dropRate[2])
self.drop_out_4 = nn.Dropout2d(p=dropRate[3]) # add
self.relu = nn.ReLU(inplace=True)
self.block_nums = block_nums
self.layer1 = self.blocklayer(block, num_filter * map_num[0], self.block_nums[0], BatchNorm, stride=self.stride[0])
self.layer2 = self.blocklayer(block, num_filter * map_num[1], self.block_nums[1], BatchNorm, stride=self.stride[1])
self.layer3 = self.blocklayer(block, num_filter * map_num[2], self.block_nums[2], BatchNorm, stride=self.stride[2])
self.layer4 = self.blocklayer(block, num_filter * map_num[3], self.block_nums[3], BatchNorm, stride=self.stride[3])
def blocklayer(self, block, out_channels, block_nums, BatchNorm, stride=1):
downsample = None
if (stride != 1) or (self.in_channels != out_channels):
downsample = nn.Sequential(
conv3x3(self.in_channels, out_channels, stride=stride),
BatchNorm(out_channels))
layers = []
layers.append(block(self.in_channels, out_channels, BatchNorm, stride, downsample, is_top=True, is_dropout=False))
self.in_channels = out_channels
for i in range(1, block_nums):
layers.append(block(out_channels, out_channels, BatchNorm, is_activation=True, is_top=False, is_dropout=self.is_sub_dropout))
return nn.Sequential(*layers)
def forward(self, x, is_skip=False):
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
return out4
class FiducialPoints(nn.Module):
def __init__(self, n_classes, num_filter, architecture, BatchNorm='GN', in_channels=3):
super(FiducialPoints, self).__init__()
self.in_channels = in_channels
self.n_classes = n_classes
self.num_filter = num_filter
if BatchNorm == 'IN':
BatchNorm = nn.InstanceNorm2d
elif BatchNorm == 'BN':
BatchNorm = nn.BatchNorm2d
elif BatchNorm == 'GN':
BatchNorm = nn.GroupNorm
self.dilated_unet = architecture(self.n_classes, self.num_filter, BatchNorm, in_channels=self.in_channels)
def forward(self, x, is_softmax=True):
return self.dilated_unet(x, is_softmax)
''' Dilated Resnet For Flat By Classify with Rgress simple -2'''
class DilatedResnetForFlatByFiducialPointsS2(nn.Module):
def __init__(self, n_classes, num_filter, BatchNorm, in_channels=3):
super(DilatedResnetForFlatByFiducialPointsS2, self).__init__()
self.in_channels = in_channels
self.n_classes = n_classes
self.num_filter = num_filter
# act_fn = nn.PReLU()
act_fn = nn.ReLU(inplace=True)
# act_fn = nn.LeakyReLU(0.2)
map_num = [1, 2, 4, 8, 16]
print("\n------load DilatedResnetForFlatByFiducialPointsS2------\n")
self.resnet_head = nn.Sequential(
nn.Conv2d(self.in_channels, self.num_filter * map_num[0], kernel_size=3, stride=2, padding=1),
# nn.InstanceNorm2d(self.num_filter * map_num[0]),
# BatchNorm(1, self.num_filter * map_num[0]),
BatchNorm(self.num_filter * map_num[0]),
# nn.BatchNorm2d(self.num_filter * map_num[0]),
act_fn,
# nn.Dropout(p=0.2),
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(self.num_filter * map_num[0], self.num_filter * map_num[0], kernel_size=3, stride=2, padding=1),
BatchNorm(self.num_filter * map_num[0]),
act_fn,
# nn.Dropout(p=0.2),
)
self.resnet_down = ResNetV2StraightV2(num_filter, map_num, BatchNorm, block_nums=[3, 4, 6, 3], block=ResidualBlockWithDilatedV1, dropRate=[0, 0, 0, 0], is_sub_dropout=False)
map_num_i = 3
self.bridge_1 = nn.Sequential(
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=1),
# conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i], act_fn),
)
self.bridge_2 = nn.Sequential(
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=2),
)
self.bridge_3 = nn.Sequential(
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=5),
)
self.bridge_4 = nn.Sequential(
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=8),
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=3),
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=2),
)
self.bridge_5 = nn.Sequential(
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=12),
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=7),
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=4),
)
self.bridge_6 = nn.Sequential(
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=18),
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=12),
dilation_conv_bn_act(self.num_filter * map_num[map_num_i], self.num_filter * map_num[map_num_i],
act_fn, BatchNorm, dilation=6),
)
self.bridge_concate = nn.Sequential(
nn.Conv2d(self.num_filter * map_num[map_num_i] * 6, self.num_filter * map_num[2], kernel_size=1, stride=1, padding=0),
# BatchNorm(GN_num, self.num_filter * map_num[4]),
BatchNorm(self.num_filter * map_num[2]),
# nn.BatchNorm2d(self.num_filter * map_num[4]),
act_fn,
)
self.out_regress = nn.Sequential(
nn.Conv2d(self.num_filter * map_num[2], self.num_filter * map_num[0], kernel_size=3, stride=1, padding=1),
BatchNorm(self.num_filter * map_num[0]),
nn.PReLU(),
nn.Conv2d(self.num_filter * map_num[0], n_classes, kernel_size=3, stride=1, padding=1),
)
self.segment_regress = nn.Linear(self.num_filter * map_num[2]*31*31, 2)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
tinit.xavier_normal_(m.weight, gain=0.2)
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
tinit.xavier_normal_(m.weight, gain=0.2)
def cat(self, trans, down):
return torch.cat([trans, down], dim=1)
def forward(self, x, is_softmax):
resnet_head = self.resnet_head(x)
resnet_down = self.resnet_down(resnet_head)
bridge_1 = self.bridge_1(resnet_down)
bridge_2 = self.bridge_2(resnet_down)
bridge_3 = self.bridge_3(resnet_down)
bridge_4 = self.bridge_4(resnet_down)
bridge_5 = self.bridge_5(resnet_down)
bridge_6 = self.bridge_6(resnet_down)
bridge_concate = torch.cat([bridge_1, bridge_2, bridge_3, bridge_4, bridge_5, bridge_6], dim=1)
bridge = self.bridge_concate(bridge_concate)
out_regress = self.out_regress(bridge)
segment_regress = self.segment_regress(bridge.view(x.size(0), -1))
return out_regress, segment_regress