| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision import models |
| |
|
| |
|
| | class DoubleConv(nn.Module): |
| | def __init__(self, in_channels, out_channels, dropout=0.0): |
| | super(DoubleConv, self).__init__() |
| | layers = [ |
| | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(inplace=True), |
| | ] |
| | if dropout > 0: |
| | layers.append(nn.Dropout2d(p=dropout)) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | def crop_to_match(enc_feat, dec_feat): |
| | """Center-crop encoder feature map to match size of decoder feature map.""" |
| | _, _, H, W = dec_feat.shape |
| | enc_H, enc_W = enc_feat.shape[2], enc_feat.shape[3] |
| |
|
| | crop_top = (enc_H - H) // 2 |
| | crop_left = (enc_W - W) // 2 |
| | return enc_feat[:, :, crop_top:crop_top+H, crop_left:crop_left+W] |
| |
|
| |
|
| | class UNet(nn.Module): |
| | def __init__(self, in_channels=1, out_channels=1, dropout=0.1): |
| | super(UNet, self).__init__() |
| |
|
| | |
| | self.enc1 = DoubleConv(in_channels, 64, dropout=dropout) |
| | self.enc2 = DoubleConv(64, 128, dropout=dropout) |
| | self.enc3 = DoubleConv(128, 256, dropout=dropout) |
| | self.enc4 = DoubleConv(256, 512, dropout=dropout) |
| |
|
| | self.pool = nn.MaxPool2d(2) |
| |
|
| | self.bottleneck = DoubleConv(512, 1024, dropout=dropout) |
| |
|
| | |
| | self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) |
| | self.dec4 = DoubleConv(1024, 512, dropout=dropout) |
| | self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) |
| | self.dec3 = DoubleConv(512, 256, dropout=dropout) |
| | self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) |
| | self.dec2 = DoubleConv(256, 128, dropout=dropout) |
| | self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) |
| | self.dec1 = DoubleConv(128, 64, dropout=dropout) |
| |
|
| | self.final = nn.Conv2d(64, out_channels, kernel_size=1) |
| |
|
| | def forward(self, x): |
| | input_size = x.shape[2:] |
| |
|
| | |
| | e1 = self.enc1(x) |
| | e2 = self.enc2(self.pool(e1)) |
| | e3 = self.enc3(self.pool(e2)) |
| | e4 = self.enc4(self.pool(e3)) |
| |
|
| | |
| | b = self.bottleneck(self.pool(e4)) |
| |
|
| | |
| | d4 = self.up4(b) |
| | e4_cropped = crop_to_match(e4, d4) |
| | d4 = self.dec4(torch.cat([d4, e4_cropped], dim=1)) |
| |
|
| | d3 = self.up3(d4) |
| | e3_cropped = crop_to_match(e3, d3) |
| | d3 = self.dec3(torch.cat([d3, e3_cropped], dim=1)) |
| |
|
| | d2 = self.up2(d3) |
| | e2_cropped = crop_to_match(e2, d2) |
| | d2 = self.dec2(torch.cat([d2, e2_cropped], dim=1)) |
| |
|
| | d1 = self.up1(d2) |
| | e1_cropped = crop_to_match(e1, d1) |
| | d1 = self.dec1(torch.cat([d1, e1_cropped], dim=1)) |
| |
|
| | out = self.final(d1) |
| |
|
| | |
| | out = F.interpolate(out, size=input_size, mode="bilinear", align_corners=False) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|