|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.init as init |
|
|
import torchvision.models as models |
|
|
from torchvision.models import ResNet34_Weights |
|
|
|
|
|
|
|
|
class ResNetEncoder(nn.Module): |
|
|
def __init__(self, freeze=True): |
|
|
super().__init__() |
|
|
resnet = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) |
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
with torch.no_grad(): |
|
|
self.conv1.weight[:] = resnet.conv1.weight.mean(dim=1, keepdim=True) |
|
|
|
|
|
self.bn1 = resnet.bn1 |
|
|
self.relu = resnet.relu |
|
|
self.maxpool = resnet.maxpool |
|
|
self.layer1 = resnet.layer1 |
|
|
self.layer2 = resnet.layer2 |
|
|
self.layer3 = resnet.layer3 |
|
|
self.layer4 = resnet.layer4 |
|
|
|
|
|
if freeze: |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.bn1(x) |
|
|
x = self.relu(x) |
|
|
x1 = self.maxpool(x) |
|
|
x2 = self.layer1(x1) |
|
|
del x1 |
|
|
|
|
|
x3 = self.layer2(x2) |
|
|
x4 = self.layer3(x3) |
|
|
x5 = self.layer4(x4) |
|
|
|
|
|
return x, x2, x3, x4, x5 |
|
|
|
|
|
|
|
|
def icnr(tensor, scale=2, init_func=init.kaiming_normal_): |
|
|
ni, nf, h, w = tensor.shape |
|
|
ni2 = int(ni / (scale ** 2)) |
|
|
k = init_func(torch.zeros([ni2, nf, h, w])) |
|
|
k = k.repeat_interleave(scale ** 2, 0) |
|
|
with torch.no_grad(): |
|
|
tensor.copy_(k) |
|
|
|
|
|
|
|
|
class PixelShuffleICNR(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, scale=2): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(in_channels, out_channels * (scale ** 2), kernel_size=3, padding=1) |
|
|
icnr(self.conv.weight, scale=scale) |
|
|
self.pixel_shuffle = nn.PixelShuffle(scale) |
|
|
self.bn = nn.BatchNorm2d(out_channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv(x) |
|
|
x = self.pixel_shuffle(x) |
|
|
x = self.bn(x) |
|
|
x = self.relu(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
|
def __init__(self, in_channels, skip_channels, out_channels): |
|
|
super().__init__() |
|
|
|
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(in_channels + skip_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, x, skip): |
|
|
x = self.upsample(x) |
|
|
if skip is not None: |
|
|
x = torch.cat([x, skip], dim=1) |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.dec4 = DecoderBlock(512, 256, 256) |
|
|
self.dec3 = DecoderBlock(256, 128, 128) |
|
|
self.dec2 = DecoderBlock(128, 64, 64) |
|
|
self.dec1 = DecoderBlock(64, 64, 64) |
|
|
self.pixel_shuffle = PixelShuffleICNR(64, 16, scale=2) |
|
|
self.final = nn.Conv2d(16, 2, kernel_size=3, padding=1) |
|
|
|
|
|
def forward(self, x5, x4, x3, x2, x1): |
|
|
d4 = self.dec4(x5, x4) |
|
|
d3 = self.dec3(d4, x3) |
|
|
del d4, x4, x3 |
|
|
d2 = self.dec2(d3, x2) |
|
|
del d3, x2 |
|
|
d1 = self.dec1(d2, x1) |
|
|
del d2, x1 |
|
|
out = self.pixel_shuffle(d1) |
|
|
del d1 |
|
|
out = self.final(out) |
|
|
return torch.tanh(out) |
|
|
|
|
|
|
|
|
class UNet(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.encoder = ResNetEncoder() |
|
|
self.decoder = Decoder() |
|
|
|
|
|
def forward(self, x): |
|
|
x, x2, x3, x4, x5 = self.encoder(x) |
|
|
return self.decoder(x5, x4, x3, x2, x) |
|
|
|