| """Generate SimplexUQ benchmark figures from saved results.""" |
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.patheffects as pe |
| from matplotlib.patches import Rectangle |
| import numpy as np |
|
|
| plt.rcParams.update({ |
| "font.size": 9, |
| "font.family": "sans-serif", |
| "font.sans-serif": ["DejaVu Sans", "Arial"], |
| "axes.labelsize": 10, |
| "axes.titlesize": 10, |
| "legend.fontsize": 9, |
| "xtick.labelsize": 8, |
| "ytick.labelsize": 8, |
| "figure.dpi": 150, |
| "savefig.dpi": 300, |
| "savefig.bbox": "tight", |
| "axes.spines.top": False, |
| "axes.spines.right": False, |
| "axes.grid": False, |
| }) |
|
|
| METHOD_LABELS = { |
| "global": "Global", |
| "partition": "Mondrian", |
| "twostage": "TwoStage", |
| "fullcp": "FullCP", |
| "jackknife_plus": "Jackknife+", |
| "weighted": "Weighted", |
| "oracle": "Oracle", |
| "oneshot": "OneShot", |
| "trainres": "TrainRes", |
| } |
|
|
| METHOD_COLORS = { |
| "global": "#D55E00", |
| "partition": "#0072B2", |
| "twostage": "#009E73", |
| "fullcp": "#56B4E9", |
| "jackknife_plus": "#CC79A7", |
| "weighted": "#F0E442", |
| "oracle": "#000000", |
| "oneshot": "#7F7F7F", |
| "trainres": "#E69F00", |
| } |
|
|
| METHOD_ORDER = [ |
| "global", |
| "partition", |
| "twostage", |
| "fullcp", |
| "jackknife_plus", |
| "oneshot", |
| "trainres", |
| "weighted", |
| "oracle", |
| ] |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| PAPER_FIG_DIR = REPO_ROOT / "paper" / "rewrite_2026" / "latex" / "figures" |
|
|
| DGP_SPECS = [ |
| ("d1_homogeneous", "D1\nHom."), |
| ("d2_pure_scale", "D2\nScale"), |
| ("d3_discrete_groups_aligned", "D3\nDiscrete"), |
| ("d4_model_bias", "D4\nBias"), |
| ("d5_heavy_tail", "D5\nTail"), |
| ("d6_high_k", "D6$^{\\dagger}$\nHigh-K"), |
| ] |
|
|
| SYNTH_EXTRA_FILES = { |
| "d1_homogeneous": ["d1_homogeneous_exact.json"], |
| "d3_discrete_groups_aligned": ["d3_discrete_groups_aux.json"], |
| "d5_heavy_tail": ["d5_heavy_tail_aux.json"], |
| "d6_high_k": ["d6_high_k_aux.json", "d6_high_k_exact_appendix.json"], |
| } |
|
|
| REAL_SPECS = [ |
| ("exp2_2_softmax_cifar10_strata_entropy_fixed.json", "CIFAR-10"), |
| ("exp2_3_hyperspectral_samson_nmf_all_methods.json", "Samson"), |
| ("exp2_5_topics_K10_all_methods.json", "Topics"), |
| ("exp2_6_affective_text.json", "AffectiveText"), |
| ("exp2_4_age_ldl_K10_image_knn_main.json", "UTKFace"), |
| ("real_bulk_deconv.json", "PBMC"), |
| ] |
|
|
| REAL_EXTRA_FILES = { |
| "PBMC": [ |
| "real_bulk_deconv_fullcp.json", |
| "real_bulk_deconv_aux.json", |
| "real_bulk_deconv_trainres.json", |
| ], |
| "UTKFace": ["exp2_4_age_ldl_K10_image_knn_fullcp_2k.json"], |
| } |
|
|
| REAL_MARKERS = { |
| "global": "o", |
| "partition": "s", |
| "twostage": "^", |
| "fullcp": "D", |
| "jackknife_plus": "P", |
| "trainres": "X", |
| } |
|
|
| PROFILE_MARKERS = { |
| "global": "o", |
| "partition": "s", |
| "twostage": "^", |
| "fullcp": "D", |
| "jackknife_plus": "P", |
| "oracle": "X", |
| } |
|
|
|
|
| def load_json(path: Path) -> dict: |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def save_figure(fig: plt.Figure, output_dir: Path, filename: str) -> None: |
| """Save a figure to the results directory and mirror it into the paper tree.""" |
| output_dir.mkdir(parents=True, exist_ok=True) |
| PAPER_FIG_DIR.mkdir(parents=True, exist_ok=True) |
| out = output_dir / filename |
| mirror = PAPER_FIG_DIR / filename |
| fig.savefig(out) |
| fig.savefig(mirror) |
| print(f"Saved {out}") |
| print(f"Mirrored {mirror}") |
|
|
|
|
| def simplex_to_xy(U: np.ndarray) -> np.ndarray: |
| """Map 3-simplex points to 2D barycentric coordinates.""" |
| vertices = np.array( |
| [ |
| [0.0, 0.0], |
| [1.0, 0.0], |
| [0.5, np.sqrt(3.0) / 2.0], |
| ] |
| ) |
| return U @ vertices |
|
|
|
|
| def extract_summary(data: dict) -> dict: |
| if "summary" in data: |
| return data["summary"] |
| if "aggregated" in data: |
| return data["aggregated"] |
| raise KeyError("Result file must contain 'summary' or 'aggregated'") |
|
|
|
|
| def metric_mean(summary: dict, method: str, metric: str) -> float: |
| return float(summary[method][metric]["mean"]) |
|
|
|
|
| def metric_std(summary: dict, method: str, metric: str) -> float: |
| return float(summary[method][metric]["std"]) |
|
|
|
|
| def available_methods(summary: dict) -> list[str]: |
| return [m for m in METHOD_ORDER if m in summary] |
|
|
|
|
| def highlight_best_cells(ax, matrix: np.ndarray, methods: list[str], exclude: set[str] | None = None): |
| exclude = exclude or set() |
| for col in range(matrix.shape[1]): |
| best_row = None |
| best_val = None |
| for row, method in enumerate(methods): |
| val = matrix[row, col] |
| if method in exclude or np.isnan(val): |
| continue |
| if best_val is None or val < best_val: |
| best_val = val |
| best_row = row |
| if best_row is None: |
| continue |
| ax.add_patch( |
| Rectangle( |
| (col - 0.5, best_row - 0.5), |
| 1.0, |
| 1.0, |
| fill=False, |
| edgecolor="black", |
| linewidth=1.5, |
| ) |
| ) |
|
|
|
|
| def load_suite(results_dir: Path) -> dict[str, dict]: |
| suite = {} |
| for stem, _ in DGP_SPECS: |
| path = results_dir / f"{stem}.json" |
| if path.exists(): |
| data = load_json(path) |
| summary = extract_summary(data) |
| merged = {"summary": summary, "raw_data": data} |
|
|
| if stem in SYNTH_EXTRA_FILES: |
| for extra_name in SYNTH_EXTRA_FILES[stem]: |
| extra_path = results_dir / extra_name |
| if not extra_path.exists(): |
| continue |
| extra_data = load_json(extra_path) |
| extra_summary = extract_summary(extra_data) |
| merged["summary"] = {**merged["summary"], **extra_summary} |
| merged.setdefault("extra_raw_data", {})[extra_name] = extra_data |
|
|
| suite[stem] = merged |
| return suite |
|
|
|
|
| def load_real_suite(results_dir: Path) -> dict[str, dict]: |
| suite = {} |
| for filename, task in REAL_SPECS: |
| path = results_dir / filename |
| if not path.exists(): |
| continue |
| data = load_json(path) |
| summary = extract_summary(data) |
| merged = {"summary": summary, "raw_data": data} |
|
|
| if task in REAL_EXTRA_FILES: |
| for extra_name in REAL_EXTRA_FILES[task]: |
| extra_path = results_dir / extra_name |
| if not extra_path.exists(): |
| continue |
| extra_data = load_json(extra_path) |
| extra_summary = extract_summary(extra_data) |
| merged["summary"] = {**merged["summary"], **extra_summary} |
| merged.setdefault("extra_raw_data", {})[extra_name] = extra_data |
|
|
| suite[task] = merged |
| return suite |
|
|
|
|
| def fig1_allocation_geometry(suite: dict[str, dict], output_dir: Path): |
| """Illustrate simplex allocation failure on the smooth-scale synthetic regime.""" |
| stem = "d2_pure_scale" |
| task_path = REPO_ROOT / "release" / "simplextasks-12" / "synthetic" / stem / "task.npz" |
| if stem not in suite or not task_path.exists(): |
| print("Skipping Fig 1 allocation geometry: D2 task or summary missing") |
| return |
|
|
| task = np.load(task_path) |
| U = task["U"] |
| sigma_true = task["sigma_true"] |
| xy = simplex_to_xy(U) |
| summary = suite[stem]["summary"] |
| strata_keys = sorted(summary["global"]["stratified_coverage"].keys(), key=int) |
| x = np.arange(len(strata_keys)) |
|
|
| rng = np.random.default_rng(2026) |
| sample_idx = rng.choice(len(U), size=min(2500, len(U)), replace=False) |
| xy_sample = xy[sample_idx] |
| sigma_sample = sigma_true[sample_idx] |
|
|
| fig = plt.figure(figsize=(8.0, 2.95), constrained_layout=True) |
| gs = fig.add_gridspec(1, 3, width_ratios=[1.12, 0.92, 1.18]) |
| ax0 = fig.add_subplot(gs[0, 0]) |
| ax1 = fig.add_subplot(gs[0, 1]) |
| ax2 = fig.add_subplot(gs[0, 2]) |
|
|
| sc = ax0.scatter( |
| xy_sample[:, 0], |
| xy_sample[:, 1], |
| c=sigma_sample, |
| cmap="viridis", |
| s=8, |
| alpha=0.8, |
| linewidths=0.0, |
| ) |
| triangle = np.array( |
| [ |
| [0.0, 0.0], |
| [1.0, 0.0], |
| [0.5, np.sqrt(3.0) / 2.0], |
| [0.0, 0.0], |
| ] |
| ) |
| ax0.plot(triangle[:, 0], triangle[:, 1], color="black", linewidth=1.0) |
| ax0.text(-0.03, -0.03, r"$u_1$", ha="right", va="top") |
| ax0.text(1.03, -0.03, r"$u_2$", ha="left", va="top") |
| ax0.text(0.5, np.sqrt(3.0) / 2.0 + 0.04, r"$u_3$", ha="center", va="bottom") |
| ax0.set_title("D2 local scale on the simplex") |
| ax0.set_aspect("equal") |
| ax0.set_xlim(-0.08, 1.08) |
| ax0.set_ylim(-0.08, np.sqrt(3.0) / 2.0 + 0.1) |
| ax0.axis("off") |
| cbar = fig.colorbar(sc, ax=ax0, fraction=0.046, pad=0.02) |
| cbar.set_label(r"True local scale $\sigma(u)$") |
|
|
| target = 0.9 |
| global_cov = [summary["global"]["stratified_coverage"][k]["mean"] for k in strata_keys] |
| ax1.bar(x, global_cov, color=METHOD_COLORS["global"], alpha=0.88, width=0.72) |
| ax1.axhline(target, color="black", linestyle="--", linewidth=1) |
| ax1.set_title("Global CP allocates poorly") |
| ax1.set_xticks(x) |
| ax1.set_xticklabels([f"S{k}" for k in strata_keys]) |
| ax1.set_xlabel(r"Boundary strata ($S0 \rightarrow S4$)") |
| ax1.set_ylabel("Coverage by stratum") |
| ax1.set_ylim(0.0, 1.02) |
| ax1.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7) |
|
|
| comparison_methods = ["global", "partition", "twostage"] |
| for method in comparison_methods: |
| vals = [summary[method]["stratified_coverage"][k]["mean"] for k in strata_keys] |
| ax2.plot( |
| x, |
| vals, |
| color=METHOD_COLORS[method], |
| linewidth=1.8, |
| marker=PROFILE_MARKERS.get(method, "o"), |
| markersize=4.5, |
| label=METHOD_LABELS[method], |
| ) |
| ax2.axhline(target, color="black", linestyle="--", linewidth=1) |
| ax2.set_title("Repair depends on the regime") |
| ax2.set_xticks(x) |
| ax2.set_xticklabels([f"S{k}" for k in strata_keys]) |
| ax2.set_xlabel(r"Boundary strata ($S0 \rightarrow S4$)") |
| ax2.set_ylim(0.0, 1.02) |
| ax2.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7) |
| ax2.legend(loc="lower right", frameon=False) |
|
|
| for ax, label in zip([ax0, ax1, ax2], ["A", "B", "C"]): |
| ax.text(-0.12, 1.03, label, transform=ax.transAxes, fontsize=11, fontweight="bold") |
|
|
| save_figure(fig, output_dir, "fig1_allocation_geometry.pdf") |
| plt.close(fig) |
|
|
|
|
| def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path): |
| """Heatmap of max disparity across regimes and methods.""" |
| methods = METHOD_ORDER |
| matrix = np.full((len(methods), len(DGP_SPECS)), np.nan) |
|
|
| for j, (stem, _) in enumerate(DGP_SPECS): |
| if stem not in suite: |
| continue |
| summary = extract_summary(suite[stem]) |
| for i, method in enumerate(methods): |
| if method in summary: |
| matrix[i, j] = metric_mean(summary, method, "max_disparity") |
|
|
| fig, ax = plt.subplots(figsize=(7.0, 3.6), constrained_layout=True) |
| cmap = plt.cm.RdYlBu_r.copy() |
| cmap.set_bad("#efefef") |
| im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap, vmin=0.0, vmax=0.85) |
|
|
| ax.set_xticks(range(len(DGP_SPECS))) |
| ax.set_xticklabels([label for _, label in DGP_SPECS]) |
| ax.set_yticks(range(len(methods))) |
| ax.set_yticklabels([METHOD_LABELS[m] for m in methods]) |
| ax.set_xticks(np.arange(-0.5, len(DGP_SPECS), 1), minor=True) |
| ax.set_yticks(np.arange(-0.5, len(methods), 1), minor=True) |
| ax.grid(which="minor", color="white", linewidth=1.0) |
| ax.tick_params(which="minor", bottom=False, left=False) |
|
|
| for i in range(len(methods)): |
| for j in range(len(DGP_SPECS)): |
| val = matrix[i, j] |
| if np.isnan(val): |
| continue |
| txt_color = "white" if (val < 0.16 or val > 0.58) else "black" |
| stroke_color = "black" if txt_color == "white" else "white" |
| ax.text( |
| j, |
| i, |
| f"{val:.02f}", |
| ha="center", |
| va="center", |
| color=txt_color, |
| fontsize=7.5, |
| fontweight="bold", |
| path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)], |
| ) |
|
|
| highlight_best_cells(ax, matrix, methods, exclude={"oracle", "weighted"}) |
|
|
| cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02) |
| cbar.set_label("Max disparity") |
|
|
| save_figure(fig, output_dir, "fig1_synthetic_disparity_heatmap.pdf") |
| plt.close(fig) |
|
|
|
|
| def _plot_strata_panel(ax, data: dict, methods: list[str], title: str): |
| raw = data.get("raw_data", data) |
| config = raw.get("config", {}) |
| evaluation = config.get("evaluation", {}) |
| alpha = evaluation.get("alpha", config.get("alpha", 0.1)) |
| target = 1.0 - float(alpha) |
| summary = extract_summary(raw) |
|
|
| reference_method = next(iter(summary)) |
| strata_keys = sorted(summary[reference_method]["stratified_coverage"].keys(), key=int) |
| x = np.arange(len(strata_keys)) |
| for method in methods: |
| if method not in summary: |
| continue |
| vals = [summary[method]["stratified_coverage"][k]["mean"] for k in strata_keys] |
| ax.plot( |
| x, |
| vals, |
| color=METHOD_COLORS[method], |
| linewidth=1.6, |
| marker=PROFILE_MARKERS.get(method, "o"), |
| markersize=4.2, |
| label=METHOD_LABELS[method], |
| ) |
|
|
| ax.axhline(target, color="black", linestyle="--", linewidth=1) |
| ax.set_xticks(x) |
| ax.set_xticklabels([f"S{k}" for k in strata_keys]) |
| ax.set_ylim(0.0, 1.02) |
| ax.set_title(title) |
| ax.set_ylabel("Coverage by stratum") |
| ax.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7) |
|
|
|
|
| def fig2_stratified_profiles(suite: dict[str, dict], output_dir: Path): |
| """Representative stratified coverage plots for key regimes.""" |
| panels = [ |
| ("d2_pure_scale", ["global", "twostage", "oracle"], "D2: Smooth Scale"), |
| ("d3_discrete_groups_aligned", ["global", "partition", "oracle"], "D3: Aligned Discrete"), |
| ("d6_high_k", ["global", "partition", "twostage", "oracle"], "D6: High-K"), |
| ] |
| available_panels = [(stem, methods, title) for stem, methods, title in panels if stem in suite] |
| if not available_panels: |
| print("Skipping Fig 2: no matching synthetic results found") |
| return |
|
|
| fig, axes = plt.subplots(1, len(available_panels), figsize=(3.4 * len(available_panels), 2.9), sharey=True, constrained_layout=True) |
| if len(available_panels) == 1: |
| axes = [axes] |
|
|
| for ax, (stem, methods, title) in zip(axes, available_panels): |
| _plot_strata_panel(ax, suite[stem], methods, title) |
|
|
| handles, labels = axes[0].get_legend_handles_labels() |
| fig.legend(handles, labels, ncol=min(4, len(labels)), loc="upper center", bbox_to_anchor=(0.5, 1.12), frameon=False) |
|
|
| save_figure(fig, output_dir, "fig2_stratified_profiles.pdf") |
| plt.close(fig) |
|
|
|
|
| def fig3_regime_scatter(suite: dict[str, dict], output_dir: Path): |
| """Scatter of worst-stratum coverage vs disparity for selected methods.""" |
| selected_methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oracle"] |
| fig, ax = plt.subplots(figsize=(7.2, 4.4)) |
|
|
| for stem, label in DGP_SPECS: |
| if stem not in suite: |
| continue |
| summary = extract_summary(suite[stem]) |
| for method in selected_methods: |
| if method not in summary: |
| continue |
| x = metric_mean(summary, method, "max_disparity") |
| y = metric_mean(summary, method, "worst_stratum_coverage") |
| ax.scatter( |
| x, |
| y, |
| s=65, |
| color=METHOD_COLORS[method], |
| alpha=0.9, |
| edgecolor="white", |
| linewidth=0.8, |
| ) |
| ax.text(x + 0.01, y, label.replace("\n", " "), fontsize=7, alpha=0.9) |
|
|
| ax.axhline(0.9, color="black", linestyle="--", linewidth=1) |
| ax.set_xlabel("Max disparity") |
| ax.set_ylabel("Worst-stratum coverage") |
| ax.set_title("Synthetic Regimes: Fairness-Safety Tradeoff") |
|
|
| legend_handles = [ |
| plt.Line2D([0], [0], marker="o", color="w", label=METHOD_LABELS[m], |
| markerfacecolor=METHOD_COLORS[m], markeredgecolor="white", markersize=8) |
| for m in selected_methods |
| if any(stem in suite and m in suite[stem]["summary"] for stem, _ in DGP_SPECS) |
| ] |
| ax.legend(handles=legend_handles, ncol=3, loc="lower left") |
|
|
| save_figure(fig, output_dir, "fig3_regime_tradeoff.pdf") |
| plt.close(fig) |
|
|
|
|
| def fig4_runtime_tradeoff(suite: dict[str, dict], output_dir: Path): |
| """Runtime versus disparity on the smooth-scale benchmark.""" |
| stem = "d2_pure_scale" |
| if stem not in suite: |
| print("Skipping Fig 4: d2_pure_scale.json not found") |
| return |
|
|
| summary = extract_summary(suite[stem]) |
| methods = available_methods(summary) |
|
|
| fig, ax = plt.subplots(figsize=(6.0, 3.6), constrained_layout=True) |
| for method in methods: |
| x = metric_mean(summary, method, "runtime_sec") |
| y = metric_mean(summary, method, "max_disparity") |
| ax.scatter( |
| x, |
| y, |
| s=55, |
| color=METHOD_COLORS[method], |
| edgecolor="white", |
| linewidth=0.8, |
| marker=REAL_MARKERS.get(method, "o"), |
| ) |
| ax.text(x * 1.05 if x > 0 else x + 0.02, y, METHOD_LABELS[method], fontsize=7.3, va="center") |
|
|
| ax.set_xlabel("Mean runtime per repetition (sec)") |
| ax.set_ylabel("Max disparity") |
| ax.set_xscale("symlog", linthresh=0.01) |
| ax.grid(color="#d9d9d9", linewidth=0.7, alpha=0.7) |
|
|
| save_figure(fig, output_dir, "fig4_runtime_tradeoff.pdf") |
| plt.close(fig) |
|
|
|
|
| def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path): |
| methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oneshot", "trainres", "weighted"] |
| tasks = [task for _, task in REAL_SPECS if task in real_suite] |
| if not tasks: |
| print("Skipping Fig 5: no real-data results found") |
| return |
|
|
| matrix = np.full((len(methods), len(tasks)), np.nan) |
| for j, task in enumerate(tasks): |
| summary = real_suite[task]["summary"] |
| for i, method in enumerate(methods): |
| if method in summary and "max_disparity" in summary[method]: |
| matrix[i, j] = metric_mean(summary, method, "max_disparity") |
|
|
| task_labels = { |
| "CIFAR-10": "CIFAR-10", |
| "Samson": "Samson", |
| "Topics": "Topics", |
| "AffectiveText": "Affective\nText", |
| "UTKFace": "UTKFace", |
| "PBMC": "PBMC", |
| } |
|
|
| fig, ax = plt.subplots(figsize=(7.0, 4.0), constrained_layout=True) |
| cmap = plt.cm.RdYlBu_r.copy() |
| cmap.set_bad("#efefef") |
| im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap, vmin=0.0, vmax=0.9) |
|
|
| ax.set_xticks(range(len(tasks))) |
| ax.set_xticklabels([task_labels.get(task, task) for task in tasks]) |
| ax.set_yticks(range(len(methods))) |
| ax.set_yticklabels([METHOD_LABELS[m] for m in methods]) |
| ax.set_xticks(np.arange(-0.5, len(tasks), 1), minor=True) |
| ax.set_yticks(np.arange(-0.5, len(methods), 1), minor=True) |
| ax.grid(which="minor", color="white", linewidth=1.0) |
| ax.tick_params(which="minor", bottom=False, left=False) |
|
|
| for i in range(len(methods)): |
| for j in range(len(tasks)): |
| val = matrix[i, j] |
| if np.isnan(val): |
| continue |
| txt_color = "white" if (val < 0.16 or val > 0.58) else "black" |
| stroke_color = "black" if txt_color == "white" else "white" |
| ax.text( |
| j, |
| i, |
| f"{val:.02f}", |
| ha="center", |
| va="center", |
| color=txt_color, |
| fontsize=7.5, |
| fontweight="bold", |
| path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)], |
| ) |
|
|
| highlight_best_cells(ax, matrix, methods, exclude={"weighted"}) |
|
|
| cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02) |
| cbar.set_label("Max disparity") |
|
|
| save_figure(fig, output_dir, "fig5_real_disparity_heatmap.pdf") |
| plt.close(fig) |
|
|
|
|
| def fig6_real_tradeoff(real_suite: dict[str, dict], output_dir: Path): |
| selected_methods = ["global", "partition", "twostage", "jackknife_plus", "fullcp", "trainres"] |
| tasks = [task for _, task in REAL_SPECS if task in real_suite] |
| if not tasks: |
| print("Skipping Fig 6: no real-data results found") |
| return |
|
|
| fig, axes = plt.subplots(2, 3, figsize=(7.1, 4.8), sharey=True) |
| axes = axes.flatten() |
| used_methods = set() |
|
|
| for idx, ax in enumerate(axes): |
| if idx >= len(tasks): |
| ax.axis("off") |
| continue |
| task = tasks[idx] |
| summary = real_suite[task]["summary"] |
| xs = [] |
| for method in selected_methods: |
| if method not in summary or "mean_radius" not in summary[method]: |
| continue |
| x = metric_mean(summary, method, "mean_radius") |
| y = metric_mean(summary, method, "max_disparity") |
| xerr = metric_std(summary, method, "mean_radius") |
| yerr = metric_std(summary, method, "max_disparity") |
| xs.append(x) |
| ax.errorbar( |
| x, |
| y, |
| xerr=xerr, |
| yerr=yerr, |
| fmt="none", |
| ecolor=METHOD_COLORS[method], |
| elinewidth=0.9, |
| capsize=2.0, |
| alpha=0.28, |
| zorder=1, |
| ) |
| ax.scatter( |
| x, |
| y, |
| s=42, |
| marker=REAL_MARKERS.get(method, "o"), |
| color=METHOD_COLORS[method], |
| edgecolor="white", |
| linewidth=0.7, |
| alpha=0.96, |
| zorder=2, |
| ) |
| used_methods.add(method) |
|
|
| if xs: |
| xmin = min(xs) |
| xmax = max(xs) |
| span = xmax - xmin |
| pad = 0.08 * span if span > 0 else max(0.05, 0.15 * xmax) |
| ax.set_xlim(max(0.0, xmin - pad), xmax + pad) |
| ax.set_ylim(0.0, 0.95) |
| ax.set_title(task) |
| ax.grid(color="#d9d9d9", linewidth=0.7, alpha=0.7) |
| if idx % 3 == 0: |
| ax.set_ylabel("Max disparity") |
| if idx >= 3: |
| ax.set_xlabel("Mean radius") |
|
|
| handles = [ |
| plt.Line2D( |
| [0], [0], |
| marker=REAL_MARKERS.get(method, "o"), |
| color=METHOD_COLORS[method], |
| linestyle="None", |
| label=METHOD_LABELS[method], |
| markerfacecolor=METHOD_COLORS[method], |
| markeredgecolor="white", |
| markersize=6.5, |
| ) |
| for method in selected_methods |
| if method in used_methods |
| ] |
| fig.subplots_adjust(top=0.84, bottom=0.10, hspace=0.28, wspace=0.16) |
| fig.legend(handles=handles, ncol=3, loc="upper center", bbox_to_anchor=(0.5, 0.99), frameon=False) |
|
|
| save_figure(fig, output_dir, "fig6_real_tradeoff.pdf") |
| plt.close(fig) |
|
|
|
|
| def fig7_real_profiles(real_suite: dict[str, dict], output_dir: Path): |
| panels = [ |
| ("CIFAR-10", ["global", "partition", "jackknife_plus"], "CIFAR-10"), |
| ("Topics", ["global", "twostage", "jackknife_plus"], "Topics"), |
| ("AffectiveText", ["global", "partition", "fullcp"], "AffectiveText"), |
| ("UTKFace", ["global", "partition", "jackknife_plus"], "UTKFace"), |
| ("PBMC", ["global", "partition", "twostage"], "PBMC"), |
| ] |
| available = [(task, methods, title) for task, methods, title in panels if task in real_suite] |
| if not available: |
| print("Skipping Fig 7: no matching real-data results found") |
| return |
|
|
| fig, axes = plt.subplots(2, 3, figsize=(7.2, 4.8), sharey=True) |
| axes = axes.flatten() |
|
|
| for idx, (task, methods, title) in enumerate(available): |
| ax = axes[idx] |
| summary = real_suite[task]["summary"] |
| alpha = 0.1 |
| target = 1.0 - alpha |
| reference_method = next(iter(summary)) |
| strata_keys = sorted(summary[reference_method]["stratified_coverage"].keys(), key=int) |
| x = np.arange(len(strata_keys)) |
|
|
| for method in methods: |
| if method not in summary: |
| continue |
| vals = [summary[method]["stratified_coverage"][k]["mean"] for k in strata_keys] |
| ax.plot( |
| x, |
| vals, |
| color=METHOD_COLORS[method], |
| linewidth=1.5, |
| marker=PROFILE_MARKERS.get(method, "o"), |
| markersize=4.0, |
| label=METHOD_LABELS[method], |
| ) |
|
|
| ax.axhline(target, color="black", linestyle="--", linewidth=1) |
| ax.set_xticks(x) |
| ax.set_xticklabels([f"S{k}" for k in strata_keys]) |
| ax.set_ylim(0.0, 1.02) |
| ax.set_title(title) |
| ax.grid(axis="y", color="#d9d9d9", linewidth=0.7, alpha=0.7) |
| if idx % 3 == 0: |
| ax.set_ylabel("Coverage by stratum") |
|
|
| for idx in range(len(available), len(axes)): |
| axes[idx].axis("off") |
|
|
| handles, labels = axes[0].get_legend_handles_labels() |
| fig.subplots_adjust(top=0.84, bottom=0.10, hspace=0.25, wspace=0.15) |
| fig.legend(handles, labels, ncol=min(4, len(labels)), loc="upper center", bbox_to_anchor=(0.5, 0.99), frameon=False) |
|
|
| save_figure(fig, output_dir, "fig7_real_profiles.pdf") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--results-dir", default="results/tables") |
| parser.add_argument("--output-dir", default="results/figures") |
| parser.add_argument("--fig", default="all", help="Which figure: lead,1,2,3,4,5,6,7,synthetic,real,all") |
| args = parser.parse_args() |
|
|
| results_dir = Path(args.results_dir) |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| suite = load_suite(results_dir) |
| real_suite = load_real_suite(results_dir) |
|
|
| if args.fig in ("lead", "all", "synthetic"): |
| fig1_allocation_geometry(suite, output_dir) |
| if args.fig in ("1", "all", "synthetic"): |
| fig1_disparity_heatmap(suite, output_dir) |
| if args.fig in ("2", "all", "synthetic"): |
| fig2_stratified_profiles(suite, output_dir) |
| if args.fig in ("3", "all", "synthetic"): |
| fig3_regime_scatter(suite, output_dir) |
| if args.fig in ("4", "all", "synthetic"): |
| fig4_runtime_tradeoff(suite, output_dir) |
| if args.fig in ("5", "all", "real"): |
| fig5_real_disparity_heatmap(real_suite, output_dir) |
| if args.fig in ("6", "all", "real"): |
| fig6_real_tradeoff(real_suite, output_dir) |
| if args.fig in ("7", "all", "real"): |
| fig7_real_profiles(real_suite, output_dir) |
|
|
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|