File size: 3,145 Bytes
08f5884
 
 
1ec122c
08f5884
1ec122c
08f5884
1ec122c
08f5884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d284e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08f5884
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import os

model_name = 'ashleychen/bart-finetuning'

access_token = os.environ.get('private_token')

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)

def create_prompt(
    n_followers, n_friends, verified, retweets, favorites, in_US
):
    tweet = f"Generate review: in US: {in_US}, retweets: {retweets}, favorites: {favorites}, user followers: {n_followers}, user friends: {n_friends}, user verified: {verified}"
    return tweet

def postprocess(review):
    # dot = review.rfind('.')
    # return review[:dot+1]
    return review

def generate_reviews(n_followers, n_friends, verified, retweets, favourites, in_US):
    text = create_prompt(n_followers, n_friends, verified, retweets, favourites, in_US)
    inputs = tokenizer(text, return_tensors='pt')
    out = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        do_sample=True,
        num_return_sequences=5,
        temperature=1.5,
        top_p=0.9
    )
    reviews = []
    for review in out:
        reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))

    return reviews[0], reviews[1], reviews[2], reviews[3], reviews[4]

css = """
    #ctr {text-align: center;}
    #btn {color: white; background: linear-gradient(to right, #1CD8D2 0%, #93EDC7  51%, #1CD8D2  100%);}
"""


md_text = """<h1 style='text-align: center; margin-bottom: 1rem'>Generating Pfizer vaccine tweets with BART-base</h1>
"""

resources = """## Resources
- The Pfizer Vaccine Tweets dataset can be found [here](https://www.kaggle.com/datasets/gpreda/pfizer-vaccine-tweets)."""

demo = gr.Blocks(css=css)

with demo:
	with gr.Row():
		gr.Markdown(md_text)

	with gr.Row():
		n_followers = gr.inputs.Slider(minimum=0, maximum=2000, step=100, default=100, label="# user followers")
		n_friends = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user friends")
		n_favourites = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user favourites")
	with gr.Row():
		in_US = gr.Radio(["True", "False"], label="user location in US")
		verified = gr.Radio(["True", "False"], label="user verified ✓")
		retweets = gr.inputs.Slider(minimum=0, maximum=10, step=1, default=0, label="# retweets")
		favourites = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="# favourites")
	with gr.Row():
		button = gr.Button("Generate tweets !", elem_id='btn')

	with gr.Row():
		output1 = gr.Textbox(label="Review #1")
		output2 = gr.Textbox(label="Review #2")
		output3 = gr.Textbox(label="Review #3")
		output4 = gr.Textbox(label="Review #4")
		output5 = gr.Textbox(label="Review #5")

	with gr.Row():
		gr.Markdown(resources)

	button.click(
        fn=generate_reviews,
        inputs=[n_followers, n_friends, n_favourites, verified, retweets, favourites],
        outputs=[output1, output2, output3, output4, output5]
    )

demo.launch()