|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.nn import init
|
|
|
import torch.nn.functional as F
|
|
|
from torch.optim import lr_scheduler
|
|
|
|
|
|
from rscd.models.backbones import resnet_bit
|
|
|
|
|
|
class BIT_Backbone(torch.nn.Module):
|
|
|
def __init__(self, input_nc, output_nc,
|
|
|
resnet_stages_num=5, backbone='resnet18',
|
|
|
if_upsample_2x=True):
|
|
|
"""
|
|
|
In the constructor we instantiate two nn.Linear modules and assign them as
|
|
|
member variables.
|
|
|
"""
|
|
|
super(BIT_Backbone, self).__init__()
|
|
|
expand = 1
|
|
|
if backbone == 'resnet18':
|
|
|
self.resnet = resnet_bit.resnet18(pretrained=True,
|
|
|
replace_stride_with_dilation=[False,True,True])
|
|
|
elif backbone == 'resnet34':
|
|
|
self.resnet = resnet_bit.resnet34(pretrained=True,
|
|
|
replace_stride_with_dilation=[False,True,True])
|
|
|
elif backbone == 'resnet50':
|
|
|
self.resnet = resnet_bit.resnet50(pretrained=True,
|
|
|
replace_stride_with_dilation=[False,True,True])
|
|
|
expand = 4
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
self.upsamplex2 = nn.Upsample(scale_factor=2)
|
|
|
|
|
|
self.resnet_stages_num = resnet_stages_num
|
|
|
|
|
|
self.if_upsample_2x = if_upsample_2x
|
|
|
if self.resnet_stages_num == 5:
|
|
|
layers = 512 * expand
|
|
|
elif self.resnet_stages_num == 4:
|
|
|
layers = 256 * expand
|
|
|
elif self.resnet_stages_num == 3:
|
|
|
layers = 128 * expand
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
self.conv_pred = nn.Conv2d(layers, output_nc, kernel_size=3, padding=1)
|
|
|
|
|
|
def forward(self, x1, x2):
|
|
|
|
|
|
x1 = self.forward_single(x1)
|
|
|
x2 = self.forward_single(x2)
|
|
|
return [x1, x2]
|
|
|
|
|
|
def forward_single(self, x):
|
|
|
|
|
|
x = self.resnet.conv1(x)
|
|
|
x = self.resnet.bn1(x)
|
|
|
x = self.resnet.relu(x)
|
|
|
x = self.resnet.maxpool(x)
|
|
|
|
|
|
x_4 = self.resnet.layer1(x)
|
|
|
x_8 = self.resnet.layer2(x_4)
|
|
|
|
|
|
if self.resnet_stages_num > 3:
|
|
|
x_8 = self.resnet.layer3(x_8)
|
|
|
|
|
|
if self.resnet_stages_num == 5:
|
|
|
x_8 = self.resnet.layer4(x_8)
|
|
|
elif self.resnet_stages_num > 5:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
if self.if_upsample_2x:
|
|
|
x = self.upsamplex2(x_8)
|
|
|
else:
|
|
|
x = x_8
|
|
|
|
|
|
x = self.conv_pred(x)
|
|
|
return x
|
|
|
|
|
|
def BIT_backbone_func(cfg):
|
|
|
net = BIT_Backbone(input_nc=cfg.input_nc,
|
|
|
output_nc=cfg.output_nc,
|
|
|
resnet_stages_num=cfg.resnet_stages_num,
|
|
|
backbone=cfg.backbone,
|
|
|
if_upsample_2x=cfg.if_upsample_2x)
|
|
|
return net
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
x1 = torch.rand(4, 3, 512, 512)
|
|
|
x2 = torch.rand(4, 3, 512, 512)
|
|
|
cfg = dict(
|
|
|
type = 'BIT_Backbone',
|
|
|
input_nc=3,
|
|
|
output_nc=32,
|
|
|
resnet_stages_num=4,
|
|
|
backbone='resnet18',
|
|
|
if_upsample_2x=True,
|
|
|
)
|
|
|
from munch import DefaultMunch
|
|
|
cfg = DefaultMunch.fromDict(cfg)
|
|
|
model = BIT_backbone_func(cfg)
|
|
|
model.eval()
|
|
|
print(model)
|
|
|
outs = model(x1, x2)
|
|
|
print('BIT', outs)
|
|
|
for out in outs:
|
|
|
print(out.shape) |