|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Created in September 2022 |
|
|
@author: fabrizio.guillaro |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropy(nn.Module): |
|
|
def __init__(self, ignore_label=-1, weight=None): |
|
|
super(CrossEntropy, self).__init__() |
|
|
self.ignore_label = ignore_label |
|
|
self.criterion = nn.CrossEntropyLoss(weight=weight, |
|
|
ignore_index=ignore_label) |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
loss = self.criterion(score, target) |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
class DiceLoss(nn.Module): |
|
|
def __init__(self, ignore_label=-1, smooth=1, exponent=2): |
|
|
super(DiceLoss, self).__init__() |
|
|
self.ignore_index = ignore_label |
|
|
self.smooth = smooth |
|
|
self.exponent = exponent |
|
|
|
|
|
def dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
total_loss = 0 |
|
|
num_classes = pred.shape[1] |
|
|
for i in range(num_classes): |
|
|
if i != self.ignore_index: |
|
|
dice_loss = self.binary_dice_loss( |
|
|
pred[:, i], |
|
|
target[..., i], |
|
|
valid_mask=valid_mask,) |
|
|
total_loss += dice_loss |
|
|
return total_loss / num_classes |
|
|
|
|
|
def binary_dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
pred = pred.reshape(pred.shape[0], -1) |
|
|
target = target.reshape(target.shape[0], -1) |
|
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
|
|
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
|
|
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
|
|
|
|
|
dice = num / den |
|
|
dice = torch.mean(dice) |
|
|
return 1 - dice |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
score = F.softmax(score,dim=1) |
|
|
num_classes = score.shape[1] |
|
|
|
|
|
one_hot_target = F.one_hot( |
|
|
torch.clamp(target.long(), 0, num_classes - 1), |
|
|
num_classes=num_classes) |
|
|
valid_mask = (target != self.ignore_index).long() |
|
|
|
|
|
loss = self.dice_loss(score, one_hot_target, valid_mask) |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
class BinaryDiceLoss(nn.Module): |
|
|
def __init__(self, smooth=1, exponent=2, ignore_label=-1): |
|
|
super(BinaryDiceLoss, self).__init__() |
|
|
self.ignore_index = ignore_label |
|
|
self.smooth = smooth |
|
|
self.exponent = exponent |
|
|
|
|
|
def binary_dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
print(pred.shape, target.shape) |
|
|
pred = pred.reshape(pred.shape[0], -1) |
|
|
target = target.reshape(target.shape[0], -1) |
|
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
print(pred.shape, target.shape) |
|
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
|
|
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
|
|
|
|
|
dice = num / den |
|
|
dice = torch.mean(dice) |
|
|
return 1 - dice |
|
|
|
|
|
def forward(self, score, target): |
|
|
|
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(2), target.size(3) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
score = F.softmax(score,dim=1) |
|
|
num_classes = score.shape[1] |
|
|
|
|
|
one_hot_target = F.one_hot( |
|
|
torch.clamp(target.long(), 0, num_classes - 1), |
|
|
num_classes=num_classes) |
|
|
valid_mask = (target != self.ignore_index).long() |
|
|
|
|
|
loss = self.binary_dice_loss( |
|
|
score[:, 1], |
|
|
one_hot_target[..., 1], |
|
|
valid_mask) |
|
|
return loss |
|
|
|
|
|
def create_target_from_mask_and_label(mask, data_label): |
|
|
""" |
|
|
Convert binary mask to class-labeled target based on data_label. |
|
|
|
|
|
Args: |
|
|
mask: B H W with values 0 (black/background) or 1 (white/foreground) |
|
|
data_label: B×1 tensor or B tensor with values [0, 1, 2, 3] |
|
|
- 0: background (no edit) |
|
|
- 1: physical edit (Photoshop) |
|
|
- 2: synthetic AI edit |
|
|
- 3: other edit type |
|
|
|
|
|
Returns: |
|
|
target: B H W with values [0, 1, 2, 3] |
|
|
- 0: unedited pixels (mask == 0) |
|
|
- 1, 2, 3: edited pixels with their respective class labels |
|
|
""" |
|
|
|
|
|
|
|
|
if mask.dim() == 4: |
|
|
mask = mask.squeeze(1) |
|
|
|
|
|
|
|
|
if data_label.dim() > 1: |
|
|
data_label = data_label.squeeze() |
|
|
|
|
|
B, H, W = mask.shape |
|
|
|
|
|
|
|
|
target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device) |
|
|
|
|
|
|
|
|
for b in range(B): |
|
|
|
|
|
class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item() |
|
|
|
|
|
|
|
|
|
|
|
target[b][mask[b] == 1] = class_label |
|
|
|
|
|
return target |
|
|
|
|
|
|
|
|
def debug_target_creation(target, data_label, batch_size=4): |
|
|
""" |
|
|
Debug function to print data_label and target mapping before and after conversion. |
|
|
|
|
|
Args: |
|
|
target: Binary mask B×H×W or B×1×H×W with values 0 or 1 |
|
|
data_label: B tensor with class labels [0, 1, 2, 3] |
|
|
""" |
|
|
|
|
|
print("="*80) |
|
|
print("DEBUGGING TARGET CREATION") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\n--- BEFORE CONVERSION ---") |
|
|
print(f"Data Label shape: {data_label.shape}") |
|
|
print(f"Data Label values: {data_label}") |
|
|
print(f"Data Label dtype: {data_label.dtype}") |
|
|
|
|
|
print(f"\nTarget (mask) shape: {target.shape}") |
|
|
print(f"Target (mask) unique values: {torch.unique(target)}") |
|
|
print(f"Target (mask) dtype: {target.dtype}") |
|
|
|
|
|
|
|
|
print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---") |
|
|
if target.dim() == 4: |
|
|
target_2d = target.squeeze(1) |
|
|
else: |
|
|
target_2d = target |
|
|
|
|
|
B = target_2d.shape[0] |
|
|
for b in range(min(B, batch_size)): |
|
|
edited_pixels = (target_2d[b] == 1).sum().item() |
|
|
total_pixels = target_2d[b].numel() |
|
|
label = data_label[b].item() if data_label.dim() > 0 else data_label.item() |
|
|
print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}") |
|
|
|
|
|
|
|
|
target_converted = create_target_from_mask_and_label(target, data_label) |
|
|
|
|
|
|
|
|
print("\n--- AFTER CONVERSION ---") |
|
|
print(f"Target (converted) shape: {target_converted.shape}") |
|
|
print(f"Target (converted) unique values: {torch.unique(target_converted)}") |
|
|
print(f"Target (converted) dtype: {target_converted.dtype}") |
|
|
|
|
|
|
|
|
print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---") |
|
|
for b in range(min(B, batch_size)): |
|
|
label = data_label[b].item() if data_label.dim() > 0 else data_label.item() |
|
|
|
|
|
|
|
|
class_counts = {} |
|
|
for class_id in range(4): |
|
|
count = (target_converted[b] == class_id).sum().item() |
|
|
class_counts[class_id] = count |
|
|
|
|
|
print(f" Sample {b}:") |
|
|
print(f" Label (expected): {label}") |
|
|
print(f" Class distribution: {class_counts}") |
|
|
|
|
|
|
|
|
if label == 0: |
|
|
|
|
|
if class_counts[0] == target_converted[b].numel(): |
|
|
print(f" ✓ CORRECT: All pixels are class 0 (background)") |
|
|
else: |
|
|
print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}") |
|
|
else: |
|
|
|
|
|
if class_counts[label] > 0: |
|
|
print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}") |
|
|
else: |
|
|
print(f" ✗ ERROR: Expected class {label} pixels but found none") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
|
|
|
return target_converted |
|
|
|
|
|
class MultiClassDiceEntropyLoss(nn.Module): |
|
|
""" |
|
|
Multi-class segmentation loss combining Dice and CrossEntropy. |
|
|
Supports classes: 0 (background), 1, 2, 3 |
|
|
""" |
|
|
def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5, |
|
|
ignore_index=-1): |
|
|
super(MultiClassDiceEntropyLoss, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
self.smooth = smooth |
|
|
self.dice_weight = dice_weight |
|
|
self.ce_weight = ce_weight |
|
|
self.ignore_index = ignore_index |
|
|
|
|
|
|
|
|
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) |
|
|
|
|
|
def dice_loss(self, pred, target, valid_mask=None): |
|
|
""" |
|
|
Compute Dice loss per class and average |
|
|
|
|
|
pred: B C H W (softmax probabilities) |
|
|
target: B H W (class indices 0-3) |
|
|
valid_mask: B H W (1 for valid, 0 for ignore) |
|
|
""" |
|
|
dice_losses = [] |
|
|
|
|
|
for class_id in range(self.num_classes): |
|
|
|
|
|
pred_class = pred[:, class_id, :, :] |
|
|
target_class = (target == class_id).float() |
|
|
|
|
|
|
|
|
pred_flat = pred_class.reshape(-1) |
|
|
target_flat = target_class.reshape(-1) |
|
|
|
|
|
|
|
|
if valid_mask is not None: |
|
|
valid_flat = valid_mask.reshape(-1) |
|
|
pred_flat = pred_flat * valid_flat |
|
|
target_flat = target_flat * valid_flat |
|
|
|
|
|
|
|
|
intersection = torch.sum(pred_flat * target_flat) |
|
|
union = torch.sum(pred_flat) + torch.sum(target_flat) |
|
|
|
|
|
dice = (2 * intersection + self.smooth) / (union + self.smooth) |
|
|
dice_losses.append(1 - dice) |
|
|
|
|
|
return torch.mean(torch.stack(dice_losses)) |
|
|
|
|
|
def forward(self, score, target, data_label): |
|
|
""" |
|
|
pred: B 1 H W (U-Net output, raw logits) |
|
|
target: B H W (class labels: 0, 1, 2, or 3) |
|
|
""" |
|
|
|
|
|
if target.dim() == 4: |
|
|
target = target.squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target = target.long() |
|
|
|
|
|
|
|
|
if score.shape[2:] != target.shape[1:]: |
|
|
score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
score_probs = F.softmax(score, dim=1) |
|
|
|
|
|
|
|
|
ce_loss = self.ce_loss(score, target) |
|
|
|
|
|
|
|
|
valid_mask = (target != self.ignore_index).float() |
|
|
|
|
|
|
|
|
dice_loss = self.dice_loss(score_probs, target, valid_mask) |
|
|
|
|
|
|
|
|
total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss |
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
class DiceEntropyLoss(nn.Module): |
|
|
def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): |
|
|
super(DiceEntropyLoss, self).__init__() |
|
|
self.ignore_label = ignore_label |
|
|
self.smooth = smooth |
|
|
self.exponent = exponent |
|
|
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, |
|
|
ignore_index=ignore_label) |
|
|
|
|
|
def binary_dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
pred = pred.reshape(pred.shape[0], -1) |
|
|
target = target.reshape(target.shape[0], -1) |
|
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
|
|
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
|
|
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
|
|
|
|
|
dice = num / den |
|
|
dice = torch.mean(dice) |
|
|
return 1 - dice |
|
|
|
|
|
def forward(self, score, target): |
|
|
target = target.squeeze(1).long() |
|
|
|
|
|
target = torch.clamp(target, min=0, max=1) |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
CE_loss = self.cross_entropy(score, target) |
|
|
|
|
|
|
|
|
score = F.softmax(score,dim=1) |
|
|
num_classes = score.shape[1] |
|
|
|
|
|
one_hot_target = F.one_hot( |
|
|
torch.clamp(target.long(), 0, num_classes - 1), |
|
|
num_classes=num_classes) |
|
|
valid_mask = (target != self.ignore_label).long() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
number_of_present_classes = 4 |
|
|
dice_loss = 0 |
|
|
for class_id in [1,2,3]: |
|
|
if (target == class_id).sum() > 0: |
|
|
dice_loss += dice(pred[:, class_id], target_onehot[:, class_id]) |
|
|
dice_loss /= number_of_present_classes |
|
|
|
|
|
return 0.3*CE_loss + 0.7*dice_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
|
|
|
def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha=alpha |
|
|
self.gamma= gamma |
|
|
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none") |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
ce_loss = self.criterion(score, target) |
|
|
pt = torch.exp(-ce_loss) |
|
|
f_loss = self.alpha * (1-pt)**self.gamma * ce_loss |
|
|
return f_loss.mean() |
|
|
|
|
|
|