Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import resnet34, ResNet34_Weights | |
| def conv_block(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| class PretrainedUNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.base_model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) | |
| self.base_model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| self.encoder1 = nn.Sequential(self.base_model.conv1, self.base_model.bn1, self.base_model.relu) | |
| self.encoder2 = self.base_model.layer1 | |
| self.encoder3 = self.base_model.layer2 | |
| self.encoder4 = self.base_model.layer3 | |
| self.bottleneck = self.base_model.layer4 | |
| self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
| self.decoder4 = conv_block(256 + 256, 256) | |
| self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.decoder3 = conv_block(128 + 128, 128) | |
| self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
| self.decoder2 = conv_block(64 + 64, 64) | |
| self.final_upconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) | |
| self.final_conv = nn.Conv2d(32, 1, kernel_size=1) | |
| def forward(self, img1, img2): | |
| x = torch.cat([img1, img2], dim=1) | |
| e1 = self.encoder1(x) | |
| e2 = self.encoder2(e1) | |
| e3 = self.encoder3(e2) | |
| e4 = self.encoder4(e3) | |
| b = self.bottleneck(e4) | |
| d4 = self.upconv4(b) | |
| d4 = torch.cat([d4, e4], dim=1) | |
| d4 = self.decoder4(d4) | |
| d3 = self.upconv3(d4) | |
| d3 = torch.cat([d3, e3], dim=1) | |
| d3 = self.decoder3(d3) | |
| d2 = self.upconv2(d3) | |
| d2 = torch.cat([d2, e2], dim=1) | |
| d2 = self.decoder2(d2) | |
| d1 = self.final_upconv(d2) | |
| return self.final_conv(d1) |