Spaces:
Build error
Build error
| import torch | |
| import streamlit as st | |
| from streamlit import components | |
| import pandas as pd | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| import evaluate | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, LongT5ForConditionalGeneration | |
| import numpy as np | |
| from math import ceil | |
| import en_core_web_lg | |
| from collections import Counter | |
| from string import punctuation | |
| # Gensim | |
| import gensim | |
| from gensim.summarization import summarize | |
| import spacy | |
| nlp = en_core_web_lg.load() | |
| st.set_page_config(page_title ='Clinical Note Summarization', | |
| #page_icon= "Notes", | |
| layout='wide') | |
| st.title('Clinical Note Summarization') | |
| st.sidebar.markdown('Using transformer model') | |
| ## Loading in dataset | |
| #df = pd.read_csv('mtsamples_small.csv',index_col=0) | |
| df = pd.read_csv("demo_shpi_w_rouge25Nov.csv") | |
| #df.shape | |
| df['HADM_ID'] = df['HADM_ID'].astype(str).apply(lambda x: x.replace('.0','')) | |
| ##Renaming column | |
| #df.rename(columns={'patient id':'Patient_ID', | |
| # 'hospital admission id':'Admission_ID', | |
| # 'transcription':'Original_Text'}, inplace = True) | |
| #Renaming column | |
| df.rename(columns={'SUBJECT_ID':'Patient_ID', | |
| 'HADM_ID':'Admission_ID', | |
| 'hpi_input_text':'Original_Text', | |
| 'hpi_reference_summary':'Reference_text'}, inplace = True) | |
| #data.rename(columns={'gdp':'log(gdp)'}, inplace=True) | |
| #Filter selection | |
| st.sidebar.header("Search for Patient:") | |
| patientid = df['Patient_ID'] | |
| patient = st.sidebar.selectbox('Select Patient ID:', patientid) | |
| admissionid = df['Admission_ID'].loc[df['Patient_ID'] == patient] | |
| HospitalAdmission = st.sidebar.selectbox(' ', admissionid) | |
| #Another way to for filter selection | |
| #patient = st.sidebar.multiselect( | |
| # "Select Patient ID:", | |
| # options=df['Patient_ID'].unique(), | |
| # default= None | |
| #) | |
| #HospitalAdmission = st.sidebar.multiselect( | |
| # "Select Hospital Admission ID:", | |
| # options=df['Admission_ID'].unique(), | |
| # #default=df['Admission_ID'].unique() | |
| # default = None | |
| #) | |
| # List of Model available | |
| model = st.sidebar.selectbox('Select Model', ('BART','BERT','BertGPT2','Gensim','LexRank','Long T5','Luhn','Pysummarization','SBERT Summary Tokenizer','T5','T5 Seq2Seq','T5-Base','TextRank')) | |
| if model == 'BART': | |
| _num_beams = 4 | |
| _no_repeat_ngram_size = 3 | |
| _length_penalty = 1 | |
| _min_length = 12 | |
| _max_length = 128 | |
| _early_stopping = True | |
| else: | |
| _num_beams = 4 | |
| _no_repeat_ngram_size = 3 | |
| _length_penalty = 2 | |
| _min_length = 30 | |
| _max_length = 200 | |
| _early_stopping = True | |
| col3,col4 = st.columns(2) | |
| patientid = col3.write(f"Patient ID: {patient} ") | |
| admissionid =col4.write(f"Admission ID: {HospitalAdmission} ") | |
| col1, col2 = st.columns(2) | |
| _min_length = col1.number_input("Minimum Length", value=_min_length) | |
| _max_length = col2.number_input("Maximun Length", value=_max_length) | |
| ##_early_stopping = col3.number_input("early_stopping", value=_early_stopping) | |
| #text = st.text_area('Input Clinical Note here') | |
| # Query out relevant Clinical notes | |
| original_text = df.query( | |
| "Patient_ID == @patient & Admission_ID == @HospitalAdmission" | |
| ) | |
| original_text2 = original_text['Original_Text'].values | |
| runtext =st.text_area('Input Clinical Note here:', str(original_text2), height=300) | |
| reference_text = original_text['Reference_text'].values | |
| def visualize(sentence_list, best_sentences): | |
| text = '' | |
| #display(HTML(f'<h1>Summary - {title}</h1>')) | |
| for run_text in sentence_list: | |
| if run_text in best_sentences: | |
| #text += ' ' + str(run_text).replace(run_text, f"<mark>{run_text}</mark>") | |
| text += ' ' + str(run_text).replace(run_text, f"<span class='highlight yellow'>{run_text}</span>") | |
| else: | |
| text += ' ' + run_text | |
| # display(HTML(f""" {text} """)) | |
| output = '' | |
| best_sentences = [] | |
| for run_text in output: | |
| #print(sentence) | |
| best_sentences.append(str(run_text)) | |
| return text | |
| #===== Pysummarization ===== | |
| from pysummarization.nlpbase.auto_abstractor import AutoAbstractor | |
| from pysummarization.tokenizabledoc.simple_tokenizer import SimpleTokenizer | |
| from pysummarization.abstractabledoc.top_n_rank_abstractor import TopNRankAbstractor | |
| import regex as re | |
| auto_abstractor = AutoAbstractor() | |
| auto_abstractor.tokenizable_doc = SimpleTokenizer() | |
| auto_abstractor.delimiter_list = [".", "\n"] | |
| abstractable_doc = TopNRankAbstractor() | |
| def pysummarizer(input_text): | |
| # print(type(text)) | |
| summary = auto_abstractor.summarize(input_text, abstractable_doc) | |
| best_sentences=[] | |
| #summary_clean = ''.join([str(sentence).capitalize() for sentence in summary['summarize_result'] for summary['summarize_result'] in auto_abstractor.summarize(text, abstractable_doc)]) | |
| for sentence in summary['summarize_result']: | |
| best_sentences.append(re.sub(r'\s+', ' ', sentence).strip()) | |
| clean_summary=''.join(sentence for sentence in best_sentences) | |
| return clean_summary | |
| ##===== BERT Summary tokenizer ===== | |
| def BertSummarizer(input_text): | |
| from transformers import BigBirdTokenizer | |
| from summarizer import Summarizer | |
| bertsummarizer = Summarizer() | |
| model = Summarizer() | |
| result = model(input_text,ratio=0.4) | |
| return result | |
| ##===== SBERT ===== | |
| from summarizer.sbert import SBertSummarizer | |
| Sbertmodel = SBertSummarizer('paraphrase-MiniLM-L6-v2') | |
| def Sbert(input_text): | |
| # Sbertresult = Sbertmodel(text, num_sentences=3) | |
| Sbertresult = Sbertmodel(input_text, ratio=0.4) | |
| return Sbertresult | |
| ##===== T5 Seq2Seq ===== | |
| def t5seq2seq(input_text): | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
| tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
| inputs = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True) | |
| outputs = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True) | |
| summary= tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return summary | |
| def BertGPT2(input_text): | |
| #import nlp | |
| # BioClinicalBert with BERT2GPT2 model with GPT2 decoder | |
| from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel | |
| from transformers import AutoTokenizer, AutoModel | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") | |
| model.to(device) | |
| #bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased") | |
| bert_tokenizer= AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| # CLS token will work as BOS token | |
| bert_tokenizer.bos_token = bert_tokenizer.cls_token | |
| # SEP token will work as EOS token | |
| bert_tokenizer.eos_token = bert_tokenizer.sep_token | |
| # make sure GPT2 appends EOS in begin and end | |
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |
| outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] | |
| return outputs | |
| GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens | |
| gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| # set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id | |
| gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token | |
| # set decoding params | |
| model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id | |
| model.config.eos_token_id = gpt2_tokenizer.eos_token_id | |
| model.config.max_length = 142 | |
| model.config.min_length = 56 | |
| model.config.no_repeat_ngram_size = 3 | |
| model.early_stopping = True | |
| model.length_penalty = 2.0 | |
| model.num_beams = 4 | |
| #test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test") | |
| batch_size = 64 | |
| def Sbertmodel(batch): | |
| # Tokenizer will automatically set [BOS] <text> [EOS] | |
| # cut off at BERT max length 512 | |
| inputs = bert_tokenizer(batch, padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
| input_ids = inputs.input_ids.to("cuda") | |
| attention_mask = inputs.attention_mask.to("cuda") | |
| outputs = model.generate(input_ids, attention_mask=attention_mask) | |
| # all special tokens including will be removed | |
| output_str = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| #batch["pred"] = output_str | |
| return output_str | |
| Sbert(input_text) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| def run_model(input_text): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if model == "BART": | |
| bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
| bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | |
| input_text = str(input_text) | |
| input_text = ' '.join(input_text.split()) | |
| input_tokenized = bart_tokenizer.encode(input_text, return_tensors='pt').to(device) | |
| summary_ids = bart_model.generate(input_tokenized, | |
| num_beams=_num_beams, | |
| no_repeat_ngram_size=_no_repeat_ngram_size, | |
| length_penalty=_length_penalty, | |
| min_length=_min_length, | |
| max_length=_max_length, | |
| early_stopping=_early_stopping) | |
| output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] | |
| st.write('Summary') | |
| st.success(output[0]) | |
| elif model == "T5": | |
| t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
| t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
| input_text = str(input_text).replace('\n', '') | |
| input_text = ' '.join(input_text.split()) | |
| input_tokenized = t5_tokenizer.encode(input_text, return_tensors="pt").to(device) | |
| summary_task = torch.tensor([[21603, 10]]).to(device) | |
| input_tokenized = torch.cat([summary_task, input_tokenized], dim=-1).to(device) | |
| summary_ids = t5_model.generate(input_tokenized, | |
| num_beams=_num_beams, | |
| no_repeat_ngram_size=_no_repeat_ngram_size, | |
| length_penalty=_length_penalty, | |
| min_length=_min_length, | |
| max_length=_max_length, | |
| early_stopping=_early_stopping) | |
| output = [t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] | |
| st.write('Summary') | |
| st.success(output[0]) | |
| elif model == "Gensim": | |
| output=summarize(str(input_text)) | |
| st.write('Summary') | |
| st.success(output) | |
| elif model == "Pysummarization": | |
| output = pysummarizer(input_text) | |
| st.write('Summary') | |
| st.success(output) | |
| elif model == "BERT": | |
| output = BertSummarizer(input_text) | |
| st.write('Summary') | |
| st.success(output) | |
| elif model == "SBERT Summary Tokenizer": | |
| output = Sbert(input_text) | |
| st.write('Summary') | |
| st.success(output) | |
| elif model == "T5 Seq2Seq": | |
| output = t5seq2seq(input_text) | |
| st.write('Summary') | |
| st.success(output) | |
| elif model == "BertGPT2": #Not working correctly. to work on it later on | |
| output = BertGPT2(input_text) | |
| st.write('Summary') | |
| st.success(output) | |
| if st.button('Submit'): | |
| run_model(runtext) | |
| st.markdown('<span style="background-color: #FFFF00">testing</span> if this **works**', unsafe_allow_html=True) | |
| st.text_area('Reference text', str(reference_text)) | |
| st.text_area(visualize(runtext,reference_text),unsafe_allow_html=True) | |