ashleychen commited on
Commit
08f5884
·
1 Parent(s): 3834590

update app

Browse files
Files changed (2) hide show
  1. app.py +84 -4
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,7 +1,87 @@
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import gradio as gr
5
+ import os
6
 
7
+ model_name = 'ashleychen/bart-finetuning'
 
8
 
9
+ access_token = os.environ.get('private_token')
10
+
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
13
+
14
+ def create_prompt(
15
+ n_followers, n_friends, verified, retweets, favorites, in_US
16
+ ):
17
+ tweet = f"Generate review: in US: {in_US}, retweets: {retweets}, favorites: {favorites}, user followers: {n_followers}, user friends: {n_friends}, user verified: {verified}"
18
+ return tweet
19
+
20
+ def postprocess(review):
21
+ # dot = review.rfind('.')
22
+ # return review[:dot+1]
23
+ return review
24
+
25
+ def generate_reviews(n_followers, n_friends, verified, retweets, favourites, in_US):
26
+ text = create_prompt(n_followers, n_friends, verified, retweets, favourites, in_US)
27
+ inputs = tokenizer(text, return_tensors='pt')
28
+ out = model.generate(
29
+ input_ids=inputs.input_ids,
30
+ attention_mask=inputs.attention_mask,
31
+ do_sample=True,
32
+ num_return_sequences=5,
33
+ temperature=1.5,
34
+ top_p=0.9
35
+ )
36
+ reviews = []
37
+ for review in out:
38
+ reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
39
+
40
+ return reviews[0], reviews[1], reviews[2], reviews[3], reviews[4]
41
+
42
+ css = """
43
+ #ctr {text-align: center;}
44
+ #btn {color: white; background: linear-gradient(to right, #1CD8D2 0%, #93EDC7 51%, #1CD8D2 100%);}
45
+ """
46
+
47
+
48
+ md_text = """<h1 style='text-align: center; margin-bottom: 1rem'>Generating Pfizer vaccine tweets with BART-base</h1>
49
+ """
50
+
51
+ resources = """## Resources
52
+ - The Pfizer Vaccine Tweets dataset can be found [here](https://www.kaggle.com/datasets/gpreda/pfizer-vaccine-tweets)."""
53
+
54
+ demo = gr.Blocks(css=css)
55
+
56
+ with demo:
57
+ with gr.Row():
58
+ gr.Markdown(md_text)
59
+
60
+ with gr.Row():
61
+ n_followers = gr.inputs.Slider(minimum=0, maximum=2000, step=100, default=100, label="# user followers")
62
+ n_friends = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user friends")
63
+ n_favourites = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user favourites")
64
+ with gr.Row():
65
+ verified = gr.Radio(["True", "False"], label="user verified ✓")
66
+ retweets = gr.inputs.Slider(minimum=0, maximum=10, step=1, default=0, label="# retweets")
67
+ favourites = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="# favourites")
68
+ with gr.Row():
69
+ button = gr.Button("Generate tweets !", elem_id='btn')
70
+
71
+ with gr.Row():
72
+ output1 = gr.Textbox(label="Review #1")
73
+ output2 = gr.Textbox(label="Review #2")
74
+ output3 = gr.Textbox(label="Review #3")
75
+ output4 = gr.Textbox(label="Review #4")
76
+ output5 = gr.Textbox(label="Review #5")
77
+
78
+ with gr.Row():
79
+ gr.Markdown(resources)
80
+
81
+ button.click(
82
+ fn=generate_reviews,
83
+ inputs=[n_followers, n_friends, n_favourites, verified, retweets, favourites],
84
+ outputs=[output1, output2, output3, output4, output5]
85
+ )
86
+
87
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tensorflow
2
+ torch
3
+ transformers