Spaces:
Sleeping
Sleeping
| """Streamlit Space: Meme vs Real Event tweet classifier. | |
| Loads a fine-tuned bert-base-uncased from the Hugging Face Hub and exposes: | |
| - Single-tweet tab: live prediction + probability bar chart | |
| - Batch CSV tab: upload a CSV with a `text` column, download predictions | |
| Matching preprocessing (same regex as the training notebook) is reapplied | |
| so results mirror what the notebook produces locally. | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import os | |
| import re | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| MODEL_ID = os.environ.get("MODEL_ID", "Aryan047/Dynamic-event-detector") | |
| MAX_LENGTH = 128 | |
| LABELS = {0: "meme", 1: "real_event"} | |
| _URL_RE = re.compile(r"https?://\S+|www\.\S+") | |
| _MENTION_RE = re.compile(r"@\w+") | |
| _HASHTAG_RE = re.compile(r"#") | |
| _NON_WORD_RE = re.compile(r"[^a-z0-9\s]") | |
| _WS_RE = re.compile(r"\s+") | |
| def clean_tweet(text: str) -> str: | |
| if not isinstance(text, str): | |
| return "" | |
| t = text.lower() | |
| t = _URL_RE.sub(" ", t) | |
| t = _MENTION_RE.sub(" ", t) | |
| t = _HASHTAG_RE.sub(" ", t) | |
| t = _NON_WORD_RE.sub(" ", t) | |
| t = _WS_RE.sub(" ", t).strip() | |
| return t | |
| def load_model(model_id: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| model.eval() | |
| return tokenizer, model | |
| def predict_one(tokenizer, model, text: str) -> dict: | |
| cleaned = clean_tweet(text) | |
| if not cleaned: | |
| return { | |
| "label": "meme", | |
| "confidence": 0.0, | |
| "prob_meme": 1.0, | |
| "prob_real_event": 0.0, | |
| "clean_text": "", | |
| } | |
| enc = tokenizer(cleaned, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") | |
| probs = F.softmax(model(**enc).logits[0], dim=-1).numpy() | |
| pred = int(np.argmax(probs)) | |
| return { | |
| "label": LABELS[pred], | |
| "confidence": float(probs[pred]), | |
| "prob_meme": float(probs[0]), | |
| "prob_real_event": float(probs[1]), | |
| "clean_text": cleaned, | |
| } | |
| def predict_many(tokenizer, model, texts: list[str], batch_size: int = 32) -> pd.DataFrame: | |
| cleaned = [clean_tweet(t) for t in texts] | |
| labels, confs, p0s, p1s = [], [], [], [] | |
| progress = st.progress(0.0, text="Running predictions...") | |
| total = max(len(cleaned), 1) | |
| for i in range(0, len(cleaned), batch_size): | |
| chunk = cleaned[i : i + batch_size] | |
| empty_mask = [len(c) == 0 for c in chunk] | |
| model_inputs = [c if c else "empty" for c in chunk] | |
| enc = tokenizer( | |
| model_inputs, | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt", | |
| ) | |
| probs = F.softmax(model(**enc).logits, dim=-1).numpy() | |
| for j, p in enumerate(probs): | |
| if empty_mask[j]: | |
| labels.append("meme") | |
| confs.append(0.0) | |
| p0s.append(1.0) | |
| p1s.append(0.0) | |
| else: | |
| pred = int(np.argmax(p)) | |
| labels.append(LABELS[pred]) | |
| confs.append(float(p[pred])) | |
| p0s.append(float(p[0])) | |
| p1s.append(float(p[1])) | |
| progress.progress(min((i + batch_size) / total, 1.0)) | |
| progress.empty() | |
| return pd.DataFrame( | |
| { | |
| "text": texts, | |
| "clean_text": cleaned, | |
| "label": labels, | |
| "confidence": confs, | |
| "prob_meme": p0s, | |
| "prob_real_event": p1s, | |
| } | |
| ) | |
| def render_single_tab(tokenizer, model) -> None: | |
| st.subheader("Classify a single tweet") | |
| st.caption("Paste any tweet-style text. Labels: `meme` or `real_event`.") | |
| default_example = "Massive 6.5 earthquake just rocked Istanbul, buildings swaying" | |
| text = st.text_area("Tweet text", value=default_example, height=120) | |
| if st.button("Predict", type="primary"): | |
| if not text.strip(): | |
| st.warning("Please enter some text.") | |
| return | |
| result = predict_one(tokenizer, model, text) | |
| col1, col2 = st.columns(2) | |
| col1.metric("Predicted label", result["label"]) | |
| col2.metric("Confidence", f"{result['confidence']:.2%}") | |
| st.markdown("**Class probabilities**") | |
| st.bar_chart( | |
| pd.DataFrame( | |
| {"probability": [result["prob_meme"], result["prob_real_event"]]}, | |
| index=["meme", "real_event"], | |
| ) | |
| ) | |
| with st.expander("Details"): | |
| st.write({"cleaned_text": result["clean_text"]}) | |
| def render_batch_tab(tokenizer, model) -> None: | |
| st.subheader("Classify a CSV of tweets") | |
| st.caption("Upload a CSV with a `text` column. Predictions are added as new columns.") | |
| uploaded = st.file_uploader("CSV file", type=["csv"]) | |
| if uploaded is None: | |
| st.info("Waiting for a CSV upload...") | |
| return | |
| try: | |
| df = pd.read_csv(uploaded) | |
| except Exception as exc: | |
| st.error(f"Could not read CSV: {exc}") | |
| return | |
| if "text" not in df.columns: | |
| st.error(f"CSV must contain a `text` column. Found: {list(df.columns)}") | |
| return | |
| max_rows = 5000 | |
| if len(df) > max_rows: | |
| st.warning(f"CSV has {len(df)} rows. Truncating to first {max_rows} for the demo.") | |
| df = df.head(max_rows).copy() | |
| st.write(f"Loaded {len(df)} rows. Preview:") | |
| st.dataframe(df.head(5)) | |
| if st.button("Run batch prediction", type="primary"): | |
| out = predict_many(tokenizer, model, df["text"].tolist()) | |
| merged = pd.concat( | |
| [df.reset_index(drop=True).drop(columns=["text"]), out.reset_index(drop=True)], | |
| axis=1, | |
| ) | |
| st.success(f"Classified {len(merged)} tweets.") | |
| st.dataframe(merged.head(50)) | |
| counts = merged["label"].value_counts().reindex(["meme", "real_event"], fill_value=0) | |
| st.markdown("**Label distribution**") | |
| st.bar_chart(counts) | |
| buf = io.StringIO() | |
| merged.to_csv(buf, index=False) | |
| st.download_button( | |
| label="Download predictions CSV", | |
| data=buf.getvalue(), | |
| file_name="meme_vs_event_predictions.csv", | |
| mime="text/csv", | |
| ) | |
| def main() -> None: | |
| st.set_page_config( | |
| page_title="Meme vs Real Event Classifier", | |
| page_icon="", | |
| layout="centered", | |
| ) | |
| st.title("Meme vs Real Event Tweet Classifier") | |
| st.caption( | |
| f"Fine-tuned `bert-base-uncased` loaded from " | |
| f"[`{MODEL_ID}`](https://huggingface.co/{MODEL_ID})." | |
| ) | |
| tokenizer, model = load_model(MODEL_ID) | |
| single_tab, batch_tab, about_tab = st.tabs(["Single tweet", "Batch CSV", "About"]) | |
| with single_tab: | |
| render_single_tab(tokenizer, model) | |
| with batch_tab: | |
| render_batch_tab(tokenizer, model) | |
| with about_tab: | |
| st.markdown( | |
| """ | |
| **Pipeline**: tweets were embedded with `all-mpnet-base-v2`, clustered with | |
| BERTopic, cross-checked against the GDELT DOC 2.0 API with a lifespan-aware | |
| rule, and the resulting `(tweet, label)` pairs were used to fine-tune | |
| `bert-base-uncased`. | |
| - **Input**: raw tweet text | |
| - **Preprocessing**: lowercase, strip URLs / mentions / hashtag chars / non-word | |
| - **Max length**: 128 tokens | |
| - **Labels**: `0 = meme`, `1 = real_event` | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| main() | |