aliensmn's picture
Mirror from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation
61029c7 verified
############### DISTANCE TRANSFORM ###############
# img tensor: (bs,h,w) or (bs,1,h,w)
# returns same shape
# expects white lines, black whitespace
# defaults to diameter if empty image
from .utils import cuda_kernel, cuda_launch, cuda_int32, cuda_float32
import torch
_batch_edt_kernel = (
"kernel_dt",
"""
extern "C" __global__ void kernel_dt(
const int bs,
const int h,
const int w,
const float diam2,
float* data,
float* output
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= bs*h*w) {
return;
}
int pb = idx / (h*w);
int pi = (idx - h*w*pb) / w;
int pj = (idx - h*w*pb - w*pi);
float cost;
float mincost = diam2;
for (int j = 0; j < w; j++) {
cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j);
if (cost < mincost) {
mincost = cost;
}
}
output[idx] = mincost;
return;
}
""",
)
_batch_edt = None
def batch_edt(img, block=1024):
# must initialize cuda/cupy after forking
global _batch_edt
if _batch_edt is None:
_batch_edt = cuda_launch(*_batch_edt_kernel)
# bookkeeppingg
if len(img.shape) == 4:
assert img.shape[1] == 1
img = img.squeeze(1)
expand = True
else:
expand = False
bs, h, w = img.shape
diam2 = h**2 + w**2
odtype = img.dtype
grid = (img.nelement() + block - 1) // block
# cupy implementation
if img.is_cuda:
# first pass, y-axis
data = ((1 - img.type(torch.float32)) * diam2).contiguous()
intermed = torch.zeros_like(data)
_batch_edt(
grid=(grid, 1, 1),
block=(block, 1, 1), # < 1024
args=[
cuda_int32(bs),
cuda_int32(h),
cuda_int32(w),
cuda_float32(diam2),
data.data_ptr(),
intermed.data_ptr(),
],
)
# second pass, x-axis
intermed = intermed.permute(0, 2, 1).contiguous()
out = torch.zeros_like(intermed)
_batch_edt(
grid=(grid, 1, 1),
block=(block, 1, 1),
args=[
cuda_int32(bs),
cuda_int32(w),
cuda_int32(h),
cuda_float32(diam2),
intermed.data_ptr(),
out.data_ptr(),
],
)
ans = out.permute(0, 2, 1).sqrt()
ans = ans.type(odtype) if odtype != ans.dtype else ans
# default to scipy cpu implementation
else:
raise NotImplementedError()
""" sums = img.sum(dim=(1, 2))
ans = torch.tensor(
np.stack(
[
scipy.ndimage.morphology.distance_transform_edt(i)
if s != 0
else np.ones_like(i) # change scipy behavior for empty image
* np.sqrt(diam2)
for i, s in zip(1 - img, sums)
]
),
dtype=odtype,
) """
if expand:
ans = ans.unsqueeze(1)
return ans
__all__ = ["batch_edt"]