Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from collections import Counter | |
| import math | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| import pycountry | |
| from datasets import load_dataset | |
| # ========================= | |
| # Config | |
| # ========================= | |
| VISITS_URL = os.getenv( | |
| "VISITS_URL", | |
| "https://huggingface.co/datasets/19arjun89/ai_recruiting_agent_usage/resolve/main/usage/visits_enriched.jsonl", | |
| ) | |
| # Set this as a HF Space SECRET named MAPBOX_TOKEN | |
| MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN", "").strip() | |
| # Path to your GeoJSON (commit into the Space repo) | |
| GEOJSON_PATH = os.getenv("GEOJSON_PATH", "countries.geojson") | |
| # IMPORTANT: Set this to match the property name inside your GeoJSON features. | |
| # Common values: "properties.ISO_A3" or "properties.ADM0_A3" | |
| GEOJSON_FEATURE_ID_KEY = "properties.ISO3166-1-Alpha-3" | |
| MAX_ROWS = int(os.getenv("MAX_ROWS", "500000")) | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def normalize_country_name(country: str | None) -> str | None: | |
| if not country or not isinstance(country, str): | |
| return None | |
| c = country.strip() | |
| if not c or c.lower() == "unknown": | |
| return None | |
| return c | |
| def iso2_to_iso3(country_code: str | None) -> str | None: | |
| """Convert ISO-2 -> ISO-3 for map matching.""" | |
| if not country_code or not isinstance(country_code, str): | |
| return None | |
| c2 = country_code.strip().upper() | |
| if len(c2) != 2: | |
| return None | |
| try: | |
| rec = pycountry.countries.get(alpha_2=c2) | |
| return rec.alpha_3 if rec else None | |
| except Exception: | |
| return None | |
| def load_rows_streaming(): | |
| ds = load_dataset( | |
| "json", | |
| data_files=VISITS_URL, | |
| split="train", | |
| streaming=True, | |
| ) | |
| for i, row in enumerate(ds): | |
| yield row | |
| if i + 1 >= MAX_ROWS: | |
| break | |
| def load_geojson(path: str) -> dict: | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def patch_geojson_iso_codes(countries_geojson: dict) -> dict: | |
| """ | |
| Some GeoJSON files have ISO codes as '-99'. Patch them using the 'name' field. | |
| Updates: | |
| properties['ISO3166-1-Alpha-2'] | |
| properties['ISO3166-1-Alpha-3'] | |
| """ | |
| fixed = 0 | |
| for feat in countries_geojson.get("features", []): | |
| props = feat.get("properties", {}) or {} | |
| iso3 = str(props.get("ISO3166-1-Alpha-3", "") or "").strip() | |
| iso2 = str(props.get("ISO3166-1-Alpha-2", "") or "").strip() | |
| name = str(props.get("name", "") or "").strip() | |
| needs_fix = (iso3 == "-99" or iso2 == "-99" or not iso3 or not iso2) | |
| if not needs_fix or not name: | |
| continue | |
| try: | |
| rec = pycountry.countries.search_fuzzy(name)[0] | |
| props["ISO3166-1-Alpha-3"] = rec.alpha_3 | |
| props["ISO3166-1-Alpha-2"] = rec.alpha_2 | |
| fixed += 1 | |
| except Exception: | |
| # leave as-is if we can't resolve | |
| pass | |
| print(f"DEBUG patched GeoJSON features: {fixed}") | |
| return countries_geojson | |
| # ========================= | |
| # Main report builder | |
| # ========================= | |
| def build_report(): | |
| if not MAPBOX_TOKEN: | |
| # We can still run, but Mapbox will not render nicely without token. | |
| # We'll still build a figure (it may appear blank/limited). | |
| pass | |
| countries_geojson = patch_geojson_iso_codes(load_geojson(GEOJSON_PATH)) | |
| # Counters for clean reconciliation | |
| scanned = 0 | |
| skipped_session_start = 0 | |
| missing_country = 0 | |
| invalid_country_code = 0 | |
| # Table (country name) and map (iso3) | |
| country_counts = Counter() | |
| iso3_counts = Counter() | |
| iso3_to_name = {} | |
| for row in load_rows_streaming(): | |
| scanned += 1 | |
| event_type = str(row.get("event", "") or "").strip().lower() | |
| if event_type == "session_start": | |
| skipped_session_start += 1 | |
| continue | |
| country = normalize_country_name(row.get("final_country")) | |
| if not country: | |
| missing_country += 1 | |
| continue | |
| # Count it for the table FIRST (all usage events with a valid country name) | |
| country_counts[country] += 1 | |
| iso3 = iso2_to_iso3(row.get("final_country_code")) | |
| if not iso3: | |
| invalid_country_code += 1 | |
| continue | |
| # Count it for the map only (requires ISO3) | |
| iso3_counts[iso3] += 1 | |
| iso3_to_name.setdefault(iso3, country) | |
| # Build table dataframe | |
| table_df = ( | |
| pd.DataFrame([{"country": k, "usage events": v} for k, v in country_counts.items()]) | |
| .sort_values("usage events", ascending=False) | |
| .reset_index(drop=True) | |
| ) | |
| # Build map dataframe | |
| map_df = ( | |
| pd.DataFrame( | |
| [ | |
| {"iso3": iso3, "country": iso3_to_name.get(iso3, iso3), "usage events": cnt} | |
| for iso3, cnt in iso3_counts.items() | |
| ] | |
| ) | |
| .sort_values("usage events", ascending=False) | |
| .reset_index(drop=True) | |
| ) | |
| # Log scale for nicer color spread (keeps small countries visible) | |
| map_df["usage_log"] = map_df["usage events"].clip(lower=1).apply(lambda x: math.log10(x)) | |
| # Reconciliation | |
| rows_mappable = int(map_df["usage events"].sum()) # note: this is TOTAL events, not rows | |
| mappable_rows_count = int(sum(iso3_counts.values())) # count of rows after filters (events counted) | |
| table_rows_counted = int(sum(country_counts.values())) | |
| accounted = skipped_session_start + missing_country + invalid_country_code + mappable_rows_count | |
| # If you want “Rows mappable” to mean “rows that made it to map”, use mappable_rows_count | |
| # If you want “Total usage events” (same thing here), use table_df sum. | |
| # Map figure | |
| if map_df.empty: | |
| fig = px.scatter(title="No mappable data found") | |
| fig.update_layout(height=740, margin=dict(l=0, r=0, t=40, b=0)) | |
| summary = ( | |
| f"Rows scanned: {scanned:,}\n" | |
| f"- Rows counted in table: {table_rows_counted:,}\n" | |
| f"- Rows mapped: {mappable_rows_count:,}\n" | |
| f"- Session starts skipped: {skipped_session_start:,}\n" | |
| f"- Missing country: {missing_country:,}\n" | |
| f"- Invalid country code: {invalid_country_code:,}\n\n" | |
| f"Accounted rows: {accounted:,} / {scanned:,}\n" | |
| f"Countries (table): {len(table_df):,}\n" | |
| f"Total usage events: {int(table_df['usage events'].sum()) if len(table_df) else 0:,}" | |
| ) | |
| return fig, table_df.head(50), summary | |
| # Mapbox choropleth using GeoJSON | |
| px.set_mapbox_access_token(MAPBOX_TOKEN) | |
| map_df["iso3"] = map_df["iso3"].astype(str).str.upper() | |
| fig = px.choropleth_mapbox( | |
| map_df, | |
| geojson=countries_geojson, | |
| featureidkey=GEOJSON_FEATURE_ID_KEY, | |
| locations="iso3", | |
| color="usage_log", | |
| hover_name="country", | |
| hover_data={"usage events": True, "iso3": True}, | |
| labels={"usage_log": "Usage intensity (log10)", "usage events": "Usage events"}, | |
| mapbox_style="open-street-map", | |
| opacity=0.75, | |
| zoom=1.5, | |
| center={"lat": 15, "lon": 0}, | |
| ) | |
| fig.update_traces( | |
| # Use a clean hover card | |
| hovertemplate=( | |
| "<b>%{customdata[0]}</b><br>" | |
| "Usage events: %{customdata[1]:,}<br>" | |
| "<extra></extra>" | |
| ), | |
| # customdata lets us show real counts even though color is log-scaled | |
| customdata=map_df[["country", "usage events"]].to_numpy(), | |
| ) | |
| fig.update_traces( | |
| marker_line_width=0.8, | |
| marker_line_color="rgba(255,255,255,0.85)", # nice on light basemaps | |
| ) | |
| # Full-bleed layout | |
| fig.update_layout( | |
| height=740, | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| ) | |
| # Dashboard title | |
| fig.add_annotation( | |
| text="Usage Events by Country", | |
| x=0.01, | |
| y=0.95, | |
| xref="paper", | |
| yref="paper", | |
| xanchor="left", | |
| yanchor="top", | |
| showarrow=False, | |
| font=dict(size=20), | |
| ) | |
| fig.update_layout(coloraxis_showscale=False) | |
| # Summary text (clean math) | |
| summary = ( | |
| f"Rows scanned: {scanned:,}\n" | |
| f"- Session starts skipped: {skipped_session_start:,}\n" | |
| f"- Missing country: {missing_country:,}\n" | |
| f"- Invalid country code: {invalid_country_code:,}\n" | |
| f"- Rows mapped: {mappable_rows_count:,}\n\n" | |
| f"Accounted rows: {accounted:,} / {scanned:,}\n" | |
| f"Countries (table): {len(table_df):,}\n" | |
| f"Countries (map): {map_df['iso3'].nunique():,}\n" | |
| f"Total usage events: {int(table_df['usage events'].sum()) if len(table_df) else 0:,}" | |
| ) | |
| table_out = table_df.head(50).copy() | |
| table_out.insert(0, "refreshed_at_utc", pd.Timestamp.utcnow().strftime("%Y-%m-%d %H:%M:%S")) | |
| return fig, table_out, summary | |
| # ========================= | |
| # UI | |
| # ========================= | |
| with gr.Blocks(title="AI Recruiting Agent — Usage Map") as demo: | |
| gr.Markdown( | |
| "# AI Recruiting Agent — Usage by Country (Mapbox)\n" | |
| "This Space reads **only** `visits_enriched.jsonl`, excludes `event=session_start`, " | |
| "and plots **usage events** by country.\n\n" | |
| ) | |
| run = gr.Button("Generate map") | |
| summary = gr.Markdown() | |
| plot = gr.Plot() | |
| table = gr.Dataframe(label="Top countries", interactive=False) | |
| run.click( | |
| fn=build_report, | |
| inputs=[], | |
| outputs=[plot, table, summary], | |
| ) | |
| demo.launch() | |