File size: 727 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import torch.nn as nn
from rscd.models.backbones.seaformer import *
from rscd.models.backbones.resnet import get_resnet18, get_resnet50_OS32, get_resnet50_OS8
from rscd.models.backbones.swintransformer import *
class Base(nn.Module):
def __init__(self, name):
super().__init__()
if name == 'Seaformer':
self.backbone = SeaFormer_L(pretrained=True)
elif name == 'Resnet18':
self.backbone = get_resnet18(pretrained=True)
elif name == 'Swin':
self.backbone = swin_tiny(True)
def forward(self, xA, xB):
featuresA = self.backbone(xA)
featuresB = self.backbone(xB)
return [featuresA, featuresB]
|