Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
raw
history blame
4.82 kB
import numpy as np
import torch
import torch.distributed as dist
def reduce(tensor, world_size):
if isinstance(tensor, torch.Tensor):
tensor = tensor.clone()
dist.all_reduce(tensor, dist.ReduceOp.SUM)
tensor.div_(world_size)
return tensor
def expand(mask, num=1):
# mask: ... H W
# -----------------
# mask: ... H W
for _ in range(num):
mask[..., 1:, :] = mask[..., 1:, :] | mask[..., :-1, :]
mask[..., :-1, :] = mask[..., :-1, :] | mask[..., 1:, :]
mask[..., :, 1:] = mask[..., :, 1:] | mask[..., :, :-1]
mask[..., :, :-1] = mask[..., :, :-1] | mask[..., :, 1:]
return mask
def differentiate(mask):
# mask: ... H W
# -----------------
# diff: ... H W
diff = torch.zeros_like(mask).bool()
diff_y = mask[..., 1:, :] != mask[..., :-1, :]
diff_x = mask[..., :, 1:] != mask[..., :, :-1]
diff[..., 1:, :] = diff[..., 1:, :] | diff_y
diff[..., :-1, :] = diff[..., :-1, :] | diff_y
diff[..., :, 1:] = diff[..., :, 1:] | diff_x
diff[..., :, :-1] = diff[..., :, :-1] | diff_x
return diff
def sample_points(step, boundaries, num_samples):
if boundaries.ndim == 3:
points = []
for boundaries_k in boundaries:
points_k = sample_points(step, boundaries_k, num_samples)
points.append(points_k)
points = torch.stack(points)
else:
H, W = boundaries.shape
boundary_points, _ = sample_mask_points(step, boundaries, num_samples // 2)
num_boundary_points = boundary_points.shape[0]
num_random_points = num_samples - num_boundary_points
random_points = sample_random_points(step, H, W, num_random_points)
random_points = random_points.to(boundary_points.device)
points = torch.cat((boundary_points, random_points), dim=0)
return points
def sample_mask_points(step, mask, num_points):
num_nonzero = int(mask.sum())
i, j = torch.nonzero(mask, as_tuple=True)
if num_points < num_nonzero:
sample = np.random.choice(num_nonzero, size=num_points, replace=False)
i, j = i[sample], j[sample]
t = torch.ones_like(i) * step
x, y = j, i
points = torch.stack((t, x, y), dim=-1) # [num_points, 3]
return points.float(), (i, j)
def sample_random_points(step, height, width, num_points):
x = torch.randint(width, size=[num_points])
y = torch.randint(height, size=[num_points])
t = torch.ones(num_points) * step
points = torch.stack((t, x, y), dim=-1) # [num_points, 3]
return points.float()
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
H, W = height, width
S = shape if shape else []
if align_corners:
x = torch.linspace(0, 1, W, device=device)
y = torch.linspace(0, 1, H, device=device)
if not normalize:
x = x * (W - 1)
y = y * (H - 1)
else:
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
if not normalize:
x = x * W
y = y * H
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
x = x.view(*x_view).expand(*exp)
y = y.view(*y_view).expand(*exp)
grid = torch.stack([x, y], dim=-1)
if dtype == "numpy":
grid = grid.numpy()
return grid
def get_sobel_kernel(kernel_size):
K = kernel_size
sobel = torch.tensor(list(range(K))) - K // 2
sobel_x, sobel_y = sobel.view(-1, 1), sobel.view(1, -1)
sum_xy = sobel_x ** 2 + sobel_y ** 2
sum_xy[sum_xy == 0] = 1
sobel_x, sobel_y = sobel_x / sum_xy, sobel_y / sum_xy
sobel_kernel = torch.stack([sobel_x.unsqueeze(0), sobel_y.unsqueeze(0)], dim=0)
return sobel_kernel
def to_device(data, device):
data = {k: v.to(device) for k, v in data.items()}
return data
def get_alpha_consistency(bflow, fflow, thresh_1=0.01, thresh_2=0.5, thresh_mul=1):
norm = lambda x: x.pow(2).sum(dim=-1).sqrt()
B, H, W, C = bflow.shape
mag = norm(fflow) + norm(bflow)
grid = get_grid(H, W, shape=[B], device=fflow.device)
grid[..., 0] = grid[..., 0] + bflow[..., 0] / (W - 1)
grid[..., 1] = grid[..., 1] + bflow[..., 1] / (H - 1)
grid = grid * 2 - 1
fflow_warped = torch.nn.functional.grid_sample(fflow.permute(0, 3, 1, 2), grid, mode="bilinear", align_corners=True)
flow_diff = bflow + fflow_warped.permute(0, 2, 3, 1)
occ_thresh = thresh_1 * mag + thresh_2
occ_thresh = occ_thresh * thresh_mul
alpha = norm(flow_diff) < occ_thresh
alpha = alpha.float()
return alpha