Spaces:
Runtime error
Runtime error
| import torch | |
| class decoder_default: | |
| def __init__(self, weight=1, use_weight_map=False): | |
| self.weight = weight | |
| self.use_weight_map = use_weight_map | |
| def _make_grid(self, h, w): | |
| yy, xx = torch.meshgrid( | |
| torch.arange(h).float() / (h - 1) * 2 - 1, | |
| torch.arange(w).float() / (w - 1) * 2 - 1) | |
| return yy, xx | |
| def get_coords_from_heatmap(self, heatmap): | |
| """ | |
| inputs: | |
| - heatmap: batch x npoints x h x w | |
| outputs: | |
| - coords: batch x npoints x 2 (x,y), [-1, +1] | |
| - radius_sq: batch x npoints | |
| """ | |
| batch, npoints, h, w = heatmap.shape | |
| if self.use_weight_map: | |
| heatmap = heatmap * self.weight | |
| yy, xx = self._make_grid(h, w) | |
| yy = yy.view(1, 1, h, w).to(heatmap) | |
| xx = xx.view(1, 1, h, w).to(heatmap) | |
| heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6) | |
| yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints | |
| xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints | |
| coords = torch.stack([xx_coord, yy_coord], dim=-1) | |
| return coords | |