|
|
|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import models |
|
|
|
|
|
|
|
|
|
|
| class UNet(nn.Module): |
| def __init__(self, in_channels=1, out_channels=1): |
| super().__init__() |
|
|
| def CBR(in_c, out_c): |
| return nn.Sequential( |
| nn.Conv2d(in_c, out_c, 3, padding=1), |
| nn.BatchNorm2d(out_c), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| self.enc1 = nn.Sequential(CBR(in_channels, 64), CBR(64, 64)) |
| self.enc2 = nn.Sequential(CBR(64, 128), CBR(128, 128)) |
| self.enc3 = nn.Sequential(CBR(128, 256), CBR(256, 256)) |
| self.enc4 = nn.Sequential(CBR(256, 512), CBR(512, 512)) |
|
|
| self.pool = nn.MaxPool2d(2, 2) |
| self.center = nn.Sequential(CBR(512, 1024), CBR(1024, 512)) |
|
|
| self.up4 = nn.ConvTranspose2d(512, 512, 2, stride=2) |
| self.dec4 = nn.Sequential(CBR(1024, 512), CBR(512, 256)) |
| self.up3 = nn.ConvTranspose2d(256, 256, 2, stride=2) |
| self.dec3 = nn.Sequential(CBR(512, 256), CBR(256, 128)) |
| self.up2 = nn.ConvTranspose2d(128, 128, 2, stride=2) |
| self.dec2 = nn.Sequential(CBR(256, 128), CBR(128, 64)) |
| self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2) |
| self.dec1 = nn.Sequential(CBR(128, 64), nn.Conv2d(64, out_channels, 1)) |
|
|
| def crop_to_match(self, x, target): |
| _, _, h, w = target.size() |
| return x[:, :, :h, :w] |
|
|
| def forward(self, x): |
| e1 = self.enc1(x) |
| e2 = self.enc2(self.pool(e1)) |
| e3 = self.enc3(self.pool(e2)) |
| e4 = self.enc4(self.pool(e3)) |
| c = self.center(self.pool(e4)) |
|
|
| d4 = self.up4(c) |
| e4_c = self.crop_to_match(e4, d4) |
| d4 = torch.cat([d4, e4_c], dim=1) |
| d4 = self.dec4(d4) |
|
|
| d3 = self.up3(d4) |
| e3_c = self.crop_to_match(e3, d3) |
| d3 = torch.cat([d3, e3_c], dim=1) |
| d3 = self.dec3(d3) |
|
|
| d2 = self.up2(d3) |
| e2_c = self.crop_to_match(e2, d2) |
| d2 = torch.cat([d2, e2_c], dim=1) |
| d2 = self.dec2(d2) |
|
|
| d1 = self.up1(d2) |
| e1_c = self.crop_to_match(e1, d1) |
| d1 = torch.cat([d1, e1_c], dim=1) |
| out = self.dec1(d1) |
| return out |
| |
|
|
|
|
| |
| |
| |
| def get_deeplabv3(in_channels=1, out_channels=1): |
| |
| model = models.segmentation.deeplabv3_resnet50( |
| weights=None, |
| weights_backbone=None, |
| num_classes=out_channels |
| ) |
|
|
| |
| model.backbone.conv1 = nn.Conv2d( |
| in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False |
| ) |
|
|
| |
| |
| if hasattr(model, 'aux_classifier') and model.aux_classifier is not None: |
| |
| try: |
| |
| in_channels_aux = model.aux_classifier[4].in_channels |
| model.aux_classifier[4] = nn.Conv2d(in_channels_aux, out_channels, 1) |
| except (TypeError, IndexError): |
| |
| pass |
|
|
| |
| return model |