| import torch
|
| from torch import nn as nn
|
| import numpy as np
|
|
|
| from . import initializer as initializer
|
| from ..utils.cython import get_dist_maps
|
|
|
|
|
| def select_activation_function(activation):
|
| if isinstance(activation, str):
|
| if activation.lower() == 'relu':
|
| return nn.ReLU
|
| elif activation.lower() == 'softplus':
|
| return nn.Softplus
|
| else:
|
| raise ValueError(f"Unknown activation type {activation}")
|
| elif isinstance(activation, nn.Module):
|
| return activation
|
| else:
|
| raise ValueError(f"Unknown activation type {activation}")
|
|
|
|
|
| class BilinearConvTranspose2d(nn.ConvTranspose2d):
|
| def __init__(self, in_channels, out_channels, scale, groups=1):
|
| kernel_size = 2 * scale - scale % 2
|
| self.scale = scale
|
|
|
| super().__init__(
|
| in_channels, out_channels,
|
| kernel_size=kernel_size,
|
| stride=scale,
|
| padding=1,
|
| groups=groups,
|
| bias=False)
|
|
|
| self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
|
|
|
|
|
| class DistMaps(nn.Module):
|
| def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
|
| super(DistMaps, self).__init__()
|
| self.spatial_scale = spatial_scale
|
| self.norm_radius = norm_radius
|
| self.cpu_mode = cpu_mode
|
|
|
| def get_coord_features(self, points, batchsize, rows, cols):
|
| if self.cpu_mode:
|
| coords = []
|
| for i in range(batchsize):
|
| norm_delimeter = self.spatial_scale * self.norm_radius
|
| coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
|
| norm_delimeter))
|
| coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
|
| else:
|
| num_points = points.shape[1] // 2
|
| points = points.view(-1, 2)
|
| invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
|
| row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
|
| col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
|
|
|
| coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
|
| coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
|
|
|
| add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
|
| coords.add_(-add_xy)
|
| coords.div_(self.norm_radius * self.spatial_scale)
|
| coords.mul_(coords)
|
|
|
| coords[:, 0] += coords[:, 1]
|
| coords = coords[:, :1]
|
|
|
| coords[invalid_points, :, :, :] = 1e6
|
|
|
| coords = coords.view(-1, num_points, 1, rows, cols)
|
| coords = coords.min(dim=1)[0]
|
| coords = coords.view(-1, 2, rows, cols)
|
|
|
| coords.sqrt_().mul_(2).tanh_()
|
|
|
| return coords
|
|
|
| def forward(self, x, coords):
|
| return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])
|
|
|