"""Streamlit Space: Meme vs Real Event tweet classifier. Loads a fine-tuned bert-base-uncased from the Hugging Face Hub and exposes: - Single-tweet tab: live prediction + probability bar chart - Batch CSV tab: upload a CSV with a `text` column, download predictions Matching preprocessing (same regex as the training notebook) is reapplied so results mirror what the notebook produces locally. """ from __future__ import annotations import io import os import re import numpy as np import pandas as pd import streamlit as st import torch import torch.nn.functional as F from transformers import AutoModelForSequenceClassification, AutoTokenizer MODEL_ID = os.environ.get("MODEL_ID", "Aryan047/Dynamic-event-detector") MAX_LENGTH = 128 LABELS = {0: "meme", 1: "real_event"} _URL_RE = re.compile(r"https?://\S+|www\.\S+") _MENTION_RE = re.compile(r"@\w+") _HASHTAG_RE = re.compile(r"#") _NON_WORD_RE = re.compile(r"[^a-z0-9\s]") _WS_RE = re.compile(r"\s+") def clean_tweet(text: str) -> str: if not isinstance(text, str): return "" t = text.lower() t = _URL_RE.sub(" ", t) t = _MENTION_RE.sub(" ", t) t = _HASHTAG_RE.sub(" ", t) t = _NON_WORD_RE.sub(" ", t) t = _WS_RE.sub(" ", t).strip() return t @st.cache_resource(show_spinner="Loading model from Hugging Face Hub...") def load_model(model_id: str): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained(model_id) model.eval() return tokenizer, model @torch.no_grad() def predict_one(tokenizer, model, text: str) -> dict: cleaned = clean_tweet(text) if not cleaned: return { "label": "meme", "confidence": 0.0, "prob_meme": 1.0, "prob_real_event": 0.0, "clean_text": "", } enc = tokenizer(cleaned, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") probs = F.softmax(model(**enc).logits[0], dim=-1).numpy() pred = int(np.argmax(probs)) return { "label": LABELS[pred], "confidence": float(probs[pred]), "prob_meme": float(probs[0]), "prob_real_event": float(probs[1]), "clean_text": cleaned, } @torch.no_grad() def predict_many(tokenizer, model, texts: list[str], batch_size: int = 32) -> pd.DataFrame: cleaned = [clean_tweet(t) for t in texts] labels, confs, p0s, p1s = [], [], [], [] progress = st.progress(0.0, text="Running predictions...") total = max(len(cleaned), 1) for i in range(0, len(cleaned), batch_size): chunk = cleaned[i : i + batch_size] empty_mask = [len(c) == 0 for c in chunk] model_inputs = [c if c else "empty" for c in chunk] enc = tokenizer( model_inputs, truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt", ) probs = F.softmax(model(**enc).logits, dim=-1).numpy() for j, p in enumerate(probs): if empty_mask[j]: labels.append("meme") confs.append(0.0) p0s.append(1.0) p1s.append(0.0) else: pred = int(np.argmax(p)) labels.append(LABELS[pred]) confs.append(float(p[pred])) p0s.append(float(p[0])) p1s.append(float(p[1])) progress.progress(min((i + batch_size) / total, 1.0)) progress.empty() return pd.DataFrame( { "text": texts, "clean_text": cleaned, "label": labels, "confidence": confs, "prob_meme": p0s, "prob_real_event": p1s, } ) def render_single_tab(tokenizer, model) -> None: st.subheader("Classify a single tweet") st.caption("Paste any tweet-style text. Labels: `meme` or `real_event`.") default_example = "Massive 6.5 earthquake just rocked Istanbul, buildings swaying" text = st.text_area("Tweet text", value=default_example, height=120) if st.button("Predict", type="primary"): if not text.strip(): st.warning("Please enter some text.") return result = predict_one(tokenizer, model, text) col1, col2 = st.columns(2) col1.metric("Predicted label", result["label"]) col2.metric("Confidence", f"{result['confidence']:.2%}") st.markdown("**Class probabilities**") st.bar_chart( pd.DataFrame( {"probability": [result["prob_meme"], result["prob_real_event"]]}, index=["meme", "real_event"], ) ) with st.expander("Details"): st.write({"cleaned_text": result["clean_text"]}) def render_batch_tab(tokenizer, model) -> None: st.subheader("Classify a CSV of tweets") st.caption("Upload a CSV with a `text` column. Predictions are added as new columns.") uploaded = st.file_uploader("CSV file", type=["csv"]) if uploaded is None: st.info("Waiting for a CSV upload...") return try: df = pd.read_csv(uploaded) except Exception as exc: st.error(f"Could not read CSV: {exc}") return if "text" not in df.columns: st.error(f"CSV must contain a `text` column. Found: {list(df.columns)}") return max_rows = 5000 if len(df) > max_rows: st.warning(f"CSV has {len(df)} rows. Truncating to first {max_rows} for the demo.") df = df.head(max_rows).copy() st.write(f"Loaded {len(df)} rows. Preview:") st.dataframe(df.head(5)) if st.button("Run batch prediction", type="primary"): out = predict_many(tokenizer, model, df["text"].tolist()) merged = pd.concat( [df.reset_index(drop=True).drop(columns=["text"]), out.reset_index(drop=True)], axis=1, ) st.success(f"Classified {len(merged)} tweets.") st.dataframe(merged.head(50)) counts = merged["label"].value_counts().reindex(["meme", "real_event"], fill_value=0) st.markdown("**Label distribution**") st.bar_chart(counts) buf = io.StringIO() merged.to_csv(buf, index=False) st.download_button( label="Download predictions CSV", data=buf.getvalue(), file_name="meme_vs_event_predictions.csv", mime="text/csv", ) def main() -> None: st.set_page_config( page_title="Meme vs Real Event Classifier", page_icon="", layout="centered", ) st.title("Meme vs Real Event Tweet Classifier") st.caption( f"Fine-tuned `bert-base-uncased` loaded from " f"[`{MODEL_ID}`](https://huggingface.co/{MODEL_ID})." ) tokenizer, model = load_model(MODEL_ID) single_tab, batch_tab, about_tab = st.tabs(["Single tweet", "Batch CSV", "About"]) with single_tab: render_single_tab(tokenizer, model) with batch_tab: render_batch_tab(tokenizer, model) with about_tab: st.markdown( """ **Pipeline**: tweets were embedded with `all-mpnet-base-v2`, clustered with BERTopic, cross-checked against the GDELT DOC 2.0 API with a lifespan-aware rule, and the resulting `(tweet, label)` pairs were used to fine-tune `bert-base-uncased`. - **Input**: raw tweet text - **Preprocessing**: lowercase, strip URLs / mentions / hashtag chars / non-word - **Max length**: 128 tokens - **Labels**: `0 = meme`, `1 = real_event` """ ) if __name__ == "__main__": main()