Segmentation / losses.py
riha55's picture
Upload 7 files
6f774ac verified
# from .train3 import deeplabv3_encoder_decoder
# # from .train3 import pl
# # from .train3 import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch
# class mIoULoss(nn.Module):
# def __init__(self, weight=None, size_average=True, n_classes=4):
# super().__init__()
# self.classes = n_classes
# def to_one_hot(self, tensor):
# n, h, w = tensor.size()
# one_hot = torch.zeros(n, self.classes, h, w).to(tensor.device)
# one_hot.scatter_(1, tensor.unsqueeze(1), 1)
# return one_hot
# def forward(self, inputs, target):
# N = inputs.size(0)
# inputs = F.softmax(inputs, dim=1)
# target_oneHot = self.to_one_hot(target)
# inter = inputs * target_oneHot
# inter = inter.view(N, self.classes, -1).sum(2)
# union = inputs + target_oneHot - inter
# union = union.view(N, self.classes, -1).sum(2)
# loss = inter / union
# return 1 - loss.mean()
import torch.nn as nn
import torch.nn.functional as F
import torch
class DiceLoss(nn.Module):
def __init__(self, smooth=1.0):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, preds, labels):
#
if preds.dim() == 4:
preds = torch.sigmoid(preds)
# Flatten the tensors
preds = preds.contiguous().view(-1)
labels = labels.contiguous().view(-1)
# Compute intersection and union
intersection = (preds * labels).sum()
dice = (2. * intersection + self.smooth) / (preds.sum() + labels.sum() + self.smooth)
# Dice loss is 1 - Dice coefficient
loss = 1 - dice
return loss
class mIoULoss(nn.Module):
def __init__(self, weight=None, size_average=True, n_classes=4): # Set n_classes to 4
super().__init__()
self.classes = n_classes
def to_one_hot(self, tensor):
tensor = tensor.long() # Ensure tensor is a LongTensor
n, c, h, w = tensor.size() # Adjust size extraction
one_hot = torch.zeros(n, self.classes, h, w).to(tensor.device)
one_hot.scatter_(1, tensor, 1)
return one_hot
def forward(self, inputs, target):
# inputs => N x Classes x H x W
# target_oneHot => N x Classes x H x W
N = inputs.size()[0]
# predicted probabilities for each pixel along channel
inputs = F.softmax(inputs, dim=1)
# Numerator Product
target_oneHot = self.to_one_hot(target)
inter = inputs * target_oneHot
## Sum over all pixels N x C x H x W => N x C
inter = inter.view(N, self.classes, -1).sum(2)
# Denominator
union = inputs + target_oneHot - (inputs * target_oneHot)
## Sum over all pixels N x C x H x W => N x C
union = union.view(N, self.classes, -1).sum(2)
loss = inter / union
## Return average loss over classes and batch
return 1 - loss.mean()