Spaces:
Runtime error
Runtime error
File size: 305 Bytes
e90b704 |
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch
import torch.nn as nn
import torch.nn.functional as F
def heatmap2coord(heatmap, topk=9):
N, C, H, W = heatmap.shape
score, index = heatmap.view(N,C,1,-1).topk(topk, dim=-1)
coord = torch.cat([index%W, index//W], dim=2)
return (coord*F.softmax(score, dim=-1)).sum(-1)
|