Spaces:
Runtime error
Runtime error
| """ | |
| UNet Network in PyTorch, modified from https://github.com/milesial/Pytorch-UNet | |
| with architecture referenced from https://keras.io/examples/vision/depth_estimation | |
| for monocular depth estimation from RGB images, i.e. one output channel. | |
| """ | |
| import torch | |
| from torch import nn | |
| class UNet(nn.Module): | |
| """ | |
| The overall UNet architecture. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.downscale_blocks = nn.ModuleList( | |
| [ | |
| DownBlock(16, 32), | |
| DownBlock(32, 64), | |
| DownBlock(64, 128), | |
| DownBlock(128, 256), | |
| ] | |
| ) | |
| self.upscale_blocks = nn.ModuleList( | |
| [ | |
| UpBlock(256, 128), | |
| UpBlock(128, 64), | |
| UpBlock(64, 32), | |
| UpBlock(32, 16), | |
| ] | |
| ) | |
| self.input_conv = nn.Conv2d(3, 16, kernel_size=3, padding="same") | |
| self.output_conv = nn.Conv2d(16, 1, kernel_size=1) | |
| self.bridge = BottleNeckBlock(256) | |
| self.activation = nn.Sigmoid() | |
| def forward(self, x): | |
| x = self.input_conv(x) | |
| skip_features = [] | |
| for block in self.downscale_blocks: | |
| c, x = block(x) | |
| skip_features.append(c) | |
| x = self.bridge(x) | |
| skip_features.reverse() | |
| for block, skip in zip(self.upscale_blocks, skip_features): | |
| x = block(x, skip) | |
| x = self.output_conv(x) | |
| x = self.activation(x) | |
| return x | |
| class DownBlock(nn.Module): | |
| """ | |
| Module that performs downscaling with residual connections. | |
| """ | |
| def __init__(self, in_channels, out_channels, padding="same", stride=1): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=padding, | |
| bias=False, | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=padding, | |
| bias=False, | |
| ) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.LeakyReLU(0.2) | |
| self.maxpool = nn.MaxPool2d(2) | |
| def forward(self, x): | |
| d = self.conv1(x) | |
| x = self.bn1(d) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu(x) | |
| x = x + d | |
| p = self.maxpool(x) | |
| return x, p | |
| class UpBlock(nn.Module): | |
| """ | |
| Module that performs upscaling after concatenation with skip connections. | |
| """ | |
| def __init__(self, in_channels, out_channels, padding="same", stride=1): | |
| super().__init__() | |
| self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | |
| self.conv1 = nn.Conv2d( | |
| in_channels * 2, | |
| in_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=padding, | |
| bias=False, | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=padding, | |
| bias=False, | |
| ) | |
| self.bn1 = nn.BatchNorm2d(in_channels) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.LeakyReLU(0.2) | |
| def forward(self, x, skip): | |
| x = self.up(x) | |
| x = torch.cat([x, skip], dim=1) | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu(x) | |
| return x | |
| class BottleNeckBlock(nn.Module): | |
| """ | |
| BottleNeckBlock that serves as the UNet bridge. | |
| """ | |
| def __init__(self, channels, padding="same", strides=1): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(channels, channels, 3, 1, "same") | |
| self.conv2 = nn.Conv2d(channels, channels, 3, 1, "same") | |
| self.relu = nn.LeakyReLU(0.2) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.relu(x) | |
| return x |