| from argparse import ArgumentParser | |
| from collections import defaultdict | |
| from torch import nn | |
| from copy import deepcopy | |
| import torch | |
| import os | |
| import json | |
| from sftp import SpanPredictor | |
| import nltk | |
| def shift_grid_cos_sim(mat: torch.Tensor): | |
| mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1) | |
| mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1) | |
| cos = nn.CosineSimilarity(2) | |
| sim = (cos(mat1, mat2) + 1) / 2 | |
| return sim | |
| def all_frames(): | |
| nltk.download('framenet_v17') | |
| fn = nltk.corpus.framenet | |
| return fn.frames() | |
| def extract_relations(fr): | |
| ret = list() | |
| added = {fr.name} | |
| for rel in fr.frameRelations: | |
| for key in ['subFrameName', 'superFrameName']: | |
| rel_fr_name = rel[key] | |
| if rel_fr_name in added: | |
| continue | |
| ret.append((rel_fr_name, key[:-4])) | |
| return ret | |
| def run(): | |
| parser = ArgumentParser() | |
| parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str) | |
| parser.add_argument('dst', metavar='DESTINATION', type=str) | |
| parser.add_argument('kairos', metavar='KAIROS', type=str) | |
| parser.add_argument('--topk', metavar='TOPK', type=int, default=10) | |
| args = parser.parse_args() | |
| predictor = SpanPredictor.from_path(args.archive, cuda_device=-1) | |
| kairos_gold_mapping = json.load(open(args.kairos)) | |
| label_emb = predictor._model._span_typing.label_emb.weight.clone().detach() | |
| idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label') | |
| emb_sim = shift_grid_cos_sim(label_emb) | |
| fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()} | |
| last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone() | |
| mlp_sim = shift_grid_cos_sim(last_mlp) | |
| def rank_frame(sim): | |
| rank = sim.argsort(1, True) | |
| scores = sim.gather(1, rank) | |
| mapping = { | |
| fr.name: { | |
| 'similarity': list(), | |
| 'ontology': extract_relations(fr), | |
| 'URL': fr.URL, | |
| 'definition': fr.definition | |
| } for fr in all_frames() | |
| } | |
| for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)): | |
| left_label = idx2label[left_idx] | |
| if left_label not in mapping: | |
| continue | |
| for right_idx, s in zip(right_indices, match_scores): | |
| right_label = idx2label[int(right_idx)] | |
| if right_label not in mapping or right_idx == left_idx: | |
| continue | |
| mapping[left_label]['similarity'].append((right_label, float(s))) | |
| return mapping | |
| emb_map = rank_frame(emb_sim) | |
| mlp_map = rank_frame(mlp_sim) | |
| def dump(mapping, folder_path): | |
| os.makedirs(folder_path, exist_ok=True) | |
| json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w')) | |
| sim_lines, onto_lines = list(), list() | |
| for fr, values in mapping.items(): | |
| sim_line = [ | |
| fr, | |
| values['definition'], | |
| values['URL'], | |
| ] | |
| onto_line = deepcopy(sim_line) | |
| for rel_fr_name, rel_type in values['ontology']: | |
| onto_line.append(f'{rel_fr_name} ({rel_type})') | |
| onto_lines.append('\t'.join(onto_line)) | |
| if len(values['similarity']) > 0: | |
| for sim_fr_name, score in values['similarity'][:args.topk]: | |
| sim_line.append(f'{sim_fr_name} ({score:.3f})') | |
| sim_lines.append('\t'.join(sim_line)) | |
| with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp: | |
| fp.write('\n'.join(sim_lines)) | |
| with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp: | |
| fp.write('\n'.join(onto_lines)) | |
| kairos_dump = list() | |
| for kairos_event, kairos_content in kairos_gold_mapping.items(): | |
| for gold_fr in kairos_content['framenet']: | |
| gold_fr = gold_fr['label'] | |
| if gold_fr not in fr2definition: | |
| continue | |
| kairos_dump.append([ | |
| 'GOLD', | |
| gold_fr, | |
| kairos_event, | |
| fr2definition[gold_fr][0], | |
| fr2definition[gold_fr][1], | |
| str(kairos_content['description']), | |
| '1.00' | |
| ]) | |
| for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]: | |
| kairos_dump.append([ | |
| '', | |
| ass_fr, | |
| kairos_event, | |
| fr2definition[ass_fr][0], | |
| fr2definition[ass_fr][1], | |
| str(kairos_content['description']), | |
| f'{sim_score:.2f}' | |
| ]) | |
| kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump)) | |
| open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump)) | |
| dump(mlp_map, os.path.join(args.dst, 'mlp')) | |
| dump(emb_map, os.path.join(args.dst, 'emb')) | |
| if __name__ == '__main__': | |
| run() | |