GlacialLakes / code /lake_detection_deep_learning /trainer /models /.ipynb_checkpoints /unet-checkpoint.py
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ------------------------------------------------- | |
| # Basic blocks | |
| # ------------------------------------------------- | |
| class DoubleConv(nn.Module): | |
| """(Conv => BN => ReLU) * 2""" | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Down(nn.Module): | |
| """Downscaling with maxpool then double conv""" | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.MaxPool2d(2), | |
| DoubleConv(in_ch, out_ch) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Up(nn.Module): | |
| """Upscaling then double conv""" | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| # bilinear upsample + conv (no checkerboard artifacts) | |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
| self.conv = DoubleConv(in_ch, out_ch) | |
| def forward(self, x, skip): | |
| x = self.up(x) | |
| # pad if needed (for odd input sizes) | |
| diffY = skip.size(2) - x.size(2) | |
| diffX = skip.size(3) - x.size(3) | |
| x = F.pad(x, [ | |
| diffX // 2, diffX - diffX // 2, | |
| diffY // 2, diffY - diffY // 2 | |
| ]) | |
| x = torch.cat([skip, x], dim=1) | |
| return self.conv(x) | |
| class OutConv(nn.Module): | |
| def __init__(self, in_ch, num_classes): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_ch, num_classes, 1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Encoder(nn.Module): | |
| def __init__(self, in_channels, base_c=64): | |
| super().__init__() | |
| self.inc = DoubleConv(in_channels, base_c) # 256x256 | |
| self.down1 = Down(base_c, base_c*2) # 128x128 | |
| self.down2 = Down(base_c*2, base_c*4) # 64x64 | |
| self.down3 = Down(base_c*4, base_c*8) # 32x32 | |
| self.down4 = Down(base_c*8, base_c*16) # 16x16 | |
| def forward(self, x): | |
| x1 = self.inc(x) # 256x256 | |
| x2 = self.down1(x1) # 128x128 | |
| x3 = self.down2(x2) # 64x64 | |
| x4 = self.down3(x3) # 32x32 | |
| x5 = self.down4(x4) # 16x16 | |
| return x1, x2, x3, x4, x5 | |
| # ------------------------------------------------- | |
| # Standard U-Net | |
| # ------------------------------------------------- | |
| class model(nn.Module): | |
| def __init__(self, | |
| in_channels=3, | |
| num_classes=1, | |
| freeze_encoder=False, | |
| base_c=64): | |
| super().__init__() | |
| # Encoder | |
| self.encoder = Encoder(in_channels, base_c) # 256x256 | |
| if freeze_encoder: | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| # Decoder | |
| self.up1 = Up(base_c*16 + base_c*8, base_c*8) # 32x32 | |
| self.up2 = Up(base_c*8 + base_c*4, base_c*4) # 64x64 | |
| self.up3 = Up(base_c*4 + base_c*2, base_c*2) # 128x128 | |
| self.up4 = Up(base_c*2 + base_c, base_c) # 256x256 | |
| # Output | |
| self.outc = OutConv(base_c, num_classes) | |
| def forward(self, x): | |
| # Encoder | |
| x1, x2, x3, x4, x5 = self.encoder(x) | |
| # Decoder | |
| x = self.up1(x5, x4) # 32x32 | |
| x = self.up2(x, x3) # 64x64 | |
| x = self.up3(x, x2) # 128x128 | |
| x = self.up4(x, x1) # 256x256 | |
| # Output | |
| return self.outc(x) | |