Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import os
import torch
import numpy as np
class EarlyStopping():
def __init__(self, patience=5, min_delta=0.0):
self.patience = patience
self.min_delta = min_delta
self.best_metric = None
self.best_index = 0
def check_patience(self, metric_history):
if self.patience == 0:
return False
best_metric = min(metric_history)
best_index = metric_history.index(best_metric)
if len(metric_history) - best_index >= self.patience:
return True
return False
def check_improvement(self, metric_history):
if self.min_delta == 0.0:
return False
if len(metric_history) < 2:
return False
if metric_history[-2] - metric_history[-1] >= self.min_delta:
return True
return False
def __call__(self, metric_history):
if self.min_delta != 0.0:
return self.check_patience(metric_history)
if self.patience != 0:
return self.check_improvement(metric_history)
return False
class SaveBestModel():
def __init__(self, folder="./", mode='min'):
self.best_metric = None
self.folder = folder
self.mode = mode
def __call__(self, model, current_metric, model_name="best.pth"):
if self.best_metric is None:
self.best_metric = current_metric
save_path = os.path.join(self.folder, model_name)
torch.save(model.state_dict(), save_path)
else:
if (self.mode == 'min' and current_metric < self.best_metric) or \
(self.mode == 'max' and current_metric > self.best_metric):
self.best_metric = current_metric
save_path = os.path.join(self.folder, model_name)
torch.save(model.state_dict(), save_path)
class ModelLoss():
def __init__(self, task='segmentation', loss='focal', focal_alpha=0.25, focal_gamma=2.0):
if task not in ['segmentation', 'mae']:
raise ValueError(f"Unsupported task: {task}")
self.task = task
self.loss = loss
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
if self.task == 'segmentation' and self.loss not in ['focal', 'cross_entropy']:
raise ValueError(f"Unsupported loss for segmentation task: {self.loss}")
@staticmethod
def l1(input, target):
return torch.nn.functional.l1_loss(input, target)
@staticmethod
def l2(input, target):
return torch.nn.functional.mse_loss(input, target)
@staticmethod
def cross_entropy(input, target):
return torch.nn.functional.cross_entropy(input, target)
@staticmethod
def dice_loss(input: torch.Tensor, target: torch.Tensor, eps=1e-6):
input = torch.sigmoid(input)
input = input.view(-1)
target = target.view(-1)
intersection = (input * target).sum()
dice = (2. * intersection + eps) / (input.sum() + target.sum() + eps)
return 1 - dice
@staticmethod
def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha=0.25, gamma=2.0, eps=1e-6):
input = torch.sigmoid(input)
input = input.view(-1)
target = target.view(-1)
bce_loss = torch.nn.functional.binary_cross_entropy(input, target, reduction='none')
p_t = input * target + (1 - input) * (1 - target)
focal_loss = alpha * (1 - p_t) ** gamma * bce_loss
return focal_loss.mean()
def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.task == 'segmentation':
loss = 0.0
if self.loss == 'focal' or self.loss == 'default':
loss = ModelLoss.focal_loss(input, target, alpha=self.focal_alpha, gamma=self.focal_gamma)
elif self.loss == 'cross_entropy':
loss = ModelLoss.cross_entropy(input, target)
dice = ModelLoss.dice_loss(input, target)
return loss + dice
elif self.task == 'mae':
return ModelLoss.l2(input, target)
return None
class ModelMetrics():
def __init__(self, task='segmentation', device='cpu',threshold=0.5):
if task not in ['segmentation', 'mae']:
raise ValueError(f"Unsupported task: {task}")
self.task = task
self.device = device
self.threshold = threshold
@staticmethod
def iou_score(pred, target, eps=1e-6):
intersection = (pred * target).sum(dim=(1,2,3))
union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) - intersection
return ((intersection + eps) / (union + eps)).mean()
@staticmethod
def dice_score(pred, target, eps=1e-6):
intersection = (pred * target).sum(dim=(1,2,3))
return ((2 * intersection + eps) /
(pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + eps)).mean()
@staticmethod
def pixel_accuracy(pred, target):
return (pred == target).float().mean()
@staticmethod
def pixel_precision(pred, target, eps=1e-6):
# (B, H, W) shape of pred and target
true_positive = (pred * target).sum()
predicted_positive = pred.sum()
return (true_positive.float() + eps) / (predicted_positive.float() + eps)
@staticmethod
def recall(pred, target, eps=1e-6):
true_positive = (pred * target).sum()
actual_positive = target.sum()
return (true_positive.float() + eps) / (actual_positive.float() + eps)
@staticmethod
def l1(input, target):
return torch.nn.functional.l1_loss(input, target)
@staticmethod
def l2(input, target):
return torch.nn.functional.mse_loss(input, target)
def getLabels(self):
if self.task == 'segmentation':
return ['iou', 'dice', 'accuracy', 'precision', 'recall']
elif self.task == 'mae':
return ['l1', 'l2']
return []
def __call__(self, pred: torch.Tensor, target: torch.Tensor):
"""
pred: (B, 1, H, W) => output of model BEFORE sigmoid
target: (B, H, W) or (B, 1, H, W)
"""
if self.task == 'segmentation':
# Compute metrics
pred = torch.sigmoid(pred) # Binary segmentation
pred = (pred > 0.5).float()
iou = self.iou_score(pred, target)
dice = self.dice_score(pred, target)
acc = self.pixel_accuracy(pred, target)
pres = self.pixel_precision(pred, target)
recall = self.recall(pred, target)
return iou.item(), dice.item(), acc.item(), pres.item(), recall.item()
if self.task == 'mae':
l1 = self.l1(pred, target)
l2 = self.l2(pred, target)
return l1.item(), l2.item()
def validate_mae(model, val_loader, metrics):
model.eval()
validation_accumulator = []
with torch.no_grad():
for sources, targets in val_loader:
sources = sources.cuda()
targets = targets.cuda()
pred = model(sources)
validation_accumulator.append(metrics(pred, targets))
return np.mean(validation_accumulator, axis=0)