Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class DownSample(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x, self.pool(x) | |
| class UpSample(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| diffY = x2.size()[2] - x1.size()[2] | |
| diffX = x2.size()[3] - x1.size()[3] | |
| x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, | |
| diffY // 2, diffY - diffY // 2]) | |
| x = torch.cat([x2, x1], dim=1) | |
| return self.conv(x) | |
| class UNet(nn.Module): | |
| def __init__(self, in_channels, num_classes): | |
| super().__init__() | |
| self.down_conv_1 = DownSample(in_channels, 32) | |
| self.down_conv_2 = DownSample(32, 64) | |
| self.down_conv_3 = DownSample(64, 128) | |
| self.down_conv_4 = DownSample(128, 256) | |
| self.bottle_neck = DoubleConv(256, 512) | |
| self.up_conv_1 = UpSample(512, 256) | |
| self.up_conv_2 = UpSample(256, 128) | |
| self.up_conv_3 = UpSample(128, 64) | |
| self.up_conv_4 = UpSample(64, 32) | |
| self.out = nn.Conv2d(in_channels=32, out_channels=num_classes, kernel_size=1) | |
| def forward(self, x): | |
| down_1, p1 = self.down_conv_1(x) | |
| down_2, p2 = self.down_conv_2(p1) | |
| down_3, p3 = self.down_conv_3(p2) | |
| down_4, p4 = self.down_conv_4(p3) | |
| b = self.bottle_neck(p4) | |
| up_1 = self.up_conv_1(b, down_4) | |
| up_2 = self.up_conv_2(up_1, down_3) | |
| up_3 = self.up_conv_3(up_2, down_2) | |
| up_4 = self.up_conv_4(up_3, down_1) | |
| out = self.out(up_4) | |
| return out | |
| if __name__ == '__main__': | |
| model = UNet(in_channels=3, num_classes=1) | |
| print(model) | |