Spaces:
Sleeping
Sleeping
File size: 7,697 Bytes
f9e8817 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | """Streamlit Space: Meme vs Real Event tweet classifier.
Loads a fine-tuned bert-base-uncased from the Hugging Face Hub and exposes:
- Single-tweet tab: live prediction + probability bar chart
- Batch CSV tab: upload a CSV with a `text` column, download predictions
Matching preprocessing (same regex as the training notebook) is reapplied
so results mirror what the notebook produces locally.
"""
from __future__ import annotations
import io
import os
import re
import numpy as np
import pandas as pd
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_ID = os.environ.get("MODEL_ID", "Aryan047/Dynamic-event-detector")
MAX_LENGTH = 128
LABELS = {0: "meme", 1: "real_event"}
_URL_RE = re.compile(r"https?://\S+|www\.\S+")
_MENTION_RE = re.compile(r"@\w+")
_HASHTAG_RE = re.compile(r"#")
_NON_WORD_RE = re.compile(r"[^a-z0-9\s]")
_WS_RE = re.compile(r"\s+")
def clean_tweet(text: str) -> str:
if not isinstance(text, str):
return ""
t = text.lower()
t = _URL_RE.sub(" ", t)
t = _MENTION_RE.sub(" ", t)
t = _HASHTAG_RE.sub(" ", t)
t = _NON_WORD_RE.sub(" ", t)
t = _WS_RE.sub(" ", t).strip()
return t
@st.cache_resource(show_spinner="Loading model from Hugging Face Hub...")
def load_model(model_id: str):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.eval()
return tokenizer, model
@torch.no_grad()
def predict_one(tokenizer, model, text: str) -> dict:
cleaned = clean_tweet(text)
if not cleaned:
return {
"label": "meme",
"confidence": 0.0,
"prob_meme": 1.0,
"prob_real_event": 0.0,
"clean_text": "",
}
enc = tokenizer(cleaned, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
probs = F.softmax(model(**enc).logits[0], dim=-1).numpy()
pred = int(np.argmax(probs))
return {
"label": LABELS[pred],
"confidence": float(probs[pred]),
"prob_meme": float(probs[0]),
"prob_real_event": float(probs[1]),
"clean_text": cleaned,
}
@torch.no_grad()
def predict_many(tokenizer, model, texts: list[str], batch_size: int = 32) -> pd.DataFrame:
cleaned = [clean_tweet(t) for t in texts]
labels, confs, p0s, p1s = [], [], [], []
progress = st.progress(0.0, text="Running predictions...")
total = max(len(cleaned), 1)
for i in range(0, len(cleaned), batch_size):
chunk = cleaned[i : i + batch_size]
empty_mask = [len(c) == 0 for c in chunk]
model_inputs = [c if c else "empty" for c in chunk]
enc = tokenizer(
model_inputs,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt",
)
probs = F.softmax(model(**enc).logits, dim=-1).numpy()
for j, p in enumerate(probs):
if empty_mask[j]:
labels.append("meme")
confs.append(0.0)
p0s.append(1.0)
p1s.append(0.0)
else:
pred = int(np.argmax(p))
labels.append(LABELS[pred])
confs.append(float(p[pred]))
p0s.append(float(p[0]))
p1s.append(float(p[1]))
progress.progress(min((i + batch_size) / total, 1.0))
progress.empty()
return pd.DataFrame(
{
"text": texts,
"clean_text": cleaned,
"label": labels,
"confidence": confs,
"prob_meme": p0s,
"prob_real_event": p1s,
}
)
def render_single_tab(tokenizer, model) -> None:
st.subheader("Classify a single tweet")
st.caption("Paste any tweet-style text. Labels: `meme` or `real_event`.")
default_example = "Massive 6.5 earthquake just rocked Istanbul, buildings swaying"
text = st.text_area("Tweet text", value=default_example, height=120)
if st.button("Predict", type="primary"):
if not text.strip():
st.warning("Please enter some text.")
return
result = predict_one(tokenizer, model, text)
col1, col2 = st.columns(2)
col1.metric("Predicted label", result["label"])
col2.metric("Confidence", f"{result['confidence']:.2%}")
st.markdown("**Class probabilities**")
st.bar_chart(
pd.DataFrame(
{"probability": [result["prob_meme"], result["prob_real_event"]]},
index=["meme", "real_event"],
)
)
with st.expander("Details"):
st.write({"cleaned_text": result["clean_text"]})
def render_batch_tab(tokenizer, model) -> None:
st.subheader("Classify a CSV of tweets")
st.caption("Upload a CSV with a `text` column. Predictions are added as new columns.")
uploaded = st.file_uploader("CSV file", type=["csv"])
if uploaded is None:
st.info("Waiting for a CSV upload...")
return
try:
df = pd.read_csv(uploaded)
except Exception as exc:
st.error(f"Could not read CSV: {exc}")
return
if "text" not in df.columns:
st.error(f"CSV must contain a `text` column. Found: {list(df.columns)}")
return
max_rows = 5000
if len(df) > max_rows:
st.warning(f"CSV has {len(df)} rows. Truncating to first {max_rows} for the demo.")
df = df.head(max_rows).copy()
st.write(f"Loaded {len(df)} rows. Preview:")
st.dataframe(df.head(5))
if st.button("Run batch prediction", type="primary"):
out = predict_many(tokenizer, model, df["text"].tolist())
merged = pd.concat(
[df.reset_index(drop=True).drop(columns=["text"]), out.reset_index(drop=True)],
axis=1,
)
st.success(f"Classified {len(merged)} tweets.")
st.dataframe(merged.head(50))
counts = merged["label"].value_counts().reindex(["meme", "real_event"], fill_value=0)
st.markdown("**Label distribution**")
st.bar_chart(counts)
buf = io.StringIO()
merged.to_csv(buf, index=False)
st.download_button(
label="Download predictions CSV",
data=buf.getvalue(),
file_name="meme_vs_event_predictions.csv",
mime="text/csv",
)
def main() -> None:
st.set_page_config(
page_title="Meme vs Real Event Classifier",
page_icon="",
layout="centered",
)
st.title("Meme vs Real Event Tweet Classifier")
st.caption(
f"Fine-tuned `bert-base-uncased` loaded from "
f"[`{MODEL_ID}`](https://huggingface.co/{MODEL_ID})."
)
tokenizer, model = load_model(MODEL_ID)
single_tab, batch_tab, about_tab = st.tabs(["Single tweet", "Batch CSV", "About"])
with single_tab:
render_single_tab(tokenizer, model)
with batch_tab:
render_batch_tab(tokenizer, model)
with about_tab:
st.markdown(
"""
**Pipeline**: tweets were embedded with `all-mpnet-base-v2`, clustered with
BERTopic, cross-checked against the GDELT DOC 2.0 API with a lifespan-aware
rule, and the resulting `(tweet, label)` pairs were used to fine-tune
`bert-base-uncased`.
- **Input**: raw tweet text
- **Preprocessing**: lowercase, strip URLs / mentions / hashtag chars / non-word
- **Max length**: 128 tokens
- **Labels**: `0 = meme`, `1 = real_event`
"""
)
if __name__ == "__main__":
main()
|