Spaces:
Build error
Build error
File size: 1,890 Bytes
43c8d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import torch.nn as nn
import torchvision.models as models
class SegNet(nn.Module):
def __init__(self, num_classes=32):
super(SegNet, self).__init__()
vgg16 = models.vgg16_bn(pretrained=True)
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
self.unpool = nn.MaxUnpool2d(2, 2)
self.enc1 = nn.Sequential(*vgg16.features[:6])
self.enc2 = nn.Sequential(*vgg16.features[7:13])
self.enc3 = nn.Sequential(*vgg16.features[14:23])
self.enc4 = nn.Sequential(*vgg16.features[24:33])
self.dec4 = self.decoder_block(512, 256)
self.dec3 = self.decoder_block(256, 128)
self.dec2 = self.decoder_block(128, 64)
self.dec1 = self.decoder_block(64, 64)
self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
def decoder_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc1(x)
x1p, ind1 = self.pool(x1)
x2 = self.enc2(x1p)
x2p, ind2 = self.pool(x2)
x3 = self.enc3(x2p)
x3p, ind3 = self.pool(x3)
x4 = self.enc4(x3p)
x4p, ind4 = self.pool(x4)
d4 = self.unpool(x4p, ind4, output_size=x4.size())
d4 = self.dec4(d4)
d3 = self.unpool(d4, ind3, output_size=x3.size())
d3 = self.dec3(d3)
d2 = self.unpool(d3, ind2, output_size=x2.size())
d2 = self.dec2(d2)
d1 = self.unpool(d2, ind1, output_size=x1.size())
d1 = self.dec1(d1)
return self.classifier(d1)
|