Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# -------------------------------------------------
# Basic blocks
# -------------------------------------------------
class DoubleConv(nn.Module):
"""(Conv => BN => ReLU) * 2"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_ch, out_ch)
)
def forward(self, x):
return self.net(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
if skip is not None:
# pad if needed (for odd input sizes)
diffY = skip.size(2) - x.size(2)
diffX = skip.size(3) - x.size(3)
x = F.pad(x, [
diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2
])
x = torch.cat([skip, x], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_ch, num_classes):
super().__init__()
self.conv = nn.Conv2d(in_ch, num_classes, 1)
def forward(self, x):
return self.conv(x)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPP, self).__init__()
self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1)
self.atrous_block6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
self.atrous_block12 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
self.atrous_block18 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18)
self.lower_features = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1),
nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)
)
)
self.conv_1x1_output = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.atrous_block1(x)
x2 = self.atrous_block6(x)
x3 = self.atrous_block12(x)
x4 = self.atrous_block18(x)
x5 = self.lower_features(x)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv_1x1_output(x)
return x
class ResNetEncoder(nn.Module):
def __init__(self, in_channels=3, pretrained=True):
super().__init__()
resnet = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None)
if in_channels != 3:
resnet.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.initial = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu
)
self.maxpool = resnet.maxpool # 64x64
self.layer1 = resnet.layer1 # 64x64
self.layer2 = resnet.layer2 # 32x32
self.layer3 = resnet.layer3 # 16x16
self.layer4 = resnet.layer4 # 8x8
def forward(self, x):
x1 = self.initial(x)
x2 = self.layer1(self.maxpool(x1))
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x1, x2, x3, x4, x5
# -------------------------------------------------
# Standard U-Net
# -------------------------------------------------
class model(nn.Module):
def __init__(self, in_channels=3, num_classes=1, freeze_encoder=False):
super().__init__()
self.encoder = ResNetEncoder(in_channels, pretrained=True)
if freeze_encoder:
for p in self.encoder.parameters():
p.requires_grad = False
self.up1 = Up(512 + 256, 256)
self.up2 = Up(256 + 128, 128)
self.up3 = Up(128 + 64, 64)
self.up4 = Up(64 + 64, 64)
self.up5 = Up(64, 64)
self.aspp1 = ASPP(64, 64)
self.aspp2 = ASPP(64, 64)
self.aspp3 = ASPP(128, 128)
self.aspp4 = ASPP(256, 256)
self.outc = OutConv(64, num_classes)
def forward(self, x):
x1, x2, x3, x4, x5 = self.encoder(x)
s4 = self.aspp4(x4)
s3 = self.aspp3(x3)
s2 = self.aspp2(x2)
s1 = self.aspp1(x1)
x = self.up1(x5, s4)
x = self.up2(x, s3)
x = self.up3(x, s2)
x = self.up4(x, s1)
x = self.up5(x, None)
return self.outc(x)