Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SoftArgmax(nn.Module): | |
| """ | |
| Overview: | |
| A neural network module that computes the SoftArgmax operation (essentially a 2-dimensional spatial softmax), | |
| which is often used for location regression tasks. It converts a feature map (such as a heatmap) into precise | |
| coordinate locations. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| .. note:: | |
| For more information on SoftArgmax, you can refer to <https://en.wikipedia.org/wiki/Softmax_function> | |
| and the paper <https://arxiv.org/pdf/1504.00702.pdf>. | |
| """ | |
| def __init__(self): | |
| """ | |
| Overview: | |
| Initialize the SoftArgmax module. | |
| """ | |
| super(SoftArgmax, self).__init__() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Perform the forward pass of the SoftArgmax operation. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor, typically a heatmap representing predicted locations. | |
| Returns: | |
| - location (:obj:`torch.Tensor`): The predicted coordinates as a result of the SoftArgmax operation. | |
| Shapes: | |
| - x: :math:`(B, C, H, W)`, where `B` is the batch size, `C` is the number of channels, \ | |
| and `H` and `W` represent height and width respectively. | |
| - location: :math:`(B, 2)`, where `B` is the batch size and 2 represents the coordinates (height, width). | |
| """ | |
| # Unpack the dimensions of the input tensor | |
| B, C, H, W = x.shape | |
| device, dtype = x.device, x.dtype | |
| # Ensure the input tensor has a single channel | |
| assert C == 1, "Input tensor should have only one channel" | |
| # Create a meshgrid for the height (h_kernel) and width (w_kernel) | |
| h_kernel = torch.arange(0, H, device=device).to(dtype) | |
| h_kernel = h_kernel.view(1, 1, H, 1).repeat(1, 1, 1, W) | |
| w_kernel = torch.arange(0, W, device=device).to(dtype) | |
| w_kernel = w_kernel.view(1, 1, 1, W).repeat(1, 1, H, 1) | |
| # Apply the softmax function across the spatial dimensions (height and width) | |
| x = F.softmax(x.view(B, C, -1), dim=-1).view(B, C, H, W) | |
| # Compute the expected values for height and width by multiplying the probability map by the meshgrids | |
| h = (x * h_kernel).sum(dim=[1, 2, 3]) # Sum over the channel, height, and width dimensions | |
| w = (x * w_kernel).sum(dim=[1, 2, 3]) # Sum over the channel, height, and width dimensions | |
| # Stack the height and width coordinates along a new dimension to form the final output tensor | |
| return torch.stack([h, w], dim=1) | |