File size: 1,870 Bytes
352cafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import cv2
import numpy as np

import torch
import torch.nn.functional as F

def edge(masks, thickness=5):
    masks = masks.cpu().detach().numpy().astype(np.uint8)
    bounds = []
    # dd_s = time.time()
    for mask in masks:
        mask = np.pad(mask[0], thickness, 'constant', constant_values=0)
        mask_sobel_x = cv2.Sobel(mask, cv2.CV_16S, 1, 0)
        mask_sobel_y = cv2.Sobel(mask, cv2.CV_16S, 0, 1)
        abs_x = cv2.convertScaleAbs(mask_sobel_x)
        abs_y = cv2.convertScaleAbs(mask_sobel_y)
        bound = cv2.addWeighted(abs_x,0.5,abs_y,0.5,0)
        mask = mask[thickness:-thickness, thickness:-thickness]
        bound = bound[thickness:-thickness, thickness:-thickness]
        bound = (bound>0).astype(np.uint8)
        bounds.append(bound[np.newaxis, :, :])
    # dd_e = time.time()
    # ee_s = time.time()
    bounds = np.concatenate(bounds, axis=0)
    # bounds = bounds == 1
    # index = np.where()
    # ee_e = time.time()
    # print("dd:{}, ee:{}".format(float(dd_e-dd_s), float(ee_e-ee_s)))
    return bounds

def mask_losses(mask_logits, gt_masks, mask_weight=1.0, edge_weight=1.0):
    if len(gt_masks) == 0:
        return mask_logits.sum()*0

    mask_side_len = 224
    ## gt_masks shape: N*28*28
    mask_weights = torch.full_like(gt_masks.squeeze(1), mask_weight).float().detach()
    
    if edge_weight > 1.0:
        edges = edge(gt_masks)
        edges = torch.Tensor(edges).cuda()
        index = (edges==1)
        mask_weights[index] = edge_weight

    gt_masks = gt_masks.view(-1, mask_side_len*mask_side_len).to(dtype=torch.float32)
    mask_logits = mask_logits.view(-1, mask_side_len*mask_side_len).to(dtype=torch.float32)
    mask_weights = mask_weights.view(-1, mask_side_len*mask_side_len)

    mask_loss = F.binary_cross_entropy_with_logits(mask_logits, gt_masks, weight=mask_weights)
    return mask_loss