| from .drn import drn_d_54 | |
| from torch import nn | |
| from .modules import * | |
| class Model(nn.Module): | |
| def __init__(self, options): | |
| super(Model, self).__init__() | |
| self.options = options | |
| self.drn = drn_d_54(pretrained=True, out_map=32, num_classes=-1, out_middle=False) | |
| self.pyramid = PyramidModule(options, 512, 128) | |
| self.feature_conv = ConvBlock(1024, 512) | |
| self.segmentation_pred = nn.Conv2d(512, NUM_CORNERS + NUM_ICONS + 2 + NUM_ROOMS + 2, kernel_size=1) | |
| self.upsample = torch.nn.Upsample(size=(options.height, options.width), mode='bilinear') | |
| return | |
| def forward(self, inp): | |
| features = self.drn(inp) | |
| features = self.pyramid(features) | |
| features = self.feature_conv(features) | |
| segmentation = self.upsample(self.segmentation_pred(features)) | |
| segmentation = segmentation.transpose(1, 2).transpose(2, 3).contiguous() | |
| return torch.sigmoid(segmentation[:, :, :, :NUM_CORNERS]), segmentation[:, :, :, NUM_CORNERS:NUM_CORNERS + NUM_ICONS + 2], segmentation[:, :, :, -(NUM_ROOMS + 2):] | |