keysun89 commited on
Commit
d6ba012
·
verified ·
1 Parent(s): d84a284

Create residual_unet.py

Browse files
Files changed (1) hide show
  1. residual_unet.py +117 -0
residual_unet.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin # <-- 1. IMPORT THIS
4
+
5
+ class ResidualConvBlock(nn.Module):
6
+ """
7
+ A residual convolutional block consisting of two convolutional layers,
8
+ batch normalization, ReLU activation, and a shortcut connection.
9
+ """
10
+ def __init__(self, in_channels, out_channels):
11
+ super(ResidualConvBlock, self).__init__()
12
+ # First convolutional layer
13
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
14
+ self.bn1 = nn.BatchNorm2d(out_channels)
15
+ self.relu = nn.ReLU(inplace=True)
16
+ # Second convolutional layer
17
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
18
+ self.bn2 = nn.BatchNorm2d(out_channels)
19
+ # Shortcut connection to match dimensions if in_channels != out_channels
20
+ self.shortcut = nn.Sequential()
21
+ if in_channels != out_channels:
22
+ self.shortcut = nn.Sequential(
23
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
24
+ nn.BatchNorm2d(out_channels)
25
+ )
26
+ def forward(self, x):
27
+ # The output of the conv layers is added to the original input (shortcut)
28
+ residual = self.shortcut(x)
29
+ out = self.relu(self.bn1(self.conv1(x)))
30
+ out = self.bn2(self.conv2(out))
31
+ out += residual
32
+ return self.relu(out)
33
+
34
+ # ↓
35
+ # ↓ 2. ADD THE MIXIN HERE
36
+ # ↓
37
+ class ResidualUNet(nn.Module, PyTorchModelHubMixin):
38
+ """
39
+ Residual U-Net architecture for semantic segmentation.
40
+ The network consists of a contracting path (encoder) and an expansive path (decoder).
41
+ Input size is assumed to be (448, 448, 3).
42
+ """
43
+ def __init__(self, in_channels=3, out_channels=1):
44
+ super(ResidualUNet, self).__init__()
45
+
46
+ # Save arguments to config.json
47
+ self.in_channels = in_channels
48
+ self.out_channels = out_channels
49
+
50
+ # =====================================
51
+ # Encoder (Contracting Path) - 5 levels
52
+ # =====================================
53
+ self.encoder1 = ResidualConvBlock(in_channels, 64)
54
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
55
+ self.encoder2 = ResidualConvBlock(64, 128)
56
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
57
+ self.encoder3 = ResidualConvBlock(128, 256)
58
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
59
+ self.encoder4 = ResidualConvBlock(256, 512)
60
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
61
+ self.encoder5 = ResidualConvBlock(512, 1024)
62
+ self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
63
+ # =====================================
64
+ # Bottleneck
65
+ # =====================================
66
+ self.bottleneck = ResidualConvBlock(1024, 2048)
67
+ # =====================================
68
+ # Decoder (Expansive Path) - 5 levels
69
+ # =====================================
70
+ self.upconv5 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
71
+ self.decoder5 = ResidualConvBlock(1024 + 1024, 1024) # Concatenating skip connection from encoder5
72
+ self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
73
+ self.decoder4 = ResidualConvBlock(512 + 512, 512) # Concatenating skip connection from encoder4
74
+ self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
75
+ self.decoder3 = ResidualConvBlock(256 + 256, 256) # Concatenating skip connection from encoder3
76
+ self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
77
+ self.decoder2 = ResidualConvBlock(128 + 128, 128) # Concatenating skip connection from encoder2
78
+ self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
79
+ self.decoder1 = ResidualConvBlock(64 + 64, 64) # Concatenating skip connection from encoder1
80
+ # =====================================
81
+ # Output Layer
82
+ # =====================================
83
+ self.outconv = nn.Conv2d(64, out_channels, kernel_size=1)
84
+
85
+ def forward(self, x):
86
+ # Encoder path
87
+ skip1 = self.encoder1(x)
88
+ p1 = self.pool1(skip1)
89
+ skip2 = self.encoder2(p1)
90
+ p2 = self.pool2(skip2)
91
+ skip3 = self.encoder3(p2)
92
+ p3 = self.pool3(skip3)
93
+ skip4 = self.encoder4(p3)
94
+ p4 = self.pool4(skip4)
95
+ skip5 = self.encoder5(p4)
96
+ p5 = self.pool5(skip5)
97
+ # Bottleneck
98
+ b = self.bottleneck(p5)
99
+ # Decoder path with skip connections
100
+ d5 = self.upconv5(b)
101
+ d5 = torch.cat((skip5, d5), dim=1)
102
+ d5 = self.decoder5(d5)
103
+ d4 = self.upconv4(d5)
104
+ d4 = torch.cat((skip4, d4), dim=1)
105
+ d4 = self.decoder4(d4)
106
+ d3 = self.upconv3(d4)
107
+ d3 = torch.cat((skip3, d3), dim=1)
108
+ d3 = self.decoder3(d3)
109
+ d2 = self.upconv2(d3)
110
+ d2 = torch.cat((skip2, d2), dim=1)
111
+ d2 = self.decoder2(d2)
112
+ d1 = self.upconv1(d2)
113
+ d1 = torch.cat((skip1, d1), dim=1)
114
+ d1 = self.decoder1(d1)
115
+ # Final output
116
+ outputs = self.outconv(d1)
117
+ return outputs