danicor commited on
Commit
2a86efa
ยท
verified ยท
1 Parent(s): 15e7622

Create unet.py

Browse files
Files changed (1) hide show
  1. unet.py +65 -0
unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class UNet(nn.Module):
5
+ def __init__(self, n_channels=3, n_classes=19):
6
+ super(UNet, self).__init__()
7
+
8
+ def CBR(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
9
+ return nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.ReLU(inplace=True)
13
+ )
14
+
15
+ # Encoder
16
+ self.enc1 = nn.Sequential(CBR(n_channels, 64), CBR(64, 64))
17
+ self.enc2 = nn.Sequential(nn.MaxPool2d(2), CBR(64, 128), CBR(128, 128))
18
+ self.enc3 = nn.Sequential(nn.MaxPool2d(2), CBR(128, 256), CBR(256, 256))
19
+ self.enc4 = nn.Sequential(nn.MaxPool2d(2), CBR(256, 512), CBR(512, 512))
20
+ self.enc5 = nn.Sequential(nn.MaxPool2d(2), CBR(512, 1024), CBR(1024, 1024))
21
+
22
+ # Decoder
23
+ self.dec4 = nn.Sequential(CBR(1024+512, 512), CBR(512, 512))
24
+ self.dec3 = nn.Sequential(CBR(512+256, 256), CBR(256, 256))
25
+ self.dec2 = nn.Sequential(CBR(256+128, 128), CBR(128, 128))
26
+ self.dec1 = nn.Sequential(CBR(128+64, 64), CBR(64, 64))
27
+
28
+ # Upsampling
29
+ self.up4 = nn.ConvTranspose2d(1024, 512, 2, 2)
30
+ self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
31
+ self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
32
+ self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
33
+
34
+ # Final layer
35
+ self.final = nn.Conv2d(64, n_classes, 1)
36
+
37
+ def forward(self, x):
38
+ # Encoder
39
+ e1 = self.enc1(x)
40
+ e2 = self.enc2(e1)
41
+ e3 = self.enc3(e2)
42
+ e4 = self.enc4(e3)
43
+ e5 = self.enc5(e4)
44
+
45
+ # Decoder with skip connections
46
+ d4 = self.up4(e5)
47
+ d4 = torch.cat([d4, e4], dim=1)
48
+ d4 = self.dec4(d4)
49
+
50
+ d3 = self.up3(d4)
51
+ d3 = torch.cat([d3, e3], dim=1)
52
+ d3 = self.dec3(d3)
53
+
54
+ d2 = self.up2(d3)
55
+ d2 = torch.cat([d2, e2], dim=1)
56
+ d2 = self.dec2(d2)
57
+
58
+ d1 = self.up1(d2)
59
+ d1 = torch.cat([d1, e1], dim=1)
60
+ d1 = self.dec1(d1)
61
+
62
+ return self.final(d1)
63
+
64
+ def unet(**kwargs):
65
+ return UNet(**kwargs)