Spaces:
Sleeping
Sleeping
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| class ConvBlock(nn.Module): | |
| def __init__(self, ch_in, ch_out): | |
| super(ConvBlock, self).__init__() | |
| self.conv1 = nn.Conv2d( | |
| ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| nn.init.normal_(self.conv1.weight, mean=0.0, std=0.02) | |
| nn.init.normal_(self.conv2.weight, mean=0.0, std=0.02) | |
| self.batchnorm1 = nn.BatchNorm2d(ch_out) | |
| self.batchnorm2 = nn.BatchNorm2d(ch_out) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| h = self.relu(self.batchnorm1(self.conv1(x))) | |
| h = self.relu(self.batchnorm2(self.conv2(h))) | |
| return h | |
| class EncodeBlock(nn.Module): | |
| def __init__(self, ch_in, ch_out): | |
| super(EncodeBlock, self).__init__() | |
| self.conv = ConvBlock(ch_in, ch_out) | |
| self.pool = nn.MaxPool2d((2, 2)) | |
| def forward(self, x): | |
| skip = self.conv(x) | |
| h = self.pool(skip) | |
| return h, skip | |
| class DecodeBlock(nn.Module): | |
| def __init__(self, ch_in, ch_out): | |
| super(DecodeBlock, self).__init__() | |
| self.up = nn.ConvTranspose2d( | |
| ch_in, ch_out, kernel_size=2, stride=2, padding=0, bias=True | |
| ) | |
| self.conv = ConvBlock(ch_out * 2, ch_out) | |
| def forward(self, x, skip): | |
| h = self.up(x) | |
| h = self.conv(torch.cat([h, skip], dim=1)) | |
| return h | |
| class UNet(nn.Module): | |
| def __init__(self, ch_in, ch_out): | |
| super(UNet, self).__init__() | |
| self.econv0 = nn.Conv2d( | |
| ch_in, 64, kernel_size=1, stride=1, padding=0, bias=True | |
| ) | |
| nn.init.normal_(self.econv0.weight, mean=0.0, std=0.02) | |
| self.econv1 = self.make_downblock(64, 64) | |
| self.econv2 = self.make_downblock(64, 128) | |
| self.econv3 = self.make_downblock(128, 256) | |
| self.econv4 = self.make_downblock(256, 512) | |
| self.bottle = self.make_bottleblock(512, 1024) | |
| self.dconv4 = self.make_upblock(1024, 512) | |
| self.dconv3 = self.make_upblock(512, 256) | |
| self.dconv2 = self.make_upblock(256, 128) | |
| self.dconv1 = self.make_upblock(128, 64) | |
| self.dconv0 = nn.Conv2d( | |
| 64, ch_out, kernel_size=1, stride=1, padding=0, bias=True | |
| ) | |
| nn.init.normal_(self.dconv0.weight, mean=0.0, std=0.02) | |
| def make_downblock(self, ch_in, ch_out): | |
| return EncodeBlock(ch_in=ch_in, ch_out=ch_out) | |
| def make_bottleblock(self, ch_in, ch_out): | |
| return ConvBlock(ch_in=ch_in, ch_out=ch_out) | |
| def make_upblock(self, ch_in, ch_out): | |
| return DecodeBlock(ch_in=ch_in, ch_out=ch_out) | |
| def forward(self, x): | |
| x = self.econv0(x) | |
| x, skip1 = self.econv1(x) | |
| x, skip2 = self.econv2(x) | |
| x, skip3 = self.econv3(x) | |
| x, skip4 = self.econv4(x) | |
| x = self.bottle(x) | |
| x = self.dconv4(x, skip4) | |
| x = self.dconv3(x, skip3) | |
| x = self.dconv2(x, skip2) | |
| x = self.dconv1(x, skip1) | |
| x = self.dconv0(x) | |
| return x | |