amyyang commited on
Commit
1d7c7e8
·
1 Parent(s): 2496449

add app.py and requirement.txt

Browse files
Files changed (2) hide show
  1. app.py +228 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Load the packages
2
+ import torch
3
+ import streamlit as st
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel,BartTokenizer,BartForConditionalGeneration
5
+ import spacy
6
+ import spacy.cli
7
+ spacy.cli.download("en_core_web_sm")
8
+ nlp=spacy.load("en_core_web_sm")
9
+ nlp=spacy.load("en_core_web_sm")
10
+ from spacy import displacy
11
+
12
+ #---Sidebar Design-----
13
+
14
+ st.sidebar.subheader("Select from the dropdown list") # add the subheader of sidebar
15
+ st.sidebar.text("") # add line space
16
+
17
+ option_lang = st.sidebar.selectbox(
18
+ 'What is your native language?',
19
+ ('Japanese', 'Madarin')) # add a dropdown list for native languages
20
+
21
+ st.sidebar.write('You selected:', option_lang) # display the selected native language
22
+
23
+ st.sidebar.text("") # add line space
24
+
25
+
26
+ option_model=st.sidebar.selectbox(
27
+ 'Which language model would like to use?',
28
+ ('GPT-2', 'BART')) # add a dropdown list for language model
29
+
30
+ st.sidebar.write('You selected:', option_model) # display the selected language model
31
+
32
+ #---Main Body Design-----
33
+
34
+ st.title('Make Friends with English 🤝') # add a title for the web app
35
+
36
+ st.text("") # add line space
37
+
38
+ st.markdown('This web app is designed for ESL speakers who may face difficulty in communicating context in English.')
39
+ st.text("") # add line space
40
+
41
+ st.markdown('<p style="font-size:20px;"><strong>Enter your sentence 👇</strong></p>',unsafe_allow_html=True) # add a subtitle
42
+
43
+ original = st.text_input('', '',label_visibility="collapsed") # add a textbox to input original sentence
44
+
45
+ go = st.button('Generate') # add a 'Generate button' to run the selected language model
46
+
47
+ # Define the output directory
48
+ if option_model=='GPT-2':
49
+ output_dir = "7. Models/"+'80K_GPT2_v2'+"/"
50
+
51
+ else:
52
+ output_dir = "7. Models/"+'80K_BART_v2'+"/"
53
+
54
+
55
+ # Assign cuda to the device to use for training
56
+ if torch.cuda.is_available():
57
+ dev = "cuda:0"
58
+ print("This model will run on CUDA")
59
+ # elif torch.backends.mps.is_available():
60
+ # dev = "mps:0"
61
+ # print("This model will run on MPS")
62
+ else:
63
+ dev = "cpu"
64
+ print("This model will run on CPU")
65
+ device = torch.device(dev)
66
+
67
+
68
+ # Define the function to generate corrected sentence using GPT-2 model
69
+ def generate_prediction(prompt, max_length=100, temperature=1.0, top_p=1.0):
70
+ model = GPT2LMHeadModel.from_pretrained(output_dir).to(device)
71
+ tokenizer = GPT2Tokenizer.from_pretrained(output_dir)
72
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
73
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
74
+ with torch.no_grad():
75
+ output = model.generate(
76
+ input_ids,
77
+ attention_mask=attention_mask,
78
+ max_length=max_length,
79
+ num_return_sequences=1,
80
+ no_repeat_ngram_size=2,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ )
84
+ return tokenizer.decode(output[0], skip_special_tokens=True)
85
+
86
+ # Define the function to extract the output (corrected sentence)
87
+ def model_running(model):
88
+ if go and model=='GPT-2':
89
+ try:
90
+ tokenizer = GPT2Tokenizer.from_pretrained(output_dir)
91
+ prompt = f"input: {original} output:"
92
+ prompt_length = len(tokenizer.encode(prompt))
93
+ dynamic_max_length = int(1.5 * len(original.split())) + prompt_length
94
+
95
+ # Generate prediction
96
+ prediction = generate_prediction(prompt, max_length=dynamic_max_length, temperature=0.8, top_p=0.8)
97
+
98
+ # Extract the actual generated output
99
+ generated_output = prediction.split("output:")[1].strip()
100
+
101
+ return generated_output
102
+
103
+ except Exception as e:
104
+ st.exception("Exception: %s\n" % e)
105
+
106
+ elif go and model=='BART':
107
+ try:
108
+ model = BartForConditionalGeneration.from_pretrained(output_dir)
109
+ tokenizer = BartTokenizer.from_pretrained(output_dir)
110
+
111
+ # Tokenize the input text
112
+ input_ids = tokenizer.encode(original, return_tensors='pt')
113
+
114
+ # Generate text with the fine-tuned BART model
115
+ output_ids = model.generate(input_ids)
116
+
117
+ # Decode the output text
118
+ generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
119
+
120
+ return generated_output
121
+
122
+ except Exception as e:
123
+ st.exception("Exception: %s\n" % e)
124
+
125
+
126
+ output=model_running(option_model)
127
+
128
+ # Add the warning message based on the output
129
+ if output is None:
130
+ st.markdown('<span style="color: #FF4500;">Note: Please enter your sentence and click **Generate** button!</span>',unsafe_allow_html=True)
131
+ else:
132
+ st.text("")
133
+
134
+ st.markdown('<p style="font-size:20px;"><strong>Recommended sentence 💡</strong></p>',unsafe_allow_html=True) # add a subtitle
135
+
136
+ st.text(output) # display the corrected sentence
137
+
138
+ st.text("") # add line space
139
+
140
+ st.markdown('<p style="font-size:20px;"><strong>Part-of-speech Tagging 🏷</strong></p>',unsafe_allow_html=True) # add a subtitle
141
+
142
+ # Add the POS tags
143
+ if original!='' and output is not None:
144
+ doc=nlp(output)
145
+ for token in doc:
146
+ st.write(token,token.pos_)
147
+
148
+ st.text("") # add line space
149
+
150
+ st.markdown('<p style="font-size:20px;"><strong>Dependency Tree 🌳</strong></p>',unsafe_allow_html=True) # add a subtitle
151
+
152
+ # Add a html wrapper to hold the html file of dependency tree
153
+ HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
154
+
155
+ # Add the dependency tree
156
+ if original!='' and output is not None:
157
+ doc=nlp(output)
158
+ docs = [span.as_doc() for span in doc.sents]
159
+ html=displacy.render(docs,style='dep')
160
+ st.write(HTML_WRAPPER.format(html), unsafe_allow_html=True)
161
+
162
+
163
+ st.markdown('___')
164
+ st.markdown('by [A very beta ChatGPT-4.5](https://github.com/danish-sven/anlp-at2-gpt45/)') # add the author
165
+
166
+
167
+ # # The code below is to generate corrected sentences with GPT-2 or BART model.
168
+ # if go and option_model=='GPT-2':
169
+ # try:
170
+
171
+ # model = GPT2LMHeadModel.from_pretrained(output_dir).to(device)
172
+ # tokenizer = GPT2Tokenizer.from_pretrained(output_dir)
173
+
174
+ # def generate_prediction(prompt, max_length=100, temperature=1.0, top_p=1.0):
175
+ # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
176
+ # attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
177
+ # with torch.no_grad():
178
+ # output = model.generate(
179
+ # input_ids,
180
+ # attention_mask=attention_mask,
181
+ # max_length=max_length,
182
+ # num_return_sequences=1,
183
+ # no_repeat_ngram_size=2,
184
+ # temperature=temperature,
185
+ # top_p=top_p,
186
+ # )
187
+ # return tokenizer.decode(output[0], skip_special_tokens=True)
188
+
189
+ # # Set max_length dynamically based on the length of the original text
190
+ # prompt = f"input: {original} output:"
191
+ # prompt_length = len(tokenizer.encode(prompt))
192
+ # dynamic_max_length = int(1.5 * len(original.split())) + prompt_length
193
+
194
+ # # Generate prediction
195
+ # prediction = generate_prediction(prompt, max_length=dynamic_max_length, temperature=0.8, top_p=0.8)
196
+
197
+ # # Extract the actual generated output
198
+ # generated_output = prediction.split("output:")[1].strip()
199
+
200
+ # st.text(generated_output)
201
+
202
+ # except Exception as e:
203
+ # st.exception("Exception: %s\n" % e)
204
+
205
+ # elif go and option_model=='BART':
206
+ # try:
207
+
208
+ # model = BartForConditionalGeneration.from_pretrained(output_dir)
209
+ # tokenizer = BartTokenizer.from_pretrained(output_dir)
210
+
211
+
212
+ # # Tokenize the input text
213
+ # input_ids = tokenizer.encode(original, return_tensors='pt')
214
+
215
+ # # Generate text with the fine-tuned BART model
216
+ # output_ids = model.generate(input_ids)
217
+
218
+ # # Decode the output text
219
+ # generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
220
+
221
+ # st.text(generated_output)
222
+
223
+ # except Exception as e:
224
+ # st.exception("Exception: %s\n" % e)
225
+
226
+
227
+
228
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit==1.22.0
2
+ transformers
3
+ torch
4
+ spacy
5
+ pandas
6
+ # nltk==3.8.1
7
+ # re