import os import torch import numpy as np class EarlyStopping(): def __init__(self, patience=5, min_delta=0.0): self.patience = patience self.min_delta = min_delta self.best_metric = None self.best_index = 0 def check_patience(self, metric_history): if self.patience == 0: return False best_metric = min(metric_history) best_index = metric_history.index(best_metric) if len(metric_history) - best_index >= self.patience: return True return False def check_improvement(self, metric_history): if self.min_delta == 0.0: return False if len(metric_history) < 2: return False if metric_history[-2] - metric_history[-1] >= self.min_delta: return True return False def __call__(self, metric_history): if self.min_delta != 0.0: return self.check_patience(metric_history) if self.patience != 0: return self.check_improvement(metric_history) return False class SaveBestModel(): def __init__(self, folder="./", mode='min'): self.best_metric = None self.folder = folder self.mode = mode def __call__(self, model, current_metric, model_name="best.pth"): if self.best_metric is None: self.best_metric = current_metric save_path = os.path.join(self.folder, model_name) torch.save(model.state_dict(), save_path) else: if (self.mode == 'min' and current_metric < self.best_metric) or \ (self.mode == 'max' and current_metric > self.best_metric): self.best_metric = current_metric save_path = os.path.join(self.folder, model_name) torch.save(model.state_dict(), save_path) class ModelLoss(): def __init__(self, task='segmentation', loss='focal', focal_alpha=0.25, focal_gamma=2.0): if task not in ['segmentation', 'mae']: raise ValueError(f"Unsupported task: {task}") self.task = task self.loss = loss self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma if self.task == 'segmentation' and self.loss not in ['focal', 'cross_entropy']: raise ValueError(f"Unsupported loss for segmentation task: {self.loss}") @staticmethod def l1(input, target): return torch.nn.functional.l1_loss(input, target) @staticmethod def l2(input, target): return torch.nn.functional.mse_loss(input, target) @staticmethod def cross_entropy(input, target): return torch.nn.functional.cross_entropy(input, target) @staticmethod def dice_loss(input: torch.Tensor, target: torch.Tensor, eps=1e-6): input = torch.sigmoid(input) input = input.view(-1) target = target.view(-1) intersection = (input * target).sum() dice = (2. * intersection + eps) / (input.sum() + target.sum() + eps) return 1 - dice @staticmethod def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha=0.25, gamma=2.0, eps=1e-6): input = torch.sigmoid(input) input = input.view(-1) target = target.view(-1) bce_loss = torch.nn.functional.binary_cross_entropy(input, target, reduction='none') p_t = input * target + (1 - input) * (1 - target) focal_loss = alpha * (1 - p_t) ** gamma * bce_loss return focal_loss.mean() def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.task == 'segmentation': loss = 0.0 if self.loss == 'focal' or self.loss == 'default': loss = ModelLoss.focal_loss(input, target, alpha=self.focal_alpha, gamma=self.focal_gamma) elif self.loss == 'cross_entropy': loss = ModelLoss.cross_entropy(input, target) dice = ModelLoss.dice_loss(input, target) return loss + dice elif self.task == 'mae': return ModelLoss.l2(input, target) return None class ModelMetrics(): def __init__(self, task='segmentation', device='cpu',threshold=0.5): if task not in ['segmentation', 'mae']: raise ValueError(f"Unsupported task: {task}") self.task = task self.device = device self.threshold = threshold @staticmethod def iou_score(pred, target, eps=1e-6): intersection = (pred * target).sum(dim=(1,2,3)) union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) - intersection return ((intersection + eps) / (union + eps)).mean() @staticmethod def dice_score(pred, target, eps=1e-6): intersection = (pred * target).sum(dim=(1,2,3)) return ((2 * intersection + eps) / (pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + eps)).mean() @staticmethod def pixel_accuracy(pred, target): return (pred == target).float().mean() @staticmethod def pixel_precision(pred, target, eps=1e-6): # (B, H, W) shape of pred and target true_positive = (pred * target).sum() predicted_positive = pred.sum() return (true_positive.float() + eps) / (predicted_positive.float() + eps) @staticmethod def recall(pred, target, eps=1e-6): true_positive = (pred * target).sum() actual_positive = target.sum() return (true_positive.float() + eps) / (actual_positive.float() + eps) @staticmethod def l1(input, target): return torch.nn.functional.l1_loss(input, target) @staticmethod def l2(input, target): return torch.nn.functional.mse_loss(input, target) def getLabels(self): if self.task == 'segmentation': return ['iou', 'dice', 'accuracy', 'precision', 'recall'] elif self.task == 'mae': return ['l1', 'l2'] return [] def __call__(self, pred: torch.Tensor, target: torch.Tensor): """ pred: (B, 1, H, W) => output of model BEFORE sigmoid target: (B, H, W) or (B, 1, H, W) """ if self.task == 'segmentation': # Compute metrics pred = torch.sigmoid(pred) # Binary segmentation pred = (pred > 0.5).float() iou = self.iou_score(pred, target) dice = self.dice_score(pred, target) acc = self.pixel_accuracy(pred, target) pres = self.pixel_precision(pred, target) recall = self.recall(pred, target) return iou.item(), dice.item(), acc.item(), pres.item(), recall.item() if self.task == 'mae': l1 = self.l1(pred, target) l2 = self.l2(pred, target) return l1.item(), l2.item() def validate_mae(model, val_loader, metrics): model.eval() validation_accumulator = [] with torch.no_grad(): for sources, targets in val_loader: sources = sources.cuda() targets = targets.cuda() pred = model(sources) validation_accumulator.append(metrics(pred, targets)) return np.mean(validation_accumulator, axis=0)