Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import sys | |
| import json | |
| import tempfile | |
| import gradio as gr | |
| from transformers import ( | |
| TrainingArguments, | |
| HfArgumentParser, | |
| ) | |
| from robust_deid.ner_datasets import DatasetCreator | |
| from robust_deid.sequence_tagging import SequenceTagger | |
| from robust_deid.sequence_tagging.arguments import ( | |
| ModelArguments, | |
| DataTrainingArguments, | |
| EvaluationArguments, | |
| ) | |
| from robust_deid.deid import TextDeid | |
| class App(object): | |
| def __init__( | |
| self, | |
| model, | |
| threshold, | |
| span_constraint='super_strict', | |
| sentencizer='en_core_sci_sm', | |
| tokenizer='clinical', | |
| max_tokens=128, | |
| max_prev_sentence_token=32, | |
| max_next_sentence_token=32, | |
| default_chunk_size=32, | |
| ignore_label='NA' | |
| ): | |
| # Create the dataset creator object | |
| self._dataset_creator = DatasetCreator( | |
| sentencizer=sentencizer, | |
| tokenizer=tokenizer, | |
| max_tokens=max_tokens, | |
| max_prev_sentence_token=max_prev_sentence_token, | |
| max_next_sentence_token=max_next_sentence_token, | |
| default_chunk_size=default_chunk_size, | |
| ignore_label=ignore_label | |
| ) | |
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, EvaluationArguments, TrainingArguments)) | |
| model_config = App._get_model_config() | |
| model_config['model_name_or_path'] = App._get_model_map()[model] | |
| if threshold == 'No threshold': | |
| model_config['post_process'] = 'argmax' | |
| model_config['threshold'] = None | |
| else: | |
| model_config['post_process'] = 'threshold_max' | |
| model_config['threshold'] = App._get_threshold_map()[model_config['model_name_or_path']][threshold] | |
| print(model_config) | |
| #sys.exit(0) | |
| with tempfile.NamedTemporaryFile("w+", delete=False) as tmp: | |
| tmp.write(json.dumps(model_config) + '\n') | |
| tmp.seek(0) | |
| # If we pass only one argument to the script and it's the path to a json file, | |
| # let's parse it to get our arguments. | |
| self._model_args, self._data_args, self._evaluation_args, self._training_args = \ | |
| parser.parse_json_file(json_file=tmp.name) | |
| # Initialize the text deid object | |
| self._text_deid = TextDeid(notation=self._data_args.notation, span_constraint=span_constraint) | |
| # Initialize the sequence tagger | |
| self._sequence_tagger = SequenceTagger( | |
| task_name=self._data_args.task_name, | |
| notation=self._data_args.notation, | |
| ner_types=self._data_args.ner_types, | |
| model_name_or_path=self._model_args.model_name_or_path, | |
| config_name=self._model_args.config_name, | |
| tokenizer_name=self._model_args.tokenizer_name, | |
| post_process=self._model_args.post_process, | |
| cache_dir=self._model_args.cache_dir, | |
| model_revision=self._model_args.model_revision, | |
| use_auth_token=self._model_args.use_auth_token, | |
| threshold=self._model_args.threshold, | |
| do_lower_case=self._data_args.do_lower_case, | |
| fp16=self._training_args.fp16, | |
| seed=self._training_args.seed, | |
| local_rank=self._training_args.local_rank | |
| ) | |
| # Load the required functions of the sequence tagger | |
| self._sequence_tagger.load() | |
| def get_ner_dataset(self, notes_file): | |
| ner_notes = self._dataset_creator.create( | |
| input_file=notes_file, | |
| mode='predict', | |
| notation=self._data_args.notation, | |
| token_text_key='text', | |
| metadata_key='meta', | |
| note_id_key='note_id', | |
| label_key='label', | |
| span_text_key='spans' | |
| ) | |
| return ner_notes | |
| def get_predictions(self, ner_notes_file): | |
| self._sequence_tagger.set_predict( | |
| test_file=ner_notes_file, | |
| max_test_samples=self._data_args.max_predict_samples, | |
| preprocessing_num_workers=self._data_args.preprocessing_num_workers, | |
| overwrite_cache=self._data_args.overwrite_cache | |
| ) | |
| self._sequence_tagger.setup_trainer(training_args=self._training_args) | |
| predictions = self._sequence_tagger.predict() | |
| return predictions | |
| def get_deid_text_removed(self, notes_file, predictions_file): | |
| deid_notes = self._text_deid.run_deid( | |
| input_file=notes_file, | |
| predictions_file=predictions_file, | |
| deid_strategy='remove', | |
| keep_age=False, | |
| metadata_key='meta', | |
| note_id_key='note_id', | |
| tokens_key='tokens', | |
| predictions_key='predictions', | |
| text_key='text', | |
| ) | |
| return deid_notes | |
| def get_deid_text_replaced(self, notes_file, predictions_file): | |
| deid_notes = self._text_deid.run_deid( | |
| input_file=notes_file, | |
| predictions_file=predictions_file, | |
| deid_strategy='replace_informative', | |
| keep_age=False, | |
| metadata_key='meta', | |
| note_id_key='note_id', | |
| tokens_key='tokens', | |
| predictions_key='predictions', | |
| text_key='text', | |
| ) | |
| return deid_notes | |
| def _get_highlights(deid_text): | |
| pattern = re.compile('<<(PATIENT|STAFF|AGE|DATE|LOCATION|PHONE|ID|EMAIL|PATORG|HOSPITAL|OTHERPHI):(.)*?>>') | |
| tag_pattern = re.compile('<<(PATIENT|STAFF|AGE|DATE|LOCATION|PHONE|ID|EMAIL|PATORG|HOSPITAL|OTHERPHI):') | |
| text_list = [] | |
| current_start = 0 | |
| current_end = 0 | |
| for match in re.finditer(pattern, deid_text): | |
| full_start, full_end = match.span() | |
| sub_text = deid_text[full_start:full_end] | |
| sub_match = re.search(tag_pattern, sub_text) | |
| sub_span = sub_match.span() | |
| tag_length = sub_match.span()[1] - sub_match.span()[0] | |
| yield (deid_text[current_start:full_start], None) | |
| yield (deid_text[full_start+sub_span[1]:full_end-2], sub_match.string[sub_span[0]+2:sub_span[1]-1]) | |
| current_start = full_end | |
| yield (deid_text[full_end:], None) | |
| def _get_model_map(): | |
| return { | |
| 'OBI-RoBERTa De-ID':'obi/deid_roberta_i2b2', | |
| 'OBI-ClinicalBERT De-ID':'obi/deid_bert_i2b2' | |
| } | |
| def _get_threshold_map(): | |
| return { | |
| 'obi/deid_bert_i2b2':{"99.5": 4.656325975101986e-06, "99.7":1.8982457699258832e-06}, | |
| 'obi/deid_roberta_i2b2':{"99.5": 2.4362972672812125e-05, "99.7":2.396420546444644e-06} | |
| } | |
| def _get_model_config(): | |
| return { | |
| "post_process":None, | |
| "threshold": None, | |
| "model_name_or_path":None, | |
| "task_name":"ner", | |
| "notation":"BILOU", | |
| "ner_types":["PATIENT", "STAFF", "AGE", "DATE", "PHONE", "ID", "EMAIL", "PATORG", "LOC", "HOSP", "OTHERPHI"], | |
| "truncation":True, | |
| "max_length":512, | |
| "label_all_tokens":False, | |
| "return_entity_level_metrics":True, | |
| "text_column_name":"tokens", | |
| "label_column_name":"labels", | |
| "output_dir":"./run/models", | |
| "logging_dir":"./run/logs", | |
| "overwrite_output_dir":False, | |
| "do_train":False, | |
| "do_eval":False, | |
| "do_predict":True, | |
| "report_to":[], | |
| "per_device_train_batch_size":0, | |
| "per_device_eval_batch_size":16, | |
| "logging_steps":1000 | |
| } | |
| def deid(text, model, threshold): | |
| notes = [{"text": text, "meta": {"note_id": "note_1", "patient_id": "patient_1"}, "spans": []}] | |
| app = App(model, threshold) | |
| # Create temp notes file | |
| with tempfile.NamedTemporaryFile("w+", delete=False) as tmp: | |
| for note in notes: | |
| tmp.write(json.dumps(note) + '\n') | |
| tmp.seek(0) | |
| ner_notes = app.get_ner_dataset(tmp.name) | |
| # Create temp ner_notes file | |
| with tempfile.NamedTemporaryFile("w+", delete=False) as tmp: | |
| for ner_sentence in ner_notes: | |
| tmp.write(json.dumps(ner_sentence) + '\n') |