| | import torch |
| | from torch.nn import functional as F |
| |
|
| |
|
| | def generate_edge_tensor(label, edge_width=3): |
| | label = label.type(torch.cuda.FloatTensor) |
| | if len(label.shape) == 2: |
| | label = label.unsqueeze(0) |
| | n, h, w = label.shape |
| | edge = torch.zeros(label.shape, dtype=torch.float).cuda() |
| | |
| | edge_right = edge[:, 1:h, :] |
| | edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255) |
| | & (label[:, :h - 1, :] != 255)] = 1 |
| |
|
| | |
| | edge_up = edge[:, :, :w - 1] |
| | edge_up[(label[:, :, :w - 1] != label[:, :, 1:w]) |
| | & (label[:, :, :w - 1] != 255) |
| | & (label[:, :, 1:w] != 255)] = 1 |
| |
|
| | |
| | edge_upright = edge[:, :h - 1, :w - 1] |
| | edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w]) |
| | & (label[:, :h - 1, :w - 1] != 255) |
| | & (label[:, 1:h, 1:w] != 255)] = 1 |
| |
|
| | |
| | edge_bottomright = edge[:, :h - 1, 1:w] |
| | edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1]) |
| | & (label[:, :h - 1, 1:w] != 255) |
| | & (label[:, 1:h, :w - 1] != 255)] = 1 |
| |
|
| | kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda() |
| | with torch.no_grad(): |
| | edge = edge.unsqueeze(1) |
| | edge = F.conv2d(edge, kernel, stride=1, padding=1) |
| | edge[edge!=0] = 1 |
| | edge = edge.squeeze() |
| | return edge |
| |
|