File size: 727 Bytes
226675b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch 
import torch.nn as nn
from rscd.models.backbones.seaformer import *
from rscd.models.backbones.resnet import get_resnet18, get_resnet50_OS32, get_resnet50_OS8
from rscd.models.backbones.swintransformer import *

class Base(nn.Module):
    def __init__(self, name):
        super().__init__()
        if name == 'Seaformer':
            self.backbone = SeaFormer_L(pretrained=True)
        elif name == 'Resnet18':
            self.backbone = get_resnet18(pretrained=True)
        elif name == 'Swin':
            self.backbone = swin_tiny(True)

    def forward(self, xA, xB):
        featuresA = self.backbone(xA)
        featuresB = self.backbone(xB)

        return [featuresA, featuresB]