import os import json from collections import Counter import math import gradio as gr import pandas as pd import plotly.express as px import pycountry from datasets import load_dataset # ========================= # Config # ========================= VISITS_URL = os.getenv( "VISITS_URL", "https://huggingface.co/datasets/19arjun89/ai_recruiting_agent_usage/resolve/main/usage/visits_enriched.jsonl", ) # Set this as a HF Space SECRET named MAPBOX_TOKEN MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN", "").strip() # Path to your GeoJSON (commit into the Space repo) GEOJSON_PATH = os.getenv("GEOJSON_PATH", "countries.geojson") # IMPORTANT: Set this to match the property name inside your GeoJSON features. # Common values: "properties.ISO_A3" or "properties.ADM0_A3" GEOJSON_FEATURE_ID_KEY = "properties.ISO3166-1-Alpha-3" MAX_ROWS = int(os.getenv("MAX_ROWS", "500000")) # ========================= # Helpers # ========================= def normalize_country_name(country: str | None) -> str | None: if not country or not isinstance(country, str): return None c = country.strip() if not c or c.lower() == "unknown": return None return c def iso2_to_iso3(country_code: str | None) -> str | None: """Convert ISO-2 -> ISO-3 for map matching.""" if not country_code or not isinstance(country_code, str): return None c2 = country_code.strip().upper() if len(c2) != 2: return None try: rec = pycountry.countries.get(alpha_2=c2) return rec.alpha_3 if rec else None except Exception: return None def load_rows_streaming(): ds = load_dataset( "json", data_files=VISITS_URL, split="train", streaming=True, ) for i, row in enumerate(ds): yield row if i + 1 >= MAX_ROWS: break def load_geojson(path: str) -> dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def patch_geojson_iso_codes(countries_geojson: dict) -> dict: """ Some GeoJSON files have ISO codes as '-99'. Patch them using the 'name' field. Updates: properties['ISO3166-1-Alpha-2'] properties['ISO3166-1-Alpha-3'] """ fixed = 0 for feat in countries_geojson.get("features", []): props = feat.get("properties", {}) or {} iso3 = str(props.get("ISO3166-1-Alpha-3", "") or "").strip() iso2 = str(props.get("ISO3166-1-Alpha-2", "") or "").strip() name = str(props.get("name", "") or "").strip() needs_fix = (iso3 == "-99" or iso2 == "-99" or not iso3 or not iso2) if not needs_fix or not name: continue try: rec = pycountry.countries.search_fuzzy(name)[0] props["ISO3166-1-Alpha-3"] = rec.alpha_3 props["ISO3166-1-Alpha-2"] = rec.alpha_2 fixed += 1 except Exception: # leave as-is if we can't resolve pass print(f"DEBUG patched GeoJSON features: {fixed}") return countries_geojson # ========================= # Main report builder # ========================= def build_report(): if not MAPBOX_TOKEN: # We can still run, but Mapbox will not render nicely without token. # We'll still build a figure (it may appear blank/limited). pass countries_geojson = patch_geojson_iso_codes(load_geojson(GEOJSON_PATH)) # Counters for clean reconciliation scanned = 0 skipped_session_start = 0 missing_country = 0 invalid_country_code = 0 # Table (country name) and map (iso3) country_counts = Counter() iso3_counts = Counter() iso3_to_name = {} for row in load_rows_streaming(): scanned += 1 event_type = str(row.get("event", "") or "").strip().lower() if event_type == "session_start": skipped_session_start += 1 continue country = normalize_country_name(row.get("final_country")) if not country: missing_country += 1 continue # Count it for the table FIRST (all usage events with a valid country name) country_counts[country] += 1 iso3 = iso2_to_iso3(row.get("final_country_code")) if not iso3: invalid_country_code += 1 continue # Count it for the map only (requires ISO3) iso3_counts[iso3] += 1 iso3_to_name.setdefault(iso3, country) # Build table dataframe table_df = ( pd.DataFrame([{"country": k, "usage events": v} for k, v in country_counts.items()]) .sort_values("usage events", ascending=False) .reset_index(drop=True) ) # Build map dataframe map_df = ( pd.DataFrame( [ {"iso3": iso3, "country": iso3_to_name.get(iso3, iso3), "usage events": cnt} for iso3, cnt in iso3_counts.items() ] ) .sort_values("usage events", ascending=False) .reset_index(drop=True) ) # Log scale for nicer color spread (keeps small countries visible) map_df["usage_log"] = map_df["usage events"].clip(lower=1).apply(lambda x: math.log10(x)) # Reconciliation rows_mappable = int(map_df["usage events"].sum()) # note: this is TOTAL events, not rows mappable_rows_count = int(sum(iso3_counts.values())) # count of rows after filters (events counted) table_rows_counted = int(sum(country_counts.values())) accounted = skipped_session_start + missing_country + invalid_country_code + mappable_rows_count # If you want “Rows mappable” to mean “rows that made it to map”, use mappable_rows_count # If you want “Total usage events” (same thing here), use table_df sum. # Map figure if map_df.empty: fig = px.scatter(title="No mappable data found") fig.update_layout(height=740, margin=dict(l=0, r=0, t=40, b=0)) summary = ( f"Rows scanned: {scanned:,}\n" f"- Rows counted in table: {table_rows_counted:,}\n" f"- Rows mapped: {mappable_rows_count:,}\n" f"- Session starts skipped: {skipped_session_start:,}\n" f"- Missing country: {missing_country:,}\n" f"- Invalid country code: {invalid_country_code:,}\n\n" f"Accounted rows: {accounted:,} / {scanned:,}\n" f"Countries (table): {len(table_df):,}\n" f"Total usage events: {int(table_df['usage events'].sum()) if len(table_df) else 0:,}" ) return fig, table_df.head(50), summary # Mapbox choropleth using GeoJSON px.set_mapbox_access_token(MAPBOX_TOKEN) map_df["iso3"] = map_df["iso3"].astype(str).str.upper() fig = px.choropleth_mapbox( map_df, geojson=countries_geojson, featureidkey=GEOJSON_FEATURE_ID_KEY, locations="iso3", color="usage_log", hover_name="country", hover_data={"usage events": True, "iso3": True}, labels={"usage_log": "Usage intensity (log10)", "usage events": "Usage events"}, mapbox_style="open-street-map", opacity=0.75, zoom=1.5, center={"lat": 15, "lon": 0}, ) fig.update_traces( # Use a clean hover card hovertemplate=( "%{customdata[0]}
" "Usage events: %{customdata[1]:,}
" "" ), # customdata lets us show real counts even though color is log-scaled customdata=map_df[["country", "usage events"]].to_numpy(), ) fig.update_traces( marker_line_width=0.8, marker_line_color="rgba(255,255,255,0.85)", # nice on light basemaps ) # Full-bleed layout fig.update_layout( height=740, margin=dict(l=0, r=0, t=0, b=0), ) # Dashboard title fig.add_annotation( text="Usage Events by Country", x=0.01, y=0.95, xref="paper", yref="paper", xanchor="left", yanchor="top", showarrow=False, font=dict(size=20), ) fig.update_layout(coloraxis_showscale=False) # Summary text (clean math) summary = ( f"Rows scanned: {scanned:,}\n" f"- Session starts skipped: {skipped_session_start:,}\n" f"- Missing country: {missing_country:,}\n" f"- Invalid country code: {invalid_country_code:,}\n" f"- Rows mapped: {mappable_rows_count:,}\n\n" f"Accounted rows: {accounted:,} / {scanned:,}\n" f"Countries (table): {len(table_df):,}\n" f"Countries (map): {map_df['iso3'].nunique():,}\n" f"Total usage events: {int(table_df['usage events'].sum()) if len(table_df) else 0:,}" ) table_out = table_df.head(50).copy() table_out.insert(0, "refreshed_at_utc", pd.Timestamp.utcnow().strftime("%Y-%m-%d %H:%M:%S")) return fig, table_out, summary # ========================= # UI # ========================= with gr.Blocks(title="AI Recruiting Agent — Usage Map") as demo: gr.Markdown( "# AI Recruiting Agent — Usage by Country (Mapbox)\n" "This Space reads **only** `visits_enriched.jsonl`, excludes `event=session_start`, " "and plots **usage events** by country.\n\n" ) run = gr.Button("Generate map") summary = gr.Markdown() plot = gr.Plot() table = gr.Dataframe(label="Top countries", interactive=False) run.click( fn=build_report, inputs=[], outputs=[plot, table, summary], ) demo.launch()