Spaces:
Build error
Build error
| # Modified from: | |
| # https://github.com/anibali/pytorch-stacked-hourglass | |
| # https://github.com/bearpaw/pytorch-pose | |
| import math | |
| import torch | |
| from kornia.geometry.subpix import dsnt # kornia 0.4.0 | |
| import torch.nn.functional as F | |
| from .transforms import transform_preds | |
| __all__ = ['get_preds', 'get_preds_soft', 'calc_dists', 'dist_acc', 'accuracy', 'final_preds_untransformed', | |
| 'final_preds', 'AverageMeter'] | |
| def get_preds(scores, return_maxval=False): | |
| ''' get predictions from score maps in torch Tensor | |
| return type: torch.LongTensor | |
| ''' | |
| assert scores.dim() == 4, 'Score maps should be 4-dim' | |
| maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2) | |
| maxval = maxval.view(scores.size(0), scores.size(1), 1) | |
| idx = idx.view(scores.size(0), scores.size(1), 1) + 1 | |
| preds = idx.repeat(1, 1, 2).float() | |
| preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1 | |
| preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1 | |
| pred_mask = maxval.gt(0).repeat(1, 1, 2).float() # values > 0 | |
| preds *= pred_mask | |
| if return_maxval: | |
| return preds, maxval | |
| else: | |
| return preds | |
| def get_preds_soft(scores, return_maxval=False, norm_coords=False, norm_and_unnorm_coords=False): | |
| ''' get predictions from score maps in torch Tensor | |
| predictions are made assuming a logit output map | |
| return type: torch.LongTensor | |
| ''' | |
| # New: work on logit predictions | |
| scores_norm = dsnt.spatial_softmax2d(scores, temperature=torch.tensor(1)) | |
| # maxval_norm, idx_norm = torch.max(scores_norm.view(scores.size(0), scores.size(1), -1), 2) | |
| # from unnormalized to normalized see: | |
| # from -1to1 to 0to64 | |
| # see https://github.com/kornia/kornia/blob/b9ffe7efcba7399daeeb8028f10c22941b55d32d/kornia/utils/grid.py#L7 (line 40) | |
| # xs = (xs / (width - 1) - 0.5) * 2 | |
| # ys = (ys / (height - 1) - 0.5) * 2 | |
| device = scores.device | |
| if return_maxval: | |
| preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) | |
| # grid_sample(input, grid, mode='bilinear', padding_mode='zeros') | |
| gs_input_single = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) | |
| gs_input = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) | |
| half_pad = 2 | |
| gs_input_single_padded = F.pad(input=gs_input_single, pad=(half_pad, half_pad, half_pad, half_pad, 0, 0, 0, 0), mode='constant', value=0) | |
| gs_input_all = torch.zeros((gs_input_single.shape[0], 9, gs_input_single.shape[2], gs_input_single.shape[3])).to(device) | |
| ind_tot = 0 | |
| for ind0 in [-1, 0, 1]: | |
| for ind1 in [-1, 0, 1]: | |
| gs_input_all[:, ind_tot, :, :] = gs_input_single_padded[:, 0, half_pad+ind0:-half_pad+ind0, half_pad+ind1:-half_pad+ind1] | |
| ind_tot +=1 | |
| gs_grid = preds_normalized.reshape((-1, 2))[:, None, None, :] # (120, 1, 1, 2) | |
| gs_output_all = F.grid_sample(gs_input_all, gs_grid, mode='nearest', padding_mode='zeros', align_corners=True).reshape((gs_input_all.shape[0], gs_input_all.shape[1], 1)) | |
| gs_output = gs_output_all.sum(axis=1) | |
| # scores_norm[0, :, :, :].max(axis=2)[0].max(axis=1)[0] | |
| # gs_output[0, :, 0] | |
| gs_output_resh = gs_output.reshape((scores_norm.shape[0], scores_norm.shape[1], 1)) | |
| if norm_and_unnorm_coords: | |
| preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 | |
| return preds, preds_normalized, gs_output_resh | |
| elif norm_coords: | |
| return preds_normalized, gs_output_resh | |
| else: | |
| preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 | |
| return preds, gs_output_resh | |
| else: | |
| if norm_coords: | |
| preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) | |
| return preds_normalized | |
| else: | |
| preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 | |
| return preds | |
| def calc_dists(preds, target, normalize): | |
| preds = preds.float() | |
| target = target.float() | |
| dists = torch.zeros(preds.size(1), preds.size(0)) | |
| for n in range(preds.size(0)): | |
| for c in range(preds.size(1)): | |
| if target[n,c,0] > 1 and target[n, c, 1] > 1: | |
| dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] | |
| else: | |
| dists[c, n] = -1 | |
| return dists | |
| def dist_acc(dist, thr=0.5): | |
| ''' Return percentage below threshold while ignoring values with a -1 ''' | |
| dist = dist[dist != -1] | |
| if len(dist) > 0: | |
| return 1.0 * (dist < thr).sum().item() / len(dist) | |
| else: | |
| return -1 | |
| def accuracy(output, target, idxs=None, thr=0.5): | |
| ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations | |
| First value to be returned is average accuracy across 'idxs', followed by individual accuracies | |
| ''' | |
| if idxs is None: | |
| idxs = list(range(target.shape[-3])) | |
| preds = get_preds_soft(output) # get_preds(output) | |
| gts = get_preds(target) | |
| norm = torch.ones(preds.size(0))*output.size(3)/10 | |
| dists = calc_dists(preds, gts, norm) | |
| acc = torch.zeros(len(idxs)+1) | |
| avg_acc = 0 | |
| cnt = 0 | |
| for i in range(len(idxs)): | |
| acc[i+1] = dist_acc(dists[idxs[i]], thr=thr) | |
| if acc[i+1] >= 0: | |
| avg_acc = avg_acc + acc[i+1] | |
| cnt += 1 | |
| if cnt != 0: | |
| acc[0] = avg_acc / cnt | |
| return acc | |
| def final_preds_untransformed(output, res): | |
| coords = get_preds_soft(output) # get_preds(output) # float type | |
| # pose-processing | |
| for n in range(coords.size(0)): | |
| for p in range(coords.size(1)): | |
| hm = output[n][p] | |
| px = int(math.floor(coords[n][p][0])) | |
| py = int(math.floor(coords[n][p][1])) | |
| if px > 1 and px < res[0] and py > 1 and py < res[1]: | |
| diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) | |
| coords[n][p] += diff.sign() * .25 | |
| coords += 0.5 | |
| if coords.dim() < 3: | |
| coords = coords.unsqueeze(0) | |
| coords -= 1 # Convert from 1-based to 0-based coordinates | |
| return coords | |
| def final_preds(output, center, scale, res): | |
| coords = final_preds_untransformed(output, res) | |
| preds = coords.clone() | |
| # Transform back | |
| for i in range(coords.size(0)): | |
| preds[i] = transform_preds(coords[i], center[i], scale[i], res) | |
| if preds.dim() < 3: | |
| preds = preds.unsqueeze(0) | |
| return preds | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |