|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DoubleConv(nn.Module): |
|
|
"""(Conv => BN => ReLU) * 2""" |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Down(nn.Module): |
|
|
"""Downscaling with maxpool then double conv""" |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.MaxPool2d(2), |
|
|
DoubleConv(in_ch, out_ch) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Up(nn.Module): |
|
|
"""Upscaling then double conv""" |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
self.conv = DoubleConv(in_ch, out_ch) |
|
|
|
|
|
def forward(self, x, skip): |
|
|
x = self.up(x) |
|
|
|
|
|
|
|
|
diffY = skip.size(2) - x.size(2) |
|
|
diffX = skip.size(3) - x.size(3) |
|
|
|
|
|
x = F.pad(x, [ |
|
|
diffX // 2, diffX - diffX // 2, |
|
|
diffY // 2, diffY - diffY // 2 |
|
|
]) |
|
|
|
|
|
x = torch.cat([skip, x], dim=1) |
|
|
return self.conv(x) |
|
|
|
|
|
class OutConv(nn.Module): |
|
|
def __init__(self, in_ch, num_classes): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(in_ch, num_classes, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, in_channels, base_c=64): |
|
|
super().__init__() |
|
|
self.inc = DoubleConv(in_channels, base_c) |
|
|
self.down1 = Down(base_c, base_c*2) |
|
|
self.down2 = Down(base_c*2, base_c*4) |
|
|
self.down3 = Down(base_c*4, base_c*8) |
|
|
self.down4 = Down(base_c*8, base_c*16) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.inc(x) |
|
|
x2 = self.down1(x1) |
|
|
x3 = self.down2(x2) |
|
|
x4 = self.down3(x3) |
|
|
x5 = self.down4(x4) |
|
|
return x1, x2, x3, x4, x5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class model(nn.Module): |
|
|
def __init__(self, |
|
|
in_channels=3, |
|
|
num_classes=1, |
|
|
freeze_encoder=False, |
|
|
base_c=64): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.encoder = Encoder(in_channels, base_c) |
|
|
|
|
|
if freeze_encoder: |
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
self.up1 = Up(base_c*16 + base_c*8, base_c*8) |
|
|
self.up2 = Up(base_c*8 + base_c*4, base_c*4) |
|
|
self.up3 = Up(base_c*4 + base_c*2, base_c*2) |
|
|
self.up4 = Up(base_c*2 + base_c, base_c) |
|
|
|
|
|
|
|
|
self.outc = OutConv(base_c, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x1, x2, x3, x4, x5 = self.encoder(x) |
|
|
|
|
|
|
|
|
x = self.up1(x5, x4) |
|
|
x = self.up2(x, x3) |
|
|
x = self.up3(x, x2) |
|
|
x = self.up4(x, x1) |
|
|
|
|
|
|
|
|
return self.outc(x) |
|
|
|