| import torch | |
| import torch.nn as nn | |
| from rscd.models.backbones.seaformer import * | |
| from rscd.models.backbones.resnet import get_resnet18, get_resnet50_OS32, get_resnet50_OS8 | |
| from rscd.models.backbones.swintransformer import * | |
| class Base(nn.Module): | |
| def __init__(self, name): | |
| super().__init__() | |
| if name == 'Seaformer': | |
| self.backbone = SeaFormer_L(pretrained=True) | |
| elif name == 'Resnet18': | |
| self.backbone = get_resnet18(pretrained=True) | |
| elif name == 'Swin': | |
| self.backbone = swin_tiny(True) | |
| def forward(self, xA, xB): | |
| featuresA = self.backbone(xA) | |
| featuresB = self.backbone(xB) | |
| return [featuresA, featuresB] | |