|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .utils import cuda_kernel, cuda_launch, cuda_int32, cuda_float32 |
|
|
import torch |
|
|
|
|
|
_batch_edt_kernel = ( |
|
|
"kernel_dt", |
|
|
""" |
|
|
extern "C" __global__ void kernel_dt( |
|
|
const int bs, |
|
|
const int h, |
|
|
const int w, |
|
|
const float diam2, |
|
|
float* data, |
|
|
float* output |
|
|
) { |
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
if (idx >= bs*h*w) { |
|
|
return; |
|
|
} |
|
|
int pb = idx / (h*w); |
|
|
int pi = (idx - h*w*pb) / w; |
|
|
int pj = (idx - h*w*pb - w*pi); |
|
|
|
|
|
float cost; |
|
|
float mincost = diam2; |
|
|
for (int j = 0; j < w; j++) { |
|
|
cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j); |
|
|
if (cost < mincost) { |
|
|
mincost = cost; |
|
|
} |
|
|
} |
|
|
output[idx] = mincost; |
|
|
return; |
|
|
} |
|
|
""", |
|
|
) |
|
|
_batch_edt = None |
|
|
|
|
|
|
|
|
def batch_edt(img, block=1024): |
|
|
|
|
|
global _batch_edt |
|
|
if _batch_edt is None: |
|
|
_batch_edt = cuda_launch(*_batch_edt_kernel) |
|
|
|
|
|
|
|
|
if len(img.shape) == 4: |
|
|
assert img.shape[1] == 1 |
|
|
img = img.squeeze(1) |
|
|
expand = True |
|
|
else: |
|
|
expand = False |
|
|
bs, h, w = img.shape |
|
|
diam2 = h**2 + w**2 |
|
|
odtype = img.dtype |
|
|
grid = (img.nelement() + block - 1) // block |
|
|
|
|
|
|
|
|
if img.is_cuda: |
|
|
|
|
|
data = ((1 - img.type(torch.float32)) * diam2).contiguous() |
|
|
intermed = torch.zeros_like(data) |
|
|
_batch_edt( |
|
|
grid=(grid, 1, 1), |
|
|
block=(block, 1, 1), |
|
|
args=[ |
|
|
cuda_int32(bs), |
|
|
cuda_int32(h), |
|
|
cuda_int32(w), |
|
|
cuda_float32(diam2), |
|
|
data.data_ptr(), |
|
|
intermed.data_ptr(), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
intermed = intermed.permute(0, 2, 1).contiguous() |
|
|
out = torch.zeros_like(intermed) |
|
|
_batch_edt( |
|
|
grid=(grid, 1, 1), |
|
|
block=(block, 1, 1), |
|
|
args=[ |
|
|
cuda_int32(bs), |
|
|
cuda_int32(w), |
|
|
cuda_int32(h), |
|
|
cuda_float32(diam2), |
|
|
intermed.data_ptr(), |
|
|
out.data_ptr(), |
|
|
], |
|
|
) |
|
|
ans = out.permute(0, 2, 1).sqrt() |
|
|
ans = ans.type(odtype) if odtype != ans.dtype else ans |
|
|
|
|
|
|
|
|
else: |
|
|
raise NotImplementedError() |
|
|
""" sums = img.sum(dim=(1, 2)) |
|
|
ans = torch.tensor( |
|
|
np.stack( |
|
|
[ |
|
|
scipy.ndimage.morphology.distance_transform_edt(i) |
|
|
if s != 0 |
|
|
else np.ones_like(i) # change scipy behavior for empty image |
|
|
* np.sqrt(diam2) |
|
|
for i, s in zip(1 - img, sums) |
|
|
] |
|
|
), |
|
|
dtype=odtype, |
|
|
) """ |
|
|
|
|
|
if expand: |
|
|
ans = ans.unsqueeze(1) |
|
|
return ans |
|
|
|
|
|
__all__ = ["batch_edt"] |