mustafa2ak commited on
Commit
5f9b7b7
·
verified ·
1 Parent(s): 622fd36

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +59 -0
model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet34, ResNet34_Weights
4
+
5
+ def conv_block(in_channels, out_channels):
6
+ return nn.Sequential(
7
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
8
+ nn.BatchNorm2d(out_channels),
9
+ nn.ReLU(inplace=True),
10
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.ReLU(inplace=True)
13
+ )
14
+
15
+ class PretrainedUNet(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ self.base_model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
20
+ # Modify the first convolution layer to accept 6 channels (2x RGB images) instead of 3
21
+ self.base_model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
22
+
23
+ self.encoder1 = nn.Sequential(self.base_model.conv1, self.base_model.bn1, self.base_model.relu)
24
+ self.encoder2 = self.base_model.layer1
25
+ self.encoder3 = self.base_model.layer2
26
+ self.encoder4 = self.base_model.layer3
27
+ self.bottleneck = self.base_model.layer4
28
+
29
+ self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
30
+ self.decoder4 = conv_block(256 + 256, 256)
31
+ self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
32
+ self.decoder3 = conv_block(128 + 128, 128)
33
+ self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
34
+ self.decoder2 = conv_block(64 + 64, 64)
35
+ self.final_upconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
36
+ self.final_conv = nn.Conv2d(32, 1, kernel_size=1)
37
+
38
+ def forward(self, img1, img2):
39
+ x = torch.cat([img1, img2], dim=1)
40
+ e1 = self.encoder1(x)
41
+ e2 = self.encoder2(e1)
42
+ e3 = self.encoder3(e2)
43
+ e4 = self.encoder4(e3)
44
+ b = self.bottleneck(e4)
45
+
46
+ d4 = self.upconv4(b)
47
+ d4 = torch.cat([d4, e4], dim=1)
48
+ d4 = self.decoder4(d4)
49
+
50
+ d3 = self.upconv3(d4)
51
+ d3 = torch.cat([d3, e3], dim=1)
52
+ d3 = self.decoder3(d3)
53
+
54
+ d2 = self.upconv2(d3)
55
+ d2 = torch.cat([d2, e2], dim=1)
56
+ d2 = self.decoder2(d2)
57
+
58
+ d1 = self.final_upconv(d2)
59
+ return self.final_conv(d1)