Remove weighted from main heatmap figures
Browse files- scripts/make_figures.py +4 -4
scripts/make_figures.py
CHANGED
|
@@ -339,7 +339,7 @@ def fig1_allocation_geometry(suite: dict[str, dict], output_dir: Path):
|
|
| 339 |
|
| 340 |
def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
|
| 341 |
"""Heatmap of max disparity across regimes and methods."""
|
| 342 |
-
methods = METHOD_ORDER
|
| 343 |
matrix = np.full((len(methods), len(DGP_SPECS)), np.nan)
|
| 344 |
|
| 345 |
for j, (stem, _) in enumerate(DGP_SPECS):
|
|
@@ -383,7 +383,7 @@ def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
|
|
| 383 |
path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
|
| 384 |
)
|
| 385 |
|
| 386 |
-
highlight_best_cells(ax, matrix, methods, exclude={"oracle"
|
| 387 |
|
| 388 |
cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
|
| 389 |
cbar.set_label("Max disparity")
|
|
@@ -529,7 +529,7 @@ def fig4_runtime_tradeoff(suite: dict[str, dict], output_dir: Path):
|
|
| 529 |
|
| 530 |
|
| 531 |
def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
|
| 532 |
-
methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oneshot", "trainres"
|
| 533 |
tasks = [task for _, task in REAL_SPECS if task in real_suite]
|
| 534 |
if not tasks:
|
| 535 |
print("Skipping Fig 5: no real-data results found")
|
|
@@ -584,7 +584,7 @@ def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
|
|
| 584 |
path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
|
| 585 |
)
|
| 586 |
|
| 587 |
-
highlight_best_cells(ax, matrix, methods
|
| 588 |
|
| 589 |
cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
|
| 590 |
cbar.set_label("Max disparity")
|
|
|
|
| 339 |
|
| 340 |
def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
|
| 341 |
"""Heatmap of max disparity across regimes and methods."""
|
| 342 |
+
methods = [method for method in METHOD_ORDER if method != "weighted"]
|
| 343 |
matrix = np.full((len(methods), len(DGP_SPECS)), np.nan)
|
| 344 |
|
| 345 |
for j, (stem, _) in enumerate(DGP_SPECS):
|
|
|
|
| 383 |
path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
|
| 384 |
)
|
| 385 |
|
| 386 |
+
highlight_best_cells(ax, matrix, methods, exclude={"oracle"})
|
| 387 |
|
| 388 |
cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
|
| 389 |
cbar.set_label("Max disparity")
|
|
|
|
| 529 |
|
| 530 |
|
| 531 |
def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
|
| 532 |
+
methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oneshot", "trainres"]
|
| 533 |
tasks = [task for _, task in REAL_SPECS if task in real_suite]
|
| 534 |
if not tasks:
|
| 535 |
print("Skipping Fig 5: no real-data results found")
|
|
|
|
| 584 |
path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
|
| 585 |
)
|
| 586 |
|
| 587 |
+
highlight_best_cells(ax, matrix, methods)
|
| 588 |
|
| 589 |
cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
|
| 590 |
cbar.set_label("Max disparity")
|