| import os | |
| import sys | |
| import json | |
| from pprint import pprint | |
| from collections import defaultdict | |
| from sftp.metrics.exact_match import ExactMatch | |
| def evaluate(): | |
| em = ExactMatch(True) | |
| sm = ExactMatch(False) | |
| gold_file, pred_file = sys.argv[1:] | |
| test_sentences = {json.loads(line)['meta']['sentence ID']: json.loads(line) for line in open(gold_file).readlines()} | |
| pred_sentences = defaultdict(list) | |
| for line in open(pred_file).readlines(): | |
| one_pred = json.loads(line) | |
| pred_sentences[one_pred['meta']['sentence ID']].append(one_pred) | |
| for sent_id, gold_sent in test_sentences.items(): | |
| pred_sent = pred_sentences.get(sent_id, []) | |
| pred_frames, pred_fes = [], [] | |
| for fr_idx, fr in enumerate(pred_sent): | |
| pred_frames.append({key: fr[key] for key in ["start_idx", "end_idx", "label"]}) | |
| pred_frames[-1]['parent'] = 0 | |
| for fe in fr['children']: | |
| pred_fes.append({key: fe[key] for key in ["start_idx", "end_idx", "label"]}) | |
| pred_fes[-1]['parent'] = fr_idx+1 | |
| pred_to_eval = pred_frames + pred_fes | |
| gold_frames, gold_fes = [], [] | |
| for fr_idx, fr in enumerate(gold_sent['frame']): | |
| gold_frames.append({ | |
| 'start_idx': fr['target'][0], 'end_idx': fr['target'][-1], "label": fr['name'], 'parent': 0 | |
| }) | |
| for start_idx, end_idx, fe_name in fr['fe']: | |
| gold_fes.append({ | |
| "start_idx": start_idx, "end_idx": end_idx, "label": fe_name, "parent": fr_idx+1 | |
| }) | |
| gold_to_eval = gold_frames + gold_fes | |
| em(pred_to_eval, gold_to_eval) | |
| sm(pred_to_eval, gold_to_eval) | |
| print('EM') | |
| pprint(em.get_metric(True)) | |
| print('SM') | |
| pprint(sm.get_metric(True)) | |
| if __name__ == '__main__': | |
| evaluate() | |