| import json | |
| from argparse import ArgumentParser | |
| from collections import defaultdict | |
| import numpy as np | |
| from tqdm import tqdm | |
| from nltk.corpus import framenet as fn | |
| from sftp import SpanPredictor | |
| def run(model_path, data_path, device, use_ontology=False): | |
| data = list(map(json.loads, open(data_path).readlines())) | |
| lu2frame = defaultdict(list) | |
| for lu in fn.lus(): | |
| lu2frame[lu.name].append(lu.frame.name) | |
| predictor = SpanPredictor.from_path(model_path, cuda_device=device) | |
| frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label') | |
| all_frames = [fr.name for fr in fn.frames()] | |
| n_positive = n_total = 0 | |
| with tqdm(total=len(data)) as bar: | |
| for sent in data: | |
| bar.update() | |
| for point in sent['annotations']: | |
| model_output = predictor.force_decode( | |
| sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])] | |
| ).distribution[0] | |
| if use_ontology: | |
| candidate_frames = lu2frame[point['lu']] | |
| else: | |
| candidate_frames = all_frames | |
| candidate_prob = [-1.0 for _ in candidate_frames] | |
| for idx_can, fr in enumerate(candidate_frames): | |
| if fr in frame2idx: | |
| candidate_prob[idx_can] = model_output[frame2idx[fr]] | |
| if len(candidate_prob) > 0: | |
| pred_frame = candidate_frames[int(np.argmax(candidate_prob))] | |
| if pred_frame == point['label']: | |
| n_positive += 1 | |
| n_total += 1 | |
| bar.set_description(f'acc={n_positive/n_total*100:.3f}') | |
| print(f'acc={n_positive/n_total*100:.3f}') | |
| if __name__ == '__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument('model', metavar="MODEL") | |
| parser.add_argument('data', metavar="DATA") | |
| parser.add_argument('-d', default=-1, type=int, help='Device') | |
| parser.add_argument('-o', action='store_true', help='Flag to use ontology.') | |
| args = parser.parse_args() | |
| run(args.model, args.data, args.d, args.o) | |