|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def edge(masks, thickness=5): |
|
|
masks = masks.cpu().detach().numpy().astype(np.uint8) |
|
|
bounds = [] |
|
|
|
|
|
for mask in masks: |
|
|
mask = np.pad(mask[0], thickness, 'constant', constant_values=0) |
|
|
mask_sobel_x = cv2.Sobel(mask, cv2.CV_16S, 1, 0) |
|
|
mask_sobel_y = cv2.Sobel(mask, cv2.CV_16S, 0, 1) |
|
|
abs_x = cv2.convertScaleAbs(mask_sobel_x) |
|
|
abs_y = cv2.convertScaleAbs(mask_sobel_y) |
|
|
bound = cv2.addWeighted(abs_x,0.5,abs_y,0.5,0) |
|
|
mask = mask[thickness:-thickness, thickness:-thickness] |
|
|
bound = bound[thickness:-thickness, thickness:-thickness] |
|
|
bound = (bound>0).astype(np.uint8) |
|
|
bounds.append(bound[np.newaxis, :, :]) |
|
|
|
|
|
|
|
|
bounds = np.concatenate(bounds, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return bounds |
|
|
|
|
|
def mask_losses(mask_logits, gt_masks, mask_weight=1.0, edge_weight=1.0): |
|
|
if len(gt_masks) == 0: |
|
|
return mask_logits.sum()*0 |
|
|
|
|
|
mask_side_len = 224 |
|
|
|
|
|
mask_weights = torch.full_like(gt_masks.squeeze(1), mask_weight).float().detach() |
|
|
|
|
|
if edge_weight > 1.0: |
|
|
edges = edge(gt_masks) |
|
|
edges = torch.Tensor(edges).cuda() |
|
|
index = (edges==1) |
|
|
mask_weights[index] = edge_weight |
|
|
|
|
|
gt_masks = gt_masks.view(-1, mask_side_len*mask_side_len).to(dtype=torch.float32) |
|
|
mask_logits = mask_logits.view(-1, mask_side_len*mask_side_len).to(dtype=torch.float32) |
|
|
mask_weights = mask_weights.view(-1, mask_side_len*mask_side_len) |
|
|
|
|
|
mask_loss = F.binary_cross_entropy_with_logits(mask_logits, gt_masks, weight=mask_weights) |
|
|
return mask_loss |
|
|
|