| import torch | |
| from torch import nn | |
| from .modules import Conv2dBlock, Concat | |
| class SkipEnrgcoderDecoder(nn.Module): | |
| def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5): | |
| super(SkipEncoderDecoder, self).__init__() | |
| self.model = nn.Sequential() | |
| model_tmp = self.model | |
| for i in range(len(num_channels_down)): | |
| deeper = nn.Sequential() | |
| skip = nn.Sequential() | |
| if num_channels_skip[i] != 0: | |
| model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper)) | |
| else: | |
| model_tmp.add_module(str(len(model_tmp) + 1), deeper) | |
| model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i]))) | |
| if num_channels_skip[i] != 0: | |
| skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False)) | |
| deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False)) | |
| deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False)) | |
| deeper_main = nn.Sequential() | |
| if i == len(num_channels_down) - 1: | |
| k = num_channels_down[i] | |
| else: | |
| deeper.add_module(str(len(deeper) + 1), deeper_main) | |
| k = num_channels_up[i + 1] | |
| deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest')) | |
| model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False)) | |
| model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False)) | |
| input_depth = num_channels_down[i] | |
| model_tmp = deeper_main | |
| self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True)) | |
| self.model.add_module(str(len(self.model) + 1), nn.Sigmoid()) | |
| def forward(self, x): | |
| return self.model(x) | |
| def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10): | |
| shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]] | |
| return torch.rand(*shape) * scale |