| from typing import * | |
| import torch | |
| import json | |
| import argparse | |
| import os | |
| from tqdm import tqdm | |
| from sftp.predictor import SpanPredictor | |
| from sftp.models import SpanModel | |
| from sftp.data_reader import BetterDatasetReader | |
| def predict_doc(predictor, json_path: str): | |
| src = json.load(open(json_path)) | |
| for doc_name, entry in tqdm(list(src['entries'].items())): | |
| pred = predictor.predict_json(entry) | |
| triggers = list() | |
| for trigger in pred['prediction']: | |
| children = list() | |
| for child in trigger['children']: | |
| children.append([child['start_idx'], child['end_idx']]) | |
| triggers.append({ | |
| "span": [trigger['start_idx'], trigger['end_idx']], | |
| "argument": children | |
| }) | |
| entry['trigger span'] = triggers | |
| return src | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-a', type=str, help='archive path') | |
| parser.add_argument('-s', type=str, help='source path') | |
| parser.add_argument('-d', type=str, help='destination path') | |
| parser.add_argument('-c', type=int, default=0, help='cuda device') | |
| args = parser.parse_args() | |
| predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c) | |
| model_name = os.path.basename(args.a) | |
| tgt_path = os.path.join(args.d, model_name) | |
| os.makedirs(tgt_path, exist_ok=True) | |
| for root, _, files in os.walk(args.s): | |
| for fn in files: | |
| if not fn.endswith('json') and not fn.endswith('valid'): | |
| continue | |
| processed_json = predict_doc(predictor_, os.path.join(root, fn)) | |
| with open(os.path.join(tgt_path, fn), 'w') as fp: | |
| json.dump(processed_json, fp) | |