import os
import json
from collections import Counter
import math
import gradio as gr
import pandas as pd
import plotly.express as px
import pycountry
from datasets import load_dataset
# =========================
# Config
# =========================
VISITS_URL = os.getenv(
"VISITS_URL",
"https://huggingface.co/datasets/19arjun89/ai_recruiting_agent_usage/resolve/main/usage/visits_enriched.jsonl",
)
# Set this as a HF Space SECRET named MAPBOX_TOKEN
MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN", "").strip()
# Path to your GeoJSON (commit into the Space repo)
GEOJSON_PATH = os.getenv("GEOJSON_PATH", "countries.geojson")
# IMPORTANT: Set this to match the property name inside your GeoJSON features.
# Common values: "properties.ISO_A3" or "properties.ADM0_A3"
GEOJSON_FEATURE_ID_KEY = "properties.ISO3166-1-Alpha-3"
MAX_ROWS = int(os.getenv("MAX_ROWS", "500000"))
# =========================
# Helpers
# =========================
def normalize_country_name(country: str | None) -> str | None:
if not country or not isinstance(country, str):
return None
c = country.strip()
if not c or c.lower() == "unknown":
return None
return c
def iso2_to_iso3(country_code: str | None) -> str | None:
"""Convert ISO-2 -> ISO-3 for map matching."""
if not country_code or not isinstance(country_code, str):
return None
c2 = country_code.strip().upper()
if len(c2) != 2:
return None
try:
rec = pycountry.countries.get(alpha_2=c2)
return rec.alpha_3 if rec else None
except Exception:
return None
def load_rows_streaming():
ds = load_dataset(
"json",
data_files=VISITS_URL,
split="train",
streaming=True,
)
for i, row in enumerate(ds):
yield row
if i + 1 >= MAX_ROWS:
break
def load_geojson(path: str) -> dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def patch_geojson_iso_codes(countries_geojson: dict) -> dict:
"""
Some GeoJSON files have ISO codes as '-99'. Patch them using the 'name' field.
Updates:
properties['ISO3166-1-Alpha-2']
properties['ISO3166-1-Alpha-3']
"""
fixed = 0
for feat in countries_geojson.get("features", []):
props = feat.get("properties", {}) or {}
iso3 = str(props.get("ISO3166-1-Alpha-3", "") or "").strip()
iso2 = str(props.get("ISO3166-1-Alpha-2", "") or "").strip()
name = str(props.get("name", "") or "").strip()
needs_fix = (iso3 == "-99" or iso2 == "-99" or not iso3 or not iso2)
if not needs_fix or not name:
continue
try:
rec = pycountry.countries.search_fuzzy(name)[0]
props["ISO3166-1-Alpha-3"] = rec.alpha_3
props["ISO3166-1-Alpha-2"] = rec.alpha_2
fixed += 1
except Exception:
# leave as-is if we can't resolve
pass
print(f"DEBUG patched GeoJSON features: {fixed}")
return countries_geojson
# =========================
# Main report builder
# =========================
def build_report():
if not MAPBOX_TOKEN:
# We can still run, but Mapbox will not render nicely without token.
# We'll still build a figure (it may appear blank/limited).
pass
countries_geojson = patch_geojson_iso_codes(load_geojson(GEOJSON_PATH))
# Counters for clean reconciliation
scanned = 0
skipped_session_start = 0
missing_country = 0
invalid_country_code = 0
# Table (country name) and map (iso3)
country_counts = Counter()
iso3_counts = Counter()
iso3_to_name = {}
for row in load_rows_streaming():
scanned += 1
event_type = str(row.get("event", "") or "").strip().lower()
if event_type == "session_start":
skipped_session_start += 1
continue
country = normalize_country_name(row.get("final_country"))
if not country:
missing_country += 1
continue
# Count it for the table FIRST (all usage events with a valid country name)
country_counts[country] += 1
iso3 = iso2_to_iso3(row.get("final_country_code"))
if not iso3:
invalid_country_code += 1
continue
# Count it for the map only (requires ISO3)
iso3_counts[iso3] += 1
iso3_to_name.setdefault(iso3, country)
# Build table dataframe
table_df = (
pd.DataFrame([{"country": k, "usage events": v} for k, v in country_counts.items()])
.sort_values("usage events", ascending=False)
.reset_index(drop=True)
)
# Build map dataframe
map_df = (
pd.DataFrame(
[
{"iso3": iso3, "country": iso3_to_name.get(iso3, iso3), "usage events": cnt}
for iso3, cnt in iso3_counts.items()
]
)
.sort_values("usage events", ascending=False)
.reset_index(drop=True)
)
# Log scale for nicer color spread (keeps small countries visible)
map_df["usage_log"] = map_df["usage events"].clip(lower=1).apply(lambda x: math.log10(x))
# Reconciliation
rows_mappable = int(map_df["usage events"].sum()) # note: this is TOTAL events, not rows
mappable_rows_count = int(sum(iso3_counts.values())) # count of rows after filters (events counted)
table_rows_counted = int(sum(country_counts.values()))
accounted = skipped_session_start + missing_country + invalid_country_code + mappable_rows_count
# If you want “Rows mappable” to mean “rows that made it to map”, use mappable_rows_count
# If you want “Total usage events” (same thing here), use table_df sum.
# Map figure
if map_df.empty:
fig = px.scatter(title="No mappable data found")
fig.update_layout(height=740, margin=dict(l=0, r=0, t=40, b=0))
summary = (
f"Rows scanned: {scanned:,}\n"
f"- Rows counted in table: {table_rows_counted:,}\n"
f"- Rows mapped: {mappable_rows_count:,}\n"
f"- Session starts skipped: {skipped_session_start:,}\n"
f"- Missing country: {missing_country:,}\n"
f"- Invalid country code: {invalid_country_code:,}\n\n"
f"Accounted rows: {accounted:,} / {scanned:,}\n"
f"Countries (table): {len(table_df):,}\n"
f"Total usage events: {int(table_df['usage events'].sum()) if len(table_df) else 0:,}"
)
return fig, table_df.head(50), summary
# Mapbox choropleth using GeoJSON
px.set_mapbox_access_token(MAPBOX_TOKEN)
map_df["iso3"] = map_df["iso3"].astype(str).str.upper()
fig = px.choropleth_mapbox(
map_df,
geojson=countries_geojson,
featureidkey=GEOJSON_FEATURE_ID_KEY,
locations="iso3",
color="usage_log",
hover_name="country",
hover_data={"usage events": True, "iso3": True},
labels={"usage_log": "Usage intensity (log10)", "usage events": "Usage events"},
mapbox_style="open-street-map",
opacity=0.75,
zoom=1.5,
center={"lat": 15, "lon": 0},
)
fig.update_traces(
# Use a clean hover card
hovertemplate=(
"%{customdata[0]}
"
"Usage events: %{customdata[1]:,}
"
""
),
# customdata lets us show real counts even though color is log-scaled
customdata=map_df[["country", "usage events"]].to_numpy(),
)
fig.update_traces(
marker_line_width=0.8,
marker_line_color="rgba(255,255,255,0.85)", # nice on light basemaps
)
# Full-bleed layout
fig.update_layout(
height=740,
margin=dict(l=0, r=0, t=0, b=0),
)
# Dashboard title
fig.add_annotation(
text="Usage Events by Country",
x=0.01,
y=0.95,
xref="paper",
yref="paper",
xanchor="left",
yanchor="top",
showarrow=False,
font=dict(size=20),
)
fig.update_layout(coloraxis_showscale=False)
# Summary text (clean math)
summary = (
f"Rows scanned: {scanned:,}\n"
f"- Session starts skipped: {skipped_session_start:,}\n"
f"- Missing country: {missing_country:,}\n"
f"- Invalid country code: {invalid_country_code:,}\n"
f"- Rows mapped: {mappable_rows_count:,}\n\n"
f"Accounted rows: {accounted:,} / {scanned:,}\n"
f"Countries (table): {len(table_df):,}\n"
f"Countries (map): {map_df['iso3'].nunique():,}\n"
f"Total usage events: {int(table_df['usage events'].sum()) if len(table_df) else 0:,}"
)
table_out = table_df.head(50).copy()
table_out.insert(0, "refreshed_at_utc", pd.Timestamp.utcnow().strftime("%Y-%m-%d %H:%M:%S"))
return fig, table_out, summary
# =========================
# UI
# =========================
with gr.Blocks(title="AI Recruiting Agent — Usage Map") as demo:
gr.Markdown(
"# AI Recruiting Agent — Usage by Country (Mapbox)\n"
"This Space reads **only** `visits_enriched.jsonl`, excludes `event=session_start`, "
"and plots **usage events** by country.\n\n"
)
run = gr.Button("Generate map")
summary = gr.Markdown()
plot = gr.Plot()
table = gr.Dataframe(label="Top countries", interactive=False)
run.click(
fn=build_report,
inputs=[],
outputs=[plot, table, summary],
)
demo.launch()