ashleychen's picture
add in US checkbox
d284e46
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import os
model_name = 'ashleychen/bart-finetuning'
access_token = os.environ.get('private_token')
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
def create_prompt(
n_followers, n_friends, verified, retweets, favorites, in_US
):
tweet = f"Generate review: in US: {in_US}, retweets: {retweets}, favorites: {favorites}, user followers: {n_followers}, user friends: {n_friends}, user verified: {verified}"
return tweet
def postprocess(review):
# dot = review.rfind('.')
# return review[:dot+1]
return review
def generate_reviews(n_followers, n_friends, verified, retweets, favourites, in_US):
text = create_prompt(n_followers, n_friends, verified, retweets, favourites, in_US)
inputs = tokenizer(text, return_tensors='pt')
out = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
do_sample=True,
num_return_sequences=5,
temperature=1.5,
top_p=0.9
)
reviews = []
for review in out:
reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
return reviews[0], reviews[1], reviews[2], reviews[3], reviews[4]
css = """
#ctr {text-align: center;}
#btn {color: white; background: linear-gradient(to right, #1CD8D2 0%, #93EDC7 51%, #1CD8D2 100%);}
"""
md_text = """<h1 style='text-align: center; margin-bottom: 1rem'>Generating Pfizer vaccine tweets with BART-base</h1>
"""
resources = """## Resources
- The Pfizer Vaccine Tweets dataset can be found [here](https://www.kaggle.com/datasets/gpreda/pfizer-vaccine-tweets)."""
demo = gr.Blocks(css=css)
with demo:
with gr.Row():
gr.Markdown(md_text)
with gr.Row():
n_followers = gr.inputs.Slider(minimum=0, maximum=2000, step=100, default=100, label="# user followers")
n_friends = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user friends")
n_favourites = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user favourites")
with gr.Row():
in_US = gr.Radio(["True", "False"], label="user location in US")
verified = gr.Radio(["True", "False"], label="user verified ✓")
retweets = gr.inputs.Slider(minimum=0, maximum=10, step=1, default=0, label="# retweets")
favourites = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="# favourites")
with gr.Row():
button = gr.Button("Generate tweets !", elem_id='btn')
with gr.Row():
output1 = gr.Textbox(label="Review #1")
output2 = gr.Textbox(label="Review #2")
output3 = gr.Textbox(label="Review #3")
output4 = gr.Textbox(label="Review #4")
output5 = gr.Textbox(label="Review #5")
with gr.Row():
gr.Markdown(resources)
button.click(
fn=generate_reviews,
inputs=[n_followers, n_friends, n_favourites, verified, retweets, favourites],
outputs=[output1, output2, output3, output4, output5]
)
demo.launch()