Spaces:
Build error
Build error
| r""" Provides functions that manipulate boxes and points """ | |
| import math | |
| import torch.nn.functional as F | |
| import torch | |
| class Geometry(object): | |
| def initialize(cls, img_size): | |
| cls.img_size = img_size | |
| cls.spatial_side = int(img_size / 8) | |
| norm_grid1d = torch.linspace(-1, 1, cls.spatial_side) | |
| cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1) | |
| cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1) | |
| cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0) | |
| cls.feat_idx = torch.arange(0, cls.spatial_side).float() | |
| def normalize_kps(cls, kps): | |
| kps = kps.clone().detach() | |
| kps[kps != -2] -= (cls.img_size // 2) | |
| kps[kps != -2] /= (cls.img_size // 2) | |
| return kps | |
| def unnormalize_kps(cls, kps): | |
| kps = kps.clone().detach() | |
| kps[kps != -2] *= (cls.img_size // 2) | |
| kps[kps != -2] += (cls.img_size // 2) | |
| return kps | |
| def attentive_indexing(cls, kps, thres=0.1): | |
| r"""kps: normalized keypoints x, y (N, 2) | |
| returns attentive index map(N, spatial_side, spatial_side) | |
| """ | |
| nkps = kps.size(0) | |
| kps = kps.view(nkps, 1, 1, 2) | |
| eps = 1e-5 | |
| attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3) | |
| attmap = (attmap + eps).pow(0.5) | |
| attmap = (thres - attmap).clamp(min=0).view(nkps, -1) | |
| attmap = attmap / attmap.sum(dim=1, keepdim=True) | |
| attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side) | |
| return attmap | |
| def apply_gaussian_kernel(cls, corr, sigma=17): | |
| bsz, side, side = corr.size() | |
| center = corr.max(dim=2)[1] | |
| center_y = center // cls.spatial_side | |
| center_x = center % cls.spatial_side | |
| y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2) | |
| x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2) | |
| y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side) | |
| x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1) | |
| gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2)) | |
| filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side) | |
| filtered_corr = filtered_corr.view(bsz, side, side) | |
| return filtered_corr | |
| def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized): | |
| r""" Transfer keypoints by weighted average """ | |
| if not normalized: | |
| src_kps = Geometry.normalize_kps(src_kps) | |
| confidence_ts = cls.apply_gaussian_kernel(confidence_ts) | |
| pdf = F.softmax(confidence_ts, dim=2) | |
| prd_x = (pdf * cls.norm_grid_x).sum(dim=2) | |
| prd_y = (pdf * cls.norm_grid_y).sum(dim=2) | |
| prd_kps = [] | |
| for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)): | |
| max_pts = src_kp.size()[1] | |
| prd_xy = torch.stack([x, y]).t() | |
| src_kp = src_kp[:, :np].t() | |
| attmap = cls.attentive_indexing(src_kp).view(np, -1) | |
| prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t() | |
| pads = (torch.zeros((2, max_pts - np)) - 2) | |
| prd_kp = torch.cat([prd_kp, pads], dim=1) | |
| prd_kps.append(prd_kp) | |
| return torch.stack(prd_kps) | |
| def get_coord1d(coord4d, ksz): | |
| i, j, k, l = coord4d | |
| coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l | |
| return coord1d | |
| def get_distance(coord1, coord2): | |
| delta_y = int(math.pow(coord1[0] - coord2[0], 2)) | |
| delta_x = int(math.pow(coord1[1] - coord2[1], 2)) | |
| dist = delta_y + delta_x | |
| return dist | |
| def interpolate4d(tensor4d, size): | |
| bsz, h1, w1, h2, w2 = tensor4d.size() | |
| tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2) | |
| tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True) | |
| tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2) | |
| tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True) | |
| tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0]) | |
| return tensor4d | |
| def init_idx4d(ksz): | |
| i0 = torch.arange(0, ksz).repeat(ksz ** 3) | |
| i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2) | |
| i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz) | |
| i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1) | |
| idx4d = torch.stack([i3, i2, i1, i0]).t().numpy() | |
| return idx4d | |