| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def coords_grid(b, h, w, homogeneous=False, device=None): |
| | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) |
| |
|
| | stacks = [x, y] |
| |
|
| | if homogeneous: |
| | ones = torch.ones_like(x) |
| | stacks.append(ones) |
| |
|
| | grid = torch.stack(stacks, dim=0).float() |
| |
|
| | grid = grid[None].repeat(b, 1, 1, 1) |
| |
|
| | if device is not None: |
| | grid = grid.to(device) |
| |
|
| | return grid |
| |
|
| |
|
| | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): |
| | assert device is not None |
| |
|
| | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), |
| | torch.linspace(h_min, h_max, len_h, device=device)], |
| | ) |
| | grid = torch.stack((x, y), -1).transpose(0, 1).float() |
| |
|
| | return grid |
| |
|
| |
|
| | def normalize_coords(coords, h, w): |
| | |
| | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) |
| | return (coords - c) / c |
| |
|
| |
|
| | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): |
| | |
| | |
| | if sample_coords.size(1) != 2: |
| | sample_coords = sample_coords.permute(0, 3, 1, 2) |
| |
|
| | b, _, h, w = sample_coords.shape |
| |
|
| | |
| | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 |
| | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 |
| |
|
| | grid = torch.stack([x_grid, y_grid], dim=-1) |
| |
|
| | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) |
| |
|
| | if return_mask: |
| | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) |
| |
|
| | return img, mask |
| |
|
| | return img |
| |
|
| |
|
| | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): |
| | b, c, h, w = feature.size() |
| | assert flow.size(1) == 2 |
| |
|
| | grid = coords_grid(b, h, w).to(flow.device) + flow |
| |
|
| | return bilinear_sample(feature, grid, padding_mode=padding_mode, |
| | return_mask=mask) |
| |
|
| |
|
| | def forward_backward_consistency_check(fwd_flow, bwd_flow, |
| | alpha=0.01, |
| | beta=0.5 |
| | ): |
| | |
| | |
| | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 |
| | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 |
| | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) |
| |
|
| | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) |
| | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) |
| |
|
| | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) |
| | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) |
| |
|
| | threshold = alpha * flow_mag + beta |
| |
|
| | fwd_occ = (diff_fwd > threshold).float() |
| | bwd_occ = (diff_bwd > threshold).float() |
| |
|
| | return fwd_occ, bwd_occ |
| |
|