simplexuq-code / scripts /make_figures.py
anonymous0523ly's picture
Use data-driven heatmap color ranges
855fe54 verified
raw
history blame
27.1 kB
"""Generate SimplexUQ benchmark figures from saved results."""
import argparse
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
from matplotlib.patches import Rectangle
import numpy as np
plt.rcParams.update({
"font.size": 9,
"font.family": "sans-serif",
"font.sans-serif": ["DejaVu Sans", "Arial"],
"axes.labelsize": 10,
"axes.titlesize": 10,
"legend.fontsize": 9,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"figure.dpi": 150,
"savefig.dpi": 300,
"savefig.bbox": "tight",
"axes.spines.top": False,
"axes.spines.right": False,
"axes.grid": False,
})
METHOD_LABELS = {
"global": "Global",
"partition": "Mondrian",
"twostage": "TwoStage",
"fullcp": "FullCP",
"jackknife_plus": "Jackknife+",
"weighted": "Weighted",
"oracle": "Oracle",
"oneshot": "OneShot",
"trainres": "TrainRes",
}
METHOD_COLORS = {
"global": "#D55E00",
"partition": "#0072B2",
"twostage": "#009E73",
"fullcp": "#56B4E9",
"jackknife_plus": "#CC79A7",
"weighted": "#F0E442",
"oracle": "#000000",
"oneshot": "#7F7F7F",
"trainres": "#E69F00",
}
METHOD_ORDER = [
"global",
"partition",
"twostage",
"fullcp",
"jackknife_plus",
"oneshot",
"trainres",
"weighted",
"oracle",
]
REPO_ROOT = Path(__file__).resolve().parents[1]
PAPER_FIG_DIR = REPO_ROOT / "paper" / "rewrite_2026" / "latex" / "figures"
DGP_SPECS = [
("d1_homogeneous", "D1\nHom."),
("d2_pure_scale", "D2\nScale"),
("d3_discrete_groups_aligned", "D3\nDiscrete"),
("d4_model_bias", "D4\nBias"),
("d5_heavy_tail", "D5\nTail"),
("d6_high_k", "D6$^{\\dagger}$\nHigh-K"),
]
SYNTH_EXTRA_FILES = {
"d1_homogeneous": ["d1_homogeneous_exact.json"],
"d3_discrete_groups_aligned": ["d3_discrete_groups_aux.json"],
"d5_heavy_tail": ["d5_heavy_tail_aux.json"],
"d6_high_k": ["d6_high_k_aux.json", "d6_high_k_exact_appendix.json"],
}
REAL_SPECS = [
("exp2_2_softmax_cifar10_strata_entropy_fixed.json", "CIFAR-10"),
("exp2_3_hyperspectral_samson_nmf_all_methods.json", "Samson"),
("exp2_5_topics_K10_all_methods.json", "Topics"),
("exp2_6_affective_text.json", "AffectiveText"),
("exp2_4_age_ldl_K10_image_knn_main.json", "UTKFace"),
("real_bulk_deconv.json", "PBMC"),
]
REAL_EXTRA_FILES = {
"PBMC": [
"real_bulk_deconv_fullcp.json",
"real_bulk_deconv_aux.json",
"real_bulk_deconv_trainres.json",
],
"UTKFace": ["exp2_4_age_ldl_K10_image_knn_fullcp_2k.json"],
}
REAL_MARKERS = {
"global": "o",
"partition": "s",
"twostage": "^",
"fullcp": "D",
"jackknife_plus": "P",
"trainres": "X",
}
PROFILE_MARKERS = {
"global": "o",
"partition": "s",
"twostage": "^",
"fullcp": "D",
"jackknife_plus": "P",
"oracle": "X",
}
def load_json(path: Path) -> dict:
with open(path) as f:
return json.load(f)
def save_figure(fig: plt.Figure, output_dir: Path, filename: str) -> None:
"""Save a figure to the results directory and mirror it into the paper tree."""
output_dir.mkdir(parents=True, exist_ok=True)
PAPER_FIG_DIR.mkdir(parents=True, exist_ok=True)
out = output_dir / filename
mirror = PAPER_FIG_DIR / filename
fig.savefig(out)
fig.savefig(mirror)
print(f"Saved {out}")
print(f"Mirrored {mirror}")
def simplex_to_xy(U: np.ndarray) -> np.ndarray:
"""Map 3-simplex points to 2D barycentric coordinates."""
vertices = np.array(
[
[0.0, 0.0],
[1.0, 0.0],
[0.5, np.sqrt(3.0) / 2.0],
]
)
return U @ vertices
def extract_summary(data: dict) -> dict:
if "summary" in data:
return data["summary"]
if "aggregated" in data:
return data["aggregated"]
raise KeyError("Result file must contain 'summary' or 'aggregated'")
def metric_mean(summary: dict, method: str, metric: str) -> float:
return float(summary[method][metric]["mean"])
def metric_std(summary: dict, method: str, metric: str) -> float:
return float(summary[method][metric]["std"])
def available_methods(summary: dict) -> list[str]:
return [m for m in METHOD_ORDER if m in summary]
def highlight_best_cells(ax, matrix: np.ndarray, methods: list[str], exclude: set[str] | None = None):
exclude = exclude or set()
for col in range(matrix.shape[1]):
best_row = None
best_val = None
for row, method in enumerate(methods):
val = matrix[row, col]
if method in exclude or np.isnan(val):
continue
if best_val is None or val < best_val:
best_val = val
best_row = row
if best_row is None:
continue
ax.add_patch(
Rectangle(
(col - 0.5, best_row - 0.5),
1.0,
1.0,
fill=False,
edgecolor="black",
linewidth=1.5,
)
)
def load_suite(results_dir: Path) -> dict[str, dict]:
suite = {}
for stem, _ in DGP_SPECS:
path = results_dir / f"{stem}.json"
if path.exists():
data = load_json(path)
summary = extract_summary(data)
merged = {"summary": summary, "raw_data": data}
if stem in SYNTH_EXTRA_FILES:
for extra_name in SYNTH_EXTRA_FILES[stem]:
extra_path = results_dir / extra_name
if not extra_path.exists():
continue
extra_data = load_json(extra_path)
extra_summary = extract_summary(extra_data)
merged["summary"] = {**merged["summary"], **extra_summary}
merged.setdefault("extra_raw_data", {})[extra_name] = extra_data
suite[stem] = merged
return suite
def load_real_suite(results_dir: Path) -> dict[str, dict]:
suite = {}
for filename, task in REAL_SPECS:
path = results_dir / filename
if not path.exists():
continue
data = load_json(path)
summary = extract_summary(data)
merged = {"summary": summary, "raw_data": data}
if task in REAL_EXTRA_FILES:
for extra_name in REAL_EXTRA_FILES[task]:
extra_path = results_dir / extra_name
if not extra_path.exists():
continue
extra_data = load_json(extra_path)
extra_summary = extract_summary(extra_data)
merged["summary"] = {**merged["summary"], **extra_summary}
merged.setdefault("extra_raw_data", {})[extra_name] = extra_data
suite[task] = merged
return suite
def fig1_allocation_geometry(suite: dict[str, dict], output_dir: Path):
"""Illustrate simplex allocation failure on the smooth-scale synthetic regime."""
stem = "d2_pure_scale"
task_path = REPO_ROOT / "release" / "simplextasks-12" / "synthetic" / stem / "task.npz"
if stem not in suite or not task_path.exists():
print("Skipping Fig 1 allocation geometry: D2 task or summary missing")
return
task = np.load(task_path)
U = task["U"]
sigma_true = task["sigma_true"]
xy = simplex_to_xy(U)
summary = suite[stem]["summary"]
strata_keys = sorted(summary["global"]["stratified_coverage"].keys(), key=int)
x = np.arange(len(strata_keys))
rng = np.random.default_rng(2026)
sample_idx = rng.choice(len(U), size=min(2500, len(U)), replace=False)
xy_sample = xy[sample_idx]
sigma_sample = sigma_true[sample_idx]
fig = plt.figure(figsize=(8.0, 2.95), constrained_layout=True)
gs = fig.add_gridspec(1, 3, width_ratios=[1.12, 0.92, 1.18])
ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1])
ax2 = fig.add_subplot(gs[0, 2])
sc = ax0.scatter(
xy_sample[:, 0],
xy_sample[:, 1],
c=sigma_sample,
cmap="viridis",
s=8,
alpha=0.8,
linewidths=0.0,
)
triangle = np.array(
[
[0.0, 0.0],
[1.0, 0.0],
[0.5, np.sqrt(3.0) / 2.0],
[0.0, 0.0],
]
)
ax0.plot(triangle[:, 0], triangle[:, 1], color="black", linewidth=1.0)
ax0.text(-0.03, -0.03, r"$u_1$", ha="right", va="top")
ax0.text(1.03, -0.03, r"$u_2$", ha="left", va="top")
ax0.text(0.5, np.sqrt(3.0) / 2.0 + 0.04, r"$u_3$", ha="center", va="bottom")
ax0.set_title("D2 local scale on the simplex")
ax0.set_aspect("equal")
ax0.set_xlim(-0.08, 1.08)
ax0.set_ylim(-0.08, np.sqrt(3.0) / 2.0 + 0.1)
ax0.axis("off")
cbar = fig.colorbar(sc, ax=ax0, fraction=0.046, pad=0.02)
cbar.set_label(r"True local scale $\sigma(u)$")
target = 0.9
global_cov = [summary["global"]["stratified_coverage"][k]["mean"] for k in strata_keys]
ax1.bar(x, global_cov, color=METHOD_COLORS["global"], alpha=0.88, width=0.72)
ax1.axhline(target, color="black", linestyle="--", linewidth=1)
ax1.set_title("Global CP allocates poorly")
ax1.set_xticks(x)
ax1.set_xticklabels([f"S{k}" for k in strata_keys])
ax1.set_xlabel(r"Boundary strata ($S0 \rightarrow S4$)")
ax1.set_ylabel("Coverage by stratum")
ax1.set_ylim(0.0, 1.02)
ax1.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7)
comparison_methods = ["global", "partition", "twostage"]
for method in comparison_methods:
vals = [summary[method]["stratified_coverage"][k]["mean"] for k in strata_keys]
ax2.plot(
x,
vals,
color=METHOD_COLORS[method],
linewidth=1.8,
marker=PROFILE_MARKERS.get(method, "o"),
markersize=4.5,
label=METHOD_LABELS[method],
)
ax2.axhline(target, color="black", linestyle="--", linewidth=1)
ax2.set_title("Repair depends on the regime")
ax2.set_xticks(x)
ax2.set_xticklabels([f"S{k}" for k in strata_keys])
ax2.set_xlabel(r"Boundary strata ($S0 \rightarrow S4$)")
ax2.set_ylim(0.0, 1.02)
ax2.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7)
ax2.legend(loc="lower right", frameon=False)
for ax, label in zip([ax0, ax1, ax2], ["A", "B", "C"]):
ax.text(-0.12, 1.03, label, transform=ax.transAxes, fontsize=11, fontweight="bold")
save_figure(fig, output_dir, "fig1_allocation_geometry.pdf")
plt.close(fig)
def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
"""Heatmap of max disparity across regimes and methods."""
methods = [method for method in METHOD_ORDER if method != "weighted"]
matrix = np.full((len(methods), len(DGP_SPECS)), np.nan)
for j, (stem, _) in enumerate(DGP_SPECS):
if stem not in suite:
continue
summary = extract_summary(suite[stem])
for i, method in enumerate(methods):
if method in summary:
matrix[i, j] = metric_mean(summary, method, "max_disparity")
fig, ax = plt.subplots(figsize=(7.0, 3.6), constrained_layout=True)
cmap = plt.cm.RdYlBu_r.copy()
cmap.set_bad("#efefef")
im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap)
ax.set_xticks(range(len(DGP_SPECS)))
ax.set_xticklabels([label for _, label in DGP_SPECS])
ax.set_yticks(range(len(methods)))
ax.set_yticklabels([METHOD_LABELS[m] for m in methods])
ax.set_xticks(np.arange(-0.5, len(DGP_SPECS), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(methods), 1), minor=True)
ax.grid(which="minor", color="white", linewidth=1.0)
ax.tick_params(which="minor", bottom=False, left=False)
for i in range(len(methods)):
for j in range(len(DGP_SPECS)):
val = matrix[i, j]
if np.isnan(val):
continue
txt_color = "white" if (val < 0.16 or val > 0.58) else "black"
stroke_color = "black" if txt_color == "white" else "white"
ax.text(
j,
i,
f"{val:.02f}",
ha="center",
va="center",
color=txt_color,
fontsize=7.5,
fontweight="bold",
path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
)
highlight_best_cells(ax, matrix, methods, exclude={"oracle"})
cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
cbar.set_label("Max disparity (low → high)")
cbar.set_ticks([])
save_figure(fig, output_dir, "fig1_synthetic_disparity_heatmap.pdf")
plt.close(fig)
def _plot_strata_panel(ax, data: dict, methods: list[str], title: str):
raw = data.get("raw_data", data)
config = raw.get("config", {})
evaluation = config.get("evaluation", {})
alpha = evaluation.get("alpha", config.get("alpha", 0.1))
target = 1.0 - float(alpha)
summary = extract_summary(raw)
reference_method = next(iter(summary))
strata_keys = sorted(summary[reference_method]["stratified_coverage"].keys(), key=int)
x = np.arange(len(strata_keys))
for method in methods:
if method not in summary:
continue
vals = [summary[method]["stratified_coverage"][k]["mean"] for k in strata_keys]
ax.plot(
x,
vals,
color=METHOD_COLORS[method],
linewidth=1.6,
marker=PROFILE_MARKERS.get(method, "o"),
markersize=4.2,
label=METHOD_LABELS[method],
)
ax.axhline(target, color="black", linestyle="--", linewidth=1)
ax.set_xticks(x)
ax.set_xticklabels([f"S{k}" for k in strata_keys])
ax.set_ylim(0.0, 1.02)
ax.set_title(title)
ax.set_ylabel("Coverage by stratum")
ax.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7)
def fig2_stratified_profiles(suite: dict[str, dict], output_dir: Path):
"""Representative stratified coverage plots for key regimes."""
panels = [
("d2_pure_scale", ["global", "twostage", "oracle"], "D2: Smooth Scale"),
("d3_discrete_groups_aligned", ["global", "partition", "oracle"], "D3: Aligned Discrete"),
("d6_high_k", ["global", "partition", "twostage", "oracle"], "D6: High-K"),
]
available_panels = [(stem, methods, title) for stem, methods, title in panels if stem in suite]
if not available_panels:
print("Skipping Fig 2: no matching synthetic results found")
return
fig, axes = plt.subplots(1, len(available_panels), figsize=(3.4 * len(available_panels), 2.9), sharey=True, constrained_layout=True)
if len(available_panels) == 1:
axes = [axes]
for ax, (stem, methods, title) in zip(axes, available_panels):
_plot_strata_panel(ax, suite[stem], methods, title)
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, ncol=min(4, len(labels)), loc="upper center", bbox_to_anchor=(0.5, 1.12), frameon=False)
save_figure(fig, output_dir, "fig2_stratified_profiles.pdf")
plt.close(fig)
def fig3_regime_scatter(suite: dict[str, dict], output_dir: Path):
"""Scatter of worst-stratum coverage vs disparity for selected methods."""
selected_methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oracle"]
fig, ax = plt.subplots(figsize=(7.2, 4.4))
for stem, label in DGP_SPECS:
if stem not in suite:
continue
summary = extract_summary(suite[stem])
for method in selected_methods:
if method not in summary:
continue
x = metric_mean(summary, method, "max_disparity")
y = metric_mean(summary, method, "worst_stratum_coverage")
ax.scatter(
x,
y,
s=65,
color=METHOD_COLORS[method],
alpha=0.9,
edgecolor="white",
linewidth=0.8,
)
ax.text(x + 0.01, y, label.replace("\n", " "), fontsize=7, alpha=0.9)
ax.axhline(0.9, color="black", linestyle="--", linewidth=1)
ax.set_xlabel("Max disparity")
ax.set_ylabel("Worst-stratum coverage")
ax.set_title("Synthetic Regimes: Fairness-Safety Tradeoff")
legend_handles = [
plt.Line2D([0], [0], marker="o", color="w", label=METHOD_LABELS[m],
markerfacecolor=METHOD_COLORS[m], markeredgecolor="white", markersize=8)
for m in selected_methods
if any(stem in suite and m in suite[stem]["summary"] for stem, _ in DGP_SPECS)
]
ax.legend(handles=legend_handles, ncol=3, loc="lower left")
save_figure(fig, output_dir, "fig3_regime_tradeoff.pdf")
plt.close(fig)
def fig4_runtime_tradeoff(suite: dict[str, dict], output_dir: Path):
"""Runtime versus disparity on the smooth-scale benchmark."""
stem = "d2_pure_scale"
if stem not in suite:
print("Skipping Fig 4: d2_pure_scale.json not found")
return
summary = extract_summary(suite[stem])
methods = available_methods(summary)
fig, ax = plt.subplots(figsize=(6.0, 3.6), constrained_layout=True)
for method in methods:
x = metric_mean(summary, method, "runtime_sec")
y = metric_mean(summary, method, "max_disparity")
ax.scatter(
x,
y,
s=55,
color=METHOD_COLORS[method],
edgecolor="white",
linewidth=0.8,
marker=REAL_MARKERS.get(method, "o"),
)
ax.text(x * 1.05 if x > 0 else x + 0.02, y, METHOD_LABELS[method], fontsize=7.3, va="center")
ax.set_xlabel("Mean runtime per repetition (sec)")
ax.set_ylabel("Max disparity")
ax.set_xscale("symlog", linthresh=0.01)
ax.grid(color="#d9d9d9", linewidth=0.7, alpha=0.7)
save_figure(fig, output_dir, "fig4_runtime_tradeoff.pdf")
plt.close(fig)
def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oneshot", "trainres"]
tasks = [task for _, task in REAL_SPECS if task in real_suite]
if not tasks:
print("Skipping Fig 5: no real-data results found")
return
matrix = np.full((len(methods), len(tasks)), np.nan)
for j, task in enumerate(tasks):
summary = real_suite[task]["summary"]
for i, method in enumerate(methods):
if method in summary and "max_disparity" in summary[method]:
matrix[i, j] = metric_mean(summary, method, "max_disparity")
task_labels = {
"CIFAR-10": "CIFAR-10",
"Samson": "Samson",
"Topics": "Topics",
"AffectiveText": "Affective\nText",
"UTKFace": "UTKFace",
"PBMC": "PBMC",
}
fig, ax = plt.subplots(figsize=(7.0, 4.0), constrained_layout=True)
cmap = plt.cm.RdYlBu_r.copy()
cmap.set_bad("#efefef")
im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap)
ax.set_xticks(range(len(tasks)))
ax.set_xticklabels([task_labels.get(task, task) for task in tasks])
ax.set_yticks(range(len(methods)))
ax.set_yticklabels([METHOD_LABELS[m] for m in methods])
ax.set_xticks(np.arange(-0.5, len(tasks), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(methods), 1), minor=True)
ax.grid(which="minor", color="white", linewidth=1.0)
ax.tick_params(which="minor", bottom=False, left=False)
for i in range(len(methods)):
for j in range(len(tasks)):
val = matrix[i, j]
if np.isnan(val):
continue
txt_color = "white" if (val < 0.16 or val > 0.58) else "black"
stroke_color = "black" if txt_color == "white" else "white"
ax.text(
j,
i,
f"{val:.02f}",
ha="center",
va="center",
color=txt_color,
fontsize=7.5,
fontweight="bold",
path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
)
highlight_best_cells(ax, matrix, methods)
cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
cbar.set_label("Max disparity (low → high)")
cbar.set_ticks([])
save_figure(fig, output_dir, "fig5_real_disparity_heatmap.pdf")
plt.close(fig)
def fig6_real_tradeoff(real_suite: dict[str, dict], output_dir: Path):
selected_methods = ["global", "partition", "twostage", "jackknife_plus", "fullcp", "trainres"]
tasks = [task for _, task in REAL_SPECS if task in real_suite]
if not tasks:
print("Skipping Fig 6: no real-data results found")
return
fig, axes = plt.subplots(2, 3, figsize=(7.1, 4.8), sharey=True)
axes = axes.flatten()
used_methods = set()
for idx, ax in enumerate(axes):
if idx >= len(tasks):
ax.axis("off")
continue
task = tasks[idx]
summary = real_suite[task]["summary"]
xs = []
for method in selected_methods:
if method not in summary or "mean_radius" not in summary[method]:
continue
x = metric_mean(summary, method, "mean_radius")
y = metric_mean(summary, method, "max_disparity")
xerr = metric_std(summary, method, "mean_radius")
yerr = metric_std(summary, method, "max_disparity")
xs.append(x)
ax.errorbar(
x,
y,
xerr=xerr,
yerr=yerr,
fmt="none",
ecolor=METHOD_COLORS[method],
elinewidth=0.9,
capsize=2.0,
alpha=0.28,
zorder=1,
)
ax.scatter(
x,
y,
s=42,
marker=REAL_MARKERS.get(method, "o"),
color=METHOD_COLORS[method],
edgecolor="white",
linewidth=0.7,
alpha=0.96,
zorder=2,
)
used_methods.add(method)
if xs:
xmin = min(xs)
xmax = max(xs)
span = xmax - xmin
pad = 0.08 * span if span > 0 else max(0.05, 0.15 * xmax)
ax.set_xlim(max(0.0, xmin - pad), xmax + pad)
ax.set_ylim(0.0, 0.95)
ax.set_title(task)
ax.grid(color="#d9d9d9", linewidth=0.7, alpha=0.7)
if idx % 3 == 0:
ax.set_ylabel("Max disparity")
if idx >= 3:
ax.set_xlabel("Mean radius")
handles = [
plt.Line2D(
[0], [0],
marker=REAL_MARKERS.get(method, "o"),
color=METHOD_COLORS[method],
linestyle="None",
label=METHOD_LABELS[method],
markerfacecolor=METHOD_COLORS[method],
markeredgecolor="white",
markersize=6.5,
)
for method in selected_methods
if method in used_methods
]
fig.subplots_adjust(top=0.84, bottom=0.10, hspace=0.28, wspace=0.16)
fig.legend(handles=handles, ncol=3, loc="upper center", bbox_to_anchor=(0.5, 0.99), frameon=False)
save_figure(fig, output_dir, "fig6_real_tradeoff.pdf")
plt.close(fig)
def fig7_real_profiles(real_suite: dict[str, dict], output_dir: Path):
panels = [
("CIFAR-10", ["global", "partition", "jackknife_plus"], "CIFAR-10"),
("Topics", ["global", "twostage", "jackknife_plus"], "Topics"),
("AffectiveText", ["global", "partition", "fullcp"], "AffectiveText"),
("UTKFace", ["global", "partition", "jackknife_plus"], "UTKFace"),
("PBMC", ["global", "partition", "twostage"], "PBMC"),
]
available = [(task, methods, title) for task, methods, title in panels if task in real_suite]
if not available:
print("Skipping Fig 7: no matching real-data results found")
return
fig, axes = plt.subplots(2, 3, figsize=(7.2, 4.8), sharey=True)
axes = axes.flatten()
for idx, (task, methods, title) in enumerate(available):
ax = axes[idx]
summary = real_suite[task]["summary"]
alpha = 0.1
target = 1.0 - alpha
reference_method = next(iter(summary))
strata_keys = sorted(summary[reference_method]["stratified_coverage"].keys(), key=int)
x = np.arange(len(strata_keys))
for method in methods:
if method not in summary:
continue
vals = [summary[method]["stratified_coverage"][k]["mean"] for k in strata_keys]
ax.plot(
x,
vals,
color=METHOD_COLORS[method],
linewidth=1.5,
marker=PROFILE_MARKERS.get(method, "o"),
markersize=4.0,
label=METHOD_LABELS[method],
)
ax.axhline(target, color="black", linestyle="--", linewidth=1)
ax.set_xticks(x)
ax.set_xticklabels([f"S{k}" for k in strata_keys])
ax.set_ylim(0.0, 1.02)
ax.set_title(title)
ax.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7)
if idx % 3 == 0:
ax.set_ylabel("Coverage by stratum")
for idx in range(len(available), len(axes)):
axes[idx].axis("off")
handles, labels = axes[0].get_legend_handles_labels()
fig.subplots_adjust(top=0.84, bottom=0.10, hspace=0.25, wspace=0.15)
fig.legend(handles, labels, ncol=min(4, len(labels)), loc="upper center", bbox_to_anchor=(0.5, 0.99), frameon=False)
save_figure(fig, output_dir, "fig7_real_profiles.pdf")
plt.close(fig)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--results-dir", default="results/tables")
parser.add_argument("--output-dir", default="results/figures")
parser.add_argument("--fig", default="all", help="Which figure: lead,1,2,3,4,5,6,7,synthetic,real,all")
args = parser.parse_args()
results_dir = Path(args.results_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
suite = load_suite(results_dir)
real_suite = load_real_suite(results_dir)
if args.fig in ("lead", "all", "synthetic"):
fig1_allocation_geometry(suite, output_dir)
if args.fig in ("1", "all", "synthetic"):
fig1_disparity_heatmap(suite, output_dir)
if args.fig in ("2", "all", "synthetic"):
fig2_stratified_profiles(suite, output_dir)
if args.fig in ("3", "all", "synthetic"):
fig3_regime_scatter(suite, output_dir)
if args.fig in ("4", "all", "synthetic"):
fig4_runtime_tradeoff(suite, output_dir)
if args.fig in ("5", "all", "real"):
fig5_real_disparity_heatmap(real_suite, output_dir)
if args.fig in ("6", "all", "real"):
fig6_real_tradeoff(real_suite, output_dir)
if args.fig in ("7", "all", "real"):
fig7_real_profiles(real_suite, output_dir)
print("Done.")
if __name__ == "__main__":
main()