Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------------------------
# Basic blocks
# -------------------------------------------------
class DoubleConv(nn.Module):
"""(Conv => BN => ReLU) * 2"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_ch, out_ch)
)
def forward(self, x):
return self.net(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_ch, out_ch):
super().__init__()
# bilinear upsample + conv (no checkerboard artifacts)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
# pad if needed (for odd input sizes)
diffY = skip.size(2) - x.size(2)
diffX = skip.size(3) - x.size(3)
x = F.pad(x, [
diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2
])
x = torch.cat([skip, x], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_ch, num_classes):
super().__init__()
self.conv = nn.Conv2d(in_ch, num_classes, 1)
def forward(self, x):
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, in_channels, base_c=64):
super().__init__()
self.inc = DoubleConv(in_channels, base_c) # 256x256
self.down1 = Down(base_c, base_c*2) # 128x128
self.down2 = Down(base_c*2, base_c*4) # 64x64
self.down3 = Down(base_c*4, base_c*8) # 32x32
self.down4 = Down(base_c*8, base_c*16) # 16x16
def forward(self, x):
x1 = self.inc(x) # 256x256
x2 = self.down1(x1) # 128x128
x3 = self.down2(x2) # 64x64
x4 = self.down3(x3) # 32x32
x5 = self.down4(x4) # 16x16
return x1, x2, x3, x4, x5
# -------------------------------------------------
# Standard U-Net
# -------------------------------------------------
class model(nn.Module):
def __init__(self,
in_channels=3,
num_classes=1,
freeze_encoder=False,
base_c=64):
super().__init__()
# Encoder
self.encoder = Encoder(in_channels, base_c) # 256x256
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
# Decoder
self.up1 = Up(base_c*16 + base_c*8, base_c*8) # 32x32
self.up2 = Up(base_c*8 + base_c*4, base_c*4) # 64x64
self.up3 = Up(base_c*4 + base_c*2, base_c*2) # 128x128
self.up4 = Up(base_c*2 + base_c, base_c) # 256x256
# Output
self.outc = OutConv(base_c, num_classes)
def forward(self, x):
# Encoder
x1, x2, x3, x4, x5 = self.encoder(x)
# Decoder
x = self.up1(x5, x4) # 32x32
x = self.up2(x, x3) # 64x64
x = self.up3(x, x2) # 128x128
x = self.up4(x, x1) # 256x256
# Output
return self.outc(x)