Kim Mạnh Hưng
Add U-Net app and weights
aa04f76
import torch.nn as nn
import torch
def get_activation(activation_type):
activation_type = activation_type.lower()
if hasattr(nn, activation_type):
return getattr(nn, activation_type)()
else:
return nn.ReLU()
def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
layers = []
layers.append(ConvBatchNorm(in_channels, out_channels, activation))
for _ in range(nb_Conv - 1):
layers.append(ConvBatchNorm(out_channels, out_channels, activation))
return nn.Sequential(*layers)
class ConvBatchNorm(nn.Module):
"""(convolution => [BN] => ReLU)"""
def __init__(self, in_channels, out_channels, activation='ReLU'):
super(ConvBatchNorm, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size=3, padding=1)
self.norm = nn.BatchNorm2d(out_channels)
self.activation = get_activation(activation)
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
return self.activation(out)
class DownBlock(nn.Module):
"""Downscaling with maxpool convolution"""
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
super(DownBlock, self).__init__()
self.maxpool = nn.MaxPool2d(2)
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
def forward(self, x):
out = self.maxpool(x)
return self.nConvs(out)
class UpBlock(nn.Module):
"""Upscaling then conv"""
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
super(UpBlock, self).__init__()
# self.up = nn.Upsample(scale_factor=2)
self.up = nn.ConvTranspose2d(in_channels//2,in_channels//2,(2,2),2)
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
def forward(self, x, skip_x):
out = self.up(x)
x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension
return self.nConvs(x)
class UNet(nn.Module):
def __init__(self, n_channels=3, n_classes=9):
'''
n_channels : number of channels of the input.
By default 3, because we have RGB images
n_labels : number of channels of the ouput.
By default 3 (2 labels + 1 for the background)
'''
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
# Question here
in_channels = 64
self.inc = ConvBatchNorm(n_channels, in_channels)
self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
self.up4 = UpBlock(in_channels*16, in_channels*4, nb_Conv=2)
self.up3 = UpBlock(in_channels*8, in_channels*2, nb_Conv=2)
self.up2 = UpBlock(in_channels*4, in_channels, nb_Conv=2)
self.up1 = UpBlock(in_channels*2, in_channels, nb_Conv=2)
self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1))
if n_classes == 1:
self.last_activation = nn.Sigmoid()
else:
self.last_activation = None
def forward(self, x):
# Question here
x = x.float()
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up4(x5, x4)
x = self.up3(x, x3)
x = self.up2(x, x2)
x = self.up1(x, x1)
if self.last_activation is not None:
logits = self.last_activation(self.outc(x))
# print("111")
else:
logits = self.outc(x)
# print("222")
# logits = self.outc(x) # if using BCEWithLogitsLoss
# print(logits.size())
return logits