import os, sys, argparse from copy import deepcopy import numpy as np import torch from evaluation.utils_3d import get_instances parser = argparse.ArgumentParser() parser.add_argument('--pred_path', required=True, help='path to directory of predicted .txt files') parser.add_argument('--gt_path', required=True, help='path to directory of ground truth .txt files') parser.add_argument('--dataset', required=True, help='type of dataset, e.g. matterport3d, scannet, etc.') parser.add_argument('--output_file', default='', help='path to output file') parser.add_argument('--no_class', action='store_true', help='class agnostic evaluation') opt = parser.parse_args() # ---------- Label info ---------- # from evaluation.constants import MATTERPORT_LABELS, MATTERPORT_IDS, SCANNET_LABELS, SCANNET_IDS, SCANNETPP_LABELS, SCANNETPP_IDS if opt.dataset == 'matterport3d': CLASS_LABELS = MATTERPORT_LABELS VALID_CLASS_IDS = MATTERPORT_IDS elif opt.dataset == 'scannet': CLASS_LABELS = SCANNET_LABELS VALID_CLASS_IDS = SCANNET_IDS elif opt.dataset == 'scannetpp': CLASS_LABELS = SCANNETPP_LABELS VALID_CLASS_IDS = SCANNETPP_IDS if opt.output_file == '': opt.output_file = os.path.join(f'data/evaluation/{opt.dataset}', opt.pred_path.split('/')[-1] + '.txt') os.makedirs(os.path.dirname(opt.output_file), exist_ok=True) if opt.no_class: if 'class_agnostic' not in opt.output_file: opt.output_file = opt.output_file.replace('.txt', '_class_agnostic.txt') ID_TO_LABEL = {} LABEL_TO_ID = {} for i in range(len(VALID_CLASS_IDS)): LABEL_TO_ID[CLASS_LABELS[i]] = VALID_CLASS_IDS[i] ID_TO_LABEL[VALID_CLASS_IDS[i]] = CLASS_LABELS[i] # ---------- Evaluation params ---------- # # overlaps for evaluation opt.overlaps = np.append(np.arange(0.5,0.95,0.05), 0.25) # minimum region size for evaluation [verts] opt.min_region_sizes = np.array( [ 100 ] ) # distance thresholds [m] opt.distance_threshes = np.array( [ float('inf') ] ) # distance confidences opt.distance_confs = np.array( [ -float('inf') ] ) def evaluate_matches(matches): overlaps = opt.overlaps min_region_sizes = [ opt.min_region_sizes[0] ] dist_threshes = [ opt.distance_threshes[0] ] dist_confs = [ opt.distance_confs[0] ] # results: class x overlap ap = np.zeros( (len(dist_threshes) , len(CLASS_LABELS) , len(overlaps)) , float ) for di, (min_region_size, distance_thresh, distance_conf) in enumerate(zip(min_region_sizes, dist_threshes, dist_confs)): for oi, overlap_th in enumerate(overlaps): pred_visited = {} for m in matches: for p in matches[m]['pred']: for label_name in CLASS_LABELS: for p in matches[m]['pred'][label_name]: if 'filename' in p: pred_visited[p['filename']] = False for li, label_name in enumerate(CLASS_LABELS): y_true = np.empty(0) y_score = np.empty(0) hard_false_negatives = 0 has_gt = False has_pred = False for m in matches: pred_instances = matches[m]['pred'][label_name] gt_instances = matches[m]['gt'][label_name] # filter groups in ground truth gt_instances = [ gt for gt in gt_instances if gt['instance_id']>=1000 and gt['vert_count']>=min_region_size and gt['med_dist']<=distance_thresh and gt['dist_conf']>=distance_conf ] if gt_instances: has_gt = True if pred_instances: has_pred = True cur_true = np.ones ( len(gt_instances) ) cur_score = np.ones ( len(gt_instances) ) * (-float("inf")) cur_match = np.zeros( len(gt_instances) , dtype=bool ) # collect matches for (gti,gt) in enumerate(gt_instances): found_match = False num_pred = len(gt['matched_pred']) for pred in gt['matched_pred']: # greedy assignments if pred_visited[pred['filename']]: continue overlap = float(pred['intersection']) / (gt['vert_count']+pred['vert_count']-pred['intersection']) if overlap > overlap_th: confidence = pred['confidence'] # if already have a prediction for this gt, # the prediction with the lower score is automatically a false positive if cur_match[gti]: max_score = max( cur_score[gti] , confidence ) min_score = min( cur_score[gti] , confidence ) cur_score[gti] = max_score # append false positive cur_true = np.append(cur_true,0) cur_score = np.append(cur_score,min_score) cur_match = np.append(cur_match,True) # otherwise set score else: found_match = True cur_match[gti] = True cur_score[gti] = confidence pred_visited[pred['filename']] = True if not found_match: hard_false_negatives += 1 # remove non-matched ground truth instances cur_true = cur_true [ cur_match==True ] cur_score = cur_score[ cur_match==True ] # collect non-matched predictions as false positive for pred in pred_instances: found_gt = False for gt in pred['matched_gt']: overlap = float(gt['intersection']) / (gt['vert_count']+pred['vert_count']-gt['intersection']) if overlap > overlap_th: found_gt = True break if not found_gt: num_ignore = pred['void_intersection'] for gt in pred['matched_gt']: # group? if gt['instance_id'] < 1000: num_ignore += gt['intersection'] # small ground truth instances if gt['vert_count'] < min_region_size or gt['med_dist']>distance_thresh or gt['dist_conf'] 0, then the prediction is considered a match ''' pred_info = read_pridiction_npz(os.path.join(pred_file)) gt_ids = np.loadtxt(gt_file) if opt.no_class: gt_ids = gt_ids % 1000 + VALID_CLASS_IDS[0] * 1000 # get gt instances gt_instances = get_instances(gt_ids, VALID_CLASS_IDS, CLASS_LABELS, ID_TO_LABEL) # associate gt2pred = deepcopy(gt_instances) for label in gt2pred: for gt in gt2pred[label]: gt['matched_pred'] = [] pred2gt = {} for label in CLASS_LABELS: pred2gt[label] = [] num_pred_instances = 0 # mask of void labels in the groundtruth bool_void = np.logical_not(np.in1d(gt_ids//1000, VALID_CLASS_IDS)) gt_tensor_dict = get_gt_tensor(gt_ids, gt_instances) # go thru all prediction masks for pred_mask_file in (pred_info): if opt.no_class: label_id = VALID_CLASS_IDS[0] else: label_id = int(pred_info[pred_mask_file]['label_id']) conf = pred_info[pred_mask_file]['conf'] if not label_id in ID_TO_LABEL: continue label_name = ID_TO_LABEL[label_id] # read the mask pred_mask = pred_info[pred_mask_file]['mask'] if len(pred_mask) != len(gt_ids): print('wrong number of lines in ' + pred_mask_file + '(%d) vs #mesh vertices (%d), please double check and/or re-download the mesh' % (len(pred_mask), len(gt_ids))) raise NotImplementedError # convert to binary pred_mask = np.not_equal(pred_mask, 0) num = np.count_nonzero(pred_mask) if num < opt.min_region_sizes[0]: continue # skip if empty pred_instance = {} pred_instance['filename'] = pred_mask_file pred_instance['pred_id'] = num_pred_instances pred_instance['label_id'] = label_id pred_instance['vert_count'] = num pred_instance['confidence'] = conf pred_instance['void_intersection'] = np.count_nonzero(np.logical_and(bool_void, pred_mask)) # matched gt instances matched_gt = [] gt_tensor = gt_tensor_dict[label_name] intersection = torch.sum(gt_tensor & torch.from_numpy(pred_mask).cuda().reshape(-1, 1), dim=0) intersect_ids = torch.nonzero(intersection).cpu().numpy().reshape(-1) for gt_id in intersect_ids: gt_copy = gt_instances[label_name][gt_id].copy() pred_copy = pred_instance.copy() intersection_num = intersection[gt_id].item() gt_copy['intersection'] = intersection_num pred_copy['intersection'] = intersection_num matched_gt.append(gt_copy) gt2pred[label_name][gt_id]['matched_pred'].append(pred_copy) pred_instance['matched_gt'] = matched_gt num_pred_instances += 1 pred2gt[label_name].append(pred_instance) return gt2pred, pred2gt def print_results(avgs): sep = "" col1 = ":" lineLen = 64 print ("") print ("#"*lineLen) line = "" line += "{:<15}".format("what" ) + sep + col1 line += "{:>15}".format("AP" ) + sep line += "{:>15}".format("AP_50%" ) + sep line += "{:>15}".format("AP_25%" ) + sep print (line) print ("#"*lineLen) for (li,label_name) in enumerate(CLASS_LABELS): ap_avg = avgs["classes"][label_name]["ap"] if np.isnan(ap_avg): continue ap_50o = avgs["classes"][label_name]["ap50%"] ap_25o = avgs["classes"][label_name]["ap25%"] line = "{:<15}".format(label_name) + sep + col1 line += sep + "{:>15.3f}".format(ap_avg ) + sep line += sep + "{:>15.3f}".format(ap_50o ) + sep line += sep + "{:>15.3f}".format(ap_25o ) + sep print (line) all_ap_avg = avgs["all_ap"] all_ap_50o = avgs["all_ap_50%"] all_ap_25o = avgs["all_ap_25%"] print ("-"*lineLen) line = "{:<15}".format("average") + sep + col1 line += "{:>15.3f}".format(all_ap_avg) + sep line += "{:>15.3f}".format(all_ap_50o) + sep line += "{:>15.3f}".format(all_ap_25o) + sep print (line) print ("") def write_result_file(avgs, filename): _SPLITTER = ',' with open(filename, 'w') as f: f.write(_SPLITTER.join(['class', 'class id', 'ap', 'ap50', 'ap25']) + '\n') for i in range(len(VALID_CLASS_IDS)): class_name = CLASS_LABELS[i] class_id = VALID_CLASS_IDS[i] ap = avgs["classes"][class_name]["ap"] ap50 = avgs["classes"][class_name]["ap50%"] ap25 = avgs["classes"][class_name]["ap25%"] f.write(_SPLITTER.join([str(x) for x in [class_name, class_id, ap, ap50, ap25]]) + '\n') f.write(_SPLITTER.join([str(x) for x in [avgs["all_ap"], avgs["all_ap_50%"], avgs["all_ap_25%"]]]) + '\n') def evaluate(pred_files, gt_files, pred_path, output_file): print ('evaluating', len(pred_files), 'scans...') matches = {} for i in range(len(pred_files)): matches_key = os.path.abspath(gt_files[i]) # assign gt to predictions gt2pred, pred2gt = assign_instances_for_scan(pred_files[i], gt_files[i]) matches[matches_key] = {} matches[matches_key]['gt'] = gt2pred matches[matches_key]['pred'] = pred2gt sys.stdout.write("\rscans processed: {}".format(i+1)) sys.stdout.flush() ap_scores = evaluate_matches(matches) avgs = compute_averages(ap_scores) # print print_results(avgs) write_result_file(avgs, output_file) def main(): print('start evaluating:', opt.pred_path.split('/')[-1]) pred_files = [f for f in sorted(os.listdir(opt.pred_path)) if f.endswith('.npz') and not f.startswith('semantic_instance_evaluation')] gt_files = [] for i in range(len(pred_files)): gt_file = os.path.join(opt.gt_path, pred_files[i].replace('.npz', '.txt')) if not os.path.isfile(gt_file): print('Result file {} does not match any gt file'.format(pred_files[i])) raise NotImplementedError gt_files.append(gt_file) pred_files[i] = os.path.join(opt.pred_path, pred_files[i]) evaluate(pred_files, gt_files, opt.pred_path, opt.output_file) print('save results to', opt.output_file) if __name__ == '__main__': main()