Shahidmuneer's picture
Upload folder using huggingface_hub
8bd3ef8 verified
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
class CrossEntropy(nn.Module):
def __init__(self, ignore_label=-1, weight=None):
super(CrossEntropy, self).__init__()
self.ignore_label = ignore_label
self.criterion = nn.CrossEntropyLoss(weight=weight,
ignore_index=ignore_label)
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
loss = self.criterion(score, target)
return loss
class DiceLoss(nn.Module):
def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets
super(DiceLoss, self).__init__()
self.ignore_index = ignore_label
self.smooth = smooth
self.exponent = exponent
def dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):
if i != self.ignore_index:
dice_loss = self.binary_dice_loss(
pred[:, i],
target[..., i],
valid_mask=valid_mask,)
total_loss += dice_loss
return total_loss / num_classes
def binary_dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
dice = num / den
dice = torch.mean(dice)
return 1 - dice
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
score = F.softmax(score,dim=1)
num_classes = score.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()
loss = self.dice_loss(score, one_hot_target, valid_mask)
return loss
class BinaryDiceLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets
super(BinaryDiceLoss, self).__init__()
self.ignore_index = ignore_label
self.smooth = smooth
self.exponent = exponent
def binary_dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
print(pred.shape, target.shape)
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
print(pred.shape, target.shape)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
dice = num / den
dice = torch.mean(dice)
return 1 - dice
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(2), target.size(3)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
score = F.softmax(score,dim=1)
num_classes = score.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()
loss = self.binary_dice_loss(
score[:, 1],
one_hot_target[..., 1],
valid_mask)
return loss
def create_target_from_mask_and_label(mask, data_label):
"""
Convert binary mask to class-labeled target based on data_label.
Args:
mask: B H W with values 0 (black/background) or 1 (white/foreground)
data_label: B×1 tensor or B tensor with values [0, 1, 2, 3]
- 0: background (no edit)
- 1: physical edit (Photoshop)
- 2: synthetic AI edit
- 3: other edit type
Returns:
target: B H W with values [0, 1, 2, 3]
- 0: unedited pixels (mask == 0)
- 1, 2, 3: edited pixels with their respective class labels
"""
# Handle if mask has channel dimension
if mask.dim() == 4: # B×1×H×W
mask = mask.squeeze(1) # B×H×W
# Handle if data_label has extra dimensions
if data_label.dim() > 1:
data_label = data_label.squeeze() # B
B, H, W = mask.shape
# Initialize target with zeros (background class)
target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device)
# For each sample in batch
for b in range(B):
# Get the class label for this sample
class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
# Where mask is white (1), set the target to the class label
# Where mask is black (0), keep target as 0 (background)
target[b][mask[b] == 1] = class_label
return target
def debug_target_creation(target, data_label, batch_size=4):
"""
Debug function to print data_label and target mapping before and after conversion.
Args:
target: Binary mask B×H×W or B×1×H×W with values 0 or 1
data_label: B tensor with class labels [0, 1, 2, 3]
"""
print("="*80)
print("DEBUGGING TARGET CREATION")
print("="*80)
# Print original inputs
print("\n--- BEFORE CONVERSION ---")
print(f"Data Label shape: {data_label.shape}")
print(f"Data Label values: {data_label}")
print(f"Data Label dtype: {data_label.dtype}")
print(f"\nTarget (mask) shape: {target.shape}")
print(f"Target (mask) unique values: {torch.unique(target)}")
print(f"Target (mask) dtype: {target.dtype}")
# Print per-sample details BEFORE
print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---")
if target.dim() == 4: # B×1×H×W
target_2d = target.squeeze(1)
else:
target_2d = target
B = target_2d.shape[0]
for b in range(min(B, batch_size)):
edited_pixels = (target_2d[b] == 1).sum().item()
total_pixels = target_2d[b].numel()
label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}")
# Create target
target_converted = create_target_from_mask_and_label(target, data_label)
# Print AFTER conversion
print("\n--- AFTER CONVERSION ---")
print(f"Target (converted) shape: {target_converted.shape}")
print(f"Target (converted) unique values: {torch.unique(target_converted)}")
print(f"Target (converted) dtype: {target_converted.dtype}")
# Print per-sample details AFTER
print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---")
for b in range(min(B, batch_size)):
label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
# Count pixels for each class
class_counts = {}
for class_id in range(4):
count = (target_converted[b] == class_id).sum().item()
class_counts[class_id] = count
print(f" Sample {b}:")
print(f" Label (expected): {label}")
print(f" Class distribution: {class_counts}")
# Verify correctness
if label == 0:
# All pixels should be background (0)
if class_counts[0] == target_converted[b].numel():
print(f" ✓ CORRECT: All pixels are class 0 (background)")
else:
print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}")
else:
# Non-background pixels should have the label
if class_counts[label] > 0:
print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}")
else:
print(f" ✗ ERROR: Expected class {label} pixels but found none")
print("\n" + "="*80)
return target_converted
class MultiClassDiceEntropyLoss(nn.Module):
"""
Multi-class segmentation loss combining Dice and CrossEntropy.
Supports classes: 0 (background), 1, 2, 3
"""
def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5,
ignore_index=-1):
super(MultiClassDiceEntropyLoss, self).__init__()
self.num_classes = num_classes
self.smooth = smooth
self.dice_weight = dice_weight
self.ce_weight = ce_weight
self.ignore_index = ignore_index
# CrossEntropy loss
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
def dice_loss(self, pred, target, valid_mask=None):
"""
Compute Dice loss per class and average
pred: B C H W (softmax probabilities)
target: B H W (class indices 0-3)
valid_mask: B H W (1 for valid, 0 for ignore)
"""
dice_losses = []
for class_id in range(self.num_classes):
# One-hot encode for this class
pred_class = pred[:, class_id, :, :] # B×H×W
target_class = (target == class_id).float() # B×H×W
# Flatten
pred_flat = pred_class.reshape(-1)
target_flat = target_class.reshape(-1)
# Apply valid mask if provided
if valid_mask is not None:
valid_flat = valid_mask.reshape(-1)
pred_flat = pred_flat * valid_flat
target_flat = target_flat * valid_flat
# Dice computation
intersection = torch.sum(pred_flat * target_flat)
union = torch.sum(pred_flat) + torch.sum(target_flat)
dice = (2 * intersection + self.smooth) / (union + self.smooth)
dice_losses.append(1 - dice)
return torch.mean(torch.stack(dice_losses))
def forward(self, score, target, data_label):
"""
pred: B 1 H W (U-Net output, raw logits)
target: B H W (class labels: 0, 1, 2, or 3)
"""
# Handle if target has channel dimension
if target.dim() == 4: # B×1×H×W
target = target.squeeze(1) # B×H×W
# target = create_target_from_mask_and_label(target, data_label)
# test_result = debug_target_creation(target, data_label, batch_size=1)
# Ensure target i'=s long type
target = target.long()
# Upsample pred if needed
if score.shape[2:] != target.shape[1:]:
score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False)
# Convert single channel to multi-class
# If score is B×1×H×W, we need to expand it to B×C×H×W
# if score.shape[1] == 1:
# # U-Net outputs 1 channel, we need to create num_classes channels
# # This assumes your U-Net needs modification OR
# # we convert single channel to multi-class logits
# raise ValueError(
# f"U-Net outputs {score.shape[1]} channel but {self.num_classes} classes expected. "
# "Modify U-Net output layer to have num_classes={} channels".format(self.num_classes)
# )
# Apply softmax to get probabilities
score_probs = F.softmax(score, dim=1) # B×C×H×W
# CrossEntropy loss
ce_loss = self.ce_loss(score, target)
# Valid mask (exclude ignore_index)
valid_mask = (target != self.ignore_index).float()
# Dice loss
dice_loss = self.dice_loss(score_probs, target, valid_mask)
# Combined loss
total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss
return total_loss
class DiceEntropyLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets
super(DiceEntropyLoss, self).__init__()
self.ignore_label = ignore_label
self.smooth = smooth
self.exponent = exponent
self.cross_entropy = nn.CrossEntropyLoss(weight=weight,
ignore_index=ignore_label)
def binary_dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
dice = num / den
dice = torch.mean(dice)
return 1 - dice
def forward(self, score, target):
target = target.squeeze(1).long()
target = torch.clamp(target, min=0, max=1)
ph, pw = score.size(2), score.size(3) # (B,1,224,224)
h, w = target.size(1), target.size(2) # (B,224,224)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
CE_loss = self.cross_entropy(score, target)
score = F.softmax(score,dim=1)
num_classes = score.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_label).long()
# dice_loss = self.binary_dice_loss(
# score[:, 1],
# one_hot_target[..., 1],
# valid_mask)
number_of_present_classes = 4
dice_loss = 0
for class_id in [1,2,3]:
if (target == class_id).sum() > 0:
dice_loss += dice(pred[:, class_id], target_onehot[:, class_id])
dice_loss /= number_of_present_classes
return 0.3*CE_loss + 0.7*dice_loss
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2.
super(FocalLoss, self).__init__()
self.alpha=alpha
self.gamma= gamma
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none")
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
ce_loss = self.criterion(score, target)
pt = torch.exp(-ce_loss)
f_loss = self.alpha * (1-pt)**self.gamma * ce_loss
return f_loss.mean()