File size: 3,226 Bytes
61029c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
############### DISTANCE TRANSFORM ###############
# img tensor: (bs,h,w) or (bs,1,h,w)
# returns same shape
# expects white lines, black whitespace
# defaults to diameter if empty image
from .utils import cuda_kernel, cuda_launch, cuda_int32, cuda_float32
import torch

_batch_edt_kernel = (
    "kernel_dt",
    """
    extern "C" __global__ void kernel_dt(
        const int bs,
        const int h,
        const int w,
        const float diam2,
        float* data,
        float* output
    ) {
        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        if (idx >= bs*h*w) {
            return;
        }
        int pb = idx / (h*w);
        int pi = (idx - h*w*pb) / w;
        int pj = (idx - h*w*pb - w*pi);

        float cost;
        float mincost = diam2;
        for (int j = 0; j < w; j++) {
            cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j);
            if (cost < mincost) {
                mincost = cost;
            }
        }
        output[idx] = mincost;
        return;
    }
""",
)
_batch_edt = None


def batch_edt(img, block=1024):
    # must initialize cuda/cupy after forking
    global _batch_edt
    if _batch_edt is None:
        _batch_edt = cuda_launch(*_batch_edt_kernel)

    # bookkeeppingg
    if len(img.shape) == 4:
        assert img.shape[1] == 1
        img = img.squeeze(1)
        expand = True
    else:
        expand = False
    bs, h, w = img.shape
    diam2 = h**2 + w**2
    odtype = img.dtype
    grid = (img.nelement() + block - 1) // block

    # cupy implementation
    if img.is_cuda:
        # first pass, y-axis
        data = ((1 - img.type(torch.float32)) * diam2).contiguous()
        intermed = torch.zeros_like(data)
        _batch_edt(
            grid=(grid, 1, 1),
            block=(block, 1, 1),  # < 1024
            args=[
                cuda_int32(bs),
                cuda_int32(h),
                cuda_int32(w),
                cuda_float32(diam2),
                data.data_ptr(),
                intermed.data_ptr(),
            ],
        )

        # second pass, x-axis
        intermed = intermed.permute(0, 2, 1).contiguous()
        out = torch.zeros_like(intermed)
        _batch_edt(
            grid=(grid, 1, 1),
            block=(block, 1, 1),
            args=[
                cuda_int32(bs),
                cuda_int32(w),
                cuda_int32(h),
                cuda_float32(diam2),
                intermed.data_ptr(),
                out.data_ptr(),
            ],
        )
        ans = out.permute(0, 2, 1).sqrt()
        ans = ans.type(odtype) if odtype != ans.dtype else ans

    # default to scipy cpu implementation
    else:
        raise NotImplementedError()
        """ sums = img.sum(dim=(1, 2))
        ans = torch.tensor(
            np.stack(
                [
                    scipy.ndimage.morphology.distance_transform_edt(i)
                    if s != 0
                    else np.ones_like(i)  # change scipy behavior for empty image
                    * np.sqrt(diam2)
                    for i, s in zip(1 - img, sums)
                ]
            ),
            dtype=odtype,
        ) """

    if expand:
        ans = ans.unsqueeze(1)
    return ans

__all__ = ["batch_edt"]