danicor commited on
Commit
0e96e5f
ยท
verified ยท
1 Parent(s): 58a5967

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +48 -0
model_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class unetConv2(nn.Module):
6
+ def __init__(self, in_size, out_size, is_batchnorm):
7
+ super(unetConv2, self).__init__()
8
+
9
+ if is_batchnorm:
10
+ self.conv1 = nn.Sequential(
11
+ nn.Conv2d(in_size, out_size, 3, 1, 1),
12
+ nn.BatchNorm2d(out_size),
13
+ nn.ReLU(),
14
+ )
15
+ self.conv2 = nn.Sequential(
16
+ nn.Conv2d(out_size, out_size, 3, 1, 1),
17
+ nn.BatchNorm2d(out_size),
18
+ nn.ReLU(),
19
+ )
20
+ else:
21
+ self.conv1 = nn.Sequential(
22
+ nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU()
23
+ )
24
+ self.conv2 = nn.Sequential(
25
+ nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
26
+ )
27
+
28
+ def forward(self, inputs):
29
+ outputs = self.conv1(inputs)
30
+ outputs = self.conv2(outputs)
31
+ return outputs
32
+
33
+
34
+ class unetUp(nn.Module):
35
+ def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
36
+ super(unetUp, self).__init__()
37
+ self.conv = unetConv2(in_size, out_size, is_batchnorm)
38
+ if is_deconv:
39
+ self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
40
+ else:
41
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
42
+
43
+ def forward(self, inputs1, inputs2):
44
+ outputs2 = self.up(inputs2)
45
+ offset = outputs2.size()[2] - inputs1.size()[2]
46
+ padding = 2 * [offset // 2, offset // 2]
47
+ outputs1 = F.pad(inputs1, padding)
48
+ return self.conv(torch.cat([outputs1, outputs2], 1))