| |
|
|
| import sys |
| import torch |
| import torch.nn as nn |
|
|
| from . import GP_ShuffleUNet_pixel_shuffle, GP_ShuffleUNet_pixel_unshuffle |
| |
|
|
| |
|
|
| class _double_conv(nn.Module): |
| """ |
| Double Convolution Block |
| """ |
|
|
| def __init__(self, in_channels, out_channels, k_size, stride, bias=True, conv_layer=nn.Conv3d): |
| super(_double_conv, self).__init__() |
| self.conv_1 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=k_size // 2, bias=bias) |
| self.conv_2 = conv_layer(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=k_size // 2, bias=bias) |
|
|
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| x = self.conv_1(x) |
| x = self.relu((x)) |
| x = self.conv_2(x) |
| x = self.relu((x)) |
|
|
| return x |
|
|
|
|
| class _conv_decomp(nn.Module): |
| """ |
| Convolutional Decomposition Block |
| """ |
|
|
| def __init__(self, in_channels, out_channels, k_size, stride, bias=True, conv_layer=nn.Conv3d): |
| super(_conv_decomp, self).__init__() |
| self.conv1 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=k_size // 2, bias=bias) |
| self.conv2 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=k_size // 2, bias=bias) |
| self.conv3 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=k_size // 2, bias=bias) |
| self.conv4 = conv_layer(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| stride=stride, padding=k_size // 2, bias=bias) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| x1 = self.conv1(x) |
| x1 = self.relu((x1)) |
| x2 = self.conv2(x) |
| x2 = self.relu((x2)) |
| x3 = self.conv3(x) |
| x3 = self.relu((x3)) |
| x4 = self.conv4(x) |
| x4 = self.relu((x4)) |
| return x1, x2, x3, x4 |
|
|
|
|
| class _concat(nn.Module): |
| """ |
| Skip-Addition block |
| """ |
|
|
| def __init__(self): |
| super(_concat, self).__init__() |
|
|
| def forward(self, e1, e2, e3, e4, d1, d2, d3, d4): |
| self.X1 = e1 + d1 |
| self.X2 = e2 + d2 |
| self.X3 = e3 + d3 |
| self.X4 = e4 + d4 |
| x = torch.cat([self.X1, self.X2, self.X3, self.X4], dim=1) |
|
|
| return x |
|
|
| |
|
|
| class GP_ShuffleUNet(nn.Module): |
|
|
| def __init__(self, d=3, in_ch=1, num_features=64, n_levels=3, out_ch=1, kernel_size=3, stride=1, dropout=False, out_act="softmax"): |
| super(GP_ShuffleUNet, self).__init__() |
|
|
| self.n_levels = n_levels |
| self.dropout = nn.Dropout2d() if dropout else nn.Sequential() |
|
|
| num_features = num_features |
| filters = [num_features] |
| for _ in range(n_levels): |
| filters.append(filters[-1]*2) |
|
|
| if d==3: |
| conv_layer = nn.Conv3d |
| ps_fact = (2 ** 2) |
| elif d==2: |
| conv_layer = nn.Conv2d |
| ps_fact = 2 |
| else: |
| sys.exit("Invalid d") |
|
|
| |
| self.conv_inp = _double_conv(in_ch, filters[0], kernel_size, stride, conv_layer=conv_layer) |
|
|
| |
| self.wave_down = nn.ModuleList() |
| self.pix_unshuff = nn.ModuleList() |
| self.conv_enc = nn.ModuleList() |
| for i in range(0, n_levels): |
| self.wave_down.append(_conv_decomp(filters[i], filters[i], kernel_size, stride, conv_layer=conv_layer)) |
| self.pix_unshuff.append(pixel_unshuffle.PixelUnshuffle(num_features * (2**i), num_features * (2**i), kernel_size, stride, d=d)) |
| self.conv_enc.append(_double_conv(filters[i], filters[i+1], kernel_size, stride, conv_layer=conv_layer)) |
|
|
| |
| self.cat = _concat() |
| self.pix_shuff = nn.ModuleList() |
| self.wave_up = nn.ModuleList() |
| self.convup = nn.ModuleList() |
| for i in range(n_levels-1,-1,-1): |
| self.pix_shuff.append(pixel_shuffle.PixelShuffle(num_features * (2**(i+1)), num_features * (2**(i+1)) * ps_fact, kernel_size, stride, d=d)) |
| self.wave_up.append(_conv_decomp(filters[i], filters[i], kernel_size, stride, conv_layer=conv_layer)) |
| self.convup.append(_double_conv(filters[i] * 5, filters[i], kernel_size, stride, conv_layer=conv_layer)) |
|
|
| |
| self.GMP = nn.AdaptiveMaxPool2d((1, 1)) |
|
|
| |
| if out_act == "softmax": |
| self.last = nn.Sequential( |
| conv_layer(filters[0], out_ch, kernel_size=1, stride=1, padding=0, bias=True), |
| nn.Softmax2d() |
| ) |
| else: |
| self.out = conv_layer(filters[0], out_ch, kernel_size=1, stride=1, padding=0, bias=True) |
|
|
| |
| for m in self.modules(): |
| if isinstance(m, conv_layer): |
| weight = nn.init.kaiming_normal_(m.weight, nonlinearity='relu') |
| m.weight.data.copy_(weight) |
| if m.bias is not None: |
| m.bias.data.zero_() |
|
|
| def forward(self, x): |
| encs = [self.conv_inp(x)] |
|
|
| waves = [] |
| for i in range(self.n_levels): |
| waves.append(self.wave_down[i](encs[-1])) |
| _tmp = self.pix_unshuff[i](waves[-1][-1]) |
| encs.append(self.conv_enc[i](_tmp)) |
|
|
| dec = encs.pop() |
|
|
| dec = self.dropout(dec) |
|
|
| for i in range(self.n_levels): |
| _tmp = self.pix_shuff[i](dec) |
| _tmp_waves = self.wave_up[i](_tmp) + waves.pop() |
| _tmp_cat = self.cat(*_tmp_waves) |
| dec = self.convup[i](torch.cat([encs.pop(), _tmp_cat], dim=1)) |
|
|
| |
| if self.training: |
| x = self.GMP(dec) |
| return self.out(x).view(x.shape[0],-1) |
| else: |
| mask = self.out(dec) |
| x = self.GMP(dec) |
| pred = self.out(x).view(x.shape[0],-1) |
| return pred, mask |
|
|
|
|
| |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| image = torch.rand(2, 1, 512, 512) |
| model = GP_ShuffleUNet(d=2, in_ch=1, num_features=64, n_levels=3, out_ch=3, kernel_size=3, stride=1) |
| model.eval() |
| out = model(image) |
| print(model(image)) |
|
|