import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ============================================================
# Configuration / constants
# ============================================================
DEFAULT_LINE_COLORS: dict[int, str] = {
1: "#1f77b4",
2: "#d62728",
3: "#2ca02c",
4: "#9467bd",
5: "#ff7f0e",
}
HOVER_TEMPLATE_INDEX = (
"Age: %{customdata[0]}
"
"Exposure Level: Level %{customdata[1]} (1 = least exposed, 5 = most exposed)
"
"Year: %{x}
"
"Index: %{y:.1f}"
)
HOVER_TEMPLATE_RAW = (
"Age: %{customdata[0]}
"
"Exposure Level: Level %{customdata[1]} (1 = least exposed, 5 = most exposed)
"
"Year: %{x}
"
"Number of Employed Persons: %{y:,}"
)
# ============================================================
# Helper functions
# ============================================================
def _build_palette(line_colors: dict[int, str] | None) -> dict[int, str]:
"""
Merge user-supplied colors with defaults (user overrides default).
"""
return {**DEFAULT_LINE_COLORS, **(line_colors or {})}
def _resolve_color(
exposure_level: int | str,
palette: dict[int, str],
) -> str | None:
"""
Get color for an exposure level, trying both raw and int-casted keys.
"""
# Try using the level directly (if it's already an int key).
color = palette.get(exposure_level) # type: ignore[arg-type]
if color is not None:
return color
# Fall back to int(exposure_level) if possible.
try:
level_int = int(exposure_level)
return palette.get(level_int)
except (TypeError, ValueError):
return None
# ============================================================
# Main plotting function
# ============================================================
def create_exposure_plot(
df: pd.DataFrame,
metric: str,
metric_label: str,
weighting_label: str,
*,
value_col: str = "employment",
y_axis_label: str = "Employed Persons",
is_index: bool = False,
base_year: int | None = None,
line_colors: dict[int, str] | None = None,
) -> go.Figure:
"""
Generate a multi-row subplot figure for AI exposure by age group.
Parameters
----------
df : pd.DataFrame
Input data with columns 'age', 'year', value_col and
'daioe_{metric}_exposure_level'.
metric : str
Metric name used in the exposure column suffix.
metric_label : str
Human-readable label for the metric (for titles).
weighting_label : str
Label describing the weighting approach (for titles).
value_col : str, default "employment"
Column used for the Y-axis (e.g., counts or indices).
y_axis_label : str, default "Employed Persons"
Y-axis title.
is_index : bool, default False
If True, use index-style hover text; otherwise value-style.
base_year : int | None, default None
Optional vertical reference line and annotation for a base year.
line_colors : dict[int, str] | None, default None
Optional mapping of exposure level -> hex color. Overrides defaults.
Returns
-------
go.Figure
A Plotly Figure with one subplot per age group.
"""
exposure_col = f"daioe_{metric}_exposure_level"
# ------------------------------------------------------------------
# 1. Clean and prepare data
# ------------------------------------------------------------------
df_clean = df.dropna(subset=["age", exposure_col, value_col]).copy()
age_groups = sorted(df_clean["age"].unique())
if not age_groups:
# No valid data to plot
return go.Figure()
hover_template = HOVER_TEMPLATE_INDEX if is_index else HOVER_TEMPLATE_RAW
palette = _build_palette(line_colors)
# ------------------------------------------------------------------
# 2. Create multi-row subplot scaffolding
# ------------------------------------------------------------------
subplot_titles = [
(
f"Employed Persons Aged {age} Years by AI Exposure Level
"
f""
f"{metric_label} - {weighting_label}"
f""
)
for age in age_groups
]
fig = make_subplots(
rows=len(age_groups),
cols=1,
shared_xaxes=False,
subplot_titles=subplot_titles,
vertical_spacing=0.03,
)
# ------------------------------------------------------------------
# 3. Add traces per age group and exposure level
# ------------------------------------------------------------------
for i, age in enumerate(age_groups, start=1):
df_age = df_clean[df_clean["age"] == age]
# Aggregate by year and exposure level
df_plot = df_age.groupby(["year", exposure_col], as_index=False)[
value_col
].sum()
for exposure_level, sub in df_plot.groupby(exposure_col):
color = _resolve_color(exposure_level, palette)
fig.add_trace(
go.Scatter(
x=sub["year"],
y=sub[value_col],
mode="lines+markers",
line=dict(width=3, color=color),
marker=dict(size=9, color=color),
name=f"Level {exposure_level}",
showlegend=(i == 1), # legend only in first row
hovertemplate=hover_template,
customdata=list(
zip(
[age] * len(sub),
[exposure_level] * len(sub),
)
),
),
row=i,
col=1,
)
# Axes for this row
fig.update_xaxes(
title_text="Year",
tickmode="linear",
dtick=1,
row=i,
col=1,
)
fig.update_yaxes(
title_text=y_axis_label,
tickformat=",",
rangemode="tozero",
row=i,
col=1,
automargin=True,
)
# ------------------------------------------------------------------
# 4. Global layout tweaks
# ------------------------------------------------------------------
# Reserve left margin for an outside-left legend so subplot widths stay consistent.
BASE_PLOT_WIDTH = 1000
LEFT_LEGEND_MARGIN = 260
fig.update_annotations(yshift=36)
fig.update_layout(
height=700 * len(age_groups),
width=BASE_PLOT_WIDTH + LEFT_LEGEND_MARGIN, # preserve plot width
legend=dict(
title=dict(
text=(
" Exposure level
"
" (1 = least exposed, 5 = most exposed) "
),
side="top",
font=dict(size=13),
),
orientation="v",
x=-0.1, # left edge of plotting area
xanchor="right", # legend sits just outside-left
y=0.98,
yanchor="top",
itemsizing="constant",
itemwidth=35, # keeps items compact
tracegroupgap=6,
bordercolor="rgba(0,0,0,0.15)",
borderwidth=1,
bgcolor="rgba(255,255,255,0.85)",
font=dict(size=12),
indentation=10,
yref="paper",
),
margin=dict(
t=170,
l=LEFT_LEGEND_MARGIN,
r=60,
b=60,
),
xaxis_showgrid=True,
yaxis_showgrid=True,
template="plotly_white",
)
# ------------------------------------------------------------------
# 5. Optional base-year line and annotation
# ------------------------------------------------------------------
if base_year is not None:
fig.add_vline(
x=base_year,
line_width=2,
line_dash="dash",
line_color="black",
opacity=0.8,
row="all",
col=1,
)
annotation_text = (
"Base year 2022 — ChatGPT launch and generative AI takeoff"
if base_year == 2022
else f"Base year {base_year} (normalization anchor)"
)
n_rows = len(age_groups)
for i in range(1, n_rows + 1):
# Plotly validation rules:
# - First subplot (i=1) must use 'x' and 'y domain' (with space).
# - Subsequent subplots (i>1) must use 'x{i}' and 'y{i} domain'.
if i == 1:
xref_val = "x"
yref_val = "y domain"
else:
xref_val = f"x{i}"
yref_val = f"y{i} domain"
fig.add_annotation(
x=base_year,
xref=xref_val,
y=0.955, # Position just above the plot area (1.0)
yref=yref_val,
text=annotation_text,
showarrow=False,
font=dict(color="black", size=11),
bgcolor="rgba(255,255,255,0.7)",
yshift=10, # Slight upward shift to clear titles/ticks
)
return fig