rawanessam's picture
Update pytorch/models/model.py
d2c2839 verified
from .drn import drn_d_54
from torch import nn
from .modules import *
class Model(nn.Module):
def __init__(self, options):
super(Model, self).__init__()
self.options = options
self.drn = drn_d_54(pretrained=True, out_map=32, num_classes=-1, out_middle=False)
self.pyramid = PyramidModule(options, 512, 128)
self.feature_conv = ConvBlock(1024, 512)
self.segmentation_pred = nn.Conv2d(512, NUM_CORNERS + NUM_ICONS + 2 + NUM_ROOMS + 2, kernel_size=1)
self.upsample = torch.nn.Upsample(size=(options.height, options.width), mode='bilinear')
return
def forward(self, inp):
features = self.drn(inp)
features = self.pyramid(features)
features = self.feature_conv(features)
segmentation = self.upsample(self.segmentation_pred(features))
segmentation = segmentation.transpose(1, 2).transpose(2, 3).contiguous()
return torch.sigmoid(segmentation[:, :, :, :NUM_CORNERS]), segmentation[:, :, :, NUM_CORNERS:NUM_CORNERS + NUM_ICONS + 2], segmentation[:, :, :, -(NUM_ROOMS + 2):]