| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def mesh_grid(B, H, W): |
| | |
| | x_base = torch.arange(0, W).repeat(B, H, 1) |
| | y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) |
| |
|
| | base_grid = torch.stack([x_base, y_base], 1) |
| | return base_grid |
| |
|
| |
|
| | def norm_grid(v_grid): |
| | _, _, H, W = v_grid.size() |
| |
|
| | |
| | v_grid_norm = torch.zeros_like(v_grid) |
| | v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0 |
| | v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0 |
| | return v_grid_norm.permute(0, 2, 3, 1) |
| |
|
| |
|
| | def get_corresponding_map(data): |
| | """ |
| | |
| | :param data: unnormalized coordinates Bx2xHxW |
| | :return: Bx1xHxW |
| | """ |
| | B, _, H, W = data.size() |
| |
|
| | |
| | |
| |
|
| | x = data[:, 0, :, :].view(B, -1) |
| | y = data[:, 1, :, :].view(B, -1) |
| |
|
| | |
| | |
| |
|
| | x1 = torch.floor(x) |
| | x_floor = x1.clamp(0, W - 1) |
| | y1 = torch.floor(y) |
| | y_floor = y1.clamp(0, H - 1) |
| | x0 = x1 + 1 |
| | x_ceil = x0.clamp(0, W - 1) |
| | y0 = y1 + 1 |
| | y_ceil = y0.clamp(0, H - 1) |
| |
|
| | x_ceil_out = x0 != x_ceil |
| | y_ceil_out = y0 != y_ceil |
| | x_floor_out = x1 != x_floor |
| | y_floor_out = y1 != y_floor |
| | invalid = torch.cat([x_ceil_out | y_ceil_out, |
| | x_ceil_out | y_floor_out, |
| | x_floor_out | y_ceil_out, |
| | x_floor_out | y_floor_out], dim=1) |
| |
|
| | |
| | corresponding_map = torch.zeros(B, H * W).type_as(data) |
| | indices = torch.cat([x_ceil + y_ceil * W, |
| | x_ceil + y_floor * W, |
| | x_floor + y_ceil * W, |
| | x_floor + y_floor * W], 1).long() |
| | values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), |
| | (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), |
| | (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), |
| | (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))], |
| | 1) |
| | |
| |
|
| | values[invalid] = 0 |
| |
|
| | corresponding_map.scatter_add_(1, indices, values) |
| | |
| | corresponding_map = corresponding_map.view(B, H, W) |
| |
|
| | return corresponding_map.unsqueeze(1) |
| |
|
| |
|
| | def flow_warp(x, flow12, pad='border', mode='bilinear'): |
| | B, _, H, W = x.size() |
| |
|
| | base_grid = mesh_grid(B, H, W).type_as(x) |
| |
|
| | v_grid = norm_grid(base_grid + flow12) |
| | im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad) |
| | return im1_recons |
| |
|
| |
|
| | def get_occu_mask_bidirection(flow12, flow21, mask, scale=1, bias=0.5): |
| | flow21_warped = flow_warp(flow21, flow12, pad='zeros') |
| | flow12_diff = flow12 + flow21_warped |
| | mag = (flow12 * flow12).sum(1, keepdim=True) + \ |
| | (flow21_warped * flow21_warped).sum(1, keepdim=True) |
| | occ_thresh = scale * mag + bias |
| | occu = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh |
| | |
| | |
| | return occu |
| |
|
| |
|
| | def get_occu_mask_backward(flow21, th=0.2): |
| | B, _, H, W = flow21.size() |
| | base_grid = mesh_grid(B, H, W).type_as(flow21) |
| |
|
| | corr_map = get_corresponding_map(base_grid + flow21) |
| | occu_mask = corr_map.clamp(min=0., max=1.) < th |
| | return occu_mask.float() |
| |
|
| | def get_ssv_weights(cycle_corres, input, mask, scale_value): |
| | vgrid = (cycle_corres.mul(scale_value) - 1.0).permute(0,2,3,1) |
| | new_input = nn.functional.grid_sample(input, vgrid, align_corners=True, padding_mode='border') |
| | color_diff = (((input[:, :3, :, :] - new_input[:, :3, :, :]) / 255.0) ** 2).sum(1, keepdim=True) |
| | depth_diff = (((input[:, 3:, :, :] - new_input[:, 3:, :, :])) ** 2).sum(1, keepdim=True) |
| | diff = torch.mul(mask.float(), color_diff + depth_diff) |
| | return torch.exp(-diff) |
| |
|