import json import altair as alt import numpy as np import pandas as pd import pydeck as pdk import streamlit as st from collections import Counter from pathlib import Path # ────────────────────────────────────────────────────────────────────────────── # Page configuration # ────────────────────────────────────────────────────────────────────────────── st.set_page_config( page_title="PLACES Dataset Explorer", page_icon="🌍", layout="wide", ) # ────────────────────────────────────────────────────────────────────────────── # Geocoded coordinates for each participant locale # ────────────────────────────────────────────────────────────────────────────── LOCALE_GEO = { "karnataka": { "lat": 12.9716, "lon": 77.5946, "label": "Karnataka, India", "city": "Bangalore", "color": [255, 87, 51], # coral-red }, "punjab": { "lat": 31.1471, "lon": 75.3412, "label": "Punjab, India", "city": "Amritsar", "color": [255, 195, 0], # amber }, "nigeria": { "lat": 12.0022, "lon": 8.5920, "label": "Kano, Nigeria", "city": "Kano", "color": [0, 168, 107], # green }, "ghana": { "lat": 5.1315, "lon": -1.2795, "label": "Cape Coast, Ghana", "city": "Cape Coast", "color": [30, 144, 255], # dodger blue }, "zindi": { "lat": -33.9249, "lon": 18.4241, "label": "Zindi (Online Platform)", "city": "Cape Town", "color": [155, 89, 182], # purple }, } # ────────────────────────────────────────────────────────────────────────────── # Load data (cached so it's only read once) # ────────────────────────────────────────────────────────────────────────────── # Define the filepath relative to this script file, not cwd. # __file__ = src/streamlit_app.py → .parent = src/ → .parent.parent = repo root filepath = Path(__file__).parent.parent / "PLACES_full_dataset.json" @st.cache_data def load_data(filepath=filepath): # Optional: Add a debug check that prints to your HF Space logs if it fails if not filepath.exists(): print(f"DEBUG: Current directory is {Path.cwd()}") print(f"DEBUG: Files here: {list(Path.cwd().iterdir())}") with open(filepath, "r") as f: raw = json.load(f) return pd.DataFrame(raw) df = load_data() # ────────────────────────────────────────────────────────────────────────────── # Helper: explode a list-valued column and count occurrences # ────────────────────────────────────────────────────────────────────────────── def count_list_column(series): counter = Counter() for val_list in series.dropna(): if isinstance(val_list, list): for v in val_list: counter[v] += 1 else: counter[val_list] += 1 return pd.DataFrame( {"Category": list(counter.keys()), "Count": list(counter.values())} ).sort_values("Count", ascending=False) # ══════════════════════════════════════════════════════════════════════════════ # SIDEBAR — Place selector & prompt viewer # ══════════════════════════════════════════════════════════════════════════════ with st.sidebar: st.markdown("## 📍 Select a Place") st.caption("Click a place below to see sample prompts from that region.") locale_options = { geo["label"]: key for key, geo in LOCALE_GEO.items() } selected_label = st.radio( "Choose a location", options=list(locale_options.keys()), index=0, label_visibility="collapsed", ) selected_locale = locale_options[selected_label] place_df = df[df["participant_locale"] == selected_locale] st.markdown("---") st.markdown( f"### {selected_label}\n" f"**{len(place_df):,}** records · " f"**{place_df['prompt'].nunique():,}** unique prompts" ) # Harm-type mini-breakdown for this place place_harm = count_list_column(place_df["harm_types"]) st.markdown("#### Harm types at this place") for _, row in place_harm.iterrows(): st.markdown(f"- **{row['Category']}**: {row['Count']:,}") st.markdown("---") n_sidebar = st.slider("Sample size", 1, 20, 5, key="sidebar_n") if st.button("🔄 Resample prompts", key="sidebar_resample"): pass # forces re-run → new random draw st.markdown("#### 🎲 Sample Prompts") sample_place = place_df.sample(n=min(n_sidebar, len(place_df))) for i, (_, row) in enumerate(sample_place.iterrows()): with st.container(): st.markdown( f"**Prompt:** {row['prompt']}\n\n" f"**Model revised to:** {row['revised_prompt']}\n\n" f"🛑 {', '.join(row['harm_types']) if isinstance(row['harm_types'], list) else row['harm_types']} \n" f"⚔️ {', '.join(row['attack_modes']) if isinstance(row['attack_modes'], list) else row['attack_modes']} \n" f"🎯 {', '.join(row['targeted_identity_attribute']) if isinstance(row['targeted_identity_attribute'], list) else row['targeted_identity_attribute']}" ) if i < n_sidebar - 1: st.markdown("---") # ══════════════════════════════════════════════════════════════════════════════ # TITLE # ══════════════════════════════════════════════════════════════════════════════ st.markdown( """

🌍 PLACES Dataset Explorer

Explore the PLACES adversarial-nibbler dataset — covering prompts, harm annotations, attack modes, and targeted-identity attributes from participants across Sub-Saharan Africa & India.

""", unsafe_allow_html=True, ) st.divider() # ══════════════════════════════════════════════════════════════════════════════ # TOP-LINE METRICS (big bold numbers) # ══════════════════════════════════════════════════════════════════════════════ total_records = len(df) unique_prompts = df["prompt"].nunique() unique_places = df["participant_locale"].nunique() unique_participants = df["participant_id"].nunique() col1, col2, col3, col4 = st.columns(4) col1.metric("📝 Total Records", f"{total_records:,}") col2.metric("💬 Unique Prompts", f"{unique_prompts:,}") col3.metric("📍 Unique Places", f"{unique_places:,}") col4.metric("👥 Unique Participants", f"{unique_participants:,}") st.divider() # ══════════════════════════════════════════════════════════════════════════════ # INTERACTIVE MAP # ══════════════════════════════════════════════════════════════════════════════ st.markdown("## 🗺️ Where the Data Comes From") st.markdown( "Each pin represents a participant locale. **Select a place in the sidebar** " "to see sample prompts from that region." ) # Build map dataframe locale_counts = df["participant_locale"].value_counts().to_dict() # Pre-sample 2-3 prompts per locale for hover tooltip locale_sample_cache = {} for key in LOCALE_GEO: locale_rows = df[df["participant_locale"] == key]["prompt"].dropna() if len(locale_rows) > 0: samples = locale_rows.sample(n=min(3, len(locale_rows))).tolist() locale_sample_cache[key] = "
".join( f"• {p[:100]}{'\u2026' if len(p) > 100 else ''}" for p in samples ) else: locale_sample_cache[key] = "(no prompts)" map_data = [] for key, geo in LOCALE_GEO.items(): count = locale_counts.get(key, 0) map_data.append( { "locale": key, "label": geo["label"], "city": geo["city"], "lat": geo["lat"], "lon": geo["lon"], "records": count, "color": geo["color"], # scale radius by record count (min 40000, max 120000 for visibility) "radius": max(40000, min(120000, count * 8)), "sample_prompts": locale_sample_cache.get(key, ""), } ) map_df = pd.DataFrame(map_data) # Highlight the selected place map_df["is_selected"] = map_df["locale"] == selected_locale map_df["elevation"] = map_df["is_selected"].apply(lambda s: 150000 if s else 30000) map_df["opacity"] = map_df["is_selected"].apply(lambda s: 220 if s else 140) # pydeck layers scatter_layer = pdk.Layer( "ScatterplotLayer", data=map_df, get_position=["lon", "lat"], get_radius="radius", get_fill_color="color", get_line_color=[255, 255, 255], line_width_min_pixels=2, pickable=True, opacity=0.8, stroked=True, ) column_layer = pdk.Layer( "ColumnLayer", data=map_df, get_position=["lon", "lat"], get_elevation="elevation", elevation_scale=1, radius=35000, get_fill_color="color", pickable=True, auto_highlight=True, opacity=0.6, ) text_layer = pdk.Layer( "TextLayer", data=map_df, get_position=["lon", "lat"], get_text="label", get_size=14, get_color=[30, 30, 30], # dark text for lighter basemap get_angle=0, get_text_anchor='"middle"', get_alignment_baseline='"bottom"', font_family='"Arial"', font_weight=700, billboard=True, ) # View state centred roughly between Africa and India view_state = pdk.ViewState( latitude=10, longitude=30, zoom=2.0, pitch=15, bearing=0, ) deck = pdk.Deck( layers=[scatter_layer, column_layer, text_layer], initial_view_state=view_state, tooltip={ "html": ( "{label}
" "📍 {city}  ·  📝 {records} records" "
" "Sample prompts:
" "{sample_prompts}" ), "style": { "backgroundColor": "white", "color": "#222", "fontSize": "13px", "padding": "10px 14px", "borderRadius": "8px", "maxWidth": "320px", "boxShadow": "0 2px 8px rgba(0,0,0,0.15)", }, }, # Token-free Carto Voyager — light, colourful, no API key needed map_style="https://basemaps.cartocdn.com/gl/voyager-gl-style/style.json", ) st.pydeck_chart(deck, use_container_width=True, height=380) # Small legend below the map legend_cols = st.columns(len(LOCALE_GEO)) for col, (key, geo) in zip(legend_cols, LOCALE_GEO.items()): r, g, b = geo["color"] count = locale_counts.get(key, 0) col.markdown( f" " f"**{geo['label']}** — {count:,}", unsafe_allow_html=True, ) st.divider() # ══════════════════════════════════════════════════════════════════════════════ # INTERACTIVE DATA EXPLORER # ══════════════════════════════════════════════════════════════════════════════ st.markdown("## 🔎 Explore the Data") with st.expander("Filters", expanded=True): fcol1, fcol2 = st.columns(2) with fcol1: selected_locales = st.multiselect( "Filter by Place", options=sorted(df["participant_locale"].unique()), default=sorted(df["participant_locale"].unique()), ) with fcol2: # Build a flat list of all harm types present in the dataset all_harms = sorted( {h for hlist in df["harm_types"].dropna() for h in hlist} ) selected_harms = st.multiselect( "Filter by Harm Type", options=all_harms, default=all_harms, ) search_query = st.text_input( "🔍 Search prompts (case-insensitive substring match)", "" ) # Apply filters mask = df["participant_locale"].isin(selected_locales) # Harm-type filter (row passes if it contains ANY selected harm) mask &= df["harm_types"].apply( lambda hlist: bool(set(hlist) & set(selected_harms)) if isinstance(hlist, list) else False ) if search_query: mask &= df["prompt"].str.contains(search_query, case=False, na=False) filtered = df[mask] st.markdown(f"**Showing {len(filtered):,} of {len(df):,} records**") st.dataframe( filtered[ [ "participant_locale", "prompt", "revised_prompt", "harm_types", "attack_modes", "targeted_identity_attribute", ] ].reset_index(drop=True), use_container_width=True, height=500, ) st.divider() # ══════════════════════════════════════════════════════════════════════════════ # SAMPLE PROMPTS # ══════════════════════════════════════════════════════════════════════════════ # st.markdown("## 🎲 Random Sample of Prompts") # n_samples = st.slider("Number of samples", 1, 50, 10) # if st.button("🔄 Resample"): # pass # forces re-run # sample = filtered.sample(n=min(n_samples, len(filtered))) # for _, row in sample.iterrows(): # with st.container(): # st.markdown( # f"**Locale:** `{row['participant_locale']}` \n" # f"**Prompt:** {row['prompt']} \n" # f"**Revised Prompt:** {row['revised_prompt']} \n" # f"**Harm Types:** {', '.join(row['harm_types']) if isinstance(row['harm_types'], list) else row['harm_types']} \n" # f"**Attack Modes:** {', '.join(row['attack_modes']) if isinstance(row['attack_modes'], list) else row['attack_modes']} \n" # f"**Targeted Identity:** {', '.join(row['targeted_identity_attribute']) if isinstance(row['targeted_identity_attribute'], list) else row['targeted_identity_attribute']}" # ) # st.divider()