File size: 4,815 Bytes
ef296aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import numpy as np
import torch
import torch.distributed as dist
def reduce(tensor, world_size):
if isinstance(tensor, torch.Tensor):
tensor = tensor.clone()
dist.all_reduce(tensor, dist.ReduceOp.SUM)
tensor.div_(world_size)
return tensor
def expand(mask, num=1):
# mask: ... H W
# -----------------
# mask: ... H W
for _ in range(num):
mask[..., 1:, :] = mask[..., 1:, :] | mask[..., :-1, :]
mask[..., :-1, :] = mask[..., :-1, :] | mask[..., 1:, :]
mask[..., :, 1:] = mask[..., :, 1:] | mask[..., :, :-1]
mask[..., :, :-1] = mask[..., :, :-1] | mask[..., :, 1:]
return mask
def differentiate(mask):
# mask: ... H W
# -----------------
# diff: ... H W
diff = torch.zeros_like(mask).bool()
diff_y = mask[..., 1:, :] != mask[..., :-1, :]
diff_x = mask[..., :, 1:] != mask[..., :, :-1]
diff[..., 1:, :] = diff[..., 1:, :] | diff_y
diff[..., :-1, :] = diff[..., :-1, :] | diff_y
diff[..., :, 1:] = diff[..., :, 1:] | diff_x
diff[..., :, :-1] = diff[..., :, :-1] | diff_x
return diff
def sample_points(step, boundaries, num_samples):
if boundaries.ndim == 3:
points = []
for boundaries_k in boundaries:
points_k = sample_points(step, boundaries_k, num_samples)
points.append(points_k)
points = torch.stack(points)
else:
H, W = boundaries.shape
boundary_points, _ = sample_mask_points(step, boundaries, num_samples // 2)
num_boundary_points = boundary_points.shape[0]
num_random_points = num_samples - num_boundary_points
random_points = sample_random_points(step, H, W, num_random_points)
random_points = random_points.to(boundary_points.device)
points = torch.cat((boundary_points, random_points), dim=0)
return points
def sample_mask_points(step, mask, num_points):
num_nonzero = int(mask.sum())
i, j = torch.nonzero(mask, as_tuple=True)
if num_points < num_nonzero:
sample = np.random.choice(num_nonzero, size=num_points, replace=False)
i, j = i[sample], j[sample]
t = torch.ones_like(i) * step
x, y = j, i
points = torch.stack((t, x, y), dim=-1) # [num_points, 3]
return points.float(), (i, j)
def sample_random_points(step, height, width, num_points):
x = torch.randint(width, size=[num_points])
y = torch.randint(height, size=[num_points])
t = torch.ones(num_points) * step
points = torch.stack((t, x, y), dim=-1) # [num_points, 3]
return points.float()
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
H, W = height, width
S = shape if shape else []
if align_corners:
x = torch.linspace(0, 1, W, device=device)
y = torch.linspace(0, 1, H, device=device)
if not normalize:
x = x * (W - 1)
y = y * (H - 1)
else:
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
if not normalize:
x = x * W
y = y * H
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
x = x.view(*x_view).expand(*exp)
y = y.view(*y_view).expand(*exp)
grid = torch.stack([x, y], dim=-1)
if dtype == "numpy":
grid = grid.numpy()
return grid
def get_sobel_kernel(kernel_size):
K = kernel_size
sobel = torch.tensor(list(range(K))) - K // 2
sobel_x, sobel_y = sobel.view(-1, 1), sobel.view(1, -1)
sum_xy = sobel_x ** 2 + sobel_y ** 2
sum_xy[sum_xy == 0] = 1
sobel_x, sobel_y = sobel_x / sum_xy, sobel_y / sum_xy
sobel_kernel = torch.stack([sobel_x.unsqueeze(0), sobel_y.unsqueeze(0)], dim=0)
return sobel_kernel
def to_device(data, device):
data = {k: v.to(device) for k, v in data.items()}
return data
def get_alpha_consistency(bflow, fflow, thresh_1=0.01, thresh_2=0.5, thresh_mul=1):
norm = lambda x: x.pow(2).sum(dim=-1).sqrt()
B, H, W, C = bflow.shape
mag = norm(fflow) + norm(bflow)
grid = get_grid(H, W, shape=[B], device=fflow.device)
grid[..., 0] = grid[..., 0] + bflow[..., 0] / (W - 1)
grid[..., 1] = grid[..., 1] + bflow[..., 1] / (H - 1)
grid = grid * 2 - 1
fflow_warped = torch.nn.functional.grid_sample(fflow.permute(0, 3, 1, 2), grid, mode="bilinear", align_corners=True)
flow_diff = bflow + fflow_warped.permute(0, 2, 3, 1)
occ_thresh = thresh_1 * mag + thresh_2
occ_thresh = occ_thresh * thresh_mul
alpha = norm(flow_diff) < occ_thresh
alpha = alpha.float()
return alpha |