# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA'). # # All rights reserved. # This work should only be used for nonprofit purposes. # # By downloading and/or using any of these files, you implicitly agree to all the # terms of the license, as specified in the document LICENSE.txt # (included in this package) and online at # http://www.grip.unina.it/download/LICENSE_OPEN.txt """ Created in September 2022 @author: fabrizio.guillaro """ import torch import torch.nn as nn from torch.nn import functional as F class CrossEntropy(nn.Module): def __init__(self, ignore_label=-1, weight=None): super(CrossEntropy, self).__init__() self.ignore_label = ignore_label self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) def forward(self, score, target): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.upsample( input=score, size=(h, w), mode='bilinear') loss = self.criterion(score, target) return loss class DiceLoss(nn.Module): def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets super(DiceLoss, self).__init__() self.ignore_index = ignore_label self.smooth = smooth self.exponent = exponent def dice_loss(self, pred, target, valid_mask): assert pred.shape[0] == target.shape[0] total_loss = 0 num_classes = pred.shape[1] for i in range(num_classes): if i != self.ignore_index: dice_loss = self.binary_dice_loss( pred[:, i], target[..., i], valid_mask=valid_mask,) total_loss += dice_loss return total_loss / num_classes def binary_dice_loss(self, pred, target, valid_mask): assert pred.shape[0] == target.shape[0] pred = pred.reshape(pred.shape[0], -1) target = target.reshape(target.shape[0], -1) valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) dice = num / den dice = torch.mean(dice) return 1 - dice def forward(self, score, target): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.upsample( input=score, size=(h, w), mode='bilinear') score = F.softmax(score,dim=1) num_classes = score.shape[1] one_hot_target = F.one_hot( torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes) valid_mask = (target != self.ignore_index).long() loss = self.dice_loss(score, one_hot_target, valid_mask) return loss class BinaryDiceLoss(nn.Module): def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets super(BinaryDiceLoss, self).__init__() self.ignore_index = ignore_label self.smooth = smooth self.exponent = exponent def binary_dice_loss(self, pred, target, valid_mask): assert pred.shape[0] == target.shape[0] print(pred.shape, target.shape) pred = pred.reshape(pred.shape[0], -1) target = target.reshape(target.shape[0], -1) valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) print(pred.shape, target.shape) num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) dice = num / den dice = torch.mean(dice) return 1 - dice def forward(self, score, target): ph, pw = score.size(2), score.size(3) h, w = target.size(2), target.size(3) if ph != h or pw != w: score = F.upsample( input=score, size=(h, w), mode='bilinear') score = F.softmax(score,dim=1) num_classes = score.shape[1] one_hot_target = F.one_hot( torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes) valid_mask = (target != self.ignore_index).long() loss = self.binary_dice_loss( score[:, 1], one_hot_target[..., 1], valid_mask) return loss def create_target_from_mask_and_label(mask, data_label): """ Convert binary mask to class-labeled target based on data_label. Args: mask: B H W with values 0 (black/background) or 1 (white/foreground) data_label: B×1 tensor or B tensor with values [0, 1, 2, 3] - 0: background (no edit) - 1: physical edit (Photoshop) - 2: synthetic AI edit - 3: other edit type Returns: target: B H W with values [0, 1, 2, 3] - 0: unedited pixels (mask == 0) - 1, 2, 3: edited pixels with their respective class labels """ # Handle if mask has channel dimension if mask.dim() == 4: # B×1×H×W mask = mask.squeeze(1) # B×H×W # Handle if data_label has extra dimensions if data_label.dim() > 1: data_label = data_label.squeeze() # B B, H, W = mask.shape # Initialize target with zeros (background class) target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device) # For each sample in batch for b in range(B): # Get the class label for this sample class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item() # Where mask is white (1), set the target to the class label # Where mask is black (0), keep target as 0 (background) target[b][mask[b] == 1] = class_label return target def debug_target_creation(target, data_label, batch_size=4): """ Debug function to print data_label and target mapping before and after conversion. Args: target: Binary mask B×H×W or B×1×H×W with values 0 or 1 data_label: B tensor with class labels [0, 1, 2, 3] """ print("="*80) print("DEBUGGING TARGET CREATION") print("="*80) # Print original inputs print("\n--- BEFORE CONVERSION ---") print(f"Data Label shape: {data_label.shape}") print(f"Data Label values: {data_label}") print(f"Data Label dtype: {data_label.dtype}") print(f"\nTarget (mask) shape: {target.shape}") print(f"Target (mask) unique values: {torch.unique(target)}") print(f"Target (mask) dtype: {target.dtype}") # Print per-sample details BEFORE print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---") if target.dim() == 4: # B×1×H×W target_2d = target.squeeze(1) else: target_2d = target B = target_2d.shape[0] for b in range(min(B, batch_size)): edited_pixels = (target_2d[b] == 1).sum().item() total_pixels = target_2d[b].numel() label = data_label[b].item() if data_label.dim() > 0 else data_label.item() print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}") # Create target target_converted = create_target_from_mask_and_label(target, data_label) # Print AFTER conversion print("\n--- AFTER CONVERSION ---") print(f"Target (converted) shape: {target_converted.shape}") print(f"Target (converted) unique values: {torch.unique(target_converted)}") print(f"Target (converted) dtype: {target_converted.dtype}") # Print per-sample details AFTER print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---") for b in range(min(B, batch_size)): label = data_label[b].item() if data_label.dim() > 0 else data_label.item() # Count pixels for each class class_counts = {} for class_id in range(4): count = (target_converted[b] == class_id).sum().item() class_counts[class_id] = count print(f" Sample {b}:") print(f" Label (expected): {label}") print(f" Class distribution: {class_counts}") # Verify correctness if label == 0: # All pixels should be background (0) if class_counts[0] == target_converted[b].numel(): print(f" ✓ CORRECT: All pixels are class 0 (background)") else: print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}") else: # Non-background pixels should have the label if class_counts[label] > 0: print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}") else: print(f" ✗ ERROR: Expected class {label} pixels but found none") print("\n" + "="*80) return target_converted class MultiClassDiceEntropyLoss(nn.Module): """ Multi-class segmentation loss combining Dice and CrossEntropy. Supports classes: 0 (background), 1, 2, 3 """ def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5, ignore_index=-1): super(MultiClassDiceEntropyLoss, self).__init__() self.num_classes = num_classes self.smooth = smooth self.dice_weight = dice_weight self.ce_weight = ce_weight self.ignore_index = ignore_index # CrossEntropy loss self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) def dice_loss(self, pred, target, valid_mask=None): """ Compute Dice loss per class and average pred: B C H W (softmax probabilities) target: B H W (class indices 0-3) valid_mask: B H W (1 for valid, 0 for ignore) """ dice_losses = [] for class_id in range(self.num_classes): # One-hot encode for this class pred_class = pred[:, class_id, :, :] # B×H×W target_class = (target == class_id).float() # B×H×W # Flatten pred_flat = pred_class.reshape(-1) target_flat = target_class.reshape(-1) # Apply valid mask if provided if valid_mask is not None: valid_flat = valid_mask.reshape(-1) pred_flat = pred_flat * valid_flat target_flat = target_flat * valid_flat # Dice computation intersection = torch.sum(pred_flat * target_flat) union = torch.sum(pred_flat) + torch.sum(target_flat) dice = (2 * intersection + self.smooth) / (union + self.smooth) dice_losses.append(1 - dice) return torch.mean(torch.stack(dice_losses)) def forward(self, score, target, data_label): """ pred: B 1 H W (U-Net output, raw logits) target: B H W (class labels: 0, 1, 2, or 3) """ # Handle if target has channel dimension if target.dim() == 4: # B×1×H×W target = target.squeeze(1) # B×H×W # target = create_target_from_mask_and_label(target, data_label) # test_result = debug_target_creation(target, data_label, batch_size=1) # Ensure target i'=s long type target = target.long() # Upsample pred if needed if score.shape[2:] != target.shape[1:]: score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False) # Convert single channel to multi-class # If score is B×1×H×W, we need to expand it to B×C×H×W # if score.shape[1] == 1: # # U-Net outputs 1 channel, we need to create num_classes channels # # This assumes your U-Net needs modification OR # # we convert single channel to multi-class logits # raise ValueError( # f"U-Net outputs {score.shape[1]} channel but {self.num_classes} classes expected. " # "Modify U-Net output layer to have num_classes={} channels".format(self.num_classes) # ) # Apply softmax to get probabilities score_probs = F.softmax(score, dim=1) # B×C×H×W # CrossEntropy loss ce_loss = self.ce_loss(score, target) # Valid mask (exclude ignore_index) valid_mask = (target != self.ignore_index).float() # Dice loss dice_loss = self.dice_loss(score_probs, target, valid_mask) # Combined loss total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss return total_loss class DiceEntropyLoss(nn.Module): def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets super(DiceEntropyLoss, self).__init__() self.ignore_label = ignore_label self.smooth = smooth self.exponent = exponent self.cross_entropy = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) def binary_dice_loss(self, pred, target, valid_mask): assert pred.shape[0] == target.shape[0] pred = pred.reshape(pred.shape[0], -1) target = target.reshape(target.shape[0], -1) valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) dice = num / den dice = torch.mean(dice) return 1 - dice def forward(self, score, target): target = target.squeeze(1).long() target = torch.clamp(target, min=0, max=1) ph, pw = score.size(2), score.size(3) # (B,1,224,224) h, w = target.size(1), target.size(2) # (B,224,224) if ph != h or pw != w: score = F.upsample( input=score, size=(h, w), mode='bilinear') CE_loss = self.cross_entropy(score, target) score = F.softmax(score,dim=1) num_classes = score.shape[1] one_hot_target = F.one_hot( torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes) valid_mask = (target != self.ignore_label).long() # dice_loss = self.binary_dice_loss( # score[:, 1], # one_hot_target[..., 1], # valid_mask) number_of_present_classes = 4 dice_loss = 0 for class_id in [1,2,3]: if (target == class_id).sum() > 0: dice_loss += dice(pred[:, class_id], target_onehot[:, class_id]) dice_loss /= number_of_present_classes return 0.3*CE_loss + 0.7*dice_loss class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2. super(FocalLoss, self).__init__() self.alpha=alpha self.gamma= gamma self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none") def forward(self, score, target): ph, pw = score.size(2), score.size(3) h, w = target.size(1), target.size(2) if ph != h or pw != w: score = F.upsample( input=score, size=(h, w), mode='bilinear') ce_loss = self.criterion(score, target) pt = torch.exp(-ce_loss) f_loss = self.alpha * (1-pt)**self.gamma * ce_loss return f_loss.mean()