|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision.models as models |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DoubleConv(nn.Module): |
|
|
"""(Conv => BN => ReLU) * 2""" |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Down(nn.Module): |
|
|
"""Downscaling with maxpool then double conv""" |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.MaxPool2d(2), |
|
|
DoubleConv(in_ch, out_ch) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Up(nn.Module): |
|
|
"""Upscaling then double conv""" |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
self.conv = DoubleConv(in_ch, out_ch) |
|
|
|
|
|
def forward(self, x, skip): |
|
|
x = self.up(x) |
|
|
|
|
|
if skip is not None: |
|
|
|
|
|
|
|
|
diffY = skip.size(2) - x.size(2) |
|
|
diffX = skip.size(3) - x.size(3) |
|
|
|
|
|
x = F.pad(x, [ |
|
|
diffX // 2, diffX - diffX // 2, |
|
|
diffY // 2, diffY - diffY // 2 |
|
|
]) |
|
|
|
|
|
x = torch.cat([skip, x], dim=1) |
|
|
return self.conv(x) |
|
|
|
|
|
class OutConv(nn.Module): |
|
|
def __init__(self, in_ch, num_classes): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(in_ch, num_classes, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
class ResNetEncoder(nn.Module): |
|
|
def __init__(self, in_channels=3, pretrained=True): |
|
|
super().__init__() |
|
|
|
|
|
resnet = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None) |
|
|
|
|
|
if in_channels != 3: |
|
|
resnet.conv1 = nn.Conv2d( |
|
|
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False |
|
|
) |
|
|
|
|
|
self.initial = nn.Sequential( |
|
|
resnet.conv1, |
|
|
resnet.bn1, |
|
|
resnet.relu |
|
|
) |
|
|
self.maxpool = resnet.maxpool |
|
|
|
|
|
self.layer1 = resnet.layer1 |
|
|
self.layer2 = resnet.layer2 |
|
|
self.layer3 = resnet.layer3 |
|
|
self.layer4 = resnet.layer4 |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.initial(x) |
|
|
x2 = self.layer1(self.maxpool(x1)) |
|
|
x3 = self.layer2(x2) |
|
|
x4 = self.layer3(x3) |
|
|
x5 = self.layer4(x4) |
|
|
return x1, x2, x3, x4, x5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class model(nn.Module): |
|
|
def __init__(self, in_channels=3, num_classes=1, freeze_encoder=False): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = ResNetEncoder(in_channels, pretrained=True) |
|
|
|
|
|
if freeze_encoder: |
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
self.up1 = Up(512 + 256, 256) |
|
|
self.up2 = Up(256 + 128, 128) |
|
|
self.up3 = Up(128 + 64, 64) |
|
|
self.up4 = Up(64 + 64, 64) |
|
|
self.up5 = Up(64, 64) |
|
|
self.outc = OutConv(64, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x1, x2, x3, x4, x5 = self.encoder(x) |
|
|
|
|
|
x = self.up1(x5, x4) |
|
|
x = self.up2(x, x3) |
|
|
x = self.up3(x, x2) |
|
|
x = self.up4(x, x1) |
|
|
x = self.up5(x, None) |
|
|
return self.outc(x) |