import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import pandas as pd import plotly.express as px import torch import emoji import asyncio import os import json import numpy as np import folium from streamlit_folium import st_folium from rnet_twitter import RnetTwitterClient import joblib from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # Page configuration # --------------------------------------------------------------------------- st.set_page_config(page_title="Malaria Discourse Classifier", layout="wide") st.title("🦟 Malaria Discourse Classifier (Nigeria)") st.markdown( "Classifies tweets into 5 categories using an ensemble of BERTweet, SVM, and Logistic Regression." ) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MODEL_PATH = "Gunroar/ng_malaria_tweet_classifier_model" BASELINES_PATH = MODEL_PATH # same repo LABELS = [ "Symptoms & Burden", "Treatment & Health System", "Prevention & Awareness", "Misinformation", "Irrelevant", ] LABEL_COLORS = { "Symptoms & Burden": "#4C72B0", "Treatment & Health System": "#55A868", "Prevention & Awareness": "#C44E52", "Misinformation": "#DD8452", "Irrelevant": "#8172B2", } PIN_COLORS = { "Symptoms & Burden": "blue", "Treatment & Health System": "green", "Prevention & Awareness": "red", "Misinformation": "orange", "Irrelevant": "purple", } NGA_CENTER = (9.0820, 8.6753) # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- @st.cache_resource def load_models(): tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, normalization=True) bert_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) bert_model.eval() smodel_file = hf_hub_download("Gunroar/ng_malaria_tweet_classifier_model", "linear_svc_model.joblib") lmodel_file = hf_hub_download("Gunroar/ng_malaria_tweet_classifier_model", "logistic_regression_model.joblib") tmodel_file = hf_hub_download("Gunroar/ng_malaria_tweet_classifier_model", "tfidf_vectorizer.joblib") svm = joblib.load(smodel_file) lr = joblib.load(lmodel_file) tfidf = joblib.load(tmodel_file) return tokenizer, bert_model, svm, lr, tfidf tokenizer, bert_model, svm_model, lr_model, tfidf_vec = load_models() # --------------------------------------------------------------------------- # Ensemble prediction # --------------------------------------------------------------------------- def predict_ensemble(text: str) -> dict: """Return ensemble label plus per-model details and BERTweet probabilities.""" clean_text = emoji.demojize(text, delimiters=(" ", " ")) # BERTweet inputs = tokenizer( clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) with torch.no_grad(): logits = bert_model(**inputs).logits bert_probs = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() bert_idx = int(torch.argmax(logits)) bert_label = LABELS[bert_idx] bert_conf = bert_probs[bert_idx] # Baseline models features = tfidf_vec.transform([clean_text]) svm_idx = int(svm_model.predict(features)[0]) lr_idx = int(lr_model.predict(features)[0]) svm_label = LABELS[svm_idx] lr_label = LABELS[lr_idx] # Ensemble rules final_label = bert_label # Rule 1: both baselines → Treatment, BERT says Prevention with low confidence if (svm_label == "Treatment & Health System" and lr_label == "Treatment & Health System" and bert_label == "Prevention & Awareness" and bert_conf < 0.40): final_label = "Treatment & Health System" # Rule 2: both baselines → Symptoms, BERT says Prevention with high confidence elif (svm_label == "Symptoms & Burden" and lr_label == "Symptoms & Burden" and bert_label == "Prevention & Awareness" and bert_conf > 0.70): final_label = "Symptoms & Burden" # Rule 3: BERT unsure about Irrelevant → defer to LR elif bert_label == "Irrelevant" and bert_conf < 0.70: final_label = lr_label return { "final": final_label, "bert": (bert_label, bert_conf), "svm": svm_label, "lr": lr_label, "probs": bert_probs, } # --------------------------------------------------------------------------- # Batch helper # --------------------------------------------------------------------------- def classify_dataframe( df: pd.DataFrame, text_column: str = "text", confidence_threshold: float = 0.5, ) -> pd.DataFrame: results = df[text_column].apply(predict_ensemble) df = df.copy() df["predicted_label"] = results.apply(lambda r: r["final"]) df["confidence"] = results.apply(lambda r: r["bert"][1]) df["bert_label"] = results.apply(lambda r: r["bert"][0]) df["svm_label"] = results.apply(lambda r: r["svm"]) df["lr_label"] = results.apply(lambda r: r["lr"]) low_conf_mask = df["confidence"] < confidence_threshold df.loc[low_conf_mask, "predicted_label"] = "Low Confidence" return df # --------------------------------------------------------------------------- # Twitter scraper — uses os.environ instead of st.secrets # --------------------------------------------------------------------------- @st.cache_resource def load_scraper_client(): """ Initialise RnetTwitterClient using X_AUTH_TOKEN and ct0 env vars, writing them to a temporary cookies.json the same way the original app did with st.secrets. Falls back to a pre-existing cookies.json. """ client = RnetTwitterClient() cookie_path = "cookies.json" auth_token = os.environ.get("X_AUTH_TOKEN") ct0 = os.environ.get("ct0") if auth_token and ct0: try: cookies_data = [ {"name": "auth_token", "value": auth_token}, {"name": "ct0", "value": ct0}, ] with open(cookie_path, "w") as f: json.dump(cookies_data, f) client.load_cookies(cookie_path) return client except Exception as e: st.sidebar.warning(f"Could not load cookies from env vars: {e}") if os.path.exists(cookie_path): try: client.load_cookies(cookie_path) return client except Exception as e: st.sidebar.warning(f"Could not load cookies from cookies.json: {e}") return None scraper_client = load_scraper_client() # --------------------------------------------------------------------------- # Session-state defaults # --------------------------------------------------------------------------- if "raw_scraped_tweets" not in st.session_state: st.session_state["raw_scraped_tweets"] = [] if "classified_live_tweets" not in st.session_state: st.session_state["classified_live_tweets"] = pd.DataFrame() # =========================================================================== # UI TABS # =========================================================================== tab1, tab2, tab3 = st.tabs(["Single Tweet", "Batch CSV", "Live Scrape"]) # --------------------------------------------------------------------------- # Tab 1 — Single Tweet # --------------------------------------------------------------------------- with tab1: st.subheader("Classify one tweet") user_input = st.text_area( "Enter a tweet:", "I have fever and chills. Is it malaria?", height=100, ) if st.button("Classify", key="single_classify"): if user_input.strip(): res = predict_ensemble(user_input) st.success(f"**Final Decision: {res['final']}**") col1, col2, col3 = st.columns(3) col1.metric("BERTweet", res["bert"][0], f"{res['bert'][1]:.1%}") col2.metric("SVM", res["svm"]) col3.metric("Logistic Regression",res["lr"]) # Confidence bar chart (BERTweet probabilities) conf_df = pd.DataFrame({"Category": LABELS, "Confidence": res["probs"]}) fig = px.bar( conf_df, x="Confidence", y="Category", orientation="h", color="Category", color_discrete_map=LABEL_COLORS, title="BERTweet Confidence Scores", ) fig.update_layout(showlegend=False, xaxis_tickformat=".0%") st.plotly_chart(fig, use_container_width=True) else: st.warning("Please enter some text to classify.") # --------------------------------------------------------------------------- # Tab 2 — Batch CSV # --------------------------------------------------------------------------- with tab2: st.subheader("Batch classify from CSV") threshold = st.slider( "Confidence threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05, help="Predictions with confidence below this value will be labelled 'Low Confidence'.", key="batch_threshold", ) uploaded = st.file_uploader( "Upload a CSV file with a 'text' column", type=["csv"] ) if uploaded is not None: try: batch_df = pd.read_csv(uploaded) if "text" not in batch_df.columns: st.error("The CSV must contain a column named 'text'.") else: with st.spinner("Classifying tweets…"): result_df = classify_dataframe( batch_df.copy(), text_column="text", confidence_threshold=threshold, ) st.success(f"Classified {len(result_df)} tweets.") st.dataframe( result_df[ ["text", "predicted_label", "confidence", "bert_label", "svm_label", "lr_label"] ].head(20) ) # Distribution pie chart dist = result_df["predicted_label"].value_counts().reset_index() dist.columns = ["Category", "Count"] fig_dist = px.pie( dist, names="Category", values="Count", color="Category", color_discrete_map=LABEL_COLORS, title="Category Distribution", ) st.plotly_chart(fig_dist, use_container_width=True) # Per-model agreement breakdown st.subheader("Model Agreement Breakdown") agreement_data = [] for label in LABELS: agreement_data.append({ "Category": label, "BERTweet": (result_df["bert_label"] == label).sum(), "SVM": (result_df["svm_label"] == label).sum(), "LR": (result_df["lr_label"] == label).sum(), "Ensemble": (result_df["predicted_label"] == label).sum(), }) agree_df = pd.DataFrame(agreement_data) fig_agree = px.bar( agree_df.melt(id_vars="Category", var_name="Model", value_name="Count"), x="Category", y="Count", color="Model", barmode="group", title="Predictions per Model per Category", ) st.plotly_chart(fig_agree, use_container_width=True) # Download csv_out = result_df.to_csv(index=False).encode("utf-8") st.download_button( label="⬇ Download Results CSV", data=csv_out, file_name="classified_tweets.csv", mime="text/csv", ) except Exception as e: st.error(f"Error processing file: {e}") # --------------------------------------------------------------------------- # Tab 3 — Live Scrape # --------------------------------------------------------------------------- with tab3: st.subheader("Live Tweet Scraper") st.markdown( "Search for recent tweets, classify them with the ensemble, " "and view their simulated geographic distribution within Nigeria." ) query = st.text_input( "Search Query", value='(malaria OR "mosquito net") geocode:9.0820,8.6753,650km', #"malaria Nigeria OR mosquito net Nigeria", help="Use Twitter search operators. Retweets are automatically excluded.", ) count = st.slider("Number of tweets to fetch", min_value=10, max_value=200, value=50) today = pd.to_datetime("today").date() default_start_date = today - pd.Timedelta(days=7) scrape_start_date = st.date_input("Start Date (optional)", value=default_start_date) scrape_end_date = st.date_input("End Date (optional)", value=today) scrape_threshold = st.slider( "Confidence threshold for live scrape", min_value=0.0, max_value=1.0, value=0.5, step=0.05, key="live_threshold", ) if st.button("Scrape and Classify", key="live_scrape"): if scraper_client is None: st.error( "Scraper client not initialised. " "Set X_AUTH_TOKEN and ct0 env vars or provide cookies.json." ) else: with st.spinner("Scraping tweets… This may take a moment."): async def perform_scrape(query_str, tweet_count, start_date, end_date): full_query = f"{query_str} -is:retweet" if start_date: full_query += f" since:{start_date.isoformat()}" if end_date: full_query += f" until:{end_date.isoformat()}" try: return await scraper_client.search_tweets( full_query, count=tweet_count, product="Latest" ) except Exception as e: st.error(f"Scraping error: {e}") return [] raw = asyncio.run( perform_scrape(query, count, scrape_start_date, scrape_end_date) ) st.session_state["raw_scraped_tweets"] = raw if not raw: st.warning("No tweets fetched. Try a different query or check authentication.") st.session_state["classified_live_tweets"] = pd.DataFrame() else: st.success(f"Fetched {len(raw)} tweets.") live_df = pd.DataFrame(raw) if "text" not in live_df.columns: st.warning("Scraped data has no 'text' column. Showing raw data:") st.dataframe(live_df.head()) st.session_state["classified_live_tweets"] = pd.DataFrame() else: with st.spinner("Classifying tweets…"): classified = classify_dataframe( live_df[["text"]].copy(), text_column="text", confidence_threshold=scrape_threshold, ) st.session_state["classified_live_tweets"] = classified # ---- Display stored results ---- classified_live = st.session_state["classified_live_tweets"] if not classified_live.empty: st.subheader("Classified Tweets (sample)") st.dataframe( classified_live[ ["text", "predicted_label", "confidence", "bert_label", "svm_label", "lr_label"] ].head(10) ) # Distribution pie chart dist_live = classified_live["predicted_label"].value_counts().reset_index() dist_live.columns = ["Category", "Count"] fig_pie = px.pie( dist_live, names="Category", values="Count", color="Category", color_discrete_map=LABEL_COLORS, title="Live Scrape — Category Distribution", ) st.plotly_chart(fig_pie, use_container_width=True) # Per-model agreement for live data st.subheader("Model Agreement (Live)") agree_live = [] for label in LABELS: agree_live.append({ "Category": label, "BERTweet": (classified_live["bert_label"] == label).sum(), "SVM": (classified_live["svm_label"] == label).sum(), "LR": (classified_live["lr_label"] == label).sum(), "Ensemble": (classified_live["predicted_label"] == label).sum(), }) fig_agree_live = px.bar( pd.DataFrame(agree_live).melt( id_vars="Category", var_name="Model", value_name="Count" ), x="Category", y="Count", color="Model", barmode="group", title="Live Predictions per Model per Category", ) st.plotly_chart(fig_agree_live, use_container_width=True) # Download csv_live = classified_live.to_csv(index=False).encode("utf-8") st.download_button( label="⬇ Download Scraped & Classified Results", data=csv_live, file_name="scraped_classified_tweets.csv", mime="text/csv", ) # Simulated map# ---- Smart Geographic Distribution (City Text Mapping) ---- st.subheader("Geographic Distribution Analysis") st.info( "Mapping based on query geocode boundaries. Tweets mentioning known cities " "are placed contextually; others are mapped to the search area center." ) # 1. Coordinate maps for major Nigerian hubs # 650km radius from center easily encapsulates all of these # Comprehensive coordinate map for Nigerian cities and states # Cities are listed first to ensure finer grain matching; states act as fallbacks CITY_COORDINATES = { # --- South-East (Cities & States) --- "owerri": (5.4856, 7.0358, "Imo State"), "enugu": (6.4403, 7.4942, "Enugu State"), "onitsha": (6.1421, 6.7909, "Anambra State"), "awka": (6.2107, 7.0736, "Anambra State"), "aba": (5.1066, 7.3697, "Abia State"), "umuahia": (5.5267, 7.4896, "Abia State"), "abakaliki": (6.3249, 8.1132, "Ebonyi State"), # Standalone States Fallbacks "imo": (5.5486, 7.1400, "Imo State Centroid"), "anambra": (6.2105, 6.9458, "Anambra State Centroid"), "abia": (5.4167, 7.5000, "Abia State Centroid"), "ebonyi": (6.2500, 8.0833, "Ebonyi State Centroid"), # --- South-South (Cities & States) --- "port harcourt": (4.7655, 7.0163, "Rivers State"), "phc": (4.7655, 7.0163, "Rivers State"), "benin": (6.3176, 5.6145, "Edo State"), "calabar": (4.9757, 8.3417, "Cross River State"), "uyo": (5.0333, 7.9266, "Akwa Ibom State"), "warri": (5.5167, 5.7500, "Delta State"), "asaba": (6.1983, 6.7320, "Delta State"), "yenagoa": (4.9267, 6.2642, "Bayelsa State"), # Standalone States Fallbacks "rivers state": (4.8500, 6.8500, "Rivers State Centroid"), "edo": (6.5244, 5.8987, "Edo State Centroid"), "cross river": (5.7500, 8.5000, "Cross River State Centroid"), "akwa ibom": (5.0000, 7.8333, "Akwa Ibom State Centroid"), "delta state": (5.7000, 6.0000, "Delta State Centroid"), "bayelsa": (4.7500, 6.0833, "Bayelsa State Centroid"), # --- South-West (Cities & States) --- "lagos": (6.4561, 3.3936, "Lagos State"), "ikeja": (6.5920, 3.3422, "Lagos State"), "ibadan": (7.3775, 3.9058, "Oyo State"), "abeokuta": (7.1611, 3.3483, "Ogun State"), "akure": (7.2526, 5.1931, "Ondo State"), "osogbo": (7.7710, 4.5624, "Osun State"), "ado ekiti": (7.6233, 5.2208, "Ekiti State"), # Standalone States Fallbacks "oyo": (8.0000, 4.0000, "Oyo State Centroid"), "ogun": (7.0000, 3.5000, "Ogun State Centroid"), "ondo": (7.1667, 5.0833, "Ondo State Centroid"), "osun": (7.5000, 4.5000, "Osun State Centroid"), "ekiti": (7.6667, 5.2500, "Ekiti State Centroid"), # --- North-Central (Cities & States) --- "abuja": (9.0556, 7.4914, "FCT"), "fct": (9.0556, 7.4914, "FCT"), "ilorin": (8.5000, 4.5500, "Kwara State"), "jos": (9.8965, 8.8583, "Plateau State"), "minna": (9.6139, 6.5569, "Niger State"), "lokoja": (7.8024, 6.7418, "Kogi State"), "makurdi": (7.7325, 8.5214, "Benue State"), "lafia": (8.4917, 8.5167, "Nasarawa State"), # Standalone States Fallbacks "kwara": (8.5000, 4.7500, "Kwara State Centroid"), "plateau": (9.2500, 9.5000, "Plateau State Centroid"), "niger state": (10.0000, 6.0000, "Niger State Centroid"), "kogi": (7.7500, 6.7500, "Kogi State Centroid"), "benue": (7.3333, 8.7500, "Benue State Centroid"), "nasarawa": (8.3167, 8.1667, "Nasarawa State Centroid"), # --- North-West (Cities & States) --- "kano": (12.0022, 8.5920, "Kano State"), "kaduna": (10.5105, 7.4165, "Kaduna State"), "sokoto": (13.0622, 5.2339, "Sokoto State"), "zaria": (11.0855, 7.7196, "Kaduna State"), "katsina": (12.9894, 7.6006, "Katsina State"), "gusau": (12.1628, 6.6614, "Zamfara State"), "birnin kebbi": (12.4539, 4.1975, "Kebbi State"), # Standalone States Fallbacks "zamfara": (12.1667, 6.2500, "Zamfara State Centroid"), "kebbi": (11.5000, 4.0000, "Kebbi State Centroid"), "jigawa": (12.2500, 9.7500, "Jigawa State Centroid"), # --- North-East (Cities & States) --- "maiduguri": (11.8333, 13.1500, "Borno State"), "bauchi": (10.3158, 9.8442, "Bauchi State"), "gombe": (10.2897, 11.1673, "Gombe State"), "yola": (9.2035, 12.4954, "Adamawa State"), "damaturu": (11.7470, 11.9608, "Yobe State"), "jalingo": (8.8922, 11.3636, "Taraba State"), # Standalone States Fallbacks "borno": (11.5000, 13.0000, "Borno State Centroid"), "adamawa": (9.3333, 12.5000, "Adamawa State Centroid"), "yobe": (12.0000, 11.5000, "Yobe State Centroid"), "taraba": (8.0000, 10.5000, "Taraba State Centroid") } # Parse out your search center point import re geocode_match = re.search(r"geocode:([\d\.-]+),([\d\.-]+)", query) center_lat, center_lon = (float(geocode_match.group(1)), float(geocode_match.group(2))) if geocode_match else NGA_CENTER m = folium.Map(location=[center_lat, center_lon], zoom_start=6, tiles="CartoDB positron") # Search radius overlay (650km) folium.Circle( location=[center_lat, center_lon], radius=650000, color="#2980b9", fill=True, fill_opacity=0.04, popup="Active Search Radius (650km)" ).add_to(m) rng = np.random.default_rng(seed=42) # 2. Iterate through rows and extract location signals from the text itself for _, row in classified_live.iterrows(): tweet_text = str(row["text"]).lower() author_name = row.get("author", "X User") # Default fallback assignment base_lat = center_lat + rng.uniform(-1.2, 1.2) base_lon = center_lon + rng.uniform(-1.5, 1.5) detected_loc = "Detected inside Geocode Range" # Check text against the expanded dictionary for place_key, info in CITY_COORDINATES.items(): # Use \b bound checking to ensure "imo" doesn't match "immobilize" if re.search(r'\b' + re.escape(place_key) + r'\b', tweet_text): lat, lon, location_tag = info # Apply small visual jitter base_lat = lat + rng.uniform(-0.06, 0.06) base_lon = lon + rng.uniform(-0.06, 0.06) # Formatting text output if "Centroid" in location_tag: detected_loc = f"{location_tag.replace(' Centroid', '')}" else: detected_loc = f"{place_key.title()} ({location_tag})" break # 3. Add pins to map folium.CircleMarker( location=[base_lat, base_lon], radius=6, color=PIN_COLORS.get(row["predicted_label"], "gray"), fill=True, fill_opacity=0.85, popup=folium.Popup( f"Category: {row['predicted_label']}
" f"Source Hub: {detected_loc}
" f"Author: @{author_name}

" f"\"{str(row['text'])[:140]}...\"

" f"Confidence: {row['confidence']:.2%}", max_width=280, ), ).add_to(m) st_folium(m, width=700, height=500) if st.button("Clear Scraped Data", key="clear_live_data"): st.session_state["classified_live_tweets"] = pd.DataFrame() st.session_state["raw_scraped_tweets"] = [] st.success("Scraped data cleared.") # =========================================================================== # Sidebar # =========================================================================== st.sidebar.markdown("## About") st.sidebar.info( "**Model:** BERTweet fine-tuned on Nigerian malaria-related tweets.\n\n" "**Categories:** Symptoms & Burden, Treatment & Health System, " "Prevention & Awareness, Misinformation, Irrelevant.\n\n" "**Ensemble:** BERTweet + SVM + Logistic Regression.\n\n" "**Purpose:** Final Year Project — Infodemiology.\n\n" "**Author:** Chime Ugochukwu Chiziri Kevin" ) st.sidebar.markdown("## Auth Status") auth_from_env = bool(os.environ.get("X_AUTH_TOKEN") and os.environ.get("ct0")) if auth_from_env: st.sidebar.success("✅ Twitter credentials loaded from environment variables (X_AUTH_TOKEN + ct0).") elif os.path.exists("cookies.json"): st.sidebar.success("✅ cookies.json found in application directory.") else: st.sidebar.error( "❌ No Twitter credentials found. " "Set X_AUTH_TOKEN and ct0 env vars or provide cookies.json." ) st.sidebar.markdown("## Notes") st.sidebar.markdown( "- Batch CSV must contain a `text` column.\n" "- Predictions below the confidence threshold are labelled *Low Confidence*.\n" "- Live scraping requires `X_AUTH_TOKEN` and `ct0` env vars or a `cookies.json` file.\n" "- Map locations are simulated; Twitter APIs do not provide geo-coordinates.\n" "- All three models' decisions are shown in every tab for transparency." )