| import torch
|
| import torch.nn as nn
|
| from torch.nn import init
|
| import torch.nn.functional as F
|
| from torch.optim import lr_scheduler
|
|
|
| from rscd.models.backbones import resnet_bit
|
|
|
| class BIT_Backbone(torch.nn.Module):
|
| def __init__(self, input_nc, output_nc,
|
| resnet_stages_num=5, backbone='resnet18',
|
| if_upsample_2x=True):
|
| """
|
| In the constructor we instantiate two nn.Linear modules and assign them as
|
| member variables.
|
| """
|
| super(BIT_Backbone, self).__init__()
|
| expand = 1
|
| if backbone == 'resnet18':
|
| self.resnet = resnet_bit.resnet18(pretrained=True,
|
| replace_stride_with_dilation=[False,True,True])
|
| elif backbone == 'resnet34':
|
| self.resnet = resnet_bit.resnet34(pretrained=True,
|
| replace_stride_with_dilation=[False,True,True])
|
| elif backbone == 'resnet50':
|
| self.resnet = resnet_bit.resnet50(pretrained=True,
|
| replace_stride_with_dilation=[False,True,True])
|
| expand = 4
|
| else:
|
| raise NotImplementedError
|
|
|
| self.upsamplex2 = nn.Upsample(scale_factor=2)
|
|
|
| self.resnet_stages_num = resnet_stages_num
|
|
|
| self.if_upsample_2x = if_upsample_2x
|
| if self.resnet_stages_num == 5:
|
| layers = 512 * expand
|
| elif self.resnet_stages_num == 4:
|
| layers = 256 * expand
|
| elif self.resnet_stages_num == 3:
|
| layers = 128 * expand
|
| else:
|
| raise NotImplementedError
|
| self.conv_pred = nn.Conv2d(layers, output_nc, kernel_size=3, padding=1)
|
|
|
| def forward(self, x1, x2):
|
|
|
| x1 = self.forward_single(x1)
|
| x2 = self.forward_single(x2)
|
| return [x1, x2]
|
|
|
| def forward_single(self, x):
|
|
|
| x = self.resnet.conv1(x)
|
| x = self.resnet.bn1(x)
|
| x = self.resnet.relu(x)
|
| x = self.resnet.maxpool(x)
|
|
|
| x_4 = self.resnet.layer1(x)
|
| x_8 = self.resnet.layer2(x_4)
|
|
|
| if self.resnet_stages_num > 3:
|
| x_8 = self.resnet.layer3(x_8)
|
|
|
| if self.resnet_stages_num == 5:
|
| x_8 = self.resnet.layer4(x_8)
|
| elif self.resnet_stages_num > 5:
|
| raise NotImplementedError
|
|
|
| if self.if_upsample_2x:
|
| x = self.upsamplex2(x_8)
|
| else:
|
| x = x_8
|
|
|
| x = self.conv_pred(x)
|
| return x
|
|
|
| def BIT_backbone_func(cfg):
|
| net = BIT_Backbone(input_nc=cfg.input_nc,
|
| output_nc=cfg.output_nc,
|
| resnet_stages_num=cfg.resnet_stages_num,
|
| backbone=cfg.backbone,
|
| if_upsample_2x=cfg.if_upsample_2x)
|
| return net
|
|
|
| if __name__ == '__main__':
|
| x1 = torch.rand(4, 3, 512, 512)
|
| x2 = torch.rand(4, 3, 512, 512)
|
| cfg = dict(
|
| type = 'BIT_Backbone',
|
| input_nc=3,
|
| output_nc=32,
|
| resnet_stages_num=4,
|
| backbone='resnet18',
|
| if_upsample_2x=True,
|
| )
|
| from munch import DefaultMunch
|
| cfg = DefaultMunch.fromDict(cfg)
|
| model = BIT_backbone_func(cfg)
|
| model.eval()
|
| print(model)
|
| outs = model(x1, x2)
|
| print('BIT', outs)
|
| for out in outs:
|
| print(out.shape) |