Kim Mạnh Hưng
Add U-Net app and weights
aa04f76
import torch
from torch import nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, with_bn=False):
super().__init__()
if with_bn:
self.step = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
else:
self.step = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
)
def forward(self, x):
return self.step(x)
class UNet(nn.Module):
def __init__(self, in_channels, out_channels, with_bn=False):
super().__init__()
init_channels = 32
self.out_channels = out_channels
self.en_1 = DoubleConv(in_channels , init_channels , with_bn)
self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn)
self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn)
self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn)
self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn)
self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn)
self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn)
self.de_4 = nn.Conv2d(init_channels, out_channels, 1)
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, x):
e1 = self.en_1(x)
e2 = self.en_2(self.maxpool(e1))
e3 = self.en_3(self.maxpool(e2))
e4 = self.en_4(self.maxpool(e3))
d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1))
d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1))
d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1))
d4 = self.de_4(d3)
return d4
# if self.out_channels<2:
# return torch.sigmoid(d4)
# return torch.softmax(d4, 1)