Spaces:
Sleeping
Sleeping
| import torch | |
| def iou_torch(inst1, inst2): | |
| inter = torch.logical_and(inst1, inst2).sum().float() | |
| union = torch.logical_or(inst1, inst2).sum().float() | |
| if union == 0: | |
| return torch.tensor(float('nan')) | |
| return inter / union | |
| def get_instances_torch(mask): | |
| # 返回所有非背景的 instance mask(布尔型) | |
| ids = torch.unique(mask) | |
| return [(mask == i) for i in ids if i != 0] | |
| def compute_instance_miou(pred_mask, gt_mask): | |
| # pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型 | |
| pred_instances = get_instances_torch(pred_mask) | |
| gt_instances = get_instances_torch(gt_mask) | |
| ious = [] | |
| for gt in gt_instances: | |
| best_iou = torch.tensor(0.0).to(pred_mask.device) | |
| for pred in pred_instances: | |
| i = iou_torch(pred, gt) | |
| if i > best_iou: | |
| best_iou = i | |
| ious.append(best_iou) | |
| # 处理空情况 | |
| if len(ious) == 0: | |
| return torch.tensor(float('nan')) | |
| return torch.nanmean(torch.stack(ious)) | |
| from torch import Tensor | |
| def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): | |
| # Average of Dice coefficient for all batches, or for a single mask | |
| assert input.size() == target.size() | |
| assert input.dim() == 3 or not reduce_batch_first | |
| sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) | |
| inter = 2 * (input * target).sum(dim=sum_dim) | |
| sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) | |
| sets_sum = torch.where(sets_sum == 0, inter, sets_sum) | |
| dice = (inter + epsilon) / (sets_sum + epsilon) | |
| return dice.mean() | |
| def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): | |
| # Average of Dice coefficient for all classes | |
| return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) | |
| def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): | |
| # Dice loss (objective to minimize) between 0 and 1 | |
| fn = multiclass_dice_coeff if multiclass else dice_coeff | |
| return 1 - fn(input, target, reduce_batch_first=True) | |