Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def heatmap2coord(heatmap, topk=9): | |
| N, C, H, W = heatmap.shape | |
| score, index = heatmap.view(N,C,1,-1).topk(topk, dim=-1) | |
| coord = torch.cat([index%W, index//W], dim=2) | |
| return (coord*F.softmax(score, dim=-1)).sum(-1) | |