import json
import os
import random
from pathlib import Path
from urllib.parse import urlparse
from urllib.request import urlopen
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
try:
import boto3
except ImportError:
boto3 = None
# -------------------------------------------------------------------
# Load data
# -------------------------------------------------------------------
APP_DIR = Path(__file__).resolve().parent
DEFAULT_DATA_FILENAME = (
"content-item-classification-base-multilingual_v1-0-0_aggregated_for_exploration.json"
)
DEFAULT_LOCAL_DATA_SOURCE = APP_DIR / DEFAULT_DATA_FILENAME
DEFAULT_REMOTE_DATA_SOURCE = (
"s3://140-processed-data-sandbox/content-item-classification/"
"content-item-classification-base-multilingual_v1-0-0_aggregated_for_exploration.json"
)
DATA_SOURCE = os.environ.get("DATA_SOURCE") or (
str(DEFAULT_LOCAL_DATA_SOURCE)
if DEFAULT_LOCAL_DATA_SOURCE.exists()
else DEFAULT_REMOTE_DATA_SOURCE
)
def format_source_label(source: str) -> str:
if source.startswith(("s3://", "http://", "https://")):
return source
path = Path(source)
if not path.is_absolute():
return source
try:
return path.relative_to(APP_DIR).as_posix()
except ValueError:
return str(path)
SOURCE_LABEL = format_source_label(DATA_SOURCE)
def load_json_from_s3(source: str):
if boto3 is None:
raise ImportError("boto3 is required to load data from S3.")
parsed = urlparse(source)
bucket = parsed.netloc
key = parsed.path.lstrip("/")
if not bucket or not key:
raise ValueError(f"Invalid S3 path: {source}")
session_kwargs = {}
profile = os.environ.get("AWS_PROFILE") or os.environ.get("S3_PROFILE")
if profile:
session_kwargs["profile_name"] = profile
client_kwargs = {}
endpoint_url = os.environ.get("AWS_ENDPOINT_URL") or os.environ.get(
"S3_ENDPOINT_URL"
)
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
region = os.environ.get("AWS_REGION") or os.environ.get("S3_REGION")
if region:
client_kwargs["region_name"] = region
session = boto3.Session(**session_kwargs)
client = session.client("s3", **client_kwargs)
response = client.get_object(Bucket=bucket, Key=key)
return json.loads(response["Body"].read().decode("utf-8"))
def load_data(source: str):
if source.startswith("s3://"):
return load_json_from_s3(source)
if source.startswith(("http://", "https://")):
with urlopen(source) as response:
return json.load(response)
path = Path(source)
if not path.is_absolute():
path = APP_DIR / path
with path.open(encoding="utf-8") as handle:
return json.load(handle)
data = load_data(DATA_SOURCE)
if isinstance(data, list):
raw_rows = data
elif isinstance(data, dict):
raw_rows = data.get("rows") or data.get("data") or data.get("items") or []
else:
raw_rows = []
rows = []
for row in raw_rows:
provider = str(row.get("provider_alias", "")).strip()
newspaper = str(row.get("newspaper_alias", "")).strip()
if not provider or not newspaper:
continue
try:
year = int(row["year"])
except Exception:
continue
ad_count = int(row.get("ad_count", 0) or 0)
non_ad_count = int(row.get("non_ad_count", 0) or 0)
total_count = row.get("total_count")
total_count = int(total_count) if total_count is not None else ad_count + non_ad_count
ad_share = row.get("ad_share")
if ad_share is None:
ad_share = ad_count / total_count if total_count else 0.0
rows.append(
{
"provider": provider,
"provider_name": row.get("provider_name") or provider,
"newspaper": newspaper,
"newspaper_title": row.get("newspaper_title") or newspaper,
"year": year,
"ad_count": ad_count,
"non_ad_count": non_ad_count,
"total_count": total_count,
"ad_share": float(ad_share),
"issue_count": row.get("issue_count"),
}
)
df = pd.DataFrame(rows).sort_values(["provider", "newspaper", "year"])
if df.empty:
raise ValueError("No yearly ad-classification data found.")
df = df[df["total_count"] > 0].copy()
# -------------------------------------------------------------------
# Labels
# -------------------------------------------------------------------
media_title_map = (
df[["newspaper", "newspaper_title"]]
.drop_duplicates("newspaper")
.set_index("newspaper")["newspaper_title"]
.to_dict()
)
provider_name_map = (
df[["provider", "provider_name"]]
.drop_duplicates("provider")
.set_index("provider")["provider_name"]
.to_dict()
)
def newspaper_label(alias: str) -> str:
alias = alias.strip()
title = media_title_map.get(alias, alias)
return f"{title} [{alias}]" if title != alias else alias
def provider_label(alias: str) -> str:
alias = alias.strip()
name = provider_name_map.get(alias, alias)
suffix = f"({alias})"
if isinstance(name, str) and name.endswith(suffix):
name = name[: -len(suffix)].strip()
return f"{name} [{alias}]"
provider_options = [("All", "All")] + sorted(
[(provider_label(p), p) for p in df["provider"].dropna().unique()],
key=lambda x: x[0],
)
# -------------------------------------------------------------------
# Rankings
# -------------------------------------------------------------------
ranking_by_provider = (
df.groupby(["provider", "newspaper"], as_index=False)[
["ad_count", "non_ad_count", "total_count"]
]
.sum()
)
ranking_by_provider["mean_ad_share"] = ranking_by_provider["ad_count"] / (
ranking_by_provider["total_count"].where(ranking_by_provider["total_count"] > 0, 1)
)
ranking_global = (
df.groupby("newspaper", as_index=False)[["ad_count", "non_ad_count", "total_count"]]
.sum()
)
ranking_global["mean_ad_share"] = ranking_global["ad_count"] / (
ranking_global["total_count"].where(ranking_global["total_count"] > 0, 1)
)
def get_ranked_df(provider="All", query=""):
if provider == "All":
ranked = ranking_global.copy()
else:
ranked = ranking_by_provider.loc[
ranking_by_provider["provider"] == provider, ["newspaper", "mean_ad_share"]
].copy()
ranked = ranked.sort_values(
["mean_ad_share", "newspaper"], ascending=[False, True]
).reset_index(drop=True)
if query:
q = query.strip()
def _matches(alias: str) -> bool:
if q in alias:
return True
return q in media_title_map.get(alias.strip(), "")
ranked = ranked[ranked["newspaper"].apply(_matches)].reset_index(drop=True)
return ranked
def choose_newspapers(ranked, n_best, n_worst, n_random, seed=13):
ranked_names = ranked["newspaper"].tolist()
best = ranked_names[: int(n_best)] if n_best > 0 else []
worst = ranked_names[-int(n_worst) :] if n_worst > 0 else []
remaining_for_random = [
n for n in ranked_names if n not in set(best) and n not in set(worst)
]
rng = random.Random(seed)
n_random = min(int(n_random), len(remaining_for_random))
random_pick = rng.sample(remaining_for_random, n_random) if n_random > 0 else []
selected = best + worst + random_pick
selected = list(dict.fromkeys(selected))
choices = ranked_names
return choices, selected
def update_newspapers(provider, query, n_best, n_worst, n_random):
ranked = get_ranked_df(provider, query)
choices, selected = choose_newspapers(ranked, n_best, n_worst, n_random)
labeled_choices = [(newspaper_label(n), n) for n in choices]
return gr.update(choices=labeled_choices, value=selected)
def make_plot(provider, selected_newspapers):
if not selected_newspapers:
fig = go.Figure()
fig.update_layout(
title="Select one or more newspapers",
xaxis_title="Year",
yaxis_title="Ad share",
yaxis=dict(range=[0, 1.05]),
template="plotly_white",
height=650,
)
return fig
subset = df.copy() if provider == "All" else df[df["provider"] == provider].copy()
subset = subset[subset["newspaper"].isin(selected_newspapers)]
if subset.empty:
fig = go.Figure()
fig.update_layout(
title="No data for the current selection",
xaxis_title="Year",
yaxis_title="Ad share",
yaxis=dict(range=[0, 1.05]),
template="plotly_white",
height=650,
)
return fig
ranked = get_ranked_df(provider, "")
ranked_order = [
n for n in ranked["newspaper"].tolist() if n in set(selected_newspapers)
]
fig = go.Figure()
for newspaper in ranked_order:
dfn = subset[subset["newspaper"] == newspaper].sort_values("year")
if dfn.empty:
continue
fig.add_trace(
go.Scatter(
x=dfn["year"],
y=dfn["ad_share"],
mode="markers",
name=newspaper_label(newspaper),
customdata=dfn[["ad_count", "non_ad_count", "total_count"]].values,
hovertemplate=(
"%{fullData.name}
"
"Year: %{x}
"
"Ad share: %{y:.1%}
"
"Ads: %{customdata[0]}
"
"Non-ads: %{customdata[1]}
"
"Articles: %{customdata[2]}"
""
),
)
)
year_min = subset["year"].min()
year_max = subset["year"].max()
if year_max - year_min < 10:
mid = (year_min + year_max) / 2
year_min = int(mid - 5)
year_max = int(mid + 5)
provider_display = provider if provider == "All" else provider_label(provider)
fig.update_layout(
title=f"Ad share by newspaper — provider: {provider_display}",
xaxis_title="Year",
xaxis=dict(range=[year_min - 1, year_max + 1]),
yaxis_title="Ad share",
yaxis=dict(range=[0, 1.05]),
template="plotly_white",
height=650,
)
return fig
# -------------------------------------------------------------------
# Initial state
# -------------------------------------------------------------------
initial_provider = "All"
initial_query = ""
initial_best = 10
initial_worst = 0
initial_random = 0
initial_ranked = get_ranked_df(initial_provider, initial_query)
initial_choices, initial_selected = choose_newspapers(
initial_ranked, initial_best, initial_worst, initial_random
)
# -------------------------------------------------------------------
# UI
# -------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## Ad classification exploration")
gr.Markdown(
"Explore yearly ad-share distributions by provider and newspaper. "
f"Source: `{SOURCE_LABEL}`"
)
with gr.Row():
provider = gr.Dropdown(
choices=provider_options,
value=initial_provider,
label="Provider",
)
query = gr.Textbox(
value=initial_query,
label="Filter newspapers (case-sensitive)",
placeholder="Type part of a newspaper title",
)
with gr.Row():
n_best = gr.Slider(
minimum=0,
maximum=400,
value=initial_best,
step=1,
label="Highest ad share",
)
n_worst = gr.Slider(
minimum=0,
maximum=400,
value=initial_worst,
step=1,
label="Lowest ad share",
)
n_random = gr.Slider(
minimum=0,
maximum=400,
value=initial_random,
step=1,
label="Random",
)
newspaper = gr.Dropdown(
choices=[(newspaper_label(n), n) for n in initial_choices],
value=initial_selected,
multiselect=True,
label="Newspapers (filtered and ranked)",
)
plot = gr.Plot()
selector_inputs = [provider, query, n_best, n_worst, n_random]
for trigger in selector_inputs:
trigger.change(
fn=update_newspapers,
inputs=selector_inputs,
outputs=newspaper,
)
trigger.change(
fn=lambda provider, newspaper: make_plot(provider, newspaper),
inputs=[provider, newspaper],
outputs=plot,
)
newspaper.change(
fn=make_plot,
inputs=[provider, newspaper],
outputs=plot,
)
demo.load(
fn=make_plot,
inputs=[provider, newspaper],
outputs=plot,
)
demo.launch()