import torch import transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import os model_name = 'ashleychen/bart-finetuning' access_token = os.environ.get('private_token') model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token) tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token) def create_prompt( n_followers, n_friends, verified, retweets, favorites, in_US ): tweet = f"Generate review: in US: {in_US}, retweets: {retweets}, favorites: {favorites}, user followers: {n_followers}, user friends: {n_friends}, user verified: {verified}" return tweet def postprocess(review): # dot = review.rfind('.') # return review[:dot+1] return review def generate_reviews(n_followers, n_friends, verified, retweets, favourites, in_US): text = create_prompt(n_followers, n_friends, verified, retweets, favourites, in_US) inputs = tokenizer(text, return_tensors='pt') out = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, do_sample=True, num_return_sequences=5, temperature=1.5, top_p=0.9 ) reviews = [] for review in out: reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True))) return reviews[0], reviews[1], reviews[2], reviews[3], reviews[4] css = """ #ctr {text-align: center;} #btn {color: white; background: linear-gradient(to right, #1CD8D2 0%, #93EDC7 51%, #1CD8D2 100%);} """ md_text = """