| import pdb | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .loss_functions import Contrastive_Loss, Cosine_Sim_Loss | |
| class _DMMI_Framework(nn.Module): | |
| def __init__(self, backbone, classifier): | |
| super(_DMMI_Framework, self).__init__() | |
| self.backbone = backbone | |
| self.classifier = classifier | |
| self.cossim = Cosine_Sim_Loss() | |
| self.contrastive = Contrastive_Loss() | |
| def forward(self, x, l_feats, l_feats1, l_mask, target_flag=None, training_flag=True): | |
| input_shape = x.shape[-2:] | |
| l_1, features = self.backbone(x, l_feats, l_mask) | |
| x_c1, x_c2, x_c3, x_c4 = features | |
| de_feat, l_2, x = self.classifier(l_1, l_feats1, x_c4, x_c3, x_c2, x_c1) | |
| seg_mag = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) | |
| if training_flag and target_flag!=None: | |
| loss_contrastive = self.contrastive(de_feat, l_1, target_flag) | |
| loss_cossim = self.cossim(l_1, l_2, l_mask, target_flag) | |
| else: | |
| loss_contrastive = 0 | |
| loss_cossim = 0 | |
| return loss_contrastive, loss_cossim, seg_mag | |
| class DMMI(_DMMI_Framework): | |
| pass | |