InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.optim import lr_scheduler
from rscd.models.backbones import resnet_bit
class BIT_Backbone(torch.nn.Module):
def __init__(self, input_nc, output_nc,
resnet_stages_num=5, backbone='resnet18',
if_upsample_2x=True):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(BIT_Backbone, self).__init__()
expand = 1
if backbone == 'resnet18':
self.resnet = resnet_bit.resnet18(pretrained=True,
replace_stride_with_dilation=[False,True,True])
elif backbone == 'resnet34':
self.resnet = resnet_bit.resnet34(pretrained=True,
replace_stride_with_dilation=[False,True,True])
elif backbone == 'resnet50':
self.resnet = resnet_bit.resnet50(pretrained=True,
replace_stride_with_dilation=[False,True,True])
expand = 4
else:
raise NotImplementedError
self.upsamplex2 = nn.Upsample(scale_factor=2)
self.resnet_stages_num = resnet_stages_num
self.if_upsample_2x = if_upsample_2x
if self.resnet_stages_num == 5:
layers = 512 * expand
elif self.resnet_stages_num == 4:
layers = 256 * expand
elif self.resnet_stages_num == 3:
layers = 128 * expand
else:
raise NotImplementedError
self.conv_pred = nn.Conv2d(layers, output_nc, kernel_size=3, padding=1)
def forward(self, x1, x2):
# forward backbone resnet
x1 = self.forward_single(x1)
x2 = self.forward_single(x2)
return [x1, x2]
def forward_single(self, x):
# resnet layers
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64
x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128
if self.resnet_stages_num > 3:
x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256
if self.resnet_stages_num == 5:
x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512
elif self.resnet_stages_num > 5:
raise NotImplementedError
if self.if_upsample_2x:
x = self.upsamplex2(x_8)
else:
x = x_8
# output layers
x = self.conv_pred(x)
return x
def BIT_backbone_func(cfg):
net = BIT_Backbone(input_nc=cfg.input_nc,
output_nc=cfg.output_nc,
resnet_stages_num=cfg.resnet_stages_num,
backbone=cfg.backbone,
if_upsample_2x=cfg.if_upsample_2x)
return net
if __name__ == '__main__':
x1 = torch.rand(4, 3, 512, 512)
x2 = torch.rand(4, 3, 512, 512)
cfg = dict(
type = 'BIT_Backbone',
input_nc=3,
output_nc=32,
resnet_stages_num=4,
backbone='resnet18',
if_upsample_2x=True,
)
from munch import DefaultMunch
cfg = DefaultMunch.fromDict(cfg)
model = BIT_backbone_func(cfg)
model.eval()
print(model)
outs = model(x1, x2)
print('BIT', outs)
for out in outs:
print(out.shape)