weismart1807's picture
Upload folder using huggingface_hub
e90b704 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def heatmap2coord(heatmap, topk=9):
N, C, H, W = heatmap.shape
score, index = heatmap.view(N,C,1,-1).topk(topk, dim=-1)
coord = torch.cat([index%W, index//W], dim=2)
return (coord*F.softmax(score, dim=-1)).sum(-1)