import torch import torch.nn as nn import torch.nn.functional as F def heatmap2coord(heatmap, topk=9): N, C, H, W = heatmap.shape score, index = heatmap.view(N,C,1,-1).topk(topk, dim=-1) coord = torch.cat([index%W, index//W], dim=2) return (coord*F.softmax(score, dim=-1)).sum(-1)