faranbutt789 commited on
Commit
f729e4d
·
verified ·
1 Parent(s): 1646f5a

Update unet_model.py

Browse files
Files changed (1) hide show
  1. unet_model.py +77 -0
unet_model.py CHANGED
@@ -1,6 +1,83 @@
1
  import torch
2
  import torch.nn as nn
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class UNet(nn.Module):
5
  def __init__(self):
6
  super(UNet, self).__init__()
 
1
  import torch
2
  import torch.nn as nn
3
 
4
+ class ImprovedUNet(nn.Module):
5
+ def __init__(self, dropout_rate=0.2):
6
+ super(ImprovedUNet, self).__init__()
7
+
8
+ def conv_block(in_channels, out_channels, dropout=False):
9
+ layers = [
10
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
14
+ nn.BatchNorm2d(out_channels),
15
+ nn.ReLU(inplace=True),
16
+ ]
17
+ if dropout:
18
+ layers.append(nn.Dropout2d(dropout_rate))
19
+ return nn.Sequential(*layers)
20
+
21
+ # Encoder
22
+ self.enc1 = conv_block(3, 64)
23
+ self.enc2 = conv_block(64, 128)
24
+ self.enc3 = conv_block(128, 256, dropout=True)
25
+ self.enc4 = conv_block(256, 512, dropout=True)
26
+
27
+ self.pool = nn.MaxPool2d(2)
28
+
29
+ # Bottleneck
30
+ self.bottleneck = conv_block(512, 1024, dropout=True)
31
+
32
+ # Decoder
33
+ self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
34
+ self.dec4 = conv_block(1024, 512, dropout=True)
35
+
36
+ self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
37
+ self.dec3 = conv_block(512, 256, dropout=True)
38
+
39
+ self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
40
+ self.dec2 = conv_block(256, 128)
41
+
42
+ self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
43
+ self.dec1 = conv_block(128, 64)
44
+
45
+ self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
46
+
47
+ def forward(self, x):
48
+ c1 = self.enc1(x)
49
+ p1 = self.pool(c1)
50
+
51
+ c2 = self.enc2(p1)
52
+ p2 = self.pool(c2)
53
+
54
+ c3 = self.enc3(p2)
55
+ p3 = self.pool(c3)
56
+
57
+ c4 = self.enc4(p3)
58
+ p4 = self.pool(c4)
59
+
60
+ bottleneck = self.bottleneck(p4)
61
+
62
+ u4 = self.upconv4(bottleneck)
63
+ u4 = torch.cat([u4, c4], dim=1)
64
+ d4 = self.dec4(u4)
65
+
66
+ u3 = self.upconv3(d4)
67
+ u3 = torch.cat([u3, c3], dim=1)
68
+ d3 = self.dec3(u3)
69
+
70
+ u2 = self.upconv2(d3)
71
+ u2 = torch.cat([u2, c2], dim=1)
72
+ d2 = self.dec2(u2)
73
+
74
+ u1 = self.upconv1(d2)
75
+ u1 = torch.cat([u1, c1], dim=1)
76
+ d1 = self.dec1(u1)
77
+
78
+ return torch.sigmoid(self.conv_last(d1))
79
+
80
+ # For backward compatibility, keep the original UNet class as well
81
  class UNet(nn.Module):
82
  def __init__(self):
83
  super(UNet, self).__init__()