Spaces:
Runtime error
Runtime error
| import torch | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import gradio as gr | |
| import os | |
| model_name = 'ashleychen/bart-finetuning' | |
| access_token = os.environ.get('private_token') | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token) | |
| def create_prompt( | |
| n_followers, n_friends, verified, retweets, favorites, in_US | |
| ): | |
| tweet = f"Generate review: in US: {in_US}, retweets: {retweets}, favorites: {favorites}, user followers: {n_followers}, user friends: {n_friends}, user verified: {verified}" | |
| return tweet | |
| def postprocess(review): | |
| # dot = review.rfind('.') | |
| # return review[:dot+1] | |
| return review | |
| def generate_reviews(n_followers, n_friends, verified, retweets, favourites, in_US): | |
| text = create_prompt(n_followers, n_friends, verified, retweets, favourites, in_US) | |
| inputs = tokenizer(text, return_tensors='pt') | |
| out = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| do_sample=True, | |
| num_return_sequences=5, | |
| temperature=1.5, | |
| top_p=0.9 | |
| ) | |
| reviews = [] | |
| for review in out: | |
| reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True))) | |
| return reviews[0], reviews[1], reviews[2], reviews[3], reviews[4] | |
| css = """ | |
| #ctr {text-align: center;} | |
| #btn {color: white; background: linear-gradient(to right, #1CD8D2 0%, #93EDC7 51%, #1CD8D2 100%);} | |
| """ | |
| md_text = """<h1 style='text-align: center; margin-bottom: 1rem'>Generating Pfizer vaccine tweets with BART-base</h1> | |
| """ | |
| resources = """## Resources | |
| - The Pfizer Vaccine Tweets dataset can be found [here](https://www.kaggle.com/datasets/gpreda/pfizer-vaccine-tweets).""" | |
| demo = gr.Blocks(css=css) | |
| with demo: | |
| with gr.Row(): | |
| gr.Markdown(md_text) | |
| with gr.Row(): | |
| n_followers = gr.inputs.Slider(minimum=0, maximum=2000, step=100, default=100, label="# user followers") | |
| n_friends = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user friends") | |
| n_favourites = gr.inputs.Slider(minimum=0, maximum=5000, step=100, default=2000, label="# user favourites") | |
| with gr.Row(): | |
| in_US = gr.Radio(["True", "False"], label="user location in US") | |
| verified = gr.Radio(["True", "False"], label="user verified ✓") | |
| retweets = gr.inputs.Slider(minimum=0, maximum=10, step=1, default=0, label="# retweets") | |
| favourites = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="# favourites") | |
| with gr.Row(): | |
| button = gr.Button("Generate tweets !", elem_id='btn') | |
| with gr.Row(): | |
| output1 = gr.Textbox(label="Review #1") | |
| output2 = gr.Textbox(label="Review #2") | |
| output3 = gr.Textbox(label="Review #3") | |
| output4 = gr.Textbox(label="Review #4") | |
| output5 = gr.Textbox(label="Review #5") | |
| with gr.Row(): | |
| gr.Markdown(resources) | |
| button.click( | |
| fn=generate_reviews, | |
| inputs=[n_followers, n_friends, n_favourites, verified, retweets, favourites], | |
| outputs=[output1, output2, output3, output4, output5] | |
| ) | |
| demo.launch() | |