import torch import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x) class DownSample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = DoubleConv(in_channels, out_channels) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): x = self.conv(x) return x, self.pool(x) class UpSample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.down_conv_1 = DownSample(in_channels, 32) self.down_conv_2 = DownSample(32, 64) self.down_conv_3 = DownSample(64, 128) self.down_conv_4 = DownSample(128, 256) self.bottle_neck = DoubleConv(256, 512) self.up_conv_1 = UpSample(512, 256) self.up_conv_2 = UpSample(256, 128) self.up_conv_3 = UpSample(128, 64) self.up_conv_4 = UpSample(64, 32) self.out = nn.Conv2d(in_channels=32, out_channels=num_classes, kernel_size=1) def forward(self, x): down_1, p1 = self.down_conv_1(x) down_2, p2 = self.down_conv_2(p1) down_3, p3 = self.down_conv_3(p2) down_4, p4 = self.down_conv_4(p3) b = self.bottle_neck(p4) up_1 = self.up_conv_1(b, down_4) up_2 = self.up_conv_2(up_1, down_3) up_3 = self.up_conv_3(up_2, down_2) up_4 = self.up_conv_4(up_3, down_1) out = self.out(up_4) return out if __name__ == '__main__': model = UNet(in_channels=3, num_classes=1) print(model)