import torch import torch.nn as nn import torch.nn.functional as F from config import NUM_CHANNELS, NUM_CLASSES class DoubleConv(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class SmallUNet(nn.Module): def __init__(self, in_channels: int = NUM_CHANNELS, num_classes: int = NUM_CLASSES, base_channels: int = 16): super().__init__() self.enc1 = DoubleConv(in_channels, base_channels) self.pool1 = nn.MaxPool2d(2) self.enc2 = DoubleConv(base_channels, base_channels * 2) self.pool2 = nn.MaxPool2d(2) self.enc3 = DoubleConv(base_channels * 2, base_channels * 4) self.pool3 = nn.MaxPool2d(2) self.bottleneck = DoubleConv(base_channels * 4, base_channels * 8) self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2) self.dec3 = DoubleConv(base_channels * 8, base_channels * 4) self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2) self.dec2 = DoubleConv(base_channels * 4, base_channels * 2) self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2) self.dec1 = DoubleConv(base_channels * 2, base_channels) self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) def forward(self, x): H, W = x.shape[2], x.shape[3] e1 = self.enc1(x) e2 = self.enc2(self.pool1(e1)) e3 = self.enc3(self.pool2(e2)) b = self.bottleneck(self.pool3(e3)) d3 = self.up3(b) d3 = torch.cat([d3, e3[:, :, :d3.shape[2], :d3.shape[3]]], dim=1) d3 = self.dec3(d3) d2 = self.up2(d3) d2 = torch.cat([d2, e2[:, :, :d2.shape[2], :d2.shape[3]]], dim=1) d2 = self.dec2(d2) d1 = self.up1(d2) d1 = torch.cat([d1, e1[:, :, :d1.shape[2], :d1.shape[3]]], dim=1) d1 = self.dec1(d1) out = self.head(d1) if out.shape[2] != H or out.shape[3] != W: out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False) return out