PLACES / src /streamlit_app.py
CharviRastogi's picture
Update src/streamlit_app.py
7675bd2 verified
import json
import altair as alt
import numpy as np
import pandas as pd
import pydeck as pdk
import streamlit as st
from collections import Counter
from pathlib import Path
# ──────────────────────────────────────────────────────────────────────────────
# Page configuration
# ──────────────────────────────────────────────────────────────────────────────
st.set_page_config(
page_title="PLACES Dataset Explorer",
page_icon="🌍",
layout="wide",
)
# ──────────────────────────────────────────────────────────────────────────────
# Geocoded coordinates for each participant locale
# ──────────────────────────────────────────────────────────────────────────────
LOCALE_GEO = {
"karnataka": {
"lat": 12.9716,
"lon": 77.5946,
"label": "Karnataka, India",
"city": "Bangalore",
"color": [255, 87, 51], # coral-red
},
"punjab": {
"lat": 31.1471,
"lon": 75.3412,
"label": "Punjab, India",
"city": "Amritsar",
"color": [255, 195, 0], # amber
},
"nigeria": {
"lat": 12.0022,
"lon": 8.5920,
"label": "Kano, Nigeria",
"city": "Kano",
"color": [0, 168, 107], # green
},
"ghana": {
"lat": 5.1315,
"lon": -1.2795,
"label": "Cape Coast, Ghana",
"city": "Cape Coast",
"color": [30, 144, 255], # dodger blue
},
"zindi": {
"lat": -33.9249,
"lon": 18.4241,
"label": "Zindi (Online Platform)",
"city": "Cape Town",
"color": [155, 89, 182], # purple
},
}
# ──────────────────────────────────────────────────────────────────────────────
# Load data (cached so it's only read once)
# ──────────────────────────────────────────────────────────────────────────────
# Define the filepath relative to this script file, not cwd.
# __file__ = src/streamlit_app.py β†’ .parent = src/ β†’ .parent.parent = repo root
filepath = Path(__file__).parent.parent / "PLACES_full_dataset.json"
@st.cache_data
def load_data(filepath=filepath):
# Optional: Add a debug check that prints to your HF Space logs if it fails
if not filepath.exists():
print(f"DEBUG: Current directory is {Path.cwd()}")
print(f"DEBUG: Files here: {list(Path.cwd().iterdir())}")
with open(filepath, "r") as f:
raw = json.load(f)
return pd.DataFrame(raw)
df = load_data()
# ──────────────────────────────────────────────────────────────────────────────
# Helper: explode a list-valued column and count occurrences
# ──────────────────────────────────────────────────────────────────────────────
def count_list_column(series):
counter = Counter()
for val_list in series.dropna():
if isinstance(val_list, list):
for v in val_list:
counter[v] += 1
else:
counter[val_list] += 1
return pd.DataFrame(
{"Category": list(counter.keys()), "Count": list(counter.values())}
).sort_values("Count", ascending=False)
# ══════════════════════════════════════════════════════════════════════════════
# SIDEBAR β€” Place selector & prompt viewer
# ══════════════════════════════════════════════════════════════════════════════
with st.sidebar:
st.markdown("## πŸ“ Select a Place")
st.caption("Click a place below to see sample prompts from that region.")
locale_options = {
geo["label"]: key for key, geo in LOCALE_GEO.items()
}
selected_label = st.radio(
"Choose a location",
options=list(locale_options.keys()),
index=0,
label_visibility="collapsed",
)
selected_locale = locale_options[selected_label]
place_df = df[df["participant_locale"] == selected_locale]
st.markdown("---")
st.markdown(
f"### {selected_label}\n"
f"**{len(place_df):,}** records Β· "
f"**{place_df['prompt'].nunique():,}** unique prompts"
)
# Harm-type mini-breakdown for this place
place_harm = count_list_column(place_df["harm_types"])
st.markdown("#### Harm types at this place")
for _, row in place_harm.iterrows():
st.markdown(f"- **{row['Category']}**: {row['Count']:,}")
st.markdown("---")
n_sidebar = st.slider("Sample size", 1, 20, 5, key="sidebar_n")
if st.button("πŸ”„ Resample prompts", key="sidebar_resample"):
pass # forces re-run β†’ new random draw
st.markdown("#### 🎲 Sample Prompts")
sample_place = place_df.sample(n=min(n_sidebar, len(place_df)))
for i, (_, row) in enumerate(sample_place.iterrows()):
with st.container():
st.markdown(
f"**Prompt:** {row['prompt']}\n\n"
f"**Model revised to:** {row['revised_prompt']}\n\n"
f"πŸ›‘ {', '.join(row['harm_types']) if isinstance(row['harm_types'], list) else row['harm_types']} \n"
f"βš”οΈ {', '.join(row['attack_modes']) if isinstance(row['attack_modes'], list) else row['attack_modes']} \n"
f"🎯 {', '.join(row['targeted_identity_attribute']) if isinstance(row['targeted_identity_attribute'], list) else row['targeted_identity_attribute']}"
)
if i < n_sidebar - 1:
st.markdown("---")
# ══════════════════════════════════════════════════════════════════════════════
# TITLE
# ══════════════════════════════════════════════════════════════════════════════
st.markdown(
"""
<h1 style='text-align:center;'>🌍 PLACES Dataset Explorer</h1>
<p style='text-align:center; font-size:1.1rem; color:gray;'>
Explore the <b>PLACES</b> adversarial-nibbler dataset β€” covering prompts,
harm annotations, attack modes, and targeted-identity attributes from
participants across Sub-Saharan Africa &amp; India.
</p>
""",
unsafe_allow_html=True,
)
st.divider()
# ══════════════════════════════════════════════════════════════════════════════
# TOP-LINE METRICS (big bold numbers)
# ══════════════════════════════════════════════════════════════════════════════
total_records = len(df)
unique_prompts = df["prompt"].nunique()
unique_places = df["participant_locale"].nunique()
unique_participants = df["participant_id"].nunique()
col1, col2, col3, col4 = st.columns(4)
col1.metric("πŸ“ Total Records", f"{total_records:,}")
col2.metric("πŸ’¬ Unique Prompts", f"{unique_prompts:,}")
col3.metric("πŸ“ Unique Places", f"{unique_places:,}")
col4.metric("πŸ‘₯ Unique Participants", f"{unique_participants:,}")
st.divider()
# ══════════════════════════════════════════════════════════════════════════════
# INTERACTIVE MAP
# ══════════════════════════════════════════════════════════════════════════════
st.markdown("## πŸ—ΊοΈ Where the Data Comes From")
st.markdown(
"Each pin represents a participant locale. **Select a place in the sidebar** "
"to see sample prompts from that region."
)
# Build map dataframe
locale_counts = df["participant_locale"].value_counts().to_dict()
# Pre-sample 2-3 prompts per locale for hover tooltip
locale_sample_cache = {}
for key in LOCALE_GEO:
locale_rows = df[df["participant_locale"] == key]["prompt"].dropna()
if len(locale_rows) > 0:
samples = locale_rows.sample(n=min(3, len(locale_rows))).tolist()
locale_sample_cache[key] = "<br/>".join(
f"&#8226; {p[:100]}{'\u2026' if len(p) > 100 else ''}" for p in samples
)
else:
locale_sample_cache[key] = "(no prompts)"
map_data = []
for key, geo in LOCALE_GEO.items():
count = locale_counts.get(key, 0)
map_data.append(
{
"locale": key,
"label": geo["label"],
"city": geo["city"],
"lat": geo["lat"],
"lon": geo["lon"],
"records": count,
"color": geo["color"],
# scale radius by record count (min 40000, max 120000 for visibility)
"radius": max(40000, min(120000, count * 8)),
"sample_prompts": locale_sample_cache.get(key, ""),
}
)
map_df = pd.DataFrame(map_data)
# Highlight the selected place
map_df["is_selected"] = map_df["locale"] == selected_locale
map_df["elevation"] = map_df["is_selected"].apply(lambda s: 150000 if s else 30000)
map_df["opacity"] = map_df["is_selected"].apply(lambda s: 220 if s else 140)
# pydeck layers
scatter_layer = pdk.Layer(
"ScatterplotLayer",
data=map_df,
get_position=["lon", "lat"],
get_radius="radius",
get_fill_color="color",
get_line_color=[255, 255, 255],
line_width_min_pixels=2,
pickable=True,
opacity=0.8,
stroked=True,
)
column_layer = pdk.Layer(
"ColumnLayer",
data=map_df,
get_position=["lon", "lat"],
get_elevation="elevation",
elevation_scale=1,
radius=35000,
get_fill_color="color",
pickable=True,
auto_highlight=True,
opacity=0.6,
)
text_layer = pdk.Layer(
"TextLayer",
data=map_df,
get_position=["lon", "lat"],
get_text="label",
get_size=14,
get_color=[30, 30, 30], # dark text for lighter basemap
get_angle=0,
get_text_anchor='"middle"',
get_alignment_baseline='"bottom"',
font_family='"Arial"',
font_weight=700,
billboard=True,
)
# View state centred roughly between Africa and India
view_state = pdk.ViewState(
latitude=10,
longitude=30,
zoom=2.0,
pitch=15,
bearing=0,
)
deck = pdk.Deck(
layers=[scatter_layer, column_layer, text_layer],
initial_view_state=view_state,
tooltip={
"html": (
"<b style='font-size:15px'>{label}</b><br/>"
"<span style='color:#888'>πŸ“ {city} &nbsp;Β·&nbsp; πŸ“ {records} records</span>"
"<hr style='margin:6px 0; border-color:#ddd'/>"
"<i style='font-size:12px'>Sample prompts:</i><br/>"
"<span style='font-size:12px; line-height:1.6'>{sample_prompts}</span>"
),
"style": {
"backgroundColor": "white",
"color": "#222",
"fontSize": "13px",
"padding": "10px 14px",
"borderRadius": "8px",
"maxWidth": "320px",
"boxShadow": "0 2px 8px rgba(0,0,0,0.15)",
},
},
# Token-free Carto Voyager β€” light, colourful, no API key needed
map_style="https://basemaps.cartocdn.com/gl/voyager-gl-style/style.json",
)
st.pydeck_chart(deck, use_container_width=True, height=380)
# Small legend below the map
legend_cols = st.columns(len(LOCALE_GEO))
for col, (key, geo) in zip(legend_cols, LOCALE_GEO.items()):
r, g, b = geo["color"]
count = locale_counts.get(key, 0)
col.markdown(
f"<span style='color:rgb({r},{g},{b}); font-size:1.5rem;'>●</span> "
f"**{geo['label']}** β€” {count:,}",
unsafe_allow_html=True,
)
st.divider()
# ══════════════════════════════════════════════════════════════════════════════
# INTERACTIVE DATA EXPLORER
# ══════════════════════════════════════════════════════════════════════════════
st.markdown("## πŸ”Ž Explore the Data")
with st.expander("Filters", expanded=True):
fcol1, fcol2 = st.columns(2)
with fcol1:
selected_locales = st.multiselect(
"Filter by Place",
options=sorted(df["participant_locale"].unique()),
default=sorted(df["participant_locale"].unique()),
)
with fcol2:
# Build a flat list of all harm types present in the dataset
all_harms = sorted(
{h for hlist in df["harm_types"].dropna() for h in hlist}
)
selected_harms = st.multiselect(
"Filter by Harm Type",
options=all_harms,
default=all_harms,
)
search_query = st.text_input(
"πŸ” Search prompts (case-insensitive substring match)", ""
)
# Apply filters
mask = df["participant_locale"].isin(selected_locales)
# Harm-type filter (row passes if it contains ANY selected harm)
mask &= df["harm_types"].apply(
lambda hlist: bool(set(hlist) & set(selected_harms))
if isinstance(hlist, list)
else False
)
if search_query:
mask &= df["prompt"].str.contains(search_query, case=False, na=False)
filtered = df[mask]
st.markdown(f"**Showing {len(filtered):,} of {len(df):,} records**")
st.dataframe(
filtered[
[
"participant_locale",
"prompt",
"revised_prompt",
"harm_types",
"attack_modes",
"targeted_identity_attribute",
]
].reset_index(drop=True),
use_container_width=True,
height=500,
)
st.divider()
# ══════════════════════════════════════════════════════════════════════════════
# SAMPLE PROMPTS
# ══════════════════════════════════════════════════════════════════════════════
# st.markdown("## 🎲 Random Sample of Prompts")
# n_samples = st.slider("Number of samples", 1, 50, 10)
# if st.button("πŸ”„ Resample"):
# pass # forces re-run
# sample = filtered.sample(n=min(n_samples, len(filtered)))
# for _, row in sample.iterrows():
# with st.container():
# st.markdown(
# f"**Locale:** `{row['participant_locale']}` \n"
# f"**Prompt:** {row['prompt']} \n"
# f"**Revised Prompt:** {row['revised_prompt']} \n"
# f"**Harm Types:** {', '.join(row['harm_types']) if isinstance(row['harm_types'], list) else row['harm_types']} \n"
# f"**Attack Modes:** {', '.join(row['attack_modes']) if isinstance(row['attack_modes'], list) else row['attack_modes']} \n"
# f"**Targeted Identity:** {', '.join(row['targeted_identity_attribute']) if isinstance(row['targeted_identity_attribute'], list) else row['targeted_identity_attribute']}"
# )
# st.divider()