File size: 2,182 Bytes
8f72b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch


def iou_torch(inst1, inst2):
    inter = torch.logical_and(inst1, inst2).sum().float()
    union = torch.logical_or(inst1, inst2).sum().float()
    if union == 0:
        return torch.tensor(float('nan'))
    return inter / union

def get_instances_torch(mask):
    # 返回所有非背景的 instance mask(布尔型)
    ids = torch.unique(mask)
    return [(mask == i) for i in ids if i != 0]

def compute_instance_miou(pred_mask, gt_mask):
    # pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型
    pred_instances = get_instances_torch(pred_mask)
    gt_instances = get_instances_torch(gt_mask)

    ious = []
    for gt in gt_instances:
        best_iou = torch.tensor(0.0).to(pred_mask.device)
        for pred in pred_instances:
            i = iou_torch(pred, gt)
            if i > best_iou:
                best_iou = i
        ious.append(best_iou)
    
    # 处理空情况
    if len(ious) == 0:
        return torch.tensor(float('nan'))
    return torch.nanmean(torch.stack(ious))

from torch import Tensor


def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)