import torch import transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import os model_name = 'ashleychen/bart-finetuning' access_token = os.environ.get('private_token') model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token) tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token) def create_prompt( n_followers, n_friends, verified, retweets, favorites, in_US ): tweet = f"Generate review: in US: {in_US}, retweets: {retweets}, favorites: {favorites}, user followers: {n_followers}, user friends: {n_friends}, user verified: {verified}" return tweet def postprocess(review): # dot = review.rfind('.') # return review[:dot+1] return review def generate_reviews(n_followers, n_friends, verified, retweets, favourites, in_US): text = create_prompt(n_followers, n_friends, verified, retweets, favourites, in_US) inputs = tokenizer(text, return_tensors='pt') out = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, do_sample=True, num_return_sequences=5, temperature=1.5, top_p=0.9 ) reviews = [] for review in out: reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True))) return reviews[0], reviews[1], reviews[2], reviews[3], reviews[4] css = """ #ctr {text-align: center;} #btn {color: white; background: linear-gradient(to right, #1CD8D2 0%, #93EDC7 51%, #1CD8D2 100%);} """ md_text = """

Generating Pfizer vaccine tweets with BART-base

""" resources = """## Resources - The Pfizer Vaccine Tweets dataset can be found [here](https://www.kaggle.com/datasets/gpreda/pfizer-vaccine-tweets).""" demo = gr.Blocks(css=css) with demo: with gr.Row(): gr.Markdown(md_text) with gr.Row(): n_followers = gr.inputs.Slider(minimum=0, maximum=2000, step=100, default=100, label="# user followers") n_friends = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user friends") n_favourites = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user favourites") with gr.Row(): in_US = gr.Radio(["True", "False"], label="user location in US") verified = gr.Radio(["True", "False"], label="user verified ✓") retweets = gr.inputs.Slider(minimum=0, maximum=10, step=1, default=0, label="# retweets") favourites = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="# favourites") with gr.Row(): button = gr.Button("Generate tweets !", elem_id='btn') with gr.Row(): output1 = gr.Textbox(label="Review #1") output2 = gr.Textbox(label="Review #2") output3 = gr.Textbox(label="Review #3") output4 = gr.Textbox(label="Review #4") output5 = gr.Textbox(label="Review #5") with gr.Row(): gr.Markdown(resources) button.click( fn=generate_reviews, inputs=[n_followers, n_friends, n_favourites, verified, retweets, favourites], outputs=[output1, output2, output3, output4, output5] ) demo.launch()