import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models # ------------------------------------------------- # Basic blocks # ------------------------------------------------- class DoubleConv(nn.Module): """(Conv => BN => ReLU) * 2""" def __init__(self, in_ch, out_ch): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_ch, out_ch): super().__init__() self.net = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_ch, out_ch) ) def forward(self, x): return self.net(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_ch, out_ch): super().__init__() self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv = DoubleConv(in_ch, out_ch) def forward(self, x, skip): x = self.up(x) if skip is not None: # pad if needed (for odd input sizes) diffY = skip.size(2) - x.size(2) diffX = skip.size(3) - x.size(3) x = F.pad(x, [ diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2 ]) x = torch.cat([skip, x], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_ch, num_classes): super().__init__() self.conv = nn.Conv2d(in_ch, num_classes, 1) def forward(self, x): return self.conv(x) class ASPP(nn.Module): def __init__(self, in_channels, out_channels): super(ASPP, self).__init__() self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1) self.atrous_block6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6) self.atrous_block12 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12) self.atrous_block18 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18) self.lower_features = nn.Sequential( nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1), nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) ) ) self.conv_1x1_output = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1) def forward(self, x): x1 = self.atrous_block1(x) x2 = self.atrous_block6(x) x3 = self.atrous_block12(x) x4 = self.atrous_block18(x) x5 = self.lower_features(x) x = torch.cat((x1, x2, x3, x4, x5), dim=1) x = self.conv_1x1_output(x) return x class ResNetEncoder(nn.Module): def __init__(self, in_channels=3, pretrained=True): super().__init__() resnet = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None) if in_channels != 3: resnet.conv1 = nn.Conv2d( in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False ) self.initial = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu ) self.maxpool = resnet.maxpool # 64x64 self.layer1 = resnet.layer1 # 64x64 self.layer2 = resnet.layer2 # 32x32 self.layer3 = resnet.layer3 # 16x16 self.layer4 = resnet.layer4 # 8x8 def forward(self, x): x1 = self.initial(x) x2 = self.layer1(self.maxpool(x1)) x3 = self.layer2(x2) x4 = self.layer3(x3) x5 = self.layer4(x4) return x1, x2, x3, x4, x5 # ------------------------------------------------- # Standard U-Net # ------------------------------------------------- class model(nn.Module): def __init__(self, in_channels=3, num_classes=1, freeze_encoder=False): super().__init__() self.encoder = ResNetEncoder(in_channels, pretrained=True) if freeze_encoder: for p in self.encoder.parameters(): p.requires_grad = False self.up1 = Up(512 + 256, 256) self.up2 = Up(256 + 128, 128) self.up3 = Up(128 + 64, 64) self.up4 = Up(64 + 64, 64) self.up5 = Up(64, 64) self.aspp1 = ASPP(64, 64) self.aspp2 = ASPP(64, 64) self.aspp3 = ASPP(128, 128) self.aspp4 = ASPP(256, 256) self.outc = OutConv(64, num_classes) def forward(self, x): x1, x2, x3, x4, x5 = self.encoder(x) s4 = self.aspp4(x4) s3 = self.aspp3(x3) s2 = self.aspp2(x2) s1 = self.aspp1(x1) x = self.up1(x5, s4) x = self.up2(x, s3) x = self.up3(x, s2) x = self.up4(x, s1) x = self.up5(x, None) return self.outc(x)