Spaces:
Running
Running
| import json | |
| import altair as alt | |
| import numpy as np | |
| import pandas as pd | |
| import pydeck as pdk | |
| import streamlit as st | |
| from collections import Counter | |
| from pathlib import Path | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Page configuration | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="PLACES Dataset Explorer", | |
| page_icon="π", | |
| layout="wide", | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Geocoded coordinates for each participant locale | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LOCALE_GEO = { | |
| "karnataka": { | |
| "lat": 12.9716, | |
| "lon": 77.5946, | |
| "label": "Karnataka, India", | |
| "city": "Bangalore", | |
| "color": [255, 87, 51], # coral-red | |
| }, | |
| "punjab": { | |
| "lat": 31.1471, | |
| "lon": 75.3412, | |
| "label": "Punjab, India", | |
| "city": "Amritsar", | |
| "color": [255, 195, 0], # amber | |
| }, | |
| "nigeria": { | |
| "lat": 12.0022, | |
| "lon": 8.5920, | |
| "label": "Kano, Nigeria", | |
| "city": "Kano", | |
| "color": [0, 168, 107], # green | |
| }, | |
| "ghana": { | |
| "lat": 5.1315, | |
| "lon": -1.2795, | |
| "label": "Cape Coast, Ghana", | |
| "city": "Cape Coast", | |
| "color": [30, 144, 255], # dodger blue | |
| }, | |
| "zindi": { | |
| "lat": -33.9249, | |
| "lon": 18.4241, | |
| "label": "Zindi (Online Platform)", | |
| "city": "Cape Town", | |
| "color": [155, 89, 182], # purple | |
| }, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load data (cached so it's only read once) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Define the filepath relative to this script file, not cwd. | |
| # __file__ = src/streamlit_app.py β .parent = src/ β .parent.parent = repo root | |
| filepath = Path(__file__).parent.parent / "PLACES_full_dataset.json" | |
| def load_data(filepath=filepath): | |
| # Optional: Add a debug check that prints to your HF Space logs if it fails | |
| if not filepath.exists(): | |
| print(f"DEBUG: Current directory is {Path.cwd()}") | |
| print(f"DEBUG: Files here: {list(Path.cwd().iterdir())}") | |
| with open(filepath, "r") as f: | |
| raw = json.load(f) | |
| return pd.DataFrame(raw) | |
| df = load_data() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helper: explode a list-valued column and count occurrences | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def count_list_column(series): | |
| counter = Counter() | |
| for val_list in series.dropna(): | |
| if isinstance(val_list, list): | |
| for v in val_list: | |
| counter[v] += 1 | |
| else: | |
| counter[val_list] += 1 | |
| return pd.DataFrame( | |
| {"Category": list(counter.keys()), "Count": list(counter.values())} | |
| ).sort_values("Count", ascending=False) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SIDEBAR β Place selector & prompt viewer | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.markdown("## π Select a Place") | |
| st.caption("Click a place below to see sample prompts from that region.") | |
| locale_options = { | |
| geo["label"]: key for key, geo in LOCALE_GEO.items() | |
| } | |
| selected_label = st.radio( | |
| "Choose a location", | |
| options=list(locale_options.keys()), | |
| index=0, | |
| label_visibility="collapsed", | |
| ) | |
| selected_locale = locale_options[selected_label] | |
| place_df = df[df["participant_locale"] == selected_locale] | |
| st.markdown("---") | |
| st.markdown( | |
| f"### {selected_label}\n" | |
| f"**{len(place_df):,}** records Β· " | |
| f"**{place_df['prompt'].nunique():,}** unique prompts" | |
| ) | |
| # Harm-type mini-breakdown for this place | |
| place_harm = count_list_column(place_df["harm_types"]) | |
| st.markdown("#### Harm types at this place") | |
| for _, row in place_harm.iterrows(): | |
| st.markdown(f"- **{row['Category']}**: {row['Count']:,}") | |
| st.markdown("---") | |
| n_sidebar = st.slider("Sample size", 1, 20, 5, key="sidebar_n") | |
| if st.button("π Resample prompts", key="sidebar_resample"): | |
| pass # forces re-run β new random draw | |
| st.markdown("#### π² Sample Prompts") | |
| sample_place = place_df.sample(n=min(n_sidebar, len(place_df))) | |
| for i, (_, row) in enumerate(sample_place.iterrows()): | |
| with st.container(): | |
| st.markdown( | |
| f"**Prompt:** {row['prompt']}\n\n" | |
| f"**Model revised to:** {row['revised_prompt']}\n\n" | |
| f"π {', '.join(row['harm_types']) if isinstance(row['harm_types'], list) else row['harm_types']} \n" | |
| f"βοΈ {', '.join(row['attack_modes']) if isinstance(row['attack_modes'], list) else row['attack_modes']} \n" | |
| f"π― {', '.join(row['targeted_identity_attribute']) if isinstance(row['targeted_identity_attribute'], list) else row['targeted_identity_attribute']}" | |
| ) | |
| if i < n_sidebar - 1: | |
| st.markdown("---") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TITLE | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown( | |
| """ | |
| <h1 style='text-align:center;'>π PLACES Dataset Explorer</h1> | |
| <p style='text-align:center; font-size:1.1rem; color:gray;'> | |
| Explore the <b>PLACES</b> adversarial-nibbler dataset β covering prompts, | |
| harm annotations, attack modes, and targeted-identity attributes from | |
| participants across Sub-Saharan Africa & India. | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.divider() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TOP-LINE METRICS (big bold numbers) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| total_records = len(df) | |
| unique_prompts = df["prompt"].nunique() | |
| unique_places = df["participant_locale"].nunique() | |
| unique_participants = df["participant_id"].nunique() | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("π Total Records", f"{total_records:,}") | |
| col2.metric("π¬ Unique Prompts", f"{unique_prompts:,}") | |
| col3.metric("π Unique Places", f"{unique_places:,}") | |
| col4.metric("π₯ Unique Participants", f"{unique_participants:,}") | |
| st.divider() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTERACTIVE MAP | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("## πΊοΈ Where the Data Comes From") | |
| st.markdown( | |
| "Each pin represents a participant locale. **Select a place in the sidebar** " | |
| "to see sample prompts from that region." | |
| ) | |
| # Build map dataframe | |
| locale_counts = df["participant_locale"].value_counts().to_dict() | |
| # Pre-sample 2-3 prompts per locale for hover tooltip | |
| locale_sample_cache = {} | |
| for key in LOCALE_GEO: | |
| locale_rows = df[df["participant_locale"] == key]["prompt"].dropna() | |
| if len(locale_rows) > 0: | |
| samples = locale_rows.sample(n=min(3, len(locale_rows))).tolist() | |
| locale_sample_cache[key] = "<br/>".join( | |
| f"• {p[:100]}{'\u2026' if len(p) > 100 else ''}" for p in samples | |
| ) | |
| else: | |
| locale_sample_cache[key] = "(no prompts)" | |
| map_data = [] | |
| for key, geo in LOCALE_GEO.items(): | |
| count = locale_counts.get(key, 0) | |
| map_data.append( | |
| { | |
| "locale": key, | |
| "label": geo["label"], | |
| "city": geo["city"], | |
| "lat": geo["lat"], | |
| "lon": geo["lon"], | |
| "records": count, | |
| "color": geo["color"], | |
| # scale radius by record count (min 40000, max 120000 for visibility) | |
| "radius": max(40000, min(120000, count * 8)), | |
| "sample_prompts": locale_sample_cache.get(key, ""), | |
| } | |
| ) | |
| map_df = pd.DataFrame(map_data) | |
| # Highlight the selected place | |
| map_df["is_selected"] = map_df["locale"] == selected_locale | |
| map_df["elevation"] = map_df["is_selected"].apply(lambda s: 150000 if s else 30000) | |
| map_df["opacity"] = map_df["is_selected"].apply(lambda s: 220 if s else 140) | |
| # pydeck layers | |
| scatter_layer = pdk.Layer( | |
| "ScatterplotLayer", | |
| data=map_df, | |
| get_position=["lon", "lat"], | |
| get_radius="radius", | |
| get_fill_color="color", | |
| get_line_color=[255, 255, 255], | |
| line_width_min_pixels=2, | |
| pickable=True, | |
| opacity=0.8, | |
| stroked=True, | |
| ) | |
| column_layer = pdk.Layer( | |
| "ColumnLayer", | |
| data=map_df, | |
| get_position=["lon", "lat"], | |
| get_elevation="elevation", | |
| elevation_scale=1, | |
| radius=35000, | |
| get_fill_color="color", | |
| pickable=True, | |
| auto_highlight=True, | |
| opacity=0.6, | |
| ) | |
| text_layer = pdk.Layer( | |
| "TextLayer", | |
| data=map_df, | |
| get_position=["lon", "lat"], | |
| get_text="label", | |
| get_size=14, | |
| get_color=[30, 30, 30], # dark text for lighter basemap | |
| get_angle=0, | |
| get_text_anchor='"middle"', | |
| get_alignment_baseline='"bottom"', | |
| font_family='"Arial"', | |
| font_weight=700, | |
| billboard=True, | |
| ) | |
| # View state centred roughly between Africa and India | |
| view_state = pdk.ViewState( | |
| latitude=10, | |
| longitude=30, | |
| zoom=2.0, | |
| pitch=15, | |
| bearing=0, | |
| ) | |
| deck = pdk.Deck( | |
| layers=[scatter_layer, column_layer, text_layer], | |
| initial_view_state=view_state, | |
| tooltip={ | |
| "html": ( | |
| "<b style='font-size:15px'>{label}</b><br/>" | |
| "<span style='color:#888'>π {city} Β· π {records} records</span>" | |
| "<hr style='margin:6px 0; border-color:#ddd'/>" | |
| "<i style='font-size:12px'>Sample prompts:</i><br/>" | |
| "<span style='font-size:12px; line-height:1.6'>{sample_prompts}</span>" | |
| ), | |
| "style": { | |
| "backgroundColor": "white", | |
| "color": "#222", | |
| "fontSize": "13px", | |
| "padding": "10px 14px", | |
| "borderRadius": "8px", | |
| "maxWidth": "320px", | |
| "boxShadow": "0 2px 8px rgba(0,0,0,0.15)", | |
| }, | |
| }, | |
| # Token-free Carto Voyager β light, colourful, no API key needed | |
| map_style="https://basemaps.cartocdn.com/gl/voyager-gl-style/style.json", | |
| ) | |
| st.pydeck_chart(deck, use_container_width=True, height=380) | |
| # Small legend below the map | |
| legend_cols = st.columns(len(LOCALE_GEO)) | |
| for col, (key, geo) in zip(legend_cols, LOCALE_GEO.items()): | |
| r, g, b = geo["color"] | |
| count = locale_counts.get(key, 0) | |
| col.markdown( | |
| f"<span style='color:rgb({r},{g},{b}); font-size:1.5rem;'>β</span> " | |
| f"**{geo['label']}** β {count:,}", | |
| unsafe_allow_html=True, | |
| ) | |
| st.divider() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTERACTIVE DATA EXPLORER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("## π Explore the Data") | |
| with st.expander("Filters", expanded=True): | |
| fcol1, fcol2 = st.columns(2) | |
| with fcol1: | |
| selected_locales = st.multiselect( | |
| "Filter by Place", | |
| options=sorted(df["participant_locale"].unique()), | |
| default=sorted(df["participant_locale"].unique()), | |
| ) | |
| with fcol2: | |
| # Build a flat list of all harm types present in the dataset | |
| all_harms = sorted( | |
| {h for hlist in df["harm_types"].dropna() for h in hlist} | |
| ) | |
| selected_harms = st.multiselect( | |
| "Filter by Harm Type", | |
| options=all_harms, | |
| default=all_harms, | |
| ) | |
| search_query = st.text_input( | |
| "π Search prompts (case-insensitive substring match)", "" | |
| ) | |
| # Apply filters | |
| mask = df["participant_locale"].isin(selected_locales) | |
| # Harm-type filter (row passes if it contains ANY selected harm) | |
| mask &= df["harm_types"].apply( | |
| lambda hlist: bool(set(hlist) & set(selected_harms)) | |
| if isinstance(hlist, list) | |
| else False | |
| ) | |
| if search_query: | |
| mask &= df["prompt"].str.contains(search_query, case=False, na=False) | |
| filtered = df[mask] | |
| st.markdown(f"**Showing {len(filtered):,} of {len(df):,} records**") | |
| st.dataframe( | |
| filtered[ | |
| [ | |
| "participant_locale", | |
| "prompt", | |
| "revised_prompt", | |
| "harm_types", | |
| "attack_modes", | |
| "targeted_identity_attribute", | |
| ] | |
| ].reset_index(drop=True), | |
| use_container_width=True, | |
| height=500, | |
| ) | |
| st.divider() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SAMPLE PROMPTS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # st.markdown("## π² Random Sample of Prompts") | |
| # n_samples = st.slider("Number of samples", 1, 50, 10) | |
| # if st.button("π Resample"): | |
| # pass # forces re-run | |
| # sample = filtered.sample(n=min(n_samples, len(filtered))) | |
| # for _, row in sample.iterrows(): | |
| # with st.container(): | |
| # st.markdown( | |
| # f"**Locale:** `{row['participant_locale']}` \n" | |
| # f"**Prompt:** {row['prompt']} \n" | |
| # f"**Revised Prompt:** {row['revised_prompt']} \n" | |
| # f"**Harm Types:** {', '.join(row['harm_types']) if isinstance(row['harm_types'], list) else row['harm_types']} \n" | |
| # f"**Attack Modes:** {', '.join(row['attack_modes']) if isinstance(row['attack_modes'], list) else row['attack_modes']} \n" | |
| # f"**Targeted Identity:** {', '.join(row['targeted_identity_attribute']) if isinstance(row['targeted_identity_attribute'], list) else row['targeted_identity_attribute']}" | |
| # ) | |
| # st.divider() | |