File size: 3,844 Bytes
984cdba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# -------------------------------------------------
# Basic blocks
# -------------------------------------------------
class DoubleConv(nn.Module):
"""(Conv => BN => ReLU) * 2"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_ch, out_ch)
)
def forward(self, x):
return self.net(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_ch, out_ch):
super().__init__()
# bilinear upsample + conv (no checkerboard artifacts)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
if skip is not None:
# pad if needed (for odd input sizes)
diffY = skip.size(2) - x.size(2)
diffX = skip.size(3) - x.size(3)
x = F.pad(x, [
diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2
])
x = torch.cat([skip, x], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_ch, num_classes):
super().__init__()
self.conv = nn.Conv2d(in_ch, num_classes, 1)
def forward(self, x):
return self.conv(x)
class ResNetEncoder(nn.Module):
def __init__(self, in_channels=3, pretrained=True):
super().__init__()
resnet = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None)
if in_channels != 3:
resnet.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.initial = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu
) # 128x128
self.maxpool = resnet.maxpool # 64x64
self.layer1 = resnet.layer1 # 64x64
self.layer2 = resnet.layer2 # 32x32
self.layer3 = resnet.layer3 # 16x16
self.layer4 = resnet.layer4 # 8x8
def forward(self, x):
x1 = self.initial(x)
x2 = self.layer1(self.maxpool(x1))
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x1, x2, x3, x4, x5
# -------------------------------------------------
# Standard U-Net
# -------------------------------------------------
class model(nn.Module):
def __init__(self, in_channels=3, num_classes=1, freeze_encoder=False):
super().__init__()
self.encoder = ResNetEncoder(in_channels, pretrained=True)
if freeze_encoder:
for p in self.encoder.parameters():
p.requires_grad = False
self.up1 = Up(512 + 256, 256)
self.up2 = Up(256 + 128, 128)
self.up3 = Up(128 + 64, 64)
self.up4 = Up(64 + 64, 64)
self.up5 = Up(64, 64)
self.outc = OutConv(64, num_classes)
def forward(self, x):
x1, x2, x3, x4, x5 = self.encoder(x)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.up5(x, None)
return self.outc(x) |