|
|
import os
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class EarlyStopping():
|
|
|
def __init__(self, patience=5, min_delta=0.0):
|
|
|
self.patience = patience
|
|
|
self.min_delta = min_delta
|
|
|
self.best_metric = None
|
|
|
self.best_index = 0
|
|
|
|
|
|
|
|
|
def check_patience(self, metric_history):
|
|
|
if self.patience == 0:
|
|
|
return False
|
|
|
best_metric = min(metric_history)
|
|
|
best_index = metric_history.index(best_metric)
|
|
|
|
|
|
if len(metric_history) - best_index >= self.patience:
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
def check_improvement(self, metric_history):
|
|
|
if self.min_delta == 0.0:
|
|
|
return False
|
|
|
|
|
|
if len(metric_history) < 2:
|
|
|
return False
|
|
|
|
|
|
if metric_history[-2] - metric_history[-1] >= self.min_delta:
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
def __call__(self, metric_history):
|
|
|
if self.min_delta != 0.0:
|
|
|
return self.check_patience(metric_history)
|
|
|
|
|
|
if self.patience != 0:
|
|
|
return self.check_improvement(metric_history)
|
|
|
|
|
|
return False
|
|
|
|
|
|
class SaveBestModel():
|
|
|
def __init__(self, folder="./", mode='min'):
|
|
|
self.best_metric = None
|
|
|
self.folder = folder
|
|
|
self.mode = mode
|
|
|
|
|
|
def __call__(self, model, current_metric, model_name="best.pth"):
|
|
|
if self.best_metric is None:
|
|
|
self.best_metric = current_metric
|
|
|
save_path = os.path.join(self.folder, model_name)
|
|
|
torch.save(model.state_dict(), save_path)
|
|
|
else:
|
|
|
if (self.mode == 'min' and current_metric < self.best_metric) or \
|
|
|
(self.mode == 'max' and current_metric > self.best_metric):
|
|
|
self.best_metric = current_metric
|
|
|
save_path = os.path.join(self.folder, model_name)
|
|
|
torch.save(model.state_dict(), save_path)
|
|
|
|
|
|
|
|
|
class ModelLoss():
|
|
|
def __init__(self, task='segmentation', loss='focal', focal_alpha=0.25, focal_gamma=2.0):
|
|
|
if task not in ['segmentation', 'mae']:
|
|
|
raise ValueError(f"Unsupported task: {task}")
|
|
|
|
|
|
self.task = task
|
|
|
self.loss = loss
|
|
|
self.focal_alpha = focal_alpha
|
|
|
self.focal_gamma = focal_gamma
|
|
|
if self.task == 'segmentation' and self.loss not in ['focal', 'cross_entropy']:
|
|
|
raise ValueError(f"Unsupported loss for segmentation task: {self.loss}")
|
|
|
|
|
|
@staticmethod
|
|
|
def l1(input, target):
|
|
|
return torch.nn.functional.l1_loss(input, target)
|
|
|
|
|
|
@staticmethod
|
|
|
def l2(input, target):
|
|
|
return torch.nn.functional.mse_loss(input, target)
|
|
|
|
|
|
@staticmethod
|
|
|
def cross_entropy(input, target):
|
|
|
return torch.nn.functional.cross_entropy(input, target)
|
|
|
|
|
|
@staticmethod
|
|
|
def dice_loss(input: torch.Tensor, target: torch.Tensor, eps=1e-6):
|
|
|
input = torch.sigmoid(input)
|
|
|
input = input.view(-1)
|
|
|
target = target.view(-1)
|
|
|
|
|
|
intersection = (input * target).sum()
|
|
|
dice = (2. * intersection + eps) / (input.sum() + target.sum() + eps)
|
|
|
return 1 - dice
|
|
|
|
|
|
@staticmethod
|
|
|
def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha=0.25, gamma=2.0, eps=1e-6):
|
|
|
input = torch.sigmoid(input)
|
|
|
input = input.view(-1)
|
|
|
target = target.view(-1)
|
|
|
|
|
|
bce_loss = torch.nn.functional.binary_cross_entropy(input, target, reduction='none')
|
|
|
p_t = input * target + (1 - input) * (1 - target)
|
|
|
focal_loss = alpha * (1 - p_t) ** gamma * bce_loss
|
|
|
return focal_loss.mean()
|
|
|
|
|
|
def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
if self.task == 'segmentation':
|
|
|
loss = 0.0
|
|
|
if self.loss == 'focal' or self.loss == 'default':
|
|
|
loss = ModelLoss.focal_loss(input, target, alpha=self.focal_alpha, gamma=self.focal_gamma)
|
|
|
elif self.loss == 'cross_entropy':
|
|
|
loss = ModelLoss.cross_entropy(input, target)
|
|
|
dice = ModelLoss.dice_loss(input, target)
|
|
|
return loss + dice
|
|
|
elif self.task == 'mae':
|
|
|
return ModelLoss.l2(input, target)
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class ModelMetrics():
|
|
|
def __init__(self, task='segmentation', device='cpu',threshold=0.5):
|
|
|
if task not in ['segmentation', 'mae']:
|
|
|
raise ValueError(f"Unsupported task: {task}")
|
|
|
|
|
|
self.task = task
|
|
|
self.device = device
|
|
|
self.threshold = threshold
|
|
|
|
|
|
@staticmethod
|
|
|
def iou_score(pred, target, eps=1e-6):
|
|
|
intersection = (pred * target).sum(dim=(1,2,3))
|
|
|
union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) - intersection
|
|
|
return ((intersection + eps) / (union + eps)).mean()
|
|
|
|
|
|
@staticmethod
|
|
|
def dice_score(pred, target, eps=1e-6):
|
|
|
intersection = (pred * target).sum(dim=(1,2,3))
|
|
|
return ((2 * intersection + eps) /
|
|
|
(pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + eps)).mean()
|
|
|
|
|
|
@staticmethod
|
|
|
def pixel_accuracy(pred, target):
|
|
|
return (pred == target).float().mean()
|
|
|
|
|
|
@staticmethod
|
|
|
def pixel_precision(pred, target, eps=1e-6):
|
|
|
|
|
|
true_positive = (pred * target).sum()
|
|
|
predicted_positive = pred.sum()
|
|
|
return (true_positive.float() + eps) / (predicted_positive.float() + eps)
|
|
|
|
|
|
@staticmethod
|
|
|
def recall(pred, target, eps=1e-6):
|
|
|
true_positive = (pred * target).sum()
|
|
|
actual_positive = target.sum()
|
|
|
return (true_positive.float() + eps) / (actual_positive.float() + eps)
|
|
|
|
|
|
@staticmethod
|
|
|
def l1(input, target):
|
|
|
return torch.nn.functional.l1_loss(input, target)
|
|
|
|
|
|
@staticmethod
|
|
|
def l2(input, target):
|
|
|
return torch.nn.functional.mse_loss(input, target)
|
|
|
|
|
|
def getLabels(self):
|
|
|
if self.task == 'segmentation':
|
|
|
return ['iou', 'dice', 'accuracy', 'precision', 'recall']
|
|
|
elif self.task == 'mae':
|
|
|
return ['l1', 'l2']
|
|
|
return []
|
|
|
|
|
|
def __call__(self, pred: torch.Tensor, target: torch.Tensor):
|
|
|
"""
|
|
|
pred: (B, 1, H, W) => output of model BEFORE sigmoid
|
|
|
target: (B, H, W) or (B, 1, H, W)
|
|
|
"""
|
|
|
|
|
|
if self.task == 'segmentation':
|
|
|
|
|
|
pred = torch.sigmoid(pred)
|
|
|
pred = (pred > 0.5).float()
|
|
|
iou = self.iou_score(pred, target)
|
|
|
dice = self.dice_score(pred, target)
|
|
|
acc = self.pixel_accuracy(pred, target)
|
|
|
pres = self.pixel_precision(pred, target)
|
|
|
recall = self.recall(pred, target)
|
|
|
return iou.item(), dice.item(), acc.item(), pres.item(), recall.item()
|
|
|
|
|
|
if self.task == 'mae':
|
|
|
l1 = self.l1(pred, target)
|
|
|
l2 = self.l2(pred, target)
|
|
|
return l1.item(), l2.item()
|
|
|
|
|
|
def validate_mae(model, val_loader, metrics):
|
|
|
model.eval()
|
|
|
validation_accumulator = []
|
|
|
with torch.no_grad():
|
|
|
for sources, targets in val_loader:
|
|
|
sources = sources.cuda()
|
|
|
targets = targets.cuda()
|
|
|
|
|
|
pred = model(sources)
|
|
|
|
|
|
validation_accumulator.append(metrics(pred, targets))
|
|
|
|
|
|
return np.mean(validation_accumulator, axis=0) |