HACO / lib /utils /eval_utils.py
dqj5182's picture
init
5732928
import torch
import numpy as np
def evaluation(outputs, targets_data, meta_info, mode='val', thres=0.5):
eval_out = {}
# GT
mesh_valid = meta_info['mano_valid'] is not None
# Pred
contact_pred = outputs['contact_out'].sigmoid()[0].detach().cpu().numpy()
# Error Calculate
if mesh_valid:
# Contact Metrics
cont_pre, cont_rec, cont_f1 = compute_contact_metrics(targets_data['contact_data']['contact_h'][0].detach().cpu().numpy(), outputs['contact_out'][0].detach().cpu().numpy(), mesh_valid, thres=thres)
eval_out['cont_pre'] = cont_pre
eval_out['cont_rec'] = cont_rec
eval_out['cont_f1'] = cont_f1
return eval_out
def compute_contact_metrics(gt, pred, valid, thres=0.5):
"""
Compute precision, recall, and f1 using NumPy
"""
if valid:
# True Positives
tp_num = np.sum(gt[pred >= thres])
# Denominators for precision and recall
precision_denominator = np.sum(pred >= thres)
recall_denominator = np.sum(gt)
# Compute precision, recall, and F1 score
precision_ = tp_num / precision_denominator if precision_denominator > 0 else None
recall_ = tp_num / recall_denominator if recall_denominator > 0 else None
if precision_ is not None and recall_ is not None and (precision_ + recall_) > 0:
f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
else:
f1_ = None
else:
# If not valid, return None for metrics
precision_ = None
recall_ = None
f1_ = None
return precision_, recall_, f1_