| import os | |
| import argparse | |
| from xml.etree import ElementTree | |
| import copy | |
| from operator import attrgetter | |
| import json | |
| import logging | |
| from sftp import SpanPredictor | |
| def predict_kairos(model_archive, source_folder, onto_map): | |
| xml_files = list() | |
| for root, _, files in os.walk(source_folder): | |
| for f in files: | |
| if f.endswith('.xml'): | |
| xml_files.append(os.path.join(root, f)) | |
| logging.info(f'{len(xml_files)} files are found:') | |
| for fn in xml_files: | |
| logging.info(' - ' + fn) | |
| logging.info('Loading ontology from ' + onto_map) | |
| k_map = dict() | |
| for kairos_event, content in json.load(open(onto_map)).items(): | |
| for fr in content['framenet']: | |
| if fr['label'] in k_map: | |
| logging.info("Duplicate frame: " + fr['label']) | |
| k_map[fr['label']] = kairos_event | |
| logging.info('Loading model from ' + model_archive + ' ...') | |
| predictor = SpanPredictor.from_path(model_archive) | |
| predictions = list() | |
| for fn in xml_files: | |
| logging.info('Now processing ' + os.path.basename(fn)) | |
| tree = ElementTree.parse(fn).getroot() | |
| for doc in tree: | |
| doc_meta = copy.deepcopy(doc.attrib) | |
| text = list(doc)[0] | |
| for seg in text: | |
| seg_meta = copy.deepcopy(doc_meta) | |
| seg_meta['seg'] = copy.deepcopy(seg.attrib) | |
| tokens = [child for child in seg if child.tag == 'TOKEN'] | |
| tokens.sort(key=lambda t: t.attrib['start_char']) | |
| words = list(map(attrgetter('text'), tokens)) | |
| one_pred = predictor.predict_sentence(words) | |
| one_pred['meta'] = seg_meta | |
| new_frames = list() | |
| for fr in one_pred['prediction']: | |
| if fr['label'] in k_map: | |
| fr['label'] = k_map[fr['label']] | |
| new_frames.append(fr) | |
| one_pred['prediction'] = new_frames | |
| predictions.append(one_pred) | |
| logging.info('Finished Prediction.') | |
| return predictions | |
| def do_task(input_dir, model_archive, onto_map): | |
| """ | |
| This function is called by the KAIROS infrastructure code for each | |
| TASK1 input. | |
| """ | |
| return predict_kairos(model_archive=model_archive, | |
| source_folder=input_dir, | |
| onto_map=onto_map) | |
| def run(): | |
| parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n') | |
| parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.') | |
| parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.') | |
| parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.') | |
| parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)') | |
| args = parser.parse_args() | |
| logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") | |
| predictions = predict_kairos(model_archive=args.model_archive, | |
| source_folder=args.source_folder, | |
| onto_map=args.onto_map) | |
| logging.info('Saving to ' + args.destination + ' ...') | |
| os.makedirs(os.path.dirname(args.destination), exist_ok=True) | |
| with open(args.destination, 'w') as fp: | |
| fp.write('\n'.join(map(json.dumps, predictions))) | |
| logging.info('Done.') | |
| if __name__ == '__main__': | |
| run() | |