import torch import torch.nn as nn import torchvision.models as models class SegNet(nn.Module): def __init__(self, num_classes=32): super(SegNet, self).__init__() vgg16 = models.vgg16_bn(pretrained=True) self.pool = nn.MaxPool2d(2, 2, return_indices=True) self.unpool = nn.MaxUnpool2d(2, 2) self.enc1 = nn.Sequential(*vgg16.features[:6]) self.enc2 = nn.Sequential(*vgg16.features[7:13]) self.enc3 = nn.Sequential(*vgg16.features[14:23]) self.enc4 = nn.Sequential(*vgg16.features[24:33]) self.dec4 = self.decoder_block(512, 256) self.dec3 = self.decoder_block(256, 128) self.dec2 = self.decoder_block(128, 64) self.dec1 = self.decoder_block(64, 64) self.classifier = nn.Conv2d(64, num_classes, kernel_size=1) def decoder_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x1 = self.enc1(x) x1p, ind1 = self.pool(x1) x2 = self.enc2(x1p) x2p, ind2 = self.pool(x2) x3 = self.enc3(x2p) x3p, ind3 = self.pool(x3) x4 = self.enc4(x3p) x4p, ind4 = self.pool(x4) d4 = self.unpool(x4p, ind4, output_size=x4.size()) d4 = self.dec4(d4) d3 = self.unpool(d4, ind3, output_size=x3.size()) d3 = self.dec3(d3) d2 = self.unpool(d3, ind2, output_size=x2.size()) d2 = self.dec2(d2) d1 = self.unpool(d2, ind1, output_size=x1.size()) d1 = self.dec1(d1) return self.classifier(d1)