anonymous0523ly commited on
Commit
0637094
·
verified ·
1 Parent(s): fc329a3

Remove weighted from main heatmap figures

Browse files
Files changed (1) hide show
  1. scripts/make_figures.py +4 -4
scripts/make_figures.py CHANGED
@@ -339,7 +339,7 @@ def fig1_allocation_geometry(suite: dict[str, dict], output_dir: Path):
339
 
340
  def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
341
  """Heatmap of max disparity across regimes and methods."""
342
- methods = METHOD_ORDER
343
  matrix = np.full((len(methods), len(DGP_SPECS)), np.nan)
344
 
345
  for j, (stem, _) in enumerate(DGP_SPECS):
@@ -383,7 +383,7 @@ def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
383
  path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
384
  )
385
 
386
- highlight_best_cells(ax, matrix, methods, exclude={"oracle", "weighted"})
387
 
388
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
389
  cbar.set_label("Max disparity")
@@ -529,7 +529,7 @@ def fig4_runtime_tradeoff(suite: dict[str, dict], output_dir: Path):
529
 
530
 
531
  def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
532
- methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oneshot", "trainres", "weighted"]
533
  tasks = [task for _, task in REAL_SPECS if task in real_suite]
534
  if not tasks:
535
  print("Skipping Fig 5: no real-data results found")
@@ -584,7 +584,7 @@ def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
584
  path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
585
  )
586
 
587
- highlight_best_cells(ax, matrix, methods, exclude={"weighted"})
588
 
589
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
590
  cbar.set_label("Max disparity")
 
339
 
340
  def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
341
  """Heatmap of max disparity across regimes and methods."""
342
+ methods = [method for method in METHOD_ORDER if method != "weighted"]
343
  matrix = np.full((len(methods), len(DGP_SPECS)), np.nan)
344
 
345
  for j, (stem, _) in enumerate(DGP_SPECS):
 
383
  path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
384
  )
385
 
386
+ highlight_best_cells(ax, matrix, methods, exclude={"oracle"})
387
 
388
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
389
  cbar.set_label("Max disparity")
 
529
 
530
 
531
  def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
532
+ methods = ["global", "partition", "twostage", "fullcp", "jackknife_plus", "oneshot", "trainres"]
533
  tasks = [task for _, task in REAL_SPECS if task in real_suite]
534
  if not tasks:
535
  print("Skipping Fig 5: no real-data results found")
 
584
  path_effects=[pe.withStroke(linewidth=1.1, foreground=stroke_color)],
585
  )
586
 
587
+ highlight_best_cells(ax, matrix, methods)
588
 
589
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
590
  cbar.set_label("Max disparity")