"""Gradio web UI for the viral content classifier."""
import os
os.environ["USE_TF"] = "0"
os.environ["TRANSFORMERS_NO_TF"] = "1"
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_REPO = "yrvelez/viral-classifier-roberta"
MAX_LENGTH = 128
EXAMPLE_TEXTS = [
"Nobody is talking about the fact that we just normalized working 60 hour weeks",
"My mass spec results came back. The contamination levels are beyond anything in the literature.",
"Just mass-produced my 500th widget at the Topeka factory. Another Tuesday.",
"I asked ChatGPT to write my wedding vows and nobody noticed",
"I can't believe this actually worked! Just quit my job and tripled my income in 3 months!",
"Does anyone know a good plumber in the Cincinnati area?",
"BREAKING: leaked memo reveals everything we were told was a lie!",
"Finally tried that viral pasta recipe and oh my god it's incredible!",
"Updated my spreadsheet formulas. Feels good to stay organized.",
]
print("Loading RoBERTa model from Hub...")
_device = "cpu"
_tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
_model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO).to(_device)
_model.eval()
print(f"Model loaded on {_device}.")
def predict(text: str):
"""Predict virality of a social media post."""
if not text or not text.strip():
return 0.5, "
Type something to get a prediction
"
inputs = _tokenizer(
text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH
).to(_device)
with torch.no_grad():
logits = _model(**inputs).logits
prob = torch.softmax(logits, dim=-1)[0][1].item()
if prob >= 0.65:
verdict = "Likely Viral"
color = "#22c55e"
icon = "^"
bg = "rgba(34,197,94,0.08)"
elif prob >= 0.4:
verdict = "Could Go Either Way"
color = "#eab308"
icon = "~"
bg = "rgba(234,179,8,0.08)"
else:
verdict = "Probably Won't Go Viral"
color = "#6b7280"
icon = "-"
bg = "rgba(107,114,128,0.08)"
pct = int(prob * 100)
bar_color = f"linear-gradient(90deg, {color} {pct}%, #e5e7eb {pct}%)"
html = f"""
Low engagement
High engagement
"""
return prob, html
css = """
.gradio-container {
max-width: 720px !important;
margin: auto !important;
}
footer { display: none !important; }
"""
with gr.Blocks(css=css, title="Viral Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Viral Content Classifier
Predict whether a social media post would go viral, based on patterns
learned from **100K Reddit and Twitter posts** using a fine-tuned **RoBERTa** model.
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Your post",
placeholder="Type or paste a post here...",
lines=3,
max_lines=6,
)
submit_btn = gr.Button("Analyze", variant="primary", size="lg")
result_html = gr.HTML(
value="Results will appear here
"
)
viral_score = gr.Number(visible=False)
submit_btn.click(fn=predict, inputs=text_input, outputs=[viral_score, result_html])
text_input.submit(fn=predict, inputs=text_input, outputs=[viral_score, result_html])
gr.Markdown("### Try these examples")
with gr.Row():
for ex in EXAMPLE_TEXTS:
gr.Button(ex, size="sm", variant="secondary").click(
fn=lambda t=ex: t, outputs=text_input
).then(fn=predict, inputs=text_input, outputs=[viral_score, result_html])
gr.Markdown(
""
"Model: RoBERTa-base fine-tuned on Reddit + Twitter data | ROC AUC: 0.80"
"
"
)
if __name__ == "__main__":
demo.launch()