Maslionok's picture
Use bundled exploration dataset by default
2a60f29
import json
import os
import random
from pathlib import Path
from urllib.parse import urlparse
from urllib.request import urlopen
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
try:
import boto3
except ImportError:
boto3 = None
# -------------------------------------------------------------------
# Load data
# -------------------------------------------------------------------
APP_DIR = Path(__file__).resolve().parent
DEFAULT_DATA_FILENAME = (
"content-item-classification-base-multilingual_v1-0-0_aggregated_for_exploration.json"
)
DEFAULT_LOCAL_DATA_SOURCE = APP_DIR / DEFAULT_DATA_FILENAME
DEFAULT_REMOTE_DATA_SOURCE = (
"s3://140-processed-data-sandbox/content-item-classification/"
"content-item-classification-base-multilingual_v1-0-0_aggregated_for_exploration.json"
)
DATA_SOURCE = os.environ.get("DATA_SOURCE") or (
str(DEFAULT_LOCAL_DATA_SOURCE)
if DEFAULT_LOCAL_DATA_SOURCE.exists()
else DEFAULT_REMOTE_DATA_SOURCE
)
def format_source_label(source: str) -> str:
if source.startswith(("s3://", "http://", "https://")):
return source
path = Path(source)
if not path.is_absolute():
return source
try:
return path.relative_to(APP_DIR).as_posix()
except ValueError:
return str(path)
SOURCE_LABEL = format_source_label(DATA_SOURCE)
def load_json_from_s3(source: str):
if boto3 is None:
raise ImportError("boto3 is required to load data from S3.")
parsed = urlparse(source)
bucket = parsed.netloc
key = parsed.path.lstrip("/")
if not bucket or not key:
raise ValueError(f"Invalid S3 path: {source}")
session_kwargs = {}
profile = os.environ.get("AWS_PROFILE") or os.environ.get("S3_PROFILE")
if profile:
session_kwargs["profile_name"] = profile
client_kwargs = {}
endpoint_url = os.environ.get("AWS_ENDPOINT_URL") or os.environ.get(
"S3_ENDPOINT_URL"
)
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
region = os.environ.get("AWS_REGION") or os.environ.get("S3_REGION")
if region:
client_kwargs["region_name"] = region
session = boto3.Session(**session_kwargs)
client = session.client("s3", **client_kwargs)
response = client.get_object(Bucket=bucket, Key=key)
return json.loads(response["Body"].read().decode("utf-8"))
def load_data(source: str):
if source.startswith("s3://"):
return load_json_from_s3(source)
if source.startswith(("http://", "https://")):
with urlopen(source) as response:
return json.load(response)
path = Path(source)
if not path.is_absolute():
path = APP_DIR / path
with path.open(encoding="utf-8") as handle:
return json.load(handle)
data = load_data(DATA_SOURCE)
if isinstance(data, list):
raw_rows = data
elif isinstance(data, dict):
raw_rows = data.get("rows") or data.get("data") or data.get("items") or []
else:
raw_rows = []
rows = []
for row in raw_rows:
provider = str(row.get("provider_alias", "")).strip()
newspaper = str(row.get("newspaper_alias", "")).strip()
if not provider or not newspaper:
continue
try:
year = int(row["year"])
except Exception:
continue
ad_count = int(row.get("ad_count", 0) or 0)
non_ad_count = int(row.get("non_ad_count", 0) or 0)
total_count = row.get("total_count")
total_count = int(total_count) if total_count is not None else ad_count + non_ad_count
ad_share = row.get("ad_share")
if ad_share is None:
ad_share = ad_count / total_count if total_count else 0.0
rows.append(
{
"provider": provider,
"provider_name": row.get("provider_name") or provider,
"newspaper": newspaper,
"newspaper_title": row.get("newspaper_title") or newspaper,
"year": year,
"ad_count": ad_count,
"non_ad_count": non_ad_count,
"total_count": total_count,
"ad_share": float(ad_share),
"issue_count": row.get("issue_count"),
}
)
df = pd.DataFrame(rows).sort_values(["provider", "newspaper", "year"])
if df.empty:
raise ValueError("No yearly ad-classification data found.")
df = df[df["total_count"] > 0].copy()
# -------------------------------------------------------------------
# Labels
# -------------------------------------------------------------------
media_title_map = (
df[["newspaper", "newspaper_title"]]
.drop_duplicates("newspaper")
.set_index("newspaper")["newspaper_title"]
.to_dict()
)
provider_name_map = (
df[["provider", "provider_name"]]
.drop_duplicates("provider")
.set_index("provider")["provider_name"]
.to_dict()
)
def newspaper_label(alias: str) -> str:
alias = alias.strip()
title = media_title_map.get(alias, alias)
return f"{title} [{alias}]" if title != alias else alias
def provider_label(alias: str) -> str:
alias = alias.strip()
name = provider_name_map.get(alias, alias)
suffix = f"({alias})"
if isinstance(name, str) and name.endswith(suffix):
name = name[: -len(suffix)].strip()
return f"{name} [{alias}]"
provider_options = [("All", "All")] + sorted(
[(provider_label(p), p) for p in df["provider"].dropna().unique()],
key=lambda x: x[0],
)
# -------------------------------------------------------------------
# Rankings
# -------------------------------------------------------------------
ranking_by_provider = (
df.groupby(["provider", "newspaper"], as_index=False)[
["ad_count", "non_ad_count", "total_count"]
]
.sum()
)
ranking_by_provider["mean_ad_share"] = ranking_by_provider["ad_count"] / (
ranking_by_provider["total_count"].where(ranking_by_provider["total_count"] > 0, 1)
)
ranking_global = (
df.groupby("newspaper", as_index=False)[["ad_count", "non_ad_count", "total_count"]]
.sum()
)
ranking_global["mean_ad_share"] = ranking_global["ad_count"] / (
ranking_global["total_count"].where(ranking_global["total_count"] > 0, 1)
)
def get_ranked_df(provider="All", query=""):
if provider == "All":
ranked = ranking_global.copy()
else:
ranked = ranking_by_provider.loc[
ranking_by_provider["provider"] == provider, ["newspaper", "mean_ad_share"]
].copy()
ranked = ranked.sort_values(
["mean_ad_share", "newspaper"], ascending=[False, True]
).reset_index(drop=True)
if query:
q = query.strip()
def _matches(alias: str) -> bool:
if q in alias:
return True
return q in media_title_map.get(alias.strip(), "")
ranked = ranked[ranked["newspaper"].apply(_matches)].reset_index(drop=True)
return ranked
def choose_newspapers(ranked, n_best, n_worst, n_random, seed=13):
ranked_names = ranked["newspaper"].tolist()
best = ranked_names[: int(n_best)] if n_best > 0 else []
worst = ranked_names[-int(n_worst) :] if n_worst > 0 else []
remaining_for_random = [
n for n in ranked_names if n not in set(best) and n not in set(worst)
]
rng = random.Random(seed)
n_random = min(int(n_random), len(remaining_for_random))
random_pick = rng.sample(remaining_for_random, n_random) if n_random > 0 else []
selected = best + worst + random_pick
selected = list(dict.fromkeys(selected))
choices = ranked_names
return choices, selected
def update_newspapers(provider, query, n_best, n_worst, n_random):
ranked = get_ranked_df(provider, query)
choices, selected = choose_newspapers(ranked, n_best, n_worst, n_random)
labeled_choices = [(newspaper_label(n), n) for n in choices]
return gr.update(choices=labeled_choices, value=selected)
def make_plot(provider, selected_newspapers):
if not selected_newspapers:
fig = go.Figure()
fig.update_layout(
title="Select one or more newspapers",
xaxis_title="Year",
yaxis_title="Ad share",
yaxis=dict(range=[0, 1.05]),
template="plotly_white",
height=650,
)
return fig
subset = df.copy() if provider == "All" else df[df["provider"] == provider].copy()
subset = subset[subset["newspaper"].isin(selected_newspapers)]
if subset.empty:
fig = go.Figure()
fig.update_layout(
title="No data for the current selection",
xaxis_title="Year",
yaxis_title="Ad share",
yaxis=dict(range=[0, 1.05]),
template="plotly_white",
height=650,
)
return fig
ranked = get_ranked_df(provider, "")
ranked_order = [
n for n in ranked["newspaper"].tolist() if n in set(selected_newspapers)
]
fig = go.Figure()
for newspaper in ranked_order:
dfn = subset[subset["newspaper"] == newspaper].sort_values("year")
if dfn.empty:
continue
fig.add_trace(
go.Scatter(
x=dfn["year"],
y=dfn["ad_share"],
mode="markers",
name=newspaper_label(newspaper),
customdata=dfn[["ad_count", "non_ad_count", "total_count"]].values,
hovertemplate=(
"<b>%{fullData.name}</b><br>"
"Year: %{x}<br>"
"Ad share: %{y:.1%}<br>"
"Ads: %{customdata[0]}<br>"
"Non-ads: %{customdata[1]}<br>"
"Articles: %{customdata[2]}"
"<extra></extra>"
),
)
)
year_min = subset["year"].min()
year_max = subset["year"].max()
if year_max - year_min < 10:
mid = (year_min + year_max) / 2
year_min = int(mid - 5)
year_max = int(mid + 5)
provider_display = provider if provider == "All" else provider_label(provider)
fig.update_layout(
title=f"Ad share by newspaper — provider: {provider_display}",
xaxis_title="Year",
xaxis=dict(range=[year_min - 1, year_max + 1]),
yaxis_title="Ad share",
yaxis=dict(range=[0, 1.05]),
template="plotly_white",
height=650,
)
return fig
# -------------------------------------------------------------------
# Initial state
# -------------------------------------------------------------------
initial_provider = "All"
initial_query = ""
initial_best = 10
initial_worst = 0
initial_random = 0
initial_ranked = get_ranked_df(initial_provider, initial_query)
initial_choices, initial_selected = choose_newspapers(
initial_ranked, initial_best, initial_worst, initial_random
)
# -------------------------------------------------------------------
# UI
# -------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## Ad classification exploration")
gr.Markdown(
"Explore yearly ad-share distributions by provider and newspaper. "
f"Source: `{SOURCE_LABEL}`"
)
with gr.Row():
provider = gr.Dropdown(
choices=provider_options,
value=initial_provider,
label="Provider",
)
query = gr.Textbox(
value=initial_query,
label="Filter newspapers (case-sensitive)",
placeholder="Type part of a newspaper title",
)
with gr.Row():
n_best = gr.Slider(
minimum=0,
maximum=400,
value=initial_best,
step=1,
label="Highest ad share",
)
n_worst = gr.Slider(
minimum=0,
maximum=400,
value=initial_worst,
step=1,
label="Lowest ad share",
)
n_random = gr.Slider(
minimum=0,
maximum=400,
value=initial_random,
step=1,
label="Random",
)
newspaper = gr.Dropdown(
choices=[(newspaper_label(n), n) for n in initial_choices],
value=initial_selected,
multiselect=True,
label="Newspapers (filtered and ranked)",
)
plot = gr.Plot()
selector_inputs = [provider, query, n_best, n_worst, n_random]
for trigger in selector_inputs:
trigger.change(
fn=update_newspapers,
inputs=selector_inputs,
outputs=newspaper,
)
trigger.change(
fn=lambda provider, newspaper: make_plot(provider, newspaper),
inputs=[provider, newspaper],
outputs=plot,
)
newspaper.change(
fn=make_plot,
inputs=[provider, newspaper],
outputs=plot,
)
demo.load(
fn=make_plot,
inputs=[provider, newspaper],
outputs=plot,
)
demo.launch()