danicor commited on
Commit
58a5967
ยท
verified ยท
1 Parent(s): bc5c2cc

Update unet.py

Browse files
Files changed (1) hide show
  1. unet.py +65 -63
unet.py CHANGED
@@ -1,65 +1,67 @@
1
- import torch
2
  import torch.nn as nn
 
3
 
4
- class UNet(nn.Module):
5
- def __init__(self, n_channels=3, n_classes=19):
6
- super(UNet, self).__init__()
7
-
8
- def CBR(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
9
- return nn.Sequential(
10
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
11
- nn.BatchNorm2d(out_channels),
12
- nn.ReLU(inplace=True)
13
- )
14
-
15
- # Encoder
16
- self.enc1 = nn.Sequential(CBR(n_channels, 64), CBR(64, 64))
17
- self.enc2 = nn.Sequential(nn.MaxPool2d(2), CBR(64, 128), CBR(128, 128))
18
- self.enc3 = nn.Sequential(nn.MaxPool2d(2), CBR(128, 256), CBR(256, 256))
19
- self.enc4 = nn.Sequential(nn.MaxPool2d(2), CBR(256, 512), CBR(512, 512))
20
- self.enc5 = nn.Sequential(nn.MaxPool2d(2), CBR(512, 1024), CBR(1024, 1024))
21
-
22
- # Decoder
23
- self.dec4 = nn.Sequential(CBR(1024+512, 512), CBR(512, 512))
24
- self.dec3 = nn.Sequential(CBR(512+256, 256), CBR(256, 256))
25
- self.dec2 = nn.Sequential(CBR(256+128, 128), CBR(128, 128))
26
- self.dec1 = nn.Sequential(CBR(128+64, 64), CBR(64, 64))
27
-
28
- # Upsampling
29
- self.up4 = nn.ConvTranspose2d(1024, 512, 2, 2)
30
- self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
31
- self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
32
- self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
33
-
34
- # Final layer
35
- self.final = nn.Conv2d(64, n_classes, 1)
36
-
37
- def forward(self, x):
38
- # Encoder
39
- e1 = self.enc1(x)
40
- e2 = self.enc2(e1)
41
- e3 = self.enc3(e2)
42
- e4 = self.enc4(e3)
43
- e5 = self.enc5(e4)
44
-
45
- # Decoder with skip connections
46
- d4 = self.up4(e5)
47
- d4 = torch.cat([d4, e4], dim=1)
48
- d4 = self.dec4(d4)
49
-
50
- d3 = self.up3(d4)
51
- d3 = torch.cat([d3, e3], dim=1)
52
- d3 = self.dec3(d3)
53
-
54
- d2 = self.up2(d3)
55
- d2 = torch.cat([d2, e2], dim=1)
56
- d2 = self.dec2(d2)
57
-
58
- d1 = self.up1(d2)
59
- d1 = torch.cat([d1, e1], dim=1)
60
- d1 = self.dec1(d1)
61
-
62
- return self.final(d1)
63
-
64
- def unet(**kwargs):
65
- return UNet(**kwargs)
 
 
 
 
1
  import torch.nn as nn
2
+ from model_utils import *
3
 
4
+ class unet(nn.Module):
5
+ def __init__(
6
+ self,
7
+ feature_scale=4,
8
+ n_classes=19,
9
+ is_deconv=True,
10
+ in_channels=3,
11
+ is_batchnorm=True,
12
+ ):
13
+ super(unet, self).__init__()
14
+ self.is_deconv = is_deconv
15
+ self.in_channels = in_channels
16
+ self.is_batchnorm = is_batchnorm
17
+ self.feature_scale = feature_scale
18
+
19
+ filters = [64, 128, 256, 512, 1024]
20
+ filters = [int(x / self.feature_scale) for x in filters]
21
+
22
+ # downsampling
23
+ self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
24
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2)
25
+
26
+ self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
27
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2)
28
+
29
+ self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
30
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2)
31
+
32
+ self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
33
+ self.maxpool4 = nn.MaxPool2d(kernel_size=2)
34
+
35
+ self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
36
+
37
+ # upsampling
38
+ self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv, self.is_batchnorm)
39
+ self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv, self.is_batchnorm)
40
+ self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv, self.is_batchnorm)
41
+ self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv, self.is_batchnorm)
42
+
43
+ # final conv (without any concat)
44
+ self.final = nn.Conv2d(filters[0], n_classes, 1)
45
+
46
+ def forward(self, inputs):
47
+ conv1 = self.conv1(inputs)
48
+ maxpool1 = self.maxpool1(conv1)
49
+
50
+ conv2 = self.conv2(maxpool1)
51
+ maxpool2 = self.maxpool2(conv2)
52
+
53
+ conv3 = self.conv3(maxpool2)
54
+ maxpool3 = self.maxpool3(conv3)
55
+
56
+ conv4 = self.conv4(maxpool3)
57
+ maxpool4 = self.maxpool4(conv4)
58
+
59
+ center = self.center(maxpool4)
60
+ up4 = self.up_concat4(conv4, center)
61
+ up3 = self.up_concat3(conv3, up4)
62
+ up2 = self.up_concat2(conv2, up3)
63
+ up1 = self.up_concat1(conv1, up2)
64
+
65
+ final = self.final(up1)
66
+
67
+ return final