Spaces:
Runtime error
Runtime error
| import native_rasterizer | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| MODE_BOUNDARY = "boundary" | |
| MODE_MASK = "mask" | |
| MODE_HARD_MASK = "hard_mask" | |
| MODE_MAPPING = {MODE_BOUNDARY: 0, MODE_MASK: 1, MODE_HARD_MASK: 2} | |
| class SoftPolygonFunction(Function): | |
| def forward(ctx, vertices, width, height, inv_smoothness=1.0, mode=MODE_BOUNDARY): | |
| ctx.width = width | |
| ctx.height = height | |
| ctx.inv_smoothness = inv_smoothness | |
| ctx.mode = MODE_MAPPING[mode] | |
| vertices = vertices.clone() | |
| ctx.device = vertices.device | |
| ctx.batch_size, ctx.number_vertices = vertices.shape[:2] | |
| rasterized = torch.FloatTensor(ctx.batch_size, ctx.height, ctx.width).fill_(0.0).to(device=ctx.device) | |
| contribution_map = torch.IntTensor(ctx.batch_size, ctx.height, ctx.width).fill_(0).to(device=ctx.device) | |
| rasterized, contribution_map = native_rasterizer.forward_rasterize( | |
| vertices, rasterized, contribution_map, width, height, inv_smoothness, ctx.mode | |
| ) | |
| ctx.save_for_backward(vertices, rasterized, contribution_map) | |
| return rasterized # , contribution_map | |
| def backward(ctx, grad_output): | |
| vertices, rasterized, contribution_map = ctx.saved_tensors | |
| grad_output = grad_output.contiguous() | |
| # grad_vertices = torch.FloatTensor( | |
| # ctx.batch_size, ctx.height, ctx.width, ctx.number_vertices, 2).fill_(0.0).to(device=ctx.device) | |
| grad_vertices = torch.FloatTensor(ctx.batch_size, ctx.number_vertices, 2).fill_(0.0).to(device=ctx.device) | |
| grad_vertices = native_rasterizer.backward_rasterize( | |
| vertices, | |
| rasterized, | |
| contribution_map, | |
| grad_output, | |
| grad_vertices, | |
| ctx.width, | |
| ctx.height, | |
| ctx.inv_smoothness, | |
| ctx.mode, | |
| ) | |
| return grad_vertices, None, None, None, None | |
| class SoftPolygon(nn.Module): | |
| MODES = [MODE_BOUNDARY, MODE_MASK, MODE_HARD_MASK] | |
| def __init__(self, inv_smoothness=1.0, mode=MODE_BOUNDARY): | |
| super(SoftPolygon, self).__init__() | |
| self.inv_smoothness = inv_smoothness | |
| if mode not in SoftPolygon.MODES: | |
| raise ValueError("invalid mode: {0}".format(mode)) | |
| self.mode = mode | |
| def forward(self, vertices, width, height, p, color=False): | |
| return SoftPolygonFunction.apply(vertices, width, height, self.inv_smoothness, self.mode) | |
| def pnp(vertices, width, height): | |
| device = vertices.device | |
| batch_size = vertices.size(0) | |
| polygon_dimension = vertices.size(1) | |
| y_index = torch.arange(0, height).to(device) | |
| x_index = torch.arange(0, width).to(device) | |
| grid_y, grid_x = torch.meshgrid(y_index, x_index) | |
| xp = grid_x.unsqueeze(0).repeat(batch_size, 1, 1).float() | |
| yp = grid_y.unsqueeze(0).repeat(batch_size, 1, 1).float() | |
| result = torch.zeros((batch_size, height, width)).bool().to(device) | |
| j = polygon_dimension - 1 | |
| for vn in range(polygon_dimension): | |
| from_x = vertices[:, vn, 0].unsqueeze(-1).unsqueeze(-1).repeat(1, height, width) | |
| from_y = vertices[:, vn, 1].unsqueeze(-1).unsqueeze(-1).repeat(1, height, width) | |
| to_x = vertices[:, j, 0].unsqueeze(-1).unsqueeze(-1).repeat(1, height, width) | |
| to_y = vertices[:, j, 1].unsqueeze(-1).unsqueeze(-1).repeat(1, height, width) | |
| has_condition = torch.logical_and( | |
| (from_y > yp) != (to_y > yp), xp < ((to_x - from_x) * (yp - from_y) / (to_y - from_y) + from_x) | |
| ) | |
| if has_condition.any(): | |
| result[has_condition] = ~result[has_condition] | |
| j = vn | |
| signed_result = -torch.ones((batch_size, height, width), device=device) | |
| signed_result[result] = 1.0 | |
| return signed_result | |
| # used for verification purposes only. | |
| class SoftPolygonPyTorch(nn.Module): | |
| def __init__(self, inv_smoothness=1.0): | |
| super(SoftPolygonPyTorch, self).__init__() | |
| self.inv_smoothness = inv_smoothness | |
| # vertices is N x P x 2 | |
| def forward(self, vertices, width, height, p, color=False): | |
| device = vertices.device | |
| batch_size = vertices.size(0) | |
| polygon_dimension = vertices.size(1) | |
| inside_outside = pnp(vertices, width, height) | |
| # discrete points we will sample from. | |
| y_index = torch.arange(0, height).to(device) | |
| x_index = torch.arange(0, width).to(device) | |
| grid_y, grid_x = torch.meshgrid(y_index, x_index) | |
| grid_x = grid_x.unsqueeze(0).repeat(batch_size, 1, 1).float() | |
| grid_y = grid_y.unsqueeze(0).repeat(batch_size, 1, 1).float() | |
| # do this "per dimension" | |
| distance_segments = [] | |
| over_segments = [] | |
| for from_index in range(polygon_dimension): | |
| segment_result = torch.zeros((batch_size, height, width)).to(device) | |
| from_vertex = vertices[:, from_index].unsqueeze(-1).unsqueeze(-1) | |
| if from_index == (polygon_dimension - 1): | |
| to_vertex = vertices[:, 0].unsqueeze(-1).unsqueeze(-1) | |
| else: | |
| to_vertex = vertices[:, from_index + 1].unsqueeze(-1).unsqueeze(-1) | |
| x2_sub_x1 = to_vertex[:, 0] - from_vertex[:, 0] | |
| y2_sub_y1 = to_vertex[:, 1] - from_vertex[:, 1] | |
| square_segment_length = x2_sub_x1 * x2_sub_x1 + y2_sub_y1 * y2_sub_y1 + 0.00001 | |
| # figure out if this is a major/minor segment (todo?) | |
| x_sub_x1 = grid_x - from_vertex[:, 0] | |
| y_sub_y1 = grid_y - from_vertex[:, 1] | |
| x_sub_x2 = grid_x - to_vertex[:, 0] | |
| y_sub_y2 = grid_y - to_vertex[:, 1] | |
| # dot between the given point and first vertex and first vertex and second vertex. | |
| dot = ((x_sub_x1 * x2_sub_x1) + (y_sub_y1 * y2_sub_y1)) / square_segment_length | |
| # needlessly computed sometimes. | |
| x_proj = grid_x - (from_vertex[:, 0] + dot * x2_sub_x1) | |
| y_proj = grid_y - (from_vertex[:, 1] + dot * y2_sub_y1) | |
| from_closest = dot < 0 | |
| to_closest = dot > 1 | |
| interior_closest = (dot >= 0) & (dot <= 1) | |
| segment_result[from_closest] = x_sub_x1[from_closest] ** 2 + y_sub_y1[from_closest] ** 2 | |
| segment_result[to_closest] = x_sub_x2[to_closest] ** 2 + y_sub_y2[to_closest] ** 2 | |
| segment_result[interior_closest] = x_proj[interior_closest] ** 2 + y_proj[interior_closest] ** 2 | |
| distance_map = -segment_result | |
| distance_segments.append(distance_map) | |
| signed_map = torch.sigmoid(-distance_map * inside_outside / self.inv_smoothness) | |
| over_segments.append(signed_map) | |
| F_max, F_arg = torch.max(torch.stack(distance_segments, dim=-1), dim=-1) | |
| F_theta = torch.gather(torch.stack(over_segments, dim=-1), dim=-1, index=F_arg.unsqueeze(-1))[..., 0] | |
| return F_theta | |