Spaces:
Build error
Build error
| # code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| IMG_RES = 256 # in WLDO it is 224 | |
| class Metrics(): | |
| def PCK_thresh( | |
| pred_keypoints, gt_keypoints, | |
| gtseg, has_seg, | |
| thresh, idxs, biggs=False): | |
| pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg] | |
| if idxs is None: | |
| idxs = list(range(pred_keypoints.shape[1])) | |
| idxs = np.array(idxs).astype(int) | |
| pred_keypoints = pred_keypoints[:, idxs] | |
| gt_keypoints = gt_keypoints[:, idxs] | |
| if biggs: | |
| keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES | |
| dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1) | |
| else: | |
| keypoints_gt = gt_keypoints # (0 to IMG_SIZE) | |
| dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1) | |
| seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1) | |
| hits = (dist / torch.sqrt(seg_area)) < thresh | |
| total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1) | |
| pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible | |
| return pck | |
| def PCK( | |
| pred_keypoints, keypoints, | |
| gtseg, has_seg, | |
| thresh_range=[0.15], | |
| idxs:list=None, | |
| biggs=False): | |
| """Calc PCK with same method as in eval. | |
| idxs = optional list of subset of keypoints to index from | |
| """ | |
| cumulative_pck = [] | |
| for thresh in thresh_range: | |
| pck = Metrics.PCK_thresh( | |
| pred_keypoints, keypoints, | |
| gtseg, has_seg, thresh, idxs, | |
| biggs=biggs) | |
| cumulative_pck.append(pck) | |
| pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0) | |
| return pck_mean | |
| def IOU(synth_silhouettes, gt_seg, img_border_mask, mask): | |
| for i in range(mask.shape[0]): | |
| synth_silhouettes[i] *= mask[i] | |
| # Do not penalize parts of the segmentation outside the img range | |
| gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask) | |
| intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1) | |
| union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1) | |
| acc_IOU_SCORE = intersection / union | |
| if torch.isnan(acc_IOU_SCORE).sum() > 0: | |
| import pdb; pdb.set_trace() | |
| return acc_IOU_SCORE |