Gunroar's picture
Update app.py
ee2a733 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import plotly.express as px
import torch
import emoji
import asyncio
import os
import json
import numpy as np
import folium
from streamlit_folium import st_folium
from rnet_twitter import RnetTwitterClient
import joblib
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# Page configuration
# ---------------------------------------------------------------------------
st.set_page_config(page_title="Malaria Discourse Classifier", layout="wide")
st.title("🦟 Malaria Discourse Classifier (Nigeria)")
st.markdown(
"Classifies tweets into 5 categories using an ensemble of BERTweet, SVM, and Logistic Regression."
)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MODEL_PATH = "Gunroar/ng_malaria_tweet_classifier_model"
BASELINES_PATH = MODEL_PATH # same repo
LABELS = [
"Symptoms & Burden",
"Treatment & Health System",
"Prevention & Awareness",
"Misinformation",
"Irrelevant",
]
LABEL_COLORS = {
"Symptoms & Burden": "#4C72B0",
"Treatment & Health System": "#55A868",
"Prevention & Awareness": "#C44E52",
"Misinformation": "#DD8452",
"Irrelevant": "#8172B2",
}
PIN_COLORS = {
"Symptoms & Burden": "blue",
"Treatment & Health System": "green",
"Prevention & Awareness": "red",
"Misinformation": "orange",
"Irrelevant": "purple",
}
NGA_CENTER = (9.0820, 8.6753)
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
@st.cache_resource
def load_models():
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, normalization=True)
bert_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
bert_model.eval()
smodel_file = hf_hub_download("Gunroar/ng_malaria_tweet_classifier_model", "linear_svc_model.joblib")
lmodel_file = hf_hub_download("Gunroar/ng_malaria_tweet_classifier_model", "logistic_regression_model.joblib")
tmodel_file = hf_hub_download("Gunroar/ng_malaria_tweet_classifier_model", "tfidf_vectorizer.joblib")
svm = joblib.load(smodel_file)
lr = joblib.load(lmodel_file)
tfidf = joblib.load(tmodel_file)
return tokenizer, bert_model, svm, lr, tfidf
tokenizer, bert_model, svm_model, lr_model, tfidf_vec = load_models()
# ---------------------------------------------------------------------------
# Ensemble prediction
# ---------------------------------------------------------------------------
def predict_ensemble(text: str) -> dict:
"""Return ensemble label plus per-model details and BERTweet probabilities."""
clean_text = emoji.demojize(text, delimiters=(" ", " "))
# BERTweet
inputs = tokenizer(
clean_text, return_tensors="pt",
truncation=True, padding=True, max_length=128
)
with torch.no_grad():
logits = bert_model(**inputs).logits
bert_probs = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
bert_idx = int(torch.argmax(logits))
bert_label = LABELS[bert_idx]
bert_conf = bert_probs[bert_idx]
# Baseline models
features = tfidf_vec.transform([clean_text])
svm_idx = int(svm_model.predict(features)[0])
lr_idx = int(lr_model.predict(features)[0])
svm_label = LABELS[svm_idx]
lr_label = LABELS[lr_idx]
# Ensemble rules
final_label = bert_label
# Rule 1: both baselines → Treatment, BERT says Prevention with low confidence
if (svm_label == "Treatment & Health System"
and lr_label == "Treatment & Health System"
and bert_label == "Prevention & Awareness"
and bert_conf < 0.40):
final_label = "Treatment & Health System"
# Rule 2: both baselines → Symptoms, BERT says Prevention with high confidence
elif (svm_label == "Symptoms & Burden"
and lr_label == "Symptoms & Burden"
and bert_label == "Prevention & Awareness"
and bert_conf > 0.70):
final_label = "Symptoms & Burden"
# Rule 3: BERT unsure about Irrelevant → defer to LR
elif bert_label == "Irrelevant" and bert_conf < 0.70:
final_label = lr_label
return {
"final": final_label,
"bert": (bert_label, bert_conf),
"svm": svm_label,
"lr": lr_label,
"probs": bert_probs,
}
# ---------------------------------------------------------------------------
# Batch helper
# ---------------------------------------------------------------------------
def classify_dataframe(
df: pd.DataFrame,
text_column: str = "text",
confidence_threshold: float = 0.5,
) -> pd.DataFrame:
results = df[text_column].apply(predict_ensemble)
df = df.copy()
df["predicted_label"] = results.apply(lambda r: r["final"])
df["confidence"] = results.apply(lambda r: r["bert"][1])
df["bert_label"] = results.apply(lambda r: r["bert"][0])
df["svm_label"] = results.apply(lambda r: r["svm"])
df["lr_label"] = results.apply(lambda r: r["lr"])
low_conf_mask = df["confidence"] < confidence_threshold
df.loc[low_conf_mask, "predicted_label"] = "Low Confidence"
return df
# ---------------------------------------------------------------------------
# Twitter scraper — uses os.environ instead of st.secrets
# ---------------------------------------------------------------------------
@st.cache_resource
def load_scraper_client():
"""
Initialise RnetTwitterClient using X_AUTH_TOKEN and ct0 env vars,
writing them to a temporary cookies.json the same way the original
app did with st.secrets. Falls back to a pre-existing cookies.json.
"""
client = RnetTwitterClient()
cookie_path = "cookies.json"
auth_token = os.environ.get("X_AUTH_TOKEN")
ct0 = os.environ.get("ct0")
if auth_token and ct0:
try:
cookies_data = [
{"name": "auth_token", "value": auth_token},
{"name": "ct0", "value": ct0},
]
with open(cookie_path, "w") as f:
json.dump(cookies_data, f)
client.load_cookies(cookie_path)
return client
except Exception as e:
st.sidebar.warning(f"Could not load cookies from env vars: {e}")
if os.path.exists(cookie_path):
try:
client.load_cookies(cookie_path)
return client
except Exception as e:
st.sidebar.warning(f"Could not load cookies from cookies.json: {e}")
return None
scraper_client = load_scraper_client()
# ---------------------------------------------------------------------------
# Session-state defaults
# ---------------------------------------------------------------------------
if "raw_scraped_tweets" not in st.session_state:
st.session_state["raw_scraped_tweets"] = []
if "classified_live_tweets" not in st.session_state:
st.session_state["classified_live_tweets"] = pd.DataFrame()
# ===========================================================================
# UI TABS
# ===========================================================================
tab1, tab2, tab3 = st.tabs(["Single Tweet", "Batch CSV", "Live Scrape"])
# ---------------------------------------------------------------------------
# Tab 1 — Single Tweet
# ---------------------------------------------------------------------------
with tab1:
st.subheader("Classify one tweet")
user_input = st.text_area(
"Enter a tweet:",
"I have fever and chills. Is it malaria?",
height=100,
)
if st.button("Classify", key="single_classify"):
if user_input.strip():
res = predict_ensemble(user_input)
st.success(f"**Final Decision: {res['final']}**")
col1, col2, col3 = st.columns(3)
col1.metric("BERTweet", res["bert"][0], f"{res['bert'][1]:.1%}")
col2.metric("SVM", res["svm"])
col3.metric("Logistic Regression",res["lr"])
# Confidence bar chart (BERTweet probabilities)
conf_df = pd.DataFrame({"Category": LABELS, "Confidence": res["probs"]})
fig = px.bar(
conf_df,
x="Confidence", y="Category",
orientation="h",
color="Category",
color_discrete_map=LABEL_COLORS,
title="BERTweet Confidence Scores",
)
fig.update_layout(showlegend=False, xaxis_tickformat=".0%")
st.plotly_chart(fig, use_container_width=True)
else:
st.warning("Please enter some text to classify.")
# ---------------------------------------------------------------------------
# Tab 2 — Batch CSV
# ---------------------------------------------------------------------------
with tab2:
st.subheader("Batch classify from CSV")
threshold = st.slider(
"Confidence threshold",
min_value=0.0, max_value=1.0, value=0.5, step=0.05,
help="Predictions with confidence below this value will be labelled 'Low Confidence'.",
key="batch_threshold",
)
uploaded = st.file_uploader(
"Upload a CSV file with a 'text' column", type=["csv"]
)
if uploaded is not None:
try:
batch_df = pd.read_csv(uploaded)
if "text" not in batch_df.columns:
st.error("The CSV must contain a column named 'text'.")
else:
with st.spinner("Classifying tweets…"):
result_df = classify_dataframe(
batch_df.copy(),
text_column="text",
confidence_threshold=threshold,
)
st.success(f"Classified {len(result_df)} tweets.")
st.dataframe(
result_df[
["text", "predicted_label", "confidence",
"bert_label", "svm_label", "lr_label"]
].head(20)
)
# Distribution pie chart
dist = result_df["predicted_label"].value_counts().reset_index()
dist.columns = ["Category", "Count"]
fig_dist = px.pie(
dist,
names="Category", values="Count",
color="Category", color_discrete_map=LABEL_COLORS,
title="Category Distribution",
)
st.plotly_chart(fig_dist, use_container_width=True)
# Per-model agreement breakdown
st.subheader("Model Agreement Breakdown")
agreement_data = []
for label in LABELS:
agreement_data.append({
"Category": label,
"BERTweet": (result_df["bert_label"] == label).sum(),
"SVM": (result_df["svm_label"] == label).sum(),
"LR": (result_df["lr_label"] == label).sum(),
"Ensemble": (result_df["predicted_label"] == label).sum(),
})
agree_df = pd.DataFrame(agreement_data)
fig_agree = px.bar(
agree_df.melt(id_vars="Category", var_name="Model", value_name="Count"),
x="Category", y="Count", color="Model", barmode="group",
title="Predictions per Model per Category",
)
st.plotly_chart(fig_agree, use_container_width=True)
# Download
csv_out = result_df.to_csv(index=False).encode("utf-8")
st.download_button(
label="⬇ Download Results CSV",
data=csv_out,
file_name="classified_tweets.csv",
mime="text/csv",
)
except Exception as e:
st.error(f"Error processing file: {e}")
# ---------------------------------------------------------------------------
# Tab 3 — Live Scrape
# ---------------------------------------------------------------------------
with tab3:
st.subheader("Live Tweet Scraper")
st.markdown(
"Search for recent tweets, classify them with the ensemble, "
"and view their simulated geographic distribution within Nigeria."
)
query = st.text_input(
"Search Query",
value='(malaria OR "mosquito net") geocode:9.0820,8.6753,650km',
#"malaria Nigeria OR mosquito net Nigeria",
help="Use Twitter search operators. Retweets are automatically excluded.",
)
count = st.slider("Number of tweets to fetch", min_value=10, max_value=200, value=50)
today = pd.to_datetime("today").date()
default_start_date = today - pd.Timedelta(days=7)
scrape_start_date = st.date_input("Start Date (optional)", value=default_start_date)
scrape_end_date = st.date_input("End Date (optional)", value=today)
scrape_threshold = st.slider(
"Confidence threshold for live scrape",
min_value=0.0, max_value=1.0, value=0.5, step=0.05,
key="live_threshold",
)
if st.button("Scrape and Classify", key="live_scrape"):
if scraper_client is None:
st.error(
"Scraper client not initialised. "
"Set X_AUTH_TOKEN and ct0 env vars or provide cookies.json."
)
else:
with st.spinner("Scraping tweets… This may take a moment."):
async def perform_scrape(query_str, tweet_count, start_date, end_date):
full_query = f"{query_str} -is:retweet"
if start_date:
full_query += f" since:{start_date.isoformat()}"
if end_date:
full_query += f" until:{end_date.isoformat()}"
try:
return await scraper_client.search_tweets(
full_query, count=tweet_count, product="Latest"
)
except Exception as e:
st.error(f"Scraping error: {e}")
return []
raw = asyncio.run(
perform_scrape(query, count, scrape_start_date, scrape_end_date)
)
st.session_state["raw_scraped_tweets"] = raw
if not raw:
st.warning("No tweets fetched. Try a different query or check authentication.")
st.session_state["classified_live_tweets"] = pd.DataFrame()
else:
st.success(f"Fetched {len(raw)} tweets.")
live_df = pd.DataFrame(raw)
if "text" not in live_df.columns:
st.warning("Scraped data has no 'text' column. Showing raw data:")
st.dataframe(live_df.head())
st.session_state["classified_live_tweets"] = pd.DataFrame()
else:
with st.spinner("Classifying tweets…"):
classified = classify_dataframe(
live_df[["text"]].copy(),
text_column="text",
confidence_threshold=scrape_threshold,
)
st.session_state["classified_live_tweets"] = classified
# ---- Display stored results ----
classified_live = st.session_state["classified_live_tweets"]
if not classified_live.empty:
st.subheader("Classified Tweets (sample)")
st.dataframe(
classified_live[
["text", "predicted_label", "confidence",
"bert_label", "svm_label", "lr_label"]
].head(10)
)
# Distribution pie chart
dist_live = classified_live["predicted_label"].value_counts().reset_index()
dist_live.columns = ["Category", "Count"]
fig_pie = px.pie(
dist_live,
names="Category", values="Count",
color="Category", color_discrete_map=LABEL_COLORS,
title="Live Scrape — Category Distribution",
)
st.plotly_chart(fig_pie, use_container_width=True)
# Per-model agreement for live data
st.subheader("Model Agreement (Live)")
agree_live = []
for label in LABELS:
agree_live.append({
"Category": label,
"BERTweet": (classified_live["bert_label"] == label).sum(),
"SVM": (classified_live["svm_label"] == label).sum(),
"LR": (classified_live["lr_label"] == label).sum(),
"Ensemble": (classified_live["predicted_label"] == label).sum(),
})
fig_agree_live = px.bar(
pd.DataFrame(agree_live).melt(
id_vars="Category", var_name="Model", value_name="Count"
),
x="Category", y="Count", color="Model", barmode="group",
title="Live Predictions per Model per Category",
)
st.plotly_chart(fig_agree_live, use_container_width=True)
# Download
csv_live = classified_live.to_csv(index=False).encode("utf-8")
st.download_button(
label="⬇ Download Scraped & Classified Results",
data=csv_live,
file_name="scraped_classified_tweets.csv",
mime="text/csv",
)
# Simulated map# ---- Smart Geographic Distribution (City Text Mapping) ----
st.subheader("Geographic Distribution Analysis")
st.info(
"Mapping based on query geocode boundaries. Tweets mentioning known cities "
"are placed contextually; others are mapped to the search area center."
)
# 1. Coordinate maps for major Nigerian hubs
# 650km radius from center easily encapsulates all of these
# Comprehensive coordinate map for Nigerian cities and states
# Cities are listed first to ensure finer grain matching; states act as fallbacks
CITY_COORDINATES = {
# --- South-East (Cities & States) ---
"owerri": (5.4856, 7.0358, "Imo State"),
"enugu": (6.4403, 7.4942, "Enugu State"),
"onitsha": (6.1421, 6.7909, "Anambra State"),
"awka": (6.2107, 7.0736, "Anambra State"),
"aba": (5.1066, 7.3697, "Abia State"),
"umuahia": (5.5267, 7.4896, "Abia State"),
"abakaliki": (6.3249, 8.1132, "Ebonyi State"),
# Standalone States Fallbacks
"imo": (5.5486, 7.1400, "Imo State Centroid"),
"anambra": (6.2105, 6.9458, "Anambra State Centroid"),
"abia": (5.4167, 7.5000, "Abia State Centroid"),
"ebonyi": (6.2500, 8.0833, "Ebonyi State Centroid"),
# --- South-South (Cities & States) ---
"port harcourt": (4.7655, 7.0163, "Rivers State"),
"phc": (4.7655, 7.0163, "Rivers State"),
"benin": (6.3176, 5.6145, "Edo State"),
"calabar": (4.9757, 8.3417, "Cross River State"),
"uyo": (5.0333, 7.9266, "Akwa Ibom State"),
"warri": (5.5167, 5.7500, "Delta State"),
"asaba": (6.1983, 6.7320, "Delta State"),
"yenagoa": (4.9267, 6.2642, "Bayelsa State"),
# Standalone States Fallbacks
"rivers state": (4.8500, 6.8500, "Rivers State Centroid"),
"edo": (6.5244, 5.8987, "Edo State Centroid"),
"cross river": (5.7500, 8.5000, "Cross River State Centroid"),
"akwa ibom": (5.0000, 7.8333, "Akwa Ibom State Centroid"),
"delta state": (5.7000, 6.0000, "Delta State Centroid"),
"bayelsa": (4.7500, 6.0833, "Bayelsa State Centroid"),
# --- South-West (Cities & States) ---
"lagos": (6.4561, 3.3936, "Lagos State"),
"ikeja": (6.5920, 3.3422, "Lagos State"),
"ibadan": (7.3775, 3.9058, "Oyo State"),
"abeokuta": (7.1611, 3.3483, "Ogun State"),
"akure": (7.2526, 5.1931, "Ondo State"),
"osogbo": (7.7710, 4.5624, "Osun State"),
"ado ekiti": (7.6233, 5.2208, "Ekiti State"),
# Standalone States Fallbacks
"oyo": (8.0000, 4.0000, "Oyo State Centroid"),
"ogun": (7.0000, 3.5000, "Ogun State Centroid"),
"ondo": (7.1667, 5.0833, "Ondo State Centroid"),
"osun": (7.5000, 4.5000, "Osun State Centroid"),
"ekiti": (7.6667, 5.2500, "Ekiti State Centroid"),
# --- North-Central (Cities & States) ---
"abuja": (9.0556, 7.4914, "FCT"),
"fct": (9.0556, 7.4914, "FCT"),
"ilorin": (8.5000, 4.5500, "Kwara State"),
"jos": (9.8965, 8.8583, "Plateau State"),
"minna": (9.6139, 6.5569, "Niger State"),
"lokoja": (7.8024, 6.7418, "Kogi State"),
"makurdi": (7.7325, 8.5214, "Benue State"),
"lafia": (8.4917, 8.5167, "Nasarawa State"),
# Standalone States Fallbacks
"kwara": (8.5000, 4.7500, "Kwara State Centroid"),
"plateau": (9.2500, 9.5000, "Plateau State Centroid"),
"niger state": (10.0000, 6.0000, "Niger State Centroid"),
"kogi": (7.7500, 6.7500, "Kogi State Centroid"),
"benue": (7.3333, 8.7500, "Benue State Centroid"),
"nasarawa": (8.3167, 8.1667, "Nasarawa State Centroid"),
# --- North-West (Cities & States) ---
"kano": (12.0022, 8.5920, "Kano State"),
"kaduna": (10.5105, 7.4165, "Kaduna State"),
"sokoto": (13.0622, 5.2339, "Sokoto State"),
"zaria": (11.0855, 7.7196, "Kaduna State"),
"katsina": (12.9894, 7.6006, "Katsina State"),
"gusau": (12.1628, 6.6614, "Zamfara State"),
"birnin kebbi": (12.4539, 4.1975, "Kebbi State"),
# Standalone States Fallbacks
"zamfara": (12.1667, 6.2500, "Zamfara State Centroid"),
"kebbi": (11.5000, 4.0000, "Kebbi State Centroid"),
"jigawa": (12.2500, 9.7500, "Jigawa State Centroid"),
# --- North-East (Cities & States) ---
"maiduguri": (11.8333, 13.1500, "Borno State"),
"bauchi": (10.3158, 9.8442, "Bauchi State"),
"gombe": (10.2897, 11.1673, "Gombe State"),
"yola": (9.2035, 12.4954, "Adamawa State"),
"damaturu": (11.7470, 11.9608, "Yobe State"),
"jalingo": (8.8922, 11.3636, "Taraba State"),
# Standalone States Fallbacks
"borno": (11.5000, 13.0000, "Borno State Centroid"),
"adamawa": (9.3333, 12.5000, "Adamawa State Centroid"),
"yobe": (12.0000, 11.5000, "Yobe State Centroid"),
"taraba": (8.0000, 10.5000, "Taraba State Centroid")
}
# Parse out your search center point
import re
geocode_match = re.search(r"geocode:([\d\.-]+),([\d\.-]+)", query)
center_lat, center_lon = (float(geocode_match.group(1)), float(geocode_match.group(2))) if geocode_match else NGA_CENTER
m = folium.Map(location=[center_lat, center_lon], zoom_start=6, tiles="CartoDB positron")
# Search radius overlay (650km)
folium.Circle(
location=[center_lat, center_lon],
radius=650000,
color="#2980b9",
fill=True,
fill_opacity=0.04,
popup="Active Search Radius (650km)"
).add_to(m)
rng = np.random.default_rng(seed=42)
# 2. Iterate through rows and extract location signals from the text itself
for _, row in classified_live.iterrows():
tweet_text = str(row["text"]).lower()
author_name = row.get("author", "X User")
# Default fallback assignment
base_lat = center_lat + rng.uniform(-1.2, 1.2)
base_lon = center_lon + rng.uniform(-1.5, 1.5)
detected_loc = "Detected inside Geocode Range"
# Check text against the expanded dictionary
for place_key, info in CITY_COORDINATES.items():
# Use \b bound checking to ensure "imo" doesn't match "immobilize"
if re.search(r'\b' + re.escape(place_key) + r'\b', tweet_text):
lat, lon, location_tag = info
# Apply small visual jitter
base_lat = lat + rng.uniform(-0.06, 0.06)
base_lon = lon + rng.uniform(-0.06, 0.06)
# Formatting text output
if "Centroid" in location_tag:
detected_loc = f"{location_tag.replace(' Centroid', '')}"
else:
detected_loc = f"{place_key.title()} ({location_tag})"
break
# 3. Add pins to map
folium.CircleMarker(
location=[base_lat, base_lon],
radius=6,
color=PIN_COLORS.get(row["predicted_label"], "gray"),
fill=True,
fill_opacity=0.85,
popup=folium.Popup(
f"<b>Category: {row['predicted_label']}</b><br>"
f"<b>Source Hub:</b> {detected_loc}<br>"
f"<b>Author:</b> @{author_name}<br><br>"
f"<i>\"{str(row['text'])[:140]}...\"</i><br><br>"
f"Confidence: {row['confidence']:.2%}",
max_width=280,
),
).add_to(m)
st_folium(m, width=700, height=500)
if st.button("Clear Scraped Data", key="clear_live_data"):
st.session_state["classified_live_tweets"] = pd.DataFrame()
st.session_state["raw_scraped_tweets"] = []
st.success("Scraped data cleared.")
# ===========================================================================
# Sidebar
# ===========================================================================
st.sidebar.markdown("## About")
st.sidebar.info(
"**Model:** BERTweet fine-tuned on Nigerian malaria-related tweets.\n\n"
"**Categories:** Symptoms & Burden, Treatment & Health System, "
"Prevention & Awareness, Misinformation, Irrelevant.\n\n"
"**Ensemble:** BERTweet + SVM + Logistic Regression.\n\n"
"**Purpose:** Final Year Project — Infodemiology.\n\n"
"**Author:** Chime Ugochukwu Chiziri Kevin"
)
st.sidebar.markdown("## Auth Status")
auth_from_env = bool(os.environ.get("X_AUTH_TOKEN") and os.environ.get("ct0"))
if auth_from_env:
st.sidebar.success("✅ Twitter credentials loaded from environment variables (X_AUTH_TOKEN + ct0).")
elif os.path.exists("cookies.json"):
st.sidebar.success("✅ cookies.json found in application directory.")
else:
st.sidebar.error(
"❌ No Twitter credentials found. "
"Set X_AUTH_TOKEN and ct0 env vars or provide cookies.json."
)
st.sidebar.markdown("## Notes")
st.sidebar.markdown(
"- Batch CSV must contain a `text` column.\n"
"- Predictions below the confidence threshold are labelled *Low Confidence*.\n"
"- Live scraping requires `X_AUTH_TOKEN` and `ct0` env vars or a `cookies.json` file.\n"
"- Map locations are simulated; Twitter APIs do not provide geo-coordinates.\n"
"- All three models' decisions are shown in every tab for transparency."
)