Spaces:
Running on Zero
Running on Zero
| from torch import nn | |
| from torchmetrics import StructuralSimilarityIndexMeasure as SSIM | |
| class BackboneL2SSIMLoss(nn.Module): | |
| def __init__(self, ssim_window_size=5, alpha=0.5): | |
| super(BackboneL2SSIMLoss, self).__init__() | |
| self.ssim_window_size = ssim_window_size | |
| self.alpha = alpha | |
| self.ssim_loss = SSIM(kernel_size=ssim_window_size).cuda() | |
| self.l2_loss = nn.MSELoss() | |
| def forward(self, backbone, prediction, target): | |
| ssim_loss_pred = (1.0 - self.ssim_loss(prediction, target)) | |
| l2_loss_pred = self.l2_loss(prediction, target) | |
| l2_loss_backbone = self.l2_loss(backbone, target) | |
| return self.alpha*l2_loss_backbone + l2_loss_pred + ssim_loss_pred | |
| def get_criterion(criterion_config): | |
| criterion_type = criterion_config.type | |
| if criterion_type == 'backbone-L2-SSIM': | |
| return BackboneL2SSIMLoss(**criterion_config['params']) | |
| #TODO: Add more criterion types here (L1, L2 ...) | |
| else: | |
| raise ValueError(f"Unsupported criterion type: {criterion_type}") |