drozdgk's picture
chore: vendor third_party (remove submodules, ignore artifacts)
352cafd
raw
history blame
1.87 kB
import cv2
import numpy as np
import torch
import torch.nn.functional as F
def edge(masks, thickness=5):
masks = masks.cpu().detach().numpy().astype(np.uint8)
bounds = []
# dd_s = time.time()
for mask in masks:
mask = np.pad(mask[0], thickness, 'constant', constant_values=0)
mask_sobel_x = cv2.Sobel(mask, cv2.CV_16S, 1, 0)
mask_sobel_y = cv2.Sobel(mask, cv2.CV_16S, 0, 1)
abs_x = cv2.convertScaleAbs(mask_sobel_x)
abs_y = cv2.convertScaleAbs(mask_sobel_y)
bound = cv2.addWeighted(abs_x,0.5,abs_y,0.5,0)
mask = mask[thickness:-thickness, thickness:-thickness]
bound = bound[thickness:-thickness, thickness:-thickness]
bound = (bound>0).astype(np.uint8)
bounds.append(bound[np.newaxis, :, :])
# dd_e = time.time()
# ee_s = time.time()
bounds = np.concatenate(bounds, axis=0)
# bounds = bounds == 1
# index = np.where()
# ee_e = time.time()
# print("dd:{}, ee:{}".format(float(dd_e-dd_s), float(ee_e-ee_s)))
return bounds
def mask_losses(mask_logits, gt_masks, mask_weight=1.0, edge_weight=1.0):
if len(gt_masks) == 0:
return mask_logits.sum()*0
mask_side_len = 224
## gt_masks shape: N*28*28
mask_weights = torch.full_like(gt_masks.squeeze(1), mask_weight).float().detach()
if edge_weight > 1.0:
edges = edge(gt_masks)
edges = torch.Tensor(edges).cuda()
index = (edges==1)
mask_weights[index] = edge_weight
gt_masks = gt_masks.view(-1, mask_side_len*mask_side_len).to(dtype=torch.float32)
mask_logits = mask_logits.view(-1, mask_side_len*mask_side_len).to(dtype=torch.float32)
mask_weights = mask_weights.view(-1, mask_side_len*mask_side_len)
mask_loss = F.binary_cross_entropy_with_logits(mask_logits, gt_masks, weight=mask_weights)
return mask_loss