File size: 860 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from rscd.models.backbones.Decom_Backbone import ResNet3D
class AFCD3D_backbone(nn.Module):
def __init__(self):
super(AFCD3D_backbone, self).__init__()
resnet = torchvision.models.resnet18(pretrained=True)
self.resnet = ResNet3D(resnet)
def forward(self, imageA, imageB):
imageA = imageA.unsqueeze(2)
imageB = imageB.unsqueeze(2)
x = torch.cat([imageA, imageB], 2)
size = x.size()[3:]
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x0 = self.resnet.relu(x)
x = self.resnet.maxpool(x0)
x1 = self.resnet.layer1(x)
x2 = self.resnet.layer2(x1)
x3 = self.resnet.layer3(x2)
x4 = self.resnet.layer4(x3)
return [size, x0, x1, x2, x3, x4]
|