| import numpy as np
|
| import torch
|
| import torch.distributed as dist
|
|
|
|
|
| def reduce(tensor, world_size):
|
| if isinstance(tensor, torch.Tensor):
|
| tensor = tensor.clone()
|
| dist.all_reduce(tensor, dist.ReduceOp.SUM)
|
| tensor.div_(world_size)
|
| return tensor
|
|
|
|
|
| def expand(mask, num=1):
|
|
|
|
|
|
|
| for _ in range(num):
|
| mask[..., 1:, :] = mask[..., 1:, :] | mask[..., :-1, :]
|
| mask[..., :-1, :] = mask[..., :-1, :] | mask[..., 1:, :]
|
| mask[..., :, 1:] = mask[..., :, 1:] | mask[..., :, :-1]
|
| mask[..., :, :-1] = mask[..., :, :-1] | mask[..., :, 1:]
|
| return mask
|
|
|
|
|
| def differentiate(mask):
|
|
|
|
|
|
|
| diff = torch.zeros_like(mask).bool()
|
| diff_y = mask[..., 1:, :] != mask[..., :-1, :]
|
| diff_x = mask[..., :, 1:] != mask[..., :, :-1]
|
| diff[..., 1:, :] = diff[..., 1:, :] | diff_y
|
| diff[..., :-1, :] = diff[..., :-1, :] | diff_y
|
| diff[..., :, 1:] = diff[..., :, 1:] | diff_x
|
| diff[..., :, :-1] = diff[..., :, :-1] | diff_x
|
| return diff
|
|
|
|
|
| def sample_points(step, boundaries, num_samples):
|
| if boundaries.ndim == 3:
|
| points = []
|
| for boundaries_k in boundaries:
|
| points_k = sample_points(step, boundaries_k, num_samples)
|
| points.append(points_k)
|
| points = torch.stack(points)
|
| else:
|
| H, W = boundaries.shape
|
| boundary_points, _ = sample_mask_points(step, boundaries, num_samples // 2)
|
| num_boundary_points = boundary_points.shape[0]
|
| num_random_points = num_samples - num_boundary_points
|
| random_points = sample_random_points(step, H, W, num_random_points)
|
| random_points = random_points.to(boundary_points.device)
|
| points = torch.cat((boundary_points, random_points), dim=0)
|
| return points
|
|
|
|
|
| def sample_mask_points(step, mask, num_points):
|
| num_nonzero = int(mask.sum())
|
| i, j = torch.nonzero(mask, as_tuple=True)
|
| if num_points < num_nonzero:
|
| sample = np.random.choice(num_nonzero, size=num_points, replace=False)
|
| i, j = i[sample], j[sample]
|
| t = torch.ones_like(i) * step
|
| x, y = j, i
|
| points = torch.stack((t, x, y), dim=-1)
|
| return points.float(), (i, j)
|
|
|
|
|
| def sample_random_points(step, height, width, num_points):
|
| x = torch.randint(width, size=[num_points])
|
| y = torch.randint(height, size=[num_points])
|
| t = torch.ones(num_points) * step
|
| points = torch.stack((t, x, y), dim=-1)
|
| return points.float()
|
|
|
|
|
| def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
|
| H, W = height, width
|
| S = shape if shape else []
|
| if align_corners:
|
| x = torch.linspace(0, 1, W, device=device)
|
| y = torch.linspace(0, 1, H, device=device)
|
| if not normalize:
|
| x = x * (W - 1)
|
| y = y * (H - 1)
|
| else:
|
| x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
|
| y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
|
| if not normalize:
|
| x = x * W
|
| y = y * H
|
| x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
|
| x = x.view(*x_view).expand(*exp)
|
| y = y.view(*y_view).expand(*exp)
|
| grid = torch.stack([x, y], dim=-1)
|
| if dtype == "numpy":
|
| grid = grid.numpy()
|
| return grid
|
|
|
|
|
| def get_sobel_kernel(kernel_size):
|
| K = kernel_size
|
| sobel = torch.tensor(list(range(K))) - K // 2
|
| sobel_x, sobel_y = sobel.view(-1, 1), sobel.view(1, -1)
|
| sum_xy = sobel_x ** 2 + sobel_y ** 2
|
| sum_xy[sum_xy == 0] = 1
|
| sobel_x, sobel_y = sobel_x / sum_xy, sobel_y / sum_xy
|
| sobel_kernel = torch.stack([sobel_x.unsqueeze(0), sobel_y.unsqueeze(0)], dim=0)
|
| return sobel_kernel
|
|
|
|
|
| def to_device(data, device):
|
| data = {k: v.to(device) for k, v in data.items()}
|
| return data
|
|
|
|
|
| def get_alpha_consistency(bflow, fflow, thresh_1=0.01, thresh_2=0.5, thresh_mul=1):
|
| norm = lambda x: x.pow(2).sum(dim=-1).sqrt()
|
| B, H, W, C = bflow.shape
|
|
|
| mag = norm(fflow) + norm(bflow)
|
| grid = get_grid(H, W, shape=[B], device=fflow.device)
|
| grid[..., 0] = grid[..., 0] + bflow[..., 0] / (W - 1)
|
| grid[..., 1] = grid[..., 1] + bflow[..., 1] / (H - 1)
|
| grid = grid * 2 - 1
|
| fflow_warped = torch.nn.functional.grid_sample(fflow.permute(0, 3, 1, 2), grid, mode="bilinear", align_corners=True)
|
| flow_diff = bflow + fflow_warped.permute(0, 2, 3, 1)
|
| occ_thresh = thresh_1 * mag + thresh_2
|
| occ_thresh = occ_thresh * thresh_mul
|
| alpha = norm(flow_diff) < occ_thresh
|
| alpha = alpha.float()
|
| return alpha |