import torch import torch.nn as nn import torch.nn.functional as F # ------------------------------------------------- # Basic blocks # ------------------------------------------------- class DoubleConv(nn.Module): """(Conv => BN => ReLU) * 2""" def __init__(self, in_ch, out_ch): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_ch, out_ch): super().__init__() self.net = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_ch, out_ch) ) def forward(self, x): return self.net(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_ch, out_ch): super().__init__() # bilinear upsample + conv (no checkerboard artifacts) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv = DoubleConv(in_ch, out_ch) def forward(self, x, skip): x = self.up(x) # pad if needed (for odd input sizes) diffY = skip.size(2) - x.size(2) diffX = skip.size(3) - x.size(3) x = F.pad(x, [ diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2 ]) x = torch.cat([skip, x], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_ch, num_classes): super().__init__() self.conv = nn.Conv2d(in_ch, num_classes, 1) def forward(self, x): return self.conv(x) class Encoder(nn.Module): def __init__(self, in_channels, base_c=64): super().__init__() self.inc = DoubleConv(in_channels, base_c) # 256x256 self.down1 = Down(base_c, base_c*2) # 128x128 self.down2 = Down(base_c*2, base_c*4) # 64x64 self.down3 = Down(base_c*4, base_c*8) # 32x32 self.down4 = Down(base_c*8, base_c*16) # 16x16 def forward(self, x): x1 = self.inc(x) # 256x256 x2 = self.down1(x1) # 128x128 x3 = self.down2(x2) # 64x64 x4 = self.down3(x3) # 32x32 x5 = self.down4(x4) # 16x16 return x1, x2, x3, x4, x5 # ------------------------------------------------- # Standard U-Net # ------------------------------------------------- class model(nn.Module): def __init__(self, in_channels=3, num_classes=1, freeze_encoder=False, base_c=64): super().__init__() # Encoder self.encoder = Encoder(in_channels, base_c) # 256x256 if freeze_encoder: for param in self.encoder.parameters(): param.requires_grad = False # Decoder self.up1 = Up(base_c*16 + base_c*8, base_c*8) # 32x32 self.up2 = Up(base_c*8 + base_c*4, base_c*4) # 64x64 self.up3 = Up(base_c*4 + base_c*2, base_c*2) # 128x128 self.up4 = Up(base_c*2 + base_c, base_c) # 256x256 # Output self.outc = OutConv(base_c, num_classes) def forward(self, x): # Encoder x1, x2, x3, x4, x5 = self.encoder(x) # Decoder x = self.up1(x5, x4) # 32x32 x = self.up2(x, x3) # 64x64 x = self.up3(x, x2) # 128x128 x = self.up4(x, x1) # 256x256 # Output return self.outc(x)