Patcharapron commited on
Commit
e541c73
·
verified ·
1 Parent(s): fd75082

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -0
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pprint as pp
3
+ import logging
4
+ import time
5
+ import gradio as gr
6
+ import torch
7
+ from transformers import pipeline
8
+
9
+ from utils import make_mailto_form, postprocess, clear, make_email_link
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
13
+ )
14
+
15
+ use_gpu = torch.cuda.is_available()
16
+
17
+
18
+ def generate_text(
19
+ prompt: str,
20
+ gen_length=64,
21
+ penalty_alpha=0.6,
22
+ top_k=6,
23
+ length_penalty=1.0,
24
+ # perma params (not set by user)
25
+ abs_max_length=512,
26
+ verbose=False,
27
+ ):
28
+ """
29
+ generate_text - generate text using the text generation pipeline
30
+
31
+ :param str prompt: the prompt to use for the text generation pipeline
32
+ :param int gen_length: the number of tokens to generate
33
+ :param float penalty_alpha: the penalty alpha for the text generation pipeline (contrastive search)
34
+ :param int top_k: the top k for the text generation pipeline (contrastive search)
35
+ :param int abs_max_length: the absolute max length for the text generation pipeline
36
+ :param bool verbose: verbose output
37
+ :return str: the generated text
38
+ """
39
+ global generator
40
+ if verbose:
41
+ logging.info(f"Generating text from prompt:\n\n{prompt}")
42
+ logging.info(
43
+ pp.pformat(
44
+ f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, length_penalty={length_penalty}"
45
+ )
46
+ )
47
+ st = time.perf_counter()
48
+
49
+ input_tokens = generator.tokenizer(prompt)
50
+ input_len = len(input_tokens["input_ids"])
51
+ if input_len > abs_max_length:
52
+ logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
53
+ result = generator(
54
+ prompt,
55
+ max_length=gen_length + input_len, # old API for generation
56
+ min_length=input_len + 4,
57
+ penalty_alpha=penalty_alpha,
58
+ top_k=top_k,
59
+ length_penalty=length_penalty,
60
+ ) # generate
61
+ response = result[0]["generated_text"]
62
+ rt = time.perf_counter() - st
63
+ if verbose:
64
+ logging.info(f"Generated text: {response}")
65
+ rt_string = f"Generation time: {rt:.2f}s"
66
+ logging.info(rt_string)
67
+
68
+ formatted_email = postprocess(response)
69
+ return make_mailto_form(body=formatted_email), formatted_email
70
+
71
+
72
+ def load_emailgen_model(model_tag: str):
73
+ """
74
+ load_emailgen_model - load a text generation pipeline for email generation
75
+
76
+ Args:
77
+ model_tag (str): the huggingface model tag to load
78
+
79
+ Returns:
80
+ transformers.pipelines.TextGenerationPipeline: the text generation pipeline
81
+ """
82
+ global generator
83
+ generator = pipeline(
84
+ "text-generation",
85
+ model_tag,
86
+ device=0 if use_gpu else -1,
87
+ )
88
+
89
+
90
+ def get_parser():
91
+ """
92
+ get_parser - a helper function for the argparse module
93
+ """
94
+ parser = argparse.ArgumentParser(
95
+ description="Text Generation demo for postbot",
96
+ )
97
+
98
+ parser.add_argument(
99
+ "-m",
100
+ "--model",
101
+ required=False,
102
+ type=str,
103
+ default="postbot/distilgpt2-emailgen-V2",
104
+ help="Pass an different huggingface model tag to use a custom model",
105
+ )
106
+ parser.add_argument(
107
+ "-l",
108
+ "--max_length",
109
+ required=False,
110
+ type=int,
111
+ default=40,
112
+ help="default max length of the generated text",
113
+ )
114
+ parser.add_argument(
115
+ "-a",
116
+ "--penalty_alpha",
117
+ type=float,
118
+ default=0.6,
119
+ help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "-k",
124
+ "--top_k",
125
+ type=int,
126
+ default=6,
127
+ help="The top k for the text generation pipeline (contrastive search) - default 6",
128
+ )
129
+ parser.add_argument(
130
+ "-v",
131
+ "--verbose",
132
+ required=False,
133
+ action="store_true",
134
+ help="Verbose output",
135
+ )
136
+ return parser
137
+
138
+
139
+ default_prompt = """
140
+ Hello,
141
+
142
+ Following up on last week's bubblegum shipment, I"""
143
+
144
+ available_models = [
145
+ "postbot/distilgpt2-emailgen-V2",
146
+ "postbot/distilgpt2-emailgen",
147
+ "postbot/gpt2-medium-emailgen",
148
+ "postbot/pythia-160m-hq-emails",
149
+ ]
150
+
151
+ if __name__ == "__main__":
152
+
153
+ logging.info("\n\n\nStarting new instance of app.py")
154
+ args = get_parser().parse_args()
155
+ logging.info(f"received args:\t{args}")
156
+ model_tag = args.model
157
+ verbose = args.verbose
158
+ max_length = args.max_length
159
+ top_k = args.top_k
160
+ alpha = args.penalty_alpha
161
+
162
+ assert top_k > 0, "top_k must be greater than 0"
163
+ assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1"
164
+
165
+ logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
166
+ generator = pipeline(
167
+ "text-generation",
168
+ model_tag,
169
+ device=0 if use_gpu else -1,
170
+ )
171
+
172
+ demo = gr.Blocks()
173
+
174
+ logging.info("launching interface...")
175
+
176
+ with demo:
177
+ gr.Markdown("# Auto-Complete Emails - Demo")
178
+ gr.Markdown(
179
+ "Enter part of an email, and a text-gen model will complete it! See details below. "
180
+ )
181
+ gr.Markdown("---")
182
+
183
+ with gr.Column():
184
+
185
+ gr.Markdown("## Generate Text")
186
+ gr.Markdown("Edit the prompt and parameters and press **Generate**!")
187
+ prompt_text = gr.Textbox(
188
+ lines=4,
189
+ label="Email Prompt",
190
+ value=default_prompt,
191
+ )
192
+
193
+ with gr.Row():
194
+ clear_button = gr.Button(
195
+ value="Clear Prompt",
196
+ )
197
+ num_gen_tokens = gr.Slider(
198
+ label="Generation Tokens",
199
+ value=max_length,
200
+ maximum=96,
201
+ minimum=16,
202
+ step=8,
203
+ )
204
+
205
+ generate_button = gr.Button(
206
+ value="Generate!",
207
+ variant="primary",
208
+ )
209
+ gr.Markdown("---")
210
+ gr.Markdown("### Results")
211
+ # put a large HTML placeholder here
212
+ generated_email = gr.Textbox(
213
+ label="Generated Text",
214
+ placeholder="This is where the generated text will appear",
215
+ interactive=False,
216
+ )
217
+ email_mailto_button = gr.HTML(
218
+ "<i>a clickable email button will appear here</i>"
219
+ )
220
+
221
+ gr.Markdown("---")
222
+ gr.Markdown("## Advanced Options")
223
+ gr.Markdown(
224
+ "This demo generates text via the new [contrastive search](https://huggingface.co/blog/introducing-csearch). See the csearch blog post for details on the parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding."
225
+ )
226
+ with gr.Row():
227
+ model_name = gr.Dropdown(
228
+ choices=available_models,
229
+ label="Choose a model",
230
+ value=model_tag,
231
+ )
232
+ load_model_button = gr.Button(
233
+ "Load Model",
234
+ variant="secondary",
235
+ )
236
+ with gr.Row():
237
+ contrastive_top_k = gr.Radio(
238
+ choices=[2, 4, 6, 8],
239
+ label="Top K",
240
+ value=top_k,
241
+ )
242
+
243
+ penalty_alpha = gr.Slider(
244
+ label="Penalty Alpha",
245
+ value=alpha,
246
+ maximum=1.0,
247
+ minimum=0.0,
248
+ step=0.1,
249
+ )
250
+ length_penalty = gr.Slider(
251
+ minimum=0.5,
252
+ maximum=1.0,
253
+ label="Length Penalty",
254
+ value=1.0,
255
+ step=0.1,
256
+ )
257
+ gr.Markdown("---")
258
+
259
+ with gr.Column():
260
+
261
+ gr.Markdown("## About")
262
+ gr.Markdown(
263
+ "[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 100k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage."
264
+ )
265
+ gr.Markdown(
266
+ "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something."
267
+ )
268
+ gr.Markdown("---")
269
+
270
+ clear_button.click(
271
+ fn=clear,
272
+ inputs=[prompt_text],
273
+ outputs=[prompt_text],
274
+ )
275
+ generate_button.click(
276
+ fn=generate_text,
277
+ inputs=[
278
+ prompt_text,
279
+ num_gen_tokens,
280
+ penalty_alpha,
281
+ contrastive_top_k,
282
+ length_penalty,
283
+ ],
284
+ outputs=[email_mailto_button, generated_email],
285
+ )
286
+
287
+ load_model_button.click(
288
+ fn=load_emailgen_model,
289
+ inputs=[model_name],
290
+ outputs=[],
291
+ )
292
+ demo.launch(
293
+ enable_queue=True,
294
+ share=True, # for local testing
295
+ )