InPeerReview's picture
Upload 161 files
226675b verified
import torch
from torch import nn
import sys
sys.path.append('rscd')
from utils.build import build_from_cfg
class myModel(nn.Module):
def __init__(self, cfg):
super(myModel, self).__init__()
self.backbone = build_from_cfg(cfg.backbone)
self.decoderhead = build_from_cfg(cfg.decoderhead)
def forward(self, x1, x2, gtmask=None):
backbone_outputs = self.backbone(x1, x2)
if gtmask == None:
x_list = self.decoderhead(backbone_outputs)
else:
x_list = self.decoderhead(backbone_outputs, gtmask)
return x_list
"""
对于不满足该范式的模型可在backbone部分进行定义, 并在此处导入
"""
# model_config
def build_model(cfg):
c = myModel(cfg)
return c
if __name__ == "__main__":
x1 = torch.randn(4, 3, 512, 512)
x2 = torch.randn(4, 3, 512, 512)
target = torch.randint(low=0,high=2,size=[4, 512, 512])
file_path = r"E:\zjuse\2308CD\rschangedetection\configs\SARASNet.py"
from utils.config import Config
from rscd.losses import build_loss
cfg = Config.fromfile(file_path)
net = build_model(cfg.model_config)
res = net(x1, x2)
print(res.shape)
loss = build_loss(cfg.loss_config)
compute = loss(res,target)
print(compute)