|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def evaluation(outputs, targets_data, meta_info, mode='val', thres=0.5): |
|
|
eval_out = {} |
|
|
|
|
|
|
|
|
mesh_valid = meta_info['mano_valid'] is not None |
|
|
|
|
|
|
|
|
contact_pred = outputs['contact_out'].sigmoid()[0].detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if mesh_valid: |
|
|
|
|
|
cont_pre, cont_rec, cont_f1 = compute_contact_metrics(targets_data['contact_data']['contact_h'][0].detach().cpu().numpy(), outputs['contact_out'][0].detach().cpu().numpy(), mesh_valid, thres=thres) |
|
|
eval_out['cont_pre'] = cont_pre |
|
|
eval_out['cont_rec'] = cont_rec |
|
|
eval_out['cont_f1'] = cont_f1 |
|
|
|
|
|
return eval_out |
|
|
|
|
|
|
|
|
def compute_contact_metrics(gt, pred, valid, thres=0.5): |
|
|
""" |
|
|
Compute precision, recall, and f1 using NumPy |
|
|
""" |
|
|
if valid: |
|
|
|
|
|
tp_num = np.sum(gt[pred >= thres]) |
|
|
|
|
|
|
|
|
precision_denominator = np.sum(pred >= thres) |
|
|
recall_denominator = np.sum(gt) |
|
|
|
|
|
|
|
|
precision_ = tp_num / precision_denominator if precision_denominator > 0 else None |
|
|
recall_ = tp_num / recall_denominator if recall_denominator > 0 else None |
|
|
if precision_ is not None and recall_ is not None and (precision_ + recall_) > 0: |
|
|
f1_ = 2 * precision_ * recall_ / (precision_ + recall_) |
|
|
else: |
|
|
f1_ = None |
|
|
else: |
|
|
|
|
|
precision_ = None |
|
|
recall_ = None |
|
|
f1_ = None |
|
|
|
|
|
return precision_, recall_, f1_ |