| | |
| | |
| |
|
| | import torch |
| | from torch import nn as nn |
| | from torch.nn import functional as F |
| | import os, sys |
| | import numpy as np |
| | from time import time as ttime, sleep |
| |
|
| |
|
| | class UNet_Full(nn.Module): |
| |
|
| | def __init__(self): |
| | super(UNet_Full, self).__init__() |
| | self.unet1 = UNet1(3, 3, deconv=True) |
| | self.unet2 = UNet2(3, 3, deconv=False) |
| |
|
| | def forward(self, x): |
| | n, c, h0, w0 = x.shape |
| | |
| | ph = ((h0 - 1) // 2 + 1) * 2 |
| | pw = ((w0 - 1) // 2 + 1) * 2 |
| | x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') |
| |
|
| | x1 = self.unet1(x) |
| | x2 = self.unet2(x1) |
| | |
| | x1 = F.pad(x1, (-20, -20, -20, -20)) |
| | output = torch.add(x2, x1) |
| |
|
| | if (w0 != pw or h0 != ph): |
| | output = output[:, :, :h0 * 2, :w0 * 2] |
| | |
| | return output |
| |
|
| |
|
| | class SEBlock(nn.Module): |
| | def __init__(self, in_channels, reduction=8, bias=False): |
| | super(SEBlock, self).__init__() |
| | self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias) |
| | self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias) |
| |
|
| | def forward(self, x): |
| | if ("Half" in x.type()): |
| | x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half() |
| | else: |
| | x0 = torch.mean(x, dim=(2, 3), keepdim=True) |
| | x0 = self.conv1(x0) |
| | x0 = F.relu(x0, inplace=True) |
| | x0 = self.conv2(x0) |
| | x0 = torch.sigmoid(x0) |
| | x = torch.mul(x, x0) |
| | return x |
| |
|
| | class UNetConv(nn.Module): |
| | def __init__(self, in_channels, mid_channels, out_channels, se): |
| | super(UNetConv, self).__init__() |
| | self.conv = nn.Sequential( |
| | nn.Conv2d(in_channels, mid_channels, 3, 1, 0), |
| | nn.LeakyReLU(0.1, inplace=True), |
| | nn.Conv2d(mid_channels, out_channels, 3, 1, 0), |
| | nn.LeakyReLU(0.1, inplace=True), |
| | ) |
| | if se: |
| | self.seblock = SEBlock(out_channels, reduction=8, bias=True) |
| | else: |
| | self.seblock = None |
| |
|
| | def forward(self, x): |
| | z = self.conv(x) |
| | if self.seblock is not None: |
| | z = self.seblock(z) |
| | return z |
| |
|
| | class UNet1(nn.Module): |
| | def __init__(self, in_channels, out_channels, deconv): |
| | super(UNet1, self).__init__() |
| | self.conv1 = UNetConv(in_channels, 32, 64, se=False) |
| | self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) |
| | self.conv2 = UNetConv(64, 128, 64, se=True) |
| | self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) |
| | self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) |
| |
|
| | if deconv: |
| | self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) |
| | else: |
| | self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
| | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| | elif isinstance(m, nn.Linear): |
| | nn.init.normal_(m.weight, 0, 0.01) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def forward(self, x): |
| | x1 = self.conv1(x) |
| | x2 = self.conv1_down(x1) |
| | x2 = F.leaky_relu(x2, 0.1, inplace=True) |
| | x2 = self.conv2(x2) |
| | x2 = self.conv2_up(x2) |
| | x2 = F.leaky_relu(x2, 0.1, inplace=True) |
| |
|
| | x1 = F.pad(x1, (-4, -4, -4, -4)) |
| | x3 = self.conv3(x1 + x2) |
| | x3 = F.leaky_relu(x3, 0.1, inplace=True) |
| | z = self.conv_bottom(x3) |
| | return z |
| |
|
| |
|
| | class UNet2(nn.Module): |
| | def __init__(self, in_channels, out_channels, deconv): |
| | super(UNet2, self).__init__() |
| |
|
| | self.conv1 = UNetConv(in_channels, 32, 64, se=False) |
| | self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) |
| | self.conv2 = UNetConv(64, 64, 128, se=True) |
| | self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0) |
| | self.conv3 = UNetConv(128, 256, 128, se=True) |
| | self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0) |
| | self.conv4 = UNetConv(128, 64, 64, se=True) |
| | self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) |
| | self.conv5 = nn.Conv2d(64, 64, 3, 1, 0) |
| |
|
| | if deconv: |
| | self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) |
| | else: |
| | self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
| | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| | elif isinstance(m, nn.Linear): |
| | nn.init.normal_(m.weight, 0, 0.01) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def forward(self, x): |
| | x1 = self.conv1(x) |
| | x2 = self.conv1_down(x1) |
| | x2 = F.leaky_relu(x2, 0.1, inplace=True) |
| | x2 = self.conv2(x2) |
| |
|
| | x3 = self.conv2_down(x2) |
| | x3 = F.leaky_relu(x3, 0.1, inplace=True) |
| | x3 = self.conv3(x3) |
| | x3 = self.conv3_up(x3) |
| | x3 = F.leaky_relu(x3, 0.1, inplace=True) |
| |
|
| | x2 = F.pad(x2, (-4, -4, -4, -4)) |
| | x4 = self.conv4(x2 + x3) |
| | x4 = self.conv4_up(x4) |
| | x4 = F.leaky_relu(x4, 0.1, inplace=True) |
| |
|
| | x1 = F.pad(x1, (-16, -16, -16, -16)) |
| | x5 = self.conv5(x1 + x4) |
| | x5 = F.leaky_relu(x5, 0.1, inplace=True) |
| |
|
| | z = self.conv_bottom(x5) |
| | return z |
| |
|
| | |
| |
|
| | def main(): |
| | root_path = os.path.abspath('.') |
| | sys.path.append(root_path) |
| |
|
| | from opt import opt |
| | import time |
| | |
| | model = UNet_Full().cuda() |
| | pytorch_total_params = sum(p.numel() for p in model.parameters()) |
| | print(f"CuNet has param {pytorch_total_params//1000} K params") |
| |
|
| |
|
| | |
| | x = torch.randn((1, 3, 180, 180)).cuda() |
| | start = time.time() |
| | x = model(x) |
| | print("output size is ", x.shape) |
| | total = time.time() - start |
| | print(total) |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |