Commit
·
269009c
1
Parent(s):
d984345
Upload 2 files
Browse files- generator.py +55 -0
- modules.py +63 -0
generator.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .modules import Conv2dBlock, Concat
|
| 4 |
+
|
| 5 |
+
class SkipEncoderDecoder(nn.Module):
|
| 6 |
+
def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5):
|
| 7 |
+
super(SkipEncoderDecoder, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.model = nn.Sequential()
|
| 10 |
+
model_tmp = self.model
|
| 11 |
+
|
| 12 |
+
for i in range(len(num_channels_down)):
|
| 13 |
+
|
| 14 |
+
deeper = nn.Sequential()
|
| 15 |
+
skip = nn.Sequential()
|
| 16 |
+
|
| 17 |
+
if num_channels_skip[i] != 0:
|
| 18 |
+
model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper))
|
| 19 |
+
else:
|
| 20 |
+
model_tmp.add_module(str(len(model_tmp) + 1), deeper)
|
| 21 |
+
|
| 22 |
+
model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i])))
|
| 23 |
+
|
| 24 |
+
if num_channels_skip[i] != 0:
|
| 25 |
+
skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False))
|
| 26 |
+
|
| 27 |
+
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False))
|
| 28 |
+
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False))
|
| 29 |
+
|
| 30 |
+
deeper_main = nn.Sequential()
|
| 31 |
+
|
| 32 |
+
if i == len(num_channels_down) - 1:
|
| 33 |
+
k = num_channels_down[i]
|
| 34 |
+
else:
|
| 35 |
+
deeper.add_module(str(len(deeper) + 1), deeper_main)
|
| 36 |
+
k = num_channels_up[i + 1]
|
| 37 |
+
|
| 38 |
+
deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest'))
|
| 39 |
+
|
| 40 |
+
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False))
|
| 41 |
+
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False))
|
| 42 |
+
|
| 43 |
+
input_depth = num_channels_down[i]
|
| 44 |
+
model_tmp = deeper_main
|
| 45 |
+
|
| 46 |
+
self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True))
|
| 47 |
+
self.model.add_module(str(len(self.model) + 1), nn.Sigmoid())
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return self.model(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10):
|
| 54 |
+
shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]]
|
| 55 |
+
return torch.rand(*shape) * scale
|
modules.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
class DepthwiseSeperableConv2d(nn.Module):
|
| 6 |
+
def __init__(self, input_channels, output_channels, **kwargs):
|
| 7 |
+
super(DepthwiseSeperableConv2d, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)
|
| 10 |
+
self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
x = self.depthwise(x)
|
| 14 |
+
x = self.pointwise(x)
|
| 15 |
+
|
| 16 |
+
return x
|
| 17 |
+
|
| 18 |
+
class Conv2dBlock(nn.Module):
|
| 19 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):
|
| 20 |
+
super(Conv2dBlock, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.model = nn.Sequential(
|
| 23 |
+
nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
|
| 24 |
+
DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),
|
| 25 |
+
nn.BatchNorm2d(out_channels),
|
| 26 |
+
nn.LeakyReLU(0.2)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
return self.model(x)
|
| 31 |
+
|
| 32 |
+
class Concat(nn.Module):
|
| 33 |
+
def __init__(self, dim, *args):
|
| 34 |
+
super(Concat, self).__init__()
|
| 35 |
+
self.dim = dim
|
| 36 |
+
|
| 37 |
+
for idx, module in enumerate(args):
|
| 38 |
+
self.add_module(str(idx), module)
|
| 39 |
+
|
| 40 |
+
def forward(self, input):
|
| 41 |
+
inputs = []
|
| 42 |
+
for module in self._modules.values():
|
| 43 |
+
inputs.append(module(input))
|
| 44 |
+
|
| 45 |
+
inputs_shapes2 = [x.shape[2] for x in inputs]
|
| 46 |
+
inputs_shapes3 = [x.shape[3] for x in inputs]
|
| 47 |
+
|
| 48 |
+
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
|
| 49 |
+
inputs_ = inputs
|
| 50 |
+
else:
|
| 51 |
+
target_shape2 = min(inputs_shapes2)
|
| 52 |
+
target_shape3 = min(inputs_shapes3)
|
| 53 |
+
|
| 54 |
+
inputs_ = []
|
| 55 |
+
for inp in inputs:
|
| 56 |
+
diff2 = (inp.size(2) - target_shape2) // 2
|
| 57 |
+
diff3 = (inp.size(3) - target_shape3) // 2
|
| 58 |
+
inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
|
| 59 |
+
|
| 60 |
+
return torch.cat(inputs_, dim=self.dim)
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return len(self._modules)
|