Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| #Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py | |
| def bilinear_sampler(img, coords, mode='bilinear', mask=False): | |
| """ Wrapper for grid_sample, uses pixel coordinates """ | |
| H, W = img.shape[-2:] | |
| xgrid, ygrid = coords.split([1,1], dim=-1) | |
| xgrid = 2*xgrid/(W-1) - 1 | |
| ygrid = 2*ygrid/(H-1) - 1 | |
| grid = torch.cat([xgrid, ygrid], dim=-1) | |
| # img = F.grid_sample(img, grid, align_corners=True) | |
| img = bilinear_grid_sample(img, grid, align_corners=True) | |
| if mask: | |
| mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) | |
| return img, mask.float() | |
| return img | |
| def coords_grid(batch, ht, wd, device): | |
| coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') | |
| coords = torch.stack(coords[::-1], dim=0).float() | |
| return coords[None].repeat(batch, 1, 1, 1) | |
| def manual_pad(x, pady, padx): | |
| pad = (padx, padx, pady, pady) | |
| return F.pad(x.clone().detach(), pad, "replicate") | |
| # Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160 | |
| def bilinear_grid_sample(im, grid, align_corners=False): | |
| """Given an input and a flow-field grid, computes the output using input | |
| values and pixel locations from grid. Supported only bilinear interpolation | |
| method to sample the input pixels. | |
| Args: | |
| im (torch.Tensor): Input feature map, shape (N, C, H, W) | |
| grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) | |
| align_corners {bool}: If set to True, the extrema (-1 and 1) are | |
| considered as referring to the center points of the input’s | |
| corner pixels. If set to False, they are instead considered as | |
| referring to the corner points of the input’s corner pixels, | |
| making the sampling more resolution agnostic. | |
| Returns: | |
| torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) | |
| """ | |
| n, c, h, w = im.shape | |
| gn, gh, gw, _ = grid.shape | |
| assert n == gn | |
| x = grid[:, :, :, 0] | |
| y = grid[:, :, :, 1] | |
| if align_corners: | |
| x = ((x + 1) / 2) * (w - 1) | |
| y = ((y + 1) / 2) * (h - 1) | |
| else: | |
| x = ((x + 1) * w - 1) / 2 | |
| y = ((y + 1) * h - 1) / 2 | |
| x = x.view(n, -1) | |
| y = y.view(n, -1) | |
| x0 = torch.floor(x).long() | |
| y0 = torch.floor(y).long() | |
| x1 = x0 + 1 | |
| y1 = y0 + 1 | |
| wa = ((x1 - x) * (y1 - y)).unsqueeze(1) | |
| wb = ((x1 - x) * (y - y0)).unsqueeze(1) | |
| wc = ((x - x0) * (y1 - y)).unsqueeze(1) | |
| wd = ((x - x0) * (y - y0)).unsqueeze(1) | |
| # Apply default for grid_sample function zero padding | |
| im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) | |
| padded_h = h + 2 | |
| padded_w = w + 2 | |
| # save points positions after padding | |
| x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 | |
| # Clip coordinates to padded image size | |
| x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0) | |
| x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0) | |
| x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1) | |
| x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1) | |
| y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0) | |
| y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0) | |
| y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1) | |
| y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1) | |
| im_padded = im_padded.view(n, c, -1) | |
| x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) | |
| x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) | |
| x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) | |
| x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) | |
| Ia = torch.gather(im_padded, 2, x0_y0) | |
| Ib = torch.gather(im_padded, 2, x0_y1) | |
| Ic = torch.gather(im_padded, 2, x1_y0) | |
| Id = torch.gather(im_padded, 2, x1_y1) | |
| return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) | |