"""Gradio web UI for the viral content classifier.""" import os os.environ["USE_TF"] = "0" os.environ["TRANSFORMERS_NO_TF"] = "1" import gradio as gr import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer MODEL_REPO = "yrvelez/viral-classifier-roberta" MAX_LENGTH = 128 EXAMPLE_TEXTS = [ "Nobody is talking about the fact that we just normalized working 60 hour weeks", "My mass spec results came back. The contamination levels are beyond anything in the literature.", "Just mass-produced my 500th widget at the Topeka factory. Another Tuesday.", "I asked ChatGPT to write my wedding vows and nobody noticed", "I can't believe this actually worked! Just quit my job and tripled my income in 3 months!", "Does anyone know a good plumber in the Cincinnati area?", "BREAKING: leaked memo reveals everything we were told was a lie!", "Finally tried that viral pasta recipe and oh my god it's incredible!", "Updated my spreadsheet formulas. Feels good to stay organized.", ] print("Loading RoBERTa model from Hub...") _device = "cpu" _tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) _model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO).to(_device) _model.eval() print(f"Model loaded on {_device}.") def predict(text: str): """Predict virality of a social media post.""" if not text or not text.strip(): return 0.5, "
Type something to get a prediction
" inputs = _tokenizer( text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH ).to(_device) with torch.no_grad(): logits = _model(**inputs).logits prob = torch.softmax(logits, dim=-1)[0][1].item() if prob >= 0.65: verdict = "Likely Viral" color = "#22c55e" icon = "^" bg = "rgba(34,197,94,0.08)" elif prob >= 0.4: verdict = "Could Go Either Way" color = "#eab308" icon = "~" bg = "rgba(234,179,8,0.08)" else: verdict = "Probably Won't Go Viral" color = "#6b7280" icon = "-" bg = "rgba(107,114,128,0.08)" pct = int(prob * 100) bar_color = f"linear-gradient(90deg, {color} {pct}%, #e5e7eb {pct}%)" html = f"""
{pct}%
{icon} {verdict}
Low engagement High engagement
""" return prob, html css = """ .gradio-container { max-width: 720px !important; margin: auto !important; } footer { display: none !important; } """ with gr.Blocks(css=css, title="Viral Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Viral Content Classifier Predict whether a social media post would go viral, based on patterns learned from **100K Reddit and Twitter posts** using a fine-tuned **RoBERTa** model. """ ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Your post", placeholder="Type or paste a post here...", lines=3, max_lines=6, ) submit_btn = gr.Button("Analyze", variant="primary", size="lg") result_html = gr.HTML( value="
Results will appear here
" ) viral_score = gr.Number(visible=False) submit_btn.click(fn=predict, inputs=text_input, outputs=[viral_score, result_html]) text_input.submit(fn=predict, inputs=text_input, outputs=[viral_score, result_html]) gr.Markdown("### Try these examples") with gr.Row(): for ex in EXAMPLE_TEXTS: gr.Button(ex, size="sm", variant="secondary").click( fn=lambda t=ex: t, outputs=text_input ).then(fn=predict, inputs=text_input, outputs=[viral_score, result_html]) gr.Markdown( "
" "Model: RoBERTa-base fine-tuned on Reddit + Twitter data | ROC AUC: 0.80" "
" ) if __name__ == "__main__": demo.launch()