Spaces:
Runtime error
Runtime error
| import json | |
| import pkg_resources | |
| from collections import OrderedDict | |
| # Paths | |
| data_path = pkg_resources.resource_filename('spiga', 'data/annotations') | |
| def main(): | |
| import argparse | |
| pars = argparse.ArgumentParser(description='Benchmark alignments evaluator') | |
| pars.add_argument('pred_file', nargs='+', type=str, help='Absolute path to the prediction json file (Multi file)') | |
| pars.add_argument('--eval', nargs='+', type=str, default=['lnd'], | |
| choices=['lnd', 'pose'], help='Evaluation modes') | |
| pars.add_argument('-s', '--save', action='store_true', help='Save results') | |
| args = pars.parse_args() | |
| for pred_file in args.pred_file: | |
| benchmark = get_evaluator(pred_file, args.eval, args.save) | |
| benchmark.metrics() | |
| class Evaluator: | |
| def __init__(self, data_file, evals=(), save=True, process_err=True): | |
| # Inputs | |
| self.data_file = data_file | |
| self.evals = evals | |
| self.save = save | |
| # Paths | |
| data_name = data_file.split('/')[-1] | |
| self.data_dir = data_file.split(data_name)[0] | |
| # Information from name | |
| data_name = data_name.split('.')[0] | |
| data_name = data_name.split('_') | |
| self.data_type = data_name[-1] | |
| self.database = data_name[-2] | |
| # Load predictions and annotations | |
| anns_file = data_path + '/%s/%s.json' % (self.database, self.data_type) | |
| self.anns = self.load_files(anns_file) | |
| self.pred = self.load_files(data_file) | |
| # Compute errors | |
| self.error = OrderedDict() | |
| self.error_pimg = OrderedDict() | |
| self.metrics_log = OrderedDict() | |
| if process_err: | |
| self.compute_error(self.anns, self.pred) | |
| def compute_error(self, anns, pred, select_ids=None): | |
| database_ref = [self.database, self.data_type] | |
| for eval in self.evals: | |
| self.error[eval.name] = eval.compute_error(anns, pred, database_ref, select_ids) | |
| self.error_pimg = eval.get_pimg_err(self.error_pimg) | |
| return self.error | |
| def metrics(self): | |
| for eval in self.evals: | |
| self.metrics_log[eval.name] = eval.metrics() | |
| if self.save: | |
| file_name = self.data_dir + '/metrics_%s_%s.txt' % (self.database, self.data_type) | |
| with open(file_name, 'w') as file: | |
| file.write(str(self)) | |
| return self.metrics_log | |
| def load_files(self, input_file): | |
| with open(input_file) as jsonfile: | |
| data = json.load(jsonfile) | |
| return data | |
| def _dict2text(self, name, dictionary, num_tab=1): | |
| prev_tabs = '\t'*num_tab | |
| text = '%s {\n' % name | |
| for k, v in dictionary.items(): | |
| if isinstance(v, OrderedDict) or isinstance(v, dict): | |
| text += '{}{}'.format(prev_tabs, self._dict2text(k, v, num_tab=num_tab+1)) | |
| else: | |
| text += '{}{}: {}\n'.format(prev_tabs, k, v) | |
| text += (prev_tabs + '}\n') | |
| return text | |
| def __str__(self): | |
| state_dict = self.metrics_log | |
| text = self._dict2text('Metrics', state_dict) | |
| return text | |
| def get_evaluator(pred_file, evaluate=('lnd', 'pose'), save=False, process_err=True): | |
| eval_list = [] | |
| if "lnd" in evaluate: | |
| import spiga.eval.benchmark.metrics.landmarks as mlnd | |
| eval_list.append(mlnd.MetricsLandmarks()) | |
| if "pose" in evaluate: | |
| import spiga.eval.benchmark.metrics.pose as mpose | |
| eval_list.append(mpose.MetricsHeadpose()) | |
| return Evaluator(pred_file, evals=eval_list, save=save, process_err=process_err) | |
| if __name__ == '__main__': | |
| main() | |