File size: 16,395 Bytes
8bd3ef8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 |
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
class CrossEntropy(nn.Module):
def __init__(self, ignore_label=-1, weight=None):
super(CrossEntropy, self).__init__()
self.ignore_label = ignore_label
self.criterion = nn.CrossEntropyLoss(weight=weight,
ignore_index=ignore_label)
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
loss = self.criterion(score, target)
return loss
class DiceLoss(nn.Module):
def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets
super(DiceLoss, self).__init__()
self.ignore_index = ignore_label
self.smooth = smooth
self.exponent = exponent
def dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):
if i != self.ignore_index:
dice_loss = self.binary_dice_loss(
pred[:, i],
target[..., i],
valid_mask=valid_mask,)
total_loss += dice_loss
return total_loss / num_classes
def binary_dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
dice = num / den
dice = torch.mean(dice)
return 1 - dice
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
score = F.softmax(score,dim=1)
num_classes = score.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()
loss = self.dice_loss(score, one_hot_target, valid_mask)
return loss
class BinaryDiceLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets
super(BinaryDiceLoss, self).__init__()
self.ignore_index = ignore_label
self.smooth = smooth
self.exponent = exponent
def binary_dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
print(pred.shape, target.shape)
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
print(pred.shape, target.shape)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
dice = num / den
dice = torch.mean(dice)
return 1 - dice
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(2), target.size(3)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
score = F.softmax(score,dim=1)
num_classes = score.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()
loss = self.binary_dice_loss(
score[:, 1],
one_hot_target[..., 1],
valid_mask)
return loss
def create_target_from_mask_and_label(mask, data_label):
"""
Convert binary mask to class-labeled target based on data_label.
Args:
mask: B H W with values 0 (black/background) or 1 (white/foreground)
data_label: B×1 tensor or B tensor with values [0, 1, 2, 3]
- 0: background (no edit)
- 1: physical edit (Photoshop)
- 2: synthetic AI edit
- 3: other edit type
Returns:
target: B H W with values [0, 1, 2, 3]
- 0: unedited pixels (mask == 0)
- 1, 2, 3: edited pixels with their respective class labels
"""
# Handle if mask has channel dimension
if mask.dim() == 4: # B×1×H×W
mask = mask.squeeze(1) # B×H×W
# Handle if data_label has extra dimensions
if data_label.dim() > 1:
data_label = data_label.squeeze() # B
B, H, W = mask.shape
# Initialize target with zeros (background class)
target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device)
# For each sample in batch
for b in range(B):
# Get the class label for this sample
class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
# Where mask is white (1), set the target to the class label
# Where mask is black (0), keep target as 0 (background)
target[b][mask[b] == 1] = class_label
return target
def debug_target_creation(target, data_label, batch_size=4):
"""
Debug function to print data_label and target mapping before and after conversion.
Args:
target: Binary mask B×H×W or B×1×H×W with values 0 or 1
data_label: B tensor with class labels [0, 1, 2, 3]
"""
print("="*80)
print("DEBUGGING TARGET CREATION")
print("="*80)
# Print original inputs
print("\n--- BEFORE CONVERSION ---")
print(f"Data Label shape: {data_label.shape}")
print(f"Data Label values: {data_label}")
print(f"Data Label dtype: {data_label.dtype}")
print(f"\nTarget (mask) shape: {target.shape}")
print(f"Target (mask) unique values: {torch.unique(target)}")
print(f"Target (mask) dtype: {target.dtype}")
# Print per-sample details BEFORE
print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---")
if target.dim() == 4: # B×1×H×W
target_2d = target.squeeze(1)
else:
target_2d = target
B = target_2d.shape[0]
for b in range(min(B, batch_size)):
edited_pixels = (target_2d[b] == 1).sum().item()
total_pixels = target_2d[b].numel()
label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}")
# Create target
target_converted = create_target_from_mask_and_label(target, data_label)
# Print AFTER conversion
print("\n--- AFTER CONVERSION ---")
print(f"Target (converted) shape: {target_converted.shape}")
print(f"Target (converted) unique values: {torch.unique(target_converted)}")
print(f"Target (converted) dtype: {target_converted.dtype}")
# Print per-sample details AFTER
print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---")
for b in range(min(B, batch_size)):
label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
# Count pixels for each class
class_counts = {}
for class_id in range(4):
count = (target_converted[b] == class_id).sum().item()
class_counts[class_id] = count
print(f" Sample {b}:")
print(f" Label (expected): {label}")
print(f" Class distribution: {class_counts}")
# Verify correctness
if label == 0:
# All pixels should be background (0)
if class_counts[0] == target_converted[b].numel():
print(f" ✓ CORRECT: All pixels are class 0 (background)")
else:
print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}")
else:
# Non-background pixels should have the label
if class_counts[label] > 0:
print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}")
else:
print(f" ✗ ERROR: Expected class {label} pixels but found none")
print("\n" + "="*80)
return target_converted
class MultiClassDiceEntropyLoss(nn.Module):
"""
Multi-class segmentation loss combining Dice and CrossEntropy.
Supports classes: 0 (background), 1, 2, 3
"""
def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5,
ignore_index=-1):
super(MultiClassDiceEntropyLoss, self).__init__()
self.num_classes = num_classes
self.smooth = smooth
self.dice_weight = dice_weight
self.ce_weight = ce_weight
self.ignore_index = ignore_index
# CrossEntropy loss
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
def dice_loss(self, pred, target, valid_mask=None):
"""
Compute Dice loss per class and average
pred: B C H W (softmax probabilities)
target: B H W (class indices 0-3)
valid_mask: B H W (1 for valid, 0 for ignore)
"""
dice_losses = []
for class_id in range(self.num_classes):
# One-hot encode for this class
pred_class = pred[:, class_id, :, :] # B×H×W
target_class = (target == class_id).float() # B×H×W
# Flatten
pred_flat = pred_class.reshape(-1)
target_flat = target_class.reshape(-1)
# Apply valid mask if provided
if valid_mask is not None:
valid_flat = valid_mask.reshape(-1)
pred_flat = pred_flat * valid_flat
target_flat = target_flat * valid_flat
# Dice computation
intersection = torch.sum(pred_flat * target_flat)
union = torch.sum(pred_flat) + torch.sum(target_flat)
dice = (2 * intersection + self.smooth) / (union + self.smooth)
dice_losses.append(1 - dice)
return torch.mean(torch.stack(dice_losses))
def forward(self, score, target, data_label):
"""
pred: B 1 H W (U-Net output, raw logits)
target: B H W (class labels: 0, 1, 2, or 3)
"""
# Handle if target has channel dimension
if target.dim() == 4: # B×1×H×W
target = target.squeeze(1) # B×H×W
# target = create_target_from_mask_and_label(target, data_label)
# test_result = debug_target_creation(target, data_label, batch_size=1)
# Ensure target i'=s long type
target = target.long()
# Upsample pred if needed
if score.shape[2:] != target.shape[1:]:
score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False)
# Convert single channel to multi-class
# If score is B×1×H×W, we need to expand it to B×C×H×W
# if score.shape[1] == 1:
# # U-Net outputs 1 channel, we need to create num_classes channels
# # This assumes your U-Net needs modification OR
# # we convert single channel to multi-class logits
# raise ValueError(
# f"U-Net outputs {score.shape[1]} channel but {self.num_classes} classes expected. "
# "Modify U-Net output layer to have num_classes={} channels".format(self.num_classes)
# )
# Apply softmax to get probabilities
score_probs = F.softmax(score, dim=1) # B×C×H×W
# CrossEntropy loss
ce_loss = self.ce_loss(score, target)
# Valid mask (exclude ignore_index)
valid_mask = (target != self.ignore_index).float()
# Dice loss
dice_loss = self.dice_loss(score_probs, target, valid_mask)
# Combined loss
total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss
return total_loss
class DiceEntropyLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets
super(DiceEntropyLoss, self).__init__()
self.ignore_label = ignore_label
self.smooth = smooth
self.exponent = exponent
self.cross_entropy = nn.CrossEntropyLoss(weight=weight,
ignore_index=ignore_label)
def binary_dice_loss(self, pred, target, valid_mask):
assert pred.shape[0] == target.shape[0]
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
dice = num / den
dice = torch.mean(dice)
return 1 - dice
def forward(self, score, target):
target = target.squeeze(1).long()
target = torch.clamp(target, min=0, max=1)
ph, pw = score.size(2), score.size(3) # (B,1,224,224)
h, w = target.size(1), target.size(2) # (B,224,224)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
CE_loss = self.cross_entropy(score, target)
score = F.softmax(score,dim=1)
num_classes = score.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_label).long()
# dice_loss = self.binary_dice_loss(
# score[:, 1],
# one_hot_target[..., 1],
# valid_mask)
number_of_present_classes = 4
dice_loss = 0
for class_id in [1,2,3]:
if (target == class_id).sum() > 0:
dice_loss += dice(pred[:, class_id], target_onehot[:, class_id])
dice_loss /= number_of_present_classes
return 0.3*CE_loss + 0.7*dice_loss
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2.
super(FocalLoss, self).__init__()
self.alpha=alpha
self.gamma= gamma
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none")
def forward(self, score, target):
ph, pw = score.size(2), score.size(3)
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
ce_loss = self.criterion(score, target)
pt = torch.exp(-ce_loss)
f_loss = self.alpha * (1-pt)**self.gamma * ce_loss
return f_loss.mean()
|