| |
| import argparse |
| import os.path |
| import pickle |
| import unicodedata |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| import NER_medNLP as ner |
| import utils |
| from EntityNormalizer import EntityNormalizer, EntityDictionary, DefaultDiseaseDict, DefaultDrugDict |
|
|
| device = torch.device("mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| dict_key = {} |
|
|
|
|
| |
| def to_xml(data, id_to_tags): |
| with open("key_attr.pkl", "rb") as tf: |
| key_attr = pickle.load(tf) |
|
|
| text = data['text'] |
| count = 0 |
| for i, entities in enumerate(data['entities_predicted']): |
| if entities == "": |
| return |
| span = entities['span'] |
| try: |
| type_id = id_to_tags[entities['type_id']].split('_') |
| except: |
| print("out of rage type_id", entities) |
| continue |
| tag = type_id[0] |
|
|
| if not type_id[1] == "": |
| attr = ' ' + value_to_key(type_id[1], key_attr) + '=' + '"' + type_id[1] + '"' |
| else: |
| attr = "" |
|
|
| if 'norm' in entities: |
| attr = attr + ' norm="' + str(entities['norm']) + '"' |
|
|
| add_tag = "<" + str(tag) + str(attr) + ">" |
| text = text[:span[0] + count] + add_tag + text[span[0] + count:] |
| count += len(add_tag) |
|
|
| add_tag = "</" + str(tag) + ">" |
| text = text[:span[1] + count] + add_tag + text[span[1] + count:] |
| count += len(add_tag) |
| return text |
|
|
|
|
| def predict_entities(model, tokenizer, sentences_list): |
|
|
| |
| entities_predicted_list = [] |
|
|
| text_entities_set = [] |
| for dataset in sentences_list: |
| text_entities = [] |
| for sample in tqdm(dataset, desc='Predict', leave=False): |
| text = sample |
| encoding, spans = tokenizer.encode_plus_untagged( |
| text, return_tensors='pt' |
| ) |
| encoding = {k: v.to(device) for k, v in encoding.items()} |
|
|
| with torch.no_grad(): |
| output = model(**encoding) |
| scores = output.logits |
| scores = scores[0].cpu().numpy().tolist() |
|
|
| |
| entities_predicted = tokenizer.convert_bert_output_to_entities( |
| text, scores, spans |
| ) |
|
|
| |
| entities_predicted_list.append(entities_predicted) |
| text_entities.append({'text': text, 'entities_predicted': entities_predicted}) |
| text_entities_set.append(text_entities) |
| return text_entities_set |
|
|
|
|
| def combine_sentences(text_entities_set, id_to_tags, insert: str): |
| documents = [] |
| for text_entities in text_entities_set: |
| document = [] |
| for t in text_entities: |
| document.append(to_xml(t, id_to_tags)) |
| documents.append('\n'.join(document)) |
| return documents |
|
|
|
|
| def value_to_key(value, key_attr): |
| global dict_key |
| if dict_key.get(value) != None: |
| return dict_key[value] |
| for k in key_attr.keys(): |
| for v in key_attr[k]: |
| if value == v: |
| dict_key[v] = k |
| return k |
|
|
|
|
| |
| def normalize_entities(text_entities_set, id_to_tags, disease_dict=None, disease_candidate_col=None, disease_normalization_col=None, disease_matching_threshold=None, drug_dict=None, |
| drug_candidate_col=None, drug_normalization_col=None, drug_matching_threshold=None): |
| if disease_dict: |
| disease_dict = EntityDictionary(disease_dict, disease_candidate_col, disease_normalization_col) |
| else: |
| disease_dict = DefaultDiseaseDict() |
| disease_normalizer = EntityNormalizer(disease_dict, matching_threshold=disease_matching_threshold) |
|
|
| if drug_dict: |
| drug_dict = EntityDictionary(drug_dict, drug_candidate_col, drug_normalization_col) |
| else: |
| drug_dict = DefaultDrugDict() |
| drug_normalizer = EntityNormalizer(drug_dict, matching_threshold=drug_matching_threshold) |
|
|
| for entry in tqdm(text_entities_set, desc='Normalization', leave=False): |
| for text_entities in entry: |
| entities = text_entities['entities_predicted'] |
| for entity in entities: |
| tag = id_to_tags[entity['type_id']].split('_')[0] |
|
|
| normalizer = drug_normalizer if tag == 'm-key' \ |
| else disease_normalizer if tag == 'd' \ |
| else None |
|
|
| if normalizer is None: |
| continue |
|
|
| normalization, score = normalizer.normalize(entity['name']) |
| entity['norm'] = str(normalization) |
|
|
|
|
| def run(model, input, output=None, normalize=False, **kwargs): |
| with open("id_to_tags.pkl", "rb") as tf: |
| id_to_tags = pickle.load(tf) |
| len_num_entity_type = len(id_to_tags) |
|
|
| |
| classification_model = ner.BertForTokenClassification_pl.from_pretrained_bin(model_path=model, num_labels=2 * len_num_entity_type + 1) |
| bert_tc = classification_model.bert_tc.to(device) |
|
|
| tokenizer = ner.NER_tokenizer_BIO.from_pretrained( |
| 'tohoku-nlp/bert-base-japanese-whole-word-masking', |
| num_entity_type=len_num_entity_type |
| ) |
|
|
| |
| if (os.path.isdir(input)): |
| files = [os.path.join(input, f) for f in os.listdir(input) if os.path.isfile(os.path.join(input, f))] |
| else: |
| files = [input] |
|
|
| for file in tqdm(files, desc="Input file"): |
| try: |
| with open(file) as f: |
| articles_raw = f.read() |
|
|
| article_norm = unicodedata.normalize('NFKC', articles_raw) |
|
|
| sentences_raw = utils.split_sentences(articles_raw) |
| sentences_norm = utils.split_sentences(article_norm) |
|
|
| text_entities_set = predict_entities(bert_tc, tokenizer, [sentences_norm]) |
|
|
| for i, texts_ent in enumerate(text_entities_set[0]): |
| texts_ent['text'] = sentences_raw[i] |
|
|
| if normalize: |
| normalize_entities(text_entities_set, id_to_tags, **kwargs) |
|
|
| documents = combine_sentences(text_entities_set, id_to_tags, '\n') |
|
|
| tqdm.write(f"File: {file}") |
| tqdm.write(documents[0]) |
| tqdm.write("") |
|
|
| if output: |
| with open(file.replace(input, output), 'w') as f: |
| f.write(documents[0]) |
|
|
| except Exception as e: |
| tqdm.write("Error while processing file: {}".format(file)) |
| tqdm.write(str(e)) |
| tqdm.write("") |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='Predict entities from text') |
| parser.add_argument('-m', '--model', type=str, default='pytorch_model.bin', help='Path to model checkpoint') |
| parser.add_argument('-i', '--input', type=str, default='text.txt', help='Path to text file or directory') |
| parser.add_argument('-o', '--output', type=str, default=None, help='Path to output file or directory') |
| parser.add_argument('-n', '--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization', default=False) |
|
|
| |
| parser.add_argument("--drug-dict", help="File path for overriding the default drug dictionary") |
| parser.add_argument("--drug-candidate-col", type=int, help="Column name for drug candidates in the CSV file (required if --drug-dict is specified)") |
| parser.add_argument("--drug-normalization-col", type=int, help="Column name for drug normalization in the CSV file (required if --drug-dict is specified") |
| parser.add_argument('--disease-matching-threshold', type=int, default=50, help='Matching threshold for disease dictionary') |
|
|
| parser.add_argument("--disease-dict", help="File path for overriding the default disease dictionary") |
| parser.add_argument("--disease-candidate-col", type=int, help="Column name for disease candidates in the CSV file (required if --disease-dict is specified)") |
| parser.add_argument("--disease-normalization-col", type=int, help="Column name for disease normalization in the CSV file (required if --disease-dict is specified)") |
| parser.add_argument('--drug-matching-threshold', type=int, default=50, help='Matching threshold for drug dictionary') |
| args = parser.parse_args() |
|
|
| argument_dict = vars(args) |
| run(**argument_dict) |
|
|