Spaces:
Sleeping
Sleeping
| """Gradio web UI for the viral content classifier.""" | |
| import os | |
| os.environ["USE_TF"] = "0" | |
| os.environ["TRANSFORMERS_NO_TF"] = "1" | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| MODEL_REPO = "yrvelez/viral-classifier-roberta" | |
| MAX_LENGTH = 128 | |
| EXAMPLE_TEXTS = [ | |
| "Nobody is talking about the fact that we just normalized working 60 hour weeks", | |
| "My mass spec results came back. The contamination levels are beyond anything in the literature.", | |
| "Just mass-produced my 500th widget at the Topeka factory. Another Tuesday.", | |
| "I asked ChatGPT to write my wedding vows and nobody noticed", | |
| "I can't believe this actually worked! Just quit my job and tripled my income in 3 months!", | |
| "Does anyone know a good plumber in the Cincinnati area?", | |
| "BREAKING: leaked memo reveals everything we were told was a lie!", | |
| "Finally tried that viral pasta recipe and oh my god it's incredible!", | |
| "Updated my spreadsheet formulas. Feels good to stay organized.", | |
| ] | |
| print("Loading RoBERTa model from Hub...") | |
| _device = "cpu" | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| _model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO).to(_device) | |
| _model.eval() | |
| print(f"Model loaded on {_device}.") | |
| def predict(text: str): | |
| """Predict virality of a social media post.""" | |
| if not text or not text.strip(): | |
| return 0.5, "<div style='text-align:center;color:#888'>Type something to get a prediction</div>" | |
| inputs = _tokenizer( | |
| text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH | |
| ).to(_device) | |
| with torch.no_grad(): | |
| logits = _model(**inputs).logits | |
| prob = torch.softmax(logits, dim=-1)[0][1].item() | |
| if prob >= 0.65: | |
| verdict = "Likely Viral" | |
| color = "#22c55e" | |
| icon = "^" | |
| bg = "rgba(34,197,94,0.08)" | |
| elif prob >= 0.4: | |
| verdict = "Could Go Either Way" | |
| color = "#eab308" | |
| icon = "~" | |
| bg = "rgba(234,179,8,0.08)" | |
| else: | |
| verdict = "Probably Won't Go Viral" | |
| color = "#6b7280" | |
| icon = "-" | |
| bg = "rgba(107,114,128,0.08)" | |
| pct = int(prob * 100) | |
| bar_color = f"linear-gradient(90deg, {color} {pct}%, #e5e7eb {pct}%)" | |
| html = f""" | |
| <div style="font-family: system-ui, -apple-system, sans-serif; max-width: 400px; margin: 0 auto;"> | |
| <div style="text-align:center; padding: 24px 0 16px;"> | |
| <div style="font-size: 48px; font-weight: 700; color: {color};">{pct}%</div> | |
| <div style="font-size: 18px; font-weight: 600; color: {color}; margin-top: 4px;"> | |
| {icon} {verdict} | |
| </div> | |
| </div> | |
| <div style="height: 8px; border-radius: 4px; background: {bar_color}; margin: 0 0 20px;"></div> | |
| <div style="display: flex; justify-content: space-between; font-size: 13px; color: #9ca3af;"> | |
| <span>Low engagement</span> | |
| <span>High engagement</span> | |
| </div> | |
| </div> | |
| """ | |
| return prob, html | |
| css = """ | |
| .gradio-container { | |
| max-width: 720px !important; | |
| margin: auto !important; | |
| } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks(css=css, title="Viral Classifier", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Viral Content Classifier | |
| Predict whether a social media post would go viral, based on patterns | |
| learned from **100K Reddit and Twitter posts** using a fine-tuned **RoBERTa** model. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Your post", | |
| placeholder="Type or paste a post here...", | |
| lines=3, | |
| max_lines=6, | |
| ) | |
| submit_btn = gr.Button("Analyze", variant="primary", size="lg") | |
| result_html = gr.HTML( | |
| value="<div style='text-align:center;color:#888;padding:40px 0'>Results will appear here</div>" | |
| ) | |
| viral_score = gr.Number(visible=False) | |
| submit_btn.click(fn=predict, inputs=text_input, outputs=[viral_score, result_html]) | |
| text_input.submit(fn=predict, inputs=text_input, outputs=[viral_score, result_html]) | |
| gr.Markdown("### Try these examples") | |
| with gr.Row(): | |
| for ex in EXAMPLE_TEXTS: | |
| gr.Button(ex, size="sm", variant="secondary").click( | |
| fn=lambda t=ex: t, outputs=text_input | |
| ).then(fn=predict, inputs=text_input, outputs=[viral_score, result_html]) | |
| gr.Markdown( | |
| "<div style='text-align:center;font-size:12px;color:#aaa;margin-top:24px'>" | |
| "Model: RoBERTa-base fine-tuned on Reddit + Twitter data | ROC AUC: 0.80" | |
| "</div>" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |