| import torch |
| import torch.nn as nn |
|
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| class AddCoordsTh(nn.Module): |
| def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): |
| super(AddCoordsTh, self).__init__() |
| self.x_dim = x_dim |
| self.y_dim = y_dim |
| self.with_r = with_r |
| self.with_boundary = with_boundary |
|
|
| def forward(self, input_tensor, heatmap=None): |
| """ |
| input_tensor: (batch, c, x_dim, y_dim) |
| """ |
| batch_size_tensor = input_tensor.shape[0] |
|
|
| xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(device) |
| xx_ones = xx_ones.unsqueeze(-1) |
|
|
| xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(device) |
| xx_range = xx_range.unsqueeze(1) |
|
|
| xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) |
| xx_channel = xx_channel.unsqueeze(-1) |
|
|
|
|
| yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(device) |
| yy_ones = yy_ones.unsqueeze(1) |
|
|
| yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(device) |
| yy_range = yy_range.unsqueeze(-1) |
|
|
| yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) |
| yy_channel = yy_channel.unsqueeze(-1) |
|
|
| xx_channel = xx_channel.permute(0, 3, 2, 1) |
| yy_channel = yy_channel.permute(0, 3, 2, 1) |
|
|
| xx_channel = xx_channel / (self.x_dim - 1) |
| yy_channel = yy_channel / (self.y_dim - 1) |
|
|
| xx_channel = xx_channel * 2 - 1 |
| yy_channel = yy_channel * 2 - 1 |
|
|
| xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) |
| yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) |
|
|
| if self.with_boundary and type(heatmap) != type(None): |
| boundary_channel = torch.clamp(heatmap[:, -1:, :, :], |
| 0.0, 1.0) |
|
|
| zero_tensor = torch.zeros_like(xx_channel) |
| xx_boundary_channel = torch.where(boundary_channel>0.05, |
| xx_channel, zero_tensor) |
| yy_boundary_channel = torch.where(boundary_channel>0.05, |
| yy_channel, zero_tensor) |
| if self.with_boundary and type(heatmap) != type(None): |
| xx_boundary_channel = xx_boundary_channel.to(device) |
| yy_boundary_channel = yy_boundary_channel.to(device) |
|
|
| ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) |
|
|
|
|
| if self.with_r: |
| rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) |
| rr = rr / torch.max(rr) |
| ret = torch.cat([ret, rr], dim=1) |
|
|
| if self.with_boundary and type(heatmap) != type(None): |
| ret = torch.cat([ret, xx_boundary_channel, |
| yy_boundary_channel], dim=1) |
| return ret |
|
|
|
|
| class CoordConvTh(nn.Module): |
| """CoordConv layer as in the paper.""" |
| def __init__(self, x_dim, y_dim, with_r, with_boundary, |
| in_channels, first_one=False, *args, **kwargs): |
| super(CoordConvTh, self).__init__() |
| self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, |
| with_boundary=with_boundary) |
| in_channels += 2 |
| if with_r: |
| in_channels += 1 |
| if with_boundary and not first_one: |
| in_channels += 2 |
| self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) |
|
|
| def forward(self, input_tensor, heatmap=None): |
| ret = self.addcoords(input_tensor, heatmap) |
| last_channel = ret[:, -2:, :, :] |
| ret = self.conv(ret) |
| return ret, last_channel |
|
|
|
|
| ''' |
| An alternative implementation for PyTorch with auto-infering the x-y dimensions. |
| ''' |
| class AddCoords(nn.Module): |
|
|
| def __init__(self, with_r=False): |
| super().__init__() |
| self.with_r = with_r |
|
|
| def forward(self, input_tensor): |
| """ |
| Args: |
| input_tensor: shape(batch, channel, x_dim, y_dim) |
| """ |
| batch_size, _, x_dim, y_dim = input_tensor.size() |
|
|
| xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) |
| yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) |
|
|
| xx_channel = xx_channel / (x_dim - 1) |
| yy_channel = yy_channel / (y_dim - 1) |
|
|
| xx_channel = xx_channel * 2 - 1 |
| yy_channel = yy_channel * 2 - 1 |
|
|
| xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) |
| yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) |
|
|
| if input_tensor.is_cuda: |
| xx_channel = xx_channel.to(device) |
| yy_channel = yy_channel.to(device) |
|
|
| ret = torch.cat([ |
| input_tensor, |
| xx_channel.type_as(input_tensor), |
| yy_channel.type_as(input_tensor)], dim=1) |
|
|
| if self.with_r: |
| rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) |
| if input_tensor.is_cuda: |
| rr = rr.to(device) |
| ret = torch.cat([ret, rr], dim=1) |
|
|
| return ret |
|
|
|
|
| class CoordConv(nn.Module): |
|
|
| def __init__(self, in_channels, out_channels, with_r=False, **kwargs): |
| super().__init__() |
| self.addcoords = AddCoords(with_r=with_r) |
| self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs) |
|
|
| def forward(self, x): |
| ret = self.addcoords(x) |
| ret = self.conv(ret) |
| return ret |
|
|