Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -17,7 +17,6 @@ class PretrainedUNet(nn.Module):
|
|
| 17 |
super().__init__()
|
| 18 |
|
| 19 |
self.base_model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
|
| 20 |
-
# Modify the first convolution layer to accept 6 channels (2x RGB images) instead of 3
|
| 21 |
self.base_model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 22 |
|
| 23 |
self.encoder1 = nn.Sequential(self.base_model.conv1, self.base_model.bn1, self.base_model.relu)
|
|
@@ -42,18 +41,14 @@ class PretrainedUNet(nn.Module):
|
|
| 42 |
e3 = self.encoder3(e2)
|
| 43 |
e4 = self.encoder4(e3)
|
| 44 |
b = self.bottleneck(e4)
|
| 45 |
-
|
| 46 |
d4 = self.upconv4(b)
|
| 47 |
d4 = torch.cat([d4, e4], dim=1)
|
| 48 |
d4 = self.decoder4(d4)
|
| 49 |
-
|
| 50 |
d3 = self.upconv3(d4)
|
| 51 |
d3 = torch.cat([d3, e3], dim=1)
|
| 52 |
d3 = self.decoder3(d3)
|
| 53 |
-
|
| 54 |
d2 = self.upconv2(d3)
|
| 55 |
d2 = torch.cat([d2, e2], dim=1)
|
| 56 |
d2 = self.decoder2(d2)
|
| 57 |
-
|
| 58 |
d1 = self.final_upconv(d2)
|
| 59 |
return self.final_conv(d1)
|
|
|
|
| 17 |
super().__init__()
|
| 18 |
|
| 19 |
self.base_model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
|
|
|
|
| 20 |
self.base_model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 21 |
|
| 22 |
self.encoder1 = nn.Sequential(self.base_model.conv1, self.base_model.bn1, self.base_model.relu)
|
|
|
|
| 41 |
e3 = self.encoder3(e2)
|
| 42 |
e4 = self.encoder4(e3)
|
| 43 |
b = self.bottleneck(e4)
|
|
|
|
| 44 |
d4 = self.upconv4(b)
|
| 45 |
d4 = torch.cat([d4, e4], dim=1)
|
| 46 |
d4 = self.decoder4(d4)
|
|
|
|
| 47 |
d3 = self.upconv3(d4)
|
| 48 |
d3 = torch.cat([d3, e3], dim=1)
|
| 49 |
d3 = self.decoder3(d3)
|
|
|
|
| 50 |
d2 = self.upconv2(d3)
|
| 51 |
d2 = torch.cat([d2, e2], dim=1)
|
| 52 |
d2 = self.decoder2(d2)
|
|
|
|
| 53 |
d1 = self.final_upconv(d2)
|
| 54 |
return self.final_conv(d1)
|