File size: 2,464 Bytes
7b615ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class ResNetUNet(nn.Module):
def __init__(self, num_classes=2):
super(ResNetUNet, self).__init__()
resnet = models.resnet50(pretrained=True)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
nn.init.kaiming_normal_(resnet.conv1.weight, mode='fan_out', nonlinearity='relu')
self.input_block = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu
)
self.maxpool = resnet.maxpool
self.encoder1 = resnet.layer1 # 64→256
self.encoder2 = resnet.layer2 # 256→512
self.encoder3 = resnet.layer3 # 512→1024
self.bottleneck = resnet.layer4 # 1024→2048
self.up1 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
self.up2 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.up4 = nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2)
self.up5 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv1 = nn.Sequential(
nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
nn.BatchNorm2d(1024),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
x0 = self.input_block(x)
x1 = self.maxpool(x0)
x2 = self.encoder1(x1)
x3 = self.encoder2(x2)
x4 = self.encoder3(x3)
x5 = self.bottleneck(x4)
d1 = F.relu(self.up1(x5) + x4)
d1 = self.conv1(d1)
d2 = F.relu(self.up2(d1) + x3)
d2 = self.conv2(d2)
d3 = F.relu(self.up3(d2) + x2)
d3 = self.conv3(d3)
d4 = F.relu(self.up4(d3) + x0)
d4 = self.conv4(d4)
d5 = self.up5(d4)
out = self.out_conv(d5)
return out
|