|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
def reduce(tensor, world_size):
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
|
tensor = tensor.clone()
|
|
|
dist.all_reduce(tensor, dist.ReduceOp.SUM)
|
|
|
tensor.div_(world_size)
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
def expand(mask, num=1):
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range(num):
|
|
|
mask[..., 1:, :] = mask[..., 1:, :] | mask[..., :-1, :]
|
|
|
mask[..., :-1, :] = mask[..., :-1, :] | mask[..., 1:, :]
|
|
|
mask[..., :, 1:] = mask[..., :, 1:] | mask[..., :, :-1]
|
|
|
mask[..., :, :-1] = mask[..., :, :-1] | mask[..., :, 1:]
|
|
|
return mask
|
|
|
|
|
|
|
|
|
def differentiate(mask):
|
|
|
|
|
|
|
|
|
|
|
|
diff = torch.zeros_like(mask).bool()
|
|
|
diff_y = mask[..., 1:, :] != mask[..., :-1, :]
|
|
|
diff_x = mask[..., :, 1:] != mask[..., :, :-1]
|
|
|
diff[..., 1:, :] = diff[..., 1:, :] | diff_y
|
|
|
diff[..., :-1, :] = diff[..., :-1, :] | diff_y
|
|
|
diff[..., :, 1:] = diff[..., :, 1:] | diff_x
|
|
|
diff[..., :, :-1] = diff[..., :, :-1] | diff_x
|
|
|
return diff
|
|
|
|
|
|
|
|
|
def sample_points(step, boundaries, num_samples):
|
|
|
if boundaries.ndim == 3:
|
|
|
points = []
|
|
|
for boundaries_k in boundaries:
|
|
|
points_k = sample_points(step, boundaries_k, num_samples)
|
|
|
points.append(points_k)
|
|
|
points = torch.stack(points)
|
|
|
else:
|
|
|
H, W = boundaries.shape
|
|
|
boundary_points, _ = sample_mask_points(step, boundaries, num_samples // 2)
|
|
|
num_boundary_points = boundary_points.shape[0]
|
|
|
num_random_points = num_samples - num_boundary_points
|
|
|
random_points = sample_random_points(step, H, W, num_random_points)
|
|
|
random_points = random_points.to(boundary_points.device)
|
|
|
points = torch.cat((boundary_points, random_points), dim=0)
|
|
|
return points
|
|
|
|
|
|
|
|
|
def sample_mask_points(step, mask, num_points):
|
|
|
num_nonzero = int(mask.sum())
|
|
|
i, j = torch.nonzero(mask, as_tuple=True)
|
|
|
if num_points < num_nonzero:
|
|
|
sample = np.random.choice(num_nonzero, size=num_points, replace=False)
|
|
|
i, j = i[sample], j[sample]
|
|
|
t = torch.ones_like(i) * step
|
|
|
x, y = j, i
|
|
|
points = torch.stack((t, x, y), dim=-1)
|
|
|
return points.float(), (i, j)
|
|
|
|
|
|
|
|
|
def sample_random_points(step, height, width, num_points):
|
|
|
x = torch.randint(width, size=[num_points])
|
|
|
y = torch.randint(height, size=[num_points])
|
|
|
t = torch.ones(num_points) * step
|
|
|
points = torch.stack((t, x, y), dim=-1)
|
|
|
return points.float()
|
|
|
|
|
|
|
|
|
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
|
|
|
H, W = height, width
|
|
|
S = shape if shape else []
|
|
|
if align_corners:
|
|
|
x = torch.linspace(0, 1, W, device=device)
|
|
|
y = torch.linspace(0, 1, H, device=device)
|
|
|
if not normalize:
|
|
|
x = x * (W - 1)
|
|
|
y = y * (H - 1)
|
|
|
else:
|
|
|
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
|
|
|
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
|
|
|
if not normalize:
|
|
|
x = x * W
|
|
|
y = y * H
|
|
|
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
|
|
|
x = x.view(*x_view).expand(*exp)
|
|
|
y = y.view(*y_view).expand(*exp)
|
|
|
grid = torch.stack([x, y], dim=-1)
|
|
|
if dtype == "numpy":
|
|
|
grid = grid.numpy()
|
|
|
return grid
|
|
|
|
|
|
|
|
|
def get_sobel_kernel(kernel_size):
|
|
|
K = kernel_size
|
|
|
sobel = torch.tensor(list(range(K))) - K // 2
|
|
|
sobel_x, sobel_y = sobel.view(-1, 1), sobel.view(1, -1)
|
|
|
sum_xy = sobel_x ** 2 + sobel_y ** 2
|
|
|
sum_xy[sum_xy == 0] = 1
|
|
|
sobel_x, sobel_y = sobel_x / sum_xy, sobel_y / sum_xy
|
|
|
sobel_kernel = torch.stack([sobel_x.unsqueeze(0), sobel_y.unsqueeze(0)], dim=0)
|
|
|
return sobel_kernel
|
|
|
|
|
|
|
|
|
def to_device(data, device):
|
|
|
data = {k: v.to(device) for k, v in data.items()}
|
|
|
return data
|
|
|
|
|
|
|
|
|
def get_alpha_consistency(bflow, fflow, thresh_1=0.01, thresh_2=0.5, thresh_mul=1):
|
|
|
norm = lambda x: x.pow(2).sum(dim=-1).sqrt()
|
|
|
B, H, W, C = bflow.shape
|
|
|
|
|
|
mag = norm(fflow) + norm(bflow)
|
|
|
grid = get_grid(H, W, shape=[B], device=fflow.device)
|
|
|
grid[..., 0] = grid[..., 0] + bflow[..., 0] / (W - 1)
|
|
|
grid[..., 1] = grid[..., 1] + bflow[..., 1] / (H - 1)
|
|
|
grid = grid * 2 - 1
|
|
|
fflow_warped = torch.nn.functional.grid_sample(fflow.permute(0, 3, 1, 2), grid, mode="bilinear", align_corners=True)
|
|
|
flow_diff = bflow + fflow_warped.permute(0, 2, 3, 1)
|
|
|
occ_thresh = thresh_1 * mag + thresh_2
|
|
|
occ_thresh = occ_thresh * thresh_mul
|
|
|
alpha = norm(flow_diff) < occ_thresh
|
|
|
alpha = alpha.float()
|
|
|
return alpha |