| | |
| | |
| | |
| |
|
| | import torch |
| | import numpy as np |
| | from itertools import product |
| |
|
| | def compute_prf1(count, miss, fp): |
| | """ |
| | Code modified from https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/simple_romp/evaluation/RH_evaluation/evaluation.py#L90 |
| | """ |
| | if count == 0: |
| | return 0, 0, 0 |
| | all_tp = count - miss |
| | all_fp = fp |
| | all_fn = miss |
| | if all_tp == 0: |
| | return 0., 0., 0. |
| | all_f1_score = round(all_tp / (all_tp + 0.5 * (all_fp + all_fn)), 2) |
| | all_recall = round(all_tp / (all_tp + all_fn), 2) |
| | all_precision = round(all_tp / (all_tp + all_fp), 2) |
| | return 100. * all_precision, 100.* all_recall, 100. * all_f1_score |
| |
|
| | def match_2d_greedy( |
| | pred_kps, |
| | gtkp, |
| | valid_mask, |
| | imgPath=None, |
| | baseline=None, |
| | iou_thresh=0.05, |
| | valid=None, |
| | ind=-1): |
| | ''' |
| | Code modified from: https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/simple_romp/trace2/evaluation/eval_3DPW.py#L232 |
| | matches groundtruth keypoints to the detection by considering all possible matchings. |
| | :return: best possible matching, a list of tuples, where each tuple corresponds to one match of pred_person.to gt_person. |
| | the order within one tuple is as follows (idx_pred_kps, idx_gt_kps) |
| | ''' |
| | predList = np.arange(len(pred_kps)) |
| | gtList = np.arange(len(gtkp)) |
| | |
| | |
| | combs = list(product(predList, gtList)) |
| |
|
| | errors_per_pair = {} |
| | errors_per_pair_list = [] |
| | for comb in combs: |
| | vmask = valid_mask[comb[1]] |
| | assert vmask.sum()>0, print('no valid points') |
| | errors_per_pair[str(comb)] = np.linalg.norm(pred_kps[comb[0]][vmask, :2] - gtkp[comb[1]][vmask, :2], 2) |
| | errors_per_pair_list.append(errors_per_pair[str(comb)]) |
| |
|
| | gtAssigned = np.zeros((len(gtkp),), dtype=bool) |
| | opAssigned = np.zeros((len(pred_kps),), dtype=bool) |
| | errors_per_pair_list = np.array(errors_per_pair_list) |
| |
|
| | bestMatch = [] |
| | excludedGtBecauseInvalid = [] |
| | falsePositiveCounter = 0 |
| | while np.sum(gtAssigned) < len(gtAssigned) and np.sum( |
| | opAssigned) + falsePositiveCounter < len(pred_kps): |
| | found = False |
| | falsePositive = False |
| | while not(found): |
| | if sum(np.inf == errors_per_pair_list) == len( |
| | errors_per_pair_list): |
| | print('something went wrong here') |
| |
|
| | minIdx = np.argmin(errors_per_pair_list) |
| | minComb = combs[minIdx] |
| | |
| | iou = get_bbx_overlap( |
| | pred_kps[minComb[0]], gtkp[minComb[1]]) |
| | |
| | |
| | if not(opAssigned[minComb[0]]) and not( |
| | gtAssigned[minComb[1]]) and iou >= iou_thresh: |
| | |
| | found = True |
| | errors_per_pair_list[minIdx] = np.inf |
| | else: |
| | errors_per_pair_list[minIdx] = np.inf |
| | |
| | |
| | if iou < iou_thresh: |
| | |
| | |
| | found = True |
| | falsePositive = True |
| | falsePositiveCounter += 1 |
| |
|
| | |
| | |
| | if not(valid is None): |
| | if valid[minComb[1]]: |
| | if not falsePositive: |
| | bestMatch.append(minComb) |
| | opAssigned[minComb[0]] = True |
| | gtAssigned[minComb[1]] = True |
| | else: |
| | gtAssigned[minComb[1]] = True |
| | excludedGtBecauseInvalid.append(minComb[1]) |
| |
|
| | elif not falsePositive: |
| | |
| | bestMatch.append(minComb) |
| | opAssigned[minComb[0]] = True |
| | gtAssigned[minComb[1]] = True |
| |
|
| | bestMatch = np.array(bestMatch) |
| | |
| | |
| | opAssigned = [] |
| | gtAssigned = [] |
| | for pair in bestMatch: |
| | opAssigned.append(pair[0]) |
| | gtAssigned.append(pair[1]) |
| | opAssigned.sort() |
| | gtAssigned.sort() |
| |
|
| | falsePositives = [] |
| | misses = [] |
| |
|
| | |
| | opIds = np.arange(len(pred_kps)) |
| | |
| | notAssignedIds = np.setdiff1d(opIds, opAssigned) |
| | for notAssignedId in notAssignedIds: |
| | falsePositives.append(notAssignedId) |
| | gtIds = np.arange(len(gtList)) |
| | |
| | notAssignedIdsGt = np.setdiff1d(gtIds, gtAssigned) |
| |
|
| | |
| | for notAssignedIdGt in notAssignedIdsGt: |
| | if not(valid is None): |
| | if valid[notAssignedIdGt]: |
| | |
| | misses.append(notAssignedIdGt) |
| | else: |
| | excludedGtBecauseInvalid.append(notAssignedIdGt) |
| | else: |
| | |
| | misses.append(notAssignedIdGt) |
| |
|
| | return bestMatch, falsePositives, misses |
| |
|
| | def get_bbx_overlap(p1, p2): |
| | """ |
| | Code modifed from https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/simple_romp/trace2/evaluation/eval_3DPW.py#L185 |
| | """ |
| | min_p1 = np.min(p1, axis=0) |
| | min_p2 = np.min(p2, axis=0) |
| | max_p1 = np.max(p1, axis=0) |
| | max_p2 = np.max(p2, axis=0) |
| |
|
| | bb1 = {} |
| | bb2 = {} |
| |
|
| | bb1['x1'] = min_p1[0] |
| | bb1['x2'] = max_p1[0] |
| | bb1['y1'] = min_p1[1] |
| | bb1['y2'] = max_p1[1] |
| | bb2['x1'] = min_p2[0] |
| | bb2['x2'] = max_p2[0] |
| | bb2['y1'] = min_p2[1] |
| | bb2['y2'] = max_p2[1] |
| |
|
| | assert bb1['x1'] < bb1['x2'] |
| | assert bb1['y1'] < bb1['y2'] |
| | assert bb2['x1'] < bb2['x2'] |
| | assert bb2['y1'] < bb2['y2'] |
| | |
| | x_left = max(bb1['x1'], bb2['x1']) |
| | y_top = max(bb1['y1'], bb2['y1']) |
| | x_right = min(bb1['x2'], bb2['x2']) |
| | y_bottom = min(bb1['y2'], bb2['y2']) |
| |
|
| | |
| | |
| | intersection_area = max(0, x_right - x_left + 1) * \ |
| | max(0, y_bottom - y_top + 1) |
| |
|
| | |
| | bb1_area = (bb1['x2'] - bb1['x1'] + 1) * (bb1['y2'] - bb1['y1'] + 1) |
| | bb2_area = (bb2['x2'] - bb2['x1'] + 1) * (bb2['y2'] - bb2['y1'] + 1) |
| |
|
| | |
| | |
| | |
| | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) |
| |
|
| | return iou |
| |
|
| |
|
| | class AverageMeter(object): |
| | """ |
| | Code mofied from https://github.com/pytorch/examples/blob/main/imagenet/main.py#L423 |
| | Computes and stores the average and current value |
| | """ |
| |
|
| | def __init__(self, name, fmt=':f'): |
| | self.name = name |
| | self.fmt = fmt |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | if type(val) == torch.Tensor: |
| | val = val.detach() |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| | def __str__(self): |
| | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| | return fmtstr.format(**self.__dict__) |
| | |
| |
|