|
|
import torch.nn.functional as F |
|
|
|
|
|
from util.util import compute_tensor_iu |
|
|
|
|
|
def get_new_iou_hook(values, size): |
|
|
return 'iou/new_iou_%s'%size, values['iou/new_i_%s'%size]/values['iou/new_u_%s'%size] |
|
|
|
|
|
def get_orig_iou_hook(values): |
|
|
return 'iou/orig_iou', values['iou/orig_i']/values['iou/orig_u'] |
|
|
|
|
|
def get_iou_gain(values, size): |
|
|
return 'iou/iou_gain_%s'%size, values['iou/new_iou_%s'%size] - values['iou/orig_iou'] |
|
|
|
|
|
iou_hooks_to_be_used = [ |
|
|
get_orig_iou_hook, |
|
|
lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'), |
|
|
] |
|
|
|
|
|
iou_hooks_final_only = [ |
|
|
get_orig_iou_hook, |
|
|
lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'), |
|
|
] |
|
|
|
|
|
|
|
|
def compute_loss_and_metrics(images, para, detailed=True, need_loss=True, has_lower_res=True): |
|
|
|
|
|
""" |
|
|
This part compute loss and metrics for the generator |
|
|
""" |
|
|
|
|
|
loss_and_metrics = {} |
|
|
|
|
|
gt = images['gt'] |
|
|
seg = images['seg'] |
|
|
|
|
|
pred_224 = images['pred_224'] |
|
|
|
|
|
if need_loss: |
|
|
|
|
|
ce_weights = para['ce_weight'] |
|
|
l1_weights = para['l1_weight'] |
|
|
l2_weights = para['l2_weight'] |
|
|
|
|
|
|
|
|
ce_loss = 0 |
|
|
l1_loss = 0 |
|
|
l2_loss = 0 |
|
|
loss = 0 |
|
|
|
|
|
ce_loss = F.binary_cross_entropy_with_logits(images['out_224'], (gt>0.5).float()) |
|
|
l1_loss = F.l1_loss(pred_224, gt) |
|
|
l2_loss = F.mse_loss(pred_224, gt) |
|
|
|
|
|
loss_and_metrics['grad_loss'] = F.l1_loss(images['gt_sobel'], images['pred_sobel']) |
|
|
|
|
|
|
|
|
loss = ce_loss * ce_weights + l1_loss * l1_weights + l2_loss * l2_weights |
|
|
|
|
|
loss += loss_and_metrics['grad_loss'] * para['grad_weight'] |
|
|
|
|
|
""" |
|
|
Compute IOU stats |
|
|
""" |
|
|
orig_total_i, orig_total_u = compute_tensor_iu(seg>0.5, gt>0.5) |
|
|
loss_and_metrics['iou/orig_i'] = orig_total_i |
|
|
loss_and_metrics['iou/orig_u'] = orig_total_u |
|
|
|
|
|
new_total_i, new_total_u = compute_tensor_iu(pred_224>0.5, gt>0.5) |
|
|
loss_and_metrics['iou/new_i_224'] = new_total_i |
|
|
loss_and_metrics['iou/new_u_224'] = new_total_u |
|
|
|
|
|
""" |
|
|
All done. |
|
|
Now gather everything in a dict for logging |
|
|
""" |
|
|
|
|
|
if need_loss: |
|
|
loss_and_metrics['total_loss'] = 0 |
|
|
loss_and_metrics['ce_loss'] = ce_loss |
|
|
loss_and_metrics['l1_loss'] = l1_loss |
|
|
loss_and_metrics['l2_loss'] = l2_loss |
|
|
loss_and_metrics['loss'] = loss |
|
|
|
|
|
loss_and_metrics['total_loss'] += loss |
|
|
|
|
|
return loss_and_metrics |
|
|
|
|
|
|