Spaces:
Build error
Build error
| # Load the packages | |
| import torch | |
| import streamlit as st | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel,BartTokenizer,BartForConditionalGeneration | |
| import spacy | |
| import spacy.cli | |
| spacy.cli.download("en_core_web_sm") | |
| nlp=spacy.load("en_core_web_sm") | |
| nlp=spacy.load("en_core_web_sm") | |
| from spacy import displacy | |
| #---Sidebar Design----- | |
| st.sidebar.subheader("Select from the dropdown list") # add the subheader of sidebar | |
| st.sidebar.text("") # add line space | |
| option_lang = st.sidebar.selectbox( | |
| 'What is your native language?', | |
| ('Japanese', 'Mandarin')) # add a dropdown list for native languages | |
| st.sidebar.write('You selected:', option_lang) # display the selected native language | |
| st.sidebar.text("") # add line space | |
| option_model=st.sidebar.selectbox( | |
| 'Which language model would like to use?', | |
| ('GPT-2', 'BART')) # add a dropdown list for language model | |
| st.sidebar.write('You selected:', option_model) # display the selected language model | |
| #---Main Body Design----- | |
| st.title('Make Friends with English ๐ค') # add a title for the web app | |
| st.text("") # add line space | |
| st.markdown('This web app is designed for ESL speakers who may face difficulty in communicating context in English.') | |
| st.text("") # add line space | |
| st.markdown('<p style="font-size:20px;"><strong>Enter your sentence ๐</strong></p>',unsafe_allow_html=True) # add a subtitle | |
| original = st.text_input('', '',label_visibility="collapsed") # add a textbox to input original sentence | |
| go = st.button('Generate') # add a 'Generate button' to run the selected language model | |
| # Define the output directory | |
| if option_model=='GPT-2' and option_lang == 'Japanese': | |
| model_dir = "amyyang/80K-GPT2-v2" | |
| token_dir = "amyyang/token-80K-GPT2-v2" | |
| elif option_model == 'GPT-2' and option_lang == 'Mandarin': | |
| model_dir = "amyyang/40K-GPT2-MDN-v2" | |
| token_dir = "amyyang/token-40K-GPT2-MDN-v2" | |
| elif option_model == 'BART' and option_lang == 'Mandarin': | |
| model_dir = "amyyang/60K-BART-MDN-v2" | |
| token_dir = "amyyang/token-60K-BART-MDN-v2" | |
| else: | |
| model_dir = "amyyang/80K-BART-v2" | |
| token_dir = "amyyang/token-80K-BART-v2" | |
| # Assign cuda to the device to use for training | |
| if torch.cuda.is_available(): | |
| dev = "cuda:0" | |
| print("This model will run on CUDA") | |
| # elif torch.backends.mps.is_available(): | |
| # dev = "mps:0" | |
| # print("This model will run on MPS") | |
| else: | |
| dev = "cpu" | |
| print("This model will run on CPU") | |
| device = torch.device(dev) | |
| # Define the function to generate corrected sentence using GPT-2 model | |
| def generate_prediction(prompt, max_length=100, temperature=1.0, top_p=1.0): | |
| model = GPT2LMHeadModel.from_pretrained(model_dir).to(device) | |
| tokenizer = GPT2Tokenizer.from_pretrained(token_dir) | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Define the function to extract the output (corrected sentence) | |
| def model_running(model): | |
| if go and model=='GPT-2': | |
| try: | |
| tokenizer = GPT2Tokenizer.from_pretrained(token_dir) | |
| prompt = f"input: {original} output:" | |
| prompt_length = len(tokenizer.encode(prompt)) | |
| dynamic_max_length = int(1.5 * len(original.split())) + prompt_length | |
| # Generate prediction | |
| prediction = generate_prediction(prompt, max_length=dynamic_max_length, temperature=0.8, top_p=0.8) | |
| # Extract the actual generated output | |
| generated_output = prediction.split("output:")[1].strip() | |
| return generated_output | |
| except Exception as e: | |
| st.exception("Exception: %s\n" % e) | |
| elif go and model=='BART': | |
| try: | |
| model = BartForConditionalGeneration.from_pretrained(model_dir) | |
| tokenizer = BartTokenizer.from_pretrained(token_dir) | |
| # Tokenize the input text | |
| input_ids = tokenizer.encode(original, return_tensors='pt') | |
| # Generate text with the fine-tuned BART model | |
| output_ids = model.generate(input_ids) | |
| # Decode the output text | |
| generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return generated_output | |
| except Exception as e: | |
| st.exception("Exception: %s\n" % e) | |
| output=model_running(option_model) | |
| # Add the warning message based on the output | |
| if output is None: | |
| st.markdown('<span style="color: #FF4500;">Note: Please enter your sentence and click **Generate** button!</span>',unsafe_allow_html=True) | |
| else: | |
| st.text("") | |
| st.markdown('<p style="font-size:20px;"><strong>Recommended sentence ๐ก</strong></p>',unsafe_allow_html=True) # add a subtitle | |
| st.text(output) # display the corrected sentence | |
| st.text("") # add line space | |
| st.markdown('<p style="font-size:20px;"><strong>Part-of-speech Tagging ๐ท</strong></p>',unsafe_allow_html=True) # add a subtitle | |
| # Add the POS tags | |
| if original!='' and output is not None: | |
| doc=nlp(output) | |
| for token in doc: | |
| st.write(token,token.pos_) | |
| st.text("") # add line space | |
| st.markdown('<p style="font-size:20px;"><strong>Dependency Tree ๐ณ</strong></p>',unsafe_allow_html=True) # add a subtitle | |
| # Add a html wrapper to hold the html file of dependency tree | |
| HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>""" | |
| # Add the dependency tree | |
| if original!='' and output is not None: | |
| doc=nlp(output) | |
| docs = [span.as_doc() for span in doc.sents] | |
| html=displacy.render(docs,style='dep') | |
| st.write(HTML_WRAPPER.format(html), unsafe_allow_html=True) | |
| st.markdown('___') | |
| st.markdown('by [A very beta ChatGPT-4.5](https://github.com/danish-sven/anlp-at2-gpt45/)') # add the author | |
| # # The code below is to generate corrected sentences with GPT-2 or BART model. | |
| # if go and option_model=='GPT-2': | |
| # try: | |
| # model = GPT2LMHeadModel.from_pretrained(output_dir).to(device) | |
| # tokenizer = GPT2Tokenizer.from_pretrained(output_dir) | |
| # def generate_prediction(prompt, max_length=100, temperature=1.0, top_p=1.0): | |
| # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| # attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device) | |
| # with torch.no_grad(): | |
| # output = model.generate( | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # max_length=max_length, | |
| # num_return_sequences=1, | |
| # no_repeat_ngram_size=2, | |
| # temperature=temperature, | |
| # top_p=top_p, | |
| # ) | |
| # return tokenizer.decode(output[0], skip_special_tokens=True) | |
| # # Set max_length dynamically based on the length of the original text | |
| # prompt = f"input: {original} output:" | |
| # prompt_length = len(tokenizer.encode(prompt)) | |
| # dynamic_max_length = int(1.5 * len(original.split())) + prompt_length | |
| # # Generate prediction | |
| # prediction = generate_prediction(prompt, max_length=dynamic_max_length, temperature=0.8, top_p=0.8) | |
| # # Extract the actual generated output | |
| # generated_output = prediction.split("output:")[1].strip() | |
| # st.text(generated_output) | |
| # except Exception as e: | |
| # st.exception("Exception: %s\n" % e) | |
| # elif go and option_model=='BART': | |
| # try: | |
| # model = BartForConditionalGeneration.from_pretrained(output_dir) | |
| # tokenizer = BartTokenizer.from_pretrained(output_dir) | |
| # # Tokenize the input text | |
| # input_ids = tokenizer.encode(original, return_tensors='pt') | |
| # # Generate text with the fine-tuned BART model | |
| # output_ids = model.generate(input_ids) | |
| # # Decode the output text | |
| # generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # st.text(generated_output) | |
| # except Exception as e: | |
| # st.exception("Exception: %s\n" % e) | |