Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from config import NUM_CHANNELS, NUM_CLASSES | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_ch: int, out_ch: int): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class SmallUNet(nn.Module): | |
| def __init__(self, in_channels: int = NUM_CHANNELS, num_classes: int = NUM_CLASSES, base_channels: int = 16): | |
| super().__init__() | |
| self.enc1 = DoubleConv(in_channels, base_channels) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.enc2 = DoubleConv(base_channels, base_channels * 2) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.enc3 = DoubleConv(base_channels * 2, base_channels * 4) | |
| self.pool3 = nn.MaxPool2d(2) | |
| self.bottleneck = DoubleConv(base_channels * 4, base_channels * 8) | |
| self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2) | |
| self.dec3 = DoubleConv(base_channels * 8, base_channels * 4) | |
| self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2) | |
| self.dec2 = DoubleConv(base_channels * 4, base_channels * 2) | |
| self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2) | |
| self.dec1 = DoubleConv(base_channels * 2, base_channels) | |
| self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) | |
| def forward(self, x): | |
| H, W = x.shape[2], x.shape[3] | |
| e1 = self.enc1(x) | |
| e2 = self.enc2(self.pool1(e1)) | |
| e3 = self.enc3(self.pool2(e2)) | |
| b = self.bottleneck(self.pool3(e3)) | |
| d3 = self.up3(b) | |
| d3 = torch.cat([d3, e3[:, :, :d3.shape[2], :d3.shape[3]]], dim=1) | |
| d3 = self.dec3(d3) | |
| d2 = self.up2(d3) | |
| d2 = torch.cat([d2, e2[:, :, :d2.shape[2], :d2.shape[3]]], dim=1) | |
| d2 = self.dec2(d2) | |
| d1 = self.up1(d2) | |
| d1 = torch.cat([d1, e1[:, :, :d1.shape[2], :d1.shape[3]]], dim=1) | |
| d1 = self.dec1(d1) | |
| out = self.head(d1) | |
| if out.shape[2] != H or out.shape[3] != W: | |
| out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False) | |
| return out | |