drozdgk's picture
chore: vendor third_party (remove submodules, ignore artifacts)
352cafd
raw
history blame
2.58 kB
import torch.nn.functional as F
from util.util import compute_tensor_iu
def get_new_iou_hook(values, size):
return 'iou/new_iou_%s'%size, values['iou/new_i_%s'%size]/values['iou/new_u_%s'%size]
def get_orig_iou_hook(values):
return 'iou/orig_iou', values['iou/orig_i']/values['iou/orig_u']
def get_iou_gain(values, size):
return 'iou/iou_gain_%s'%size, values['iou/new_iou_%s'%size] - values['iou/orig_iou']
iou_hooks_to_be_used = [
get_orig_iou_hook,
lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'),
]
iou_hooks_final_only = [
get_orig_iou_hook,
lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'),
]
# Compute common loss and metric for generator only
def compute_loss_and_metrics(images, para, detailed=True, need_loss=True, has_lower_res=True):
"""
This part compute loss and metrics for the generator
"""
loss_and_metrics = {}
gt = images['gt']
seg = images['seg']
pred_224 = images['pred_224']
if need_loss:
# Loss weights
ce_weights = para['ce_weight']
l1_weights = para['l1_weight']
l2_weights = para['l2_weight']
# temp holder for losses at different scale
ce_loss = 0
l1_loss = 0
l2_loss = 0
loss = 0
ce_loss = F.binary_cross_entropy_with_logits(images['out_224'], (gt>0.5).float())
l1_loss = F.l1_loss(pred_224, gt)
l2_loss = F.mse_loss(pred_224, gt)
loss_and_metrics['grad_loss'] = F.l1_loss(images['gt_sobel'], images['pred_sobel'])
# Weighted loss for different levels
loss = ce_loss * ce_weights + l1_loss * l1_weights + l2_loss * l2_weights
loss += loss_and_metrics['grad_loss'] * para['grad_weight']
"""
Compute IOU stats
"""
orig_total_i, orig_total_u = compute_tensor_iu(seg>0.5, gt>0.5)
loss_and_metrics['iou/orig_i'] = orig_total_i
loss_and_metrics['iou/orig_u'] = orig_total_u
new_total_i, new_total_u = compute_tensor_iu(pred_224>0.5, gt>0.5)
loss_and_metrics['iou/new_i_224'] = new_total_i
loss_and_metrics['iou/new_u_224'] = new_total_u
"""
All done.
Now gather everything in a dict for logging
"""
if need_loss:
loss_and_metrics['total_loss'] = 0
loss_and_metrics['ce_loss'] = ce_loss
loss_and_metrics['l1_loss'] = l1_loss
loss_and_metrics['l2_loss'] = l2_loss
loss_and_metrics['loss'] = loss
loss_and_metrics['total_loss'] += loss
return loss_and_metrics