InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
from rscd.models.backbones.seaformer import *
from rscd.models.backbones.resnet import get_resnet18, get_resnet50_OS32, get_resnet50_OS8
from rscd.models.backbones.swintransformer import *
class Base(nn.Module):
def __init__(self, name):
super().__init__()
if name == 'Seaformer':
self.backbone = SeaFormer_L(pretrained=True)
elif name == 'Resnet18':
self.backbone = get_resnet18(pretrained=True)
elif name == 'Swin':
self.backbone = swin_tiny(True)
def forward(self, xA, xB):
featuresA = self.backbone(xA)
featuresB = self.backbone(xB)
return [featuresA, featuresB]