Aryan047's picture
Deploy meme-vs-event Streamlit app
f9e8817 verified
"""Streamlit Space: Meme vs Real Event tweet classifier.
Loads a fine-tuned bert-base-uncased from the Hugging Face Hub and exposes:
- Single-tweet tab: live prediction + probability bar chart
- Batch CSV tab: upload a CSV with a `text` column, download predictions
Matching preprocessing (same regex as the training notebook) is reapplied
so results mirror what the notebook produces locally.
"""
from __future__ import annotations
import io
import os
import re
import numpy as np
import pandas as pd
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_ID = os.environ.get("MODEL_ID", "Aryan047/Dynamic-event-detector")
MAX_LENGTH = 128
LABELS = {0: "meme", 1: "real_event"}
_URL_RE = re.compile(r"https?://\S+|www\.\S+")
_MENTION_RE = re.compile(r"@\w+")
_HASHTAG_RE = re.compile(r"#")
_NON_WORD_RE = re.compile(r"[^a-z0-9\s]")
_WS_RE = re.compile(r"\s+")
def clean_tweet(text: str) -> str:
if not isinstance(text, str):
return ""
t = text.lower()
t = _URL_RE.sub(" ", t)
t = _MENTION_RE.sub(" ", t)
t = _HASHTAG_RE.sub(" ", t)
t = _NON_WORD_RE.sub(" ", t)
t = _WS_RE.sub(" ", t).strip()
return t
@st.cache_resource(show_spinner="Loading model from Hugging Face Hub...")
def load_model(model_id: str):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.eval()
return tokenizer, model
@torch.no_grad()
def predict_one(tokenizer, model, text: str) -> dict:
cleaned = clean_tweet(text)
if not cleaned:
return {
"label": "meme",
"confidence": 0.0,
"prob_meme": 1.0,
"prob_real_event": 0.0,
"clean_text": "",
}
enc = tokenizer(cleaned, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
probs = F.softmax(model(**enc).logits[0], dim=-1).numpy()
pred = int(np.argmax(probs))
return {
"label": LABELS[pred],
"confidence": float(probs[pred]),
"prob_meme": float(probs[0]),
"prob_real_event": float(probs[1]),
"clean_text": cleaned,
}
@torch.no_grad()
def predict_many(tokenizer, model, texts: list[str], batch_size: int = 32) -> pd.DataFrame:
cleaned = [clean_tweet(t) for t in texts]
labels, confs, p0s, p1s = [], [], [], []
progress = st.progress(0.0, text="Running predictions...")
total = max(len(cleaned), 1)
for i in range(0, len(cleaned), batch_size):
chunk = cleaned[i : i + batch_size]
empty_mask = [len(c) == 0 for c in chunk]
model_inputs = [c if c else "empty" for c in chunk]
enc = tokenizer(
model_inputs,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt",
)
probs = F.softmax(model(**enc).logits, dim=-1).numpy()
for j, p in enumerate(probs):
if empty_mask[j]:
labels.append("meme")
confs.append(0.0)
p0s.append(1.0)
p1s.append(0.0)
else:
pred = int(np.argmax(p))
labels.append(LABELS[pred])
confs.append(float(p[pred]))
p0s.append(float(p[0]))
p1s.append(float(p[1]))
progress.progress(min((i + batch_size) / total, 1.0))
progress.empty()
return pd.DataFrame(
{
"text": texts,
"clean_text": cleaned,
"label": labels,
"confidence": confs,
"prob_meme": p0s,
"prob_real_event": p1s,
}
)
def render_single_tab(tokenizer, model) -> None:
st.subheader("Classify a single tweet")
st.caption("Paste any tweet-style text. Labels: `meme` or `real_event`.")
default_example = "Massive 6.5 earthquake just rocked Istanbul, buildings swaying"
text = st.text_area("Tweet text", value=default_example, height=120)
if st.button("Predict", type="primary"):
if not text.strip():
st.warning("Please enter some text.")
return
result = predict_one(tokenizer, model, text)
col1, col2 = st.columns(2)
col1.metric("Predicted label", result["label"])
col2.metric("Confidence", f"{result['confidence']:.2%}")
st.markdown("**Class probabilities**")
st.bar_chart(
pd.DataFrame(
{"probability": [result["prob_meme"], result["prob_real_event"]]},
index=["meme", "real_event"],
)
)
with st.expander("Details"):
st.write({"cleaned_text": result["clean_text"]})
def render_batch_tab(tokenizer, model) -> None:
st.subheader("Classify a CSV of tweets")
st.caption("Upload a CSV with a `text` column. Predictions are added as new columns.")
uploaded = st.file_uploader("CSV file", type=["csv"])
if uploaded is None:
st.info("Waiting for a CSV upload...")
return
try:
df = pd.read_csv(uploaded)
except Exception as exc:
st.error(f"Could not read CSV: {exc}")
return
if "text" not in df.columns:
st.error(f"CSV must contain a `text` column. Found: {list(df.columns)}")
return
max_rows = 5000
if len(df) > max_rows:
st.warning(f"CSV has {len(df)} rows. Truncating to first {max_rows} for the demo.")
df = df.head(max_rows).copy()
st.write(f"Loaded {len(df)} rows. Preview:")
st.dataframe(df.head(5))
if st.button("Run batch prediction", type="primary"):
out = predict_many(tokenizer, model, df["text"].tolist())
merged = pd.concat(
[df.reset_index(drop=True).drop(columns=["text"]), out.reset_index(drop=True)],
axis=1,
)
st.success(f"Classified {len(merged)} tweets.")
st.dataframe(merged.head(50))
counts = merged["label"].value_counts().reindex(["meme", "real_event"], fill_value=0)
st.markdown("**Label distribution**")
st.bar_chart(counts)
buf = io.StringIO()
merged.to_csv(buf, index=False)
st.download_button(
label="Download predictions CSV",
data=buf.getvalue(),
file_name="meme_vs_event_predictions.csv",
mime="text/csv",
)
def main() -> None:
st.set_page_config(
page_title="Meme vs Real Event Classifier",
page_icon="",
layout="centered",
)
st.title("Meme vs Real Event Tweet Classifier")
st.caption(
f"Fine-tuned `bert-base-uncased` loaded from "
f"[`{MODEL_ID}`](https://huggingface.co/{MODEL_ID})."
)
tokenizer, model = load_model(MODEL_ID)
single_tab, batch_tab, about_tab = st.tabs(["Single tweet", "Batch CSV", "About"])
with single_tab:
render_single_tab(tokenizer, model)
with batch_tab:
render_batch_tab(tokenizer, model)
with about_tab:
st.markdown(
"""
**Pipeline**: tweets were embedded with `all-mpnet-base-v2`, clustered with
BERTopic, cross-checked against the GDELT DOC 2.0 API with a lifespan-aware
rule, and the resulting `(tweet, label)` pairs were used to fine-tune
`bert-base-uncased`.
- **Input**: raw tweet text
- **Preprocessing**: lowercase, strip URLs / mentions / hashtag chars / non-word
- **Max length**: 128 tokens
- **Labels**: `0 = meme`, `1 = real_event`
"""
)
if __name__ == "__main__":
main()