| import torch |
| from torch import nn |
| import sys |
| sys.path.append('rscd') |
| from utils.build import build_from_cfg |
|
|
| class myModel(nn.Module): |
| def __init__(self, cfg): |
| super(myModel, self).__init__() |
| self.backbone = build_from_cfg(cfg.backbone) |
| self.decoderhead = build_from_cfg(cfg.decoderhead) |
| |
| def forward(self, x1, x2, gtmask=None): |
| backbone_outputs = self.backbone(x1, x2) |
| if gtmask == None: |
| x_list = self.decoderhead(backbone_outputs) |
| else: |
| x_list = self.decoderhead(backbone_outputs, gtmask) |
| return x_list |
|
|
| """ |
| 对于不满足该范式的模型可在backbone部分进行定义, 并在此处导入 |
| """ |
|
|
| |
| def build_model(cfg): |
| c = myModel(cfg) |
| return c |
|
|
|
|
| if __name__ == "__main__": |
| x1 = torch.randn(4, 3, 512, 512) |
| x2 = torch.randn(4, 3, 512, 512) |
| target = torch.randint(low=0,high=2,size=[4, 512, 512]) |
| file_path = r"E:\zjuse\2308CD\rschangedetection\configs\SARASNet.py" |
|
|
| from utils.config import Config |
| from rscd.losses import build_loss |
|
|
| cfg = Config.fromfile(file_path) |
| net = build_model(cfg.model_config) |
| res = net(x1, x2) |
| print(res.shape) |
| loss = build_loss(cfg.loss_config) |
|
|
| compute = loss(res,target) |
| print(compute) |