anonymous0523ly commited on
Commit
855fe54
·
verified ·
1 Parent(s): 0637094

Use data-driven heatmap color ranges

Browse files
Files changed (1) hide show
  1. scripts/make_figures.py +6 -4
scripts/make_figures.py CHANGED
@@ -353,7 +353,7 @@ def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
353
  fig, ax = plt.subplots(figsize=(7.0, 3.6), constrained_layout=True)
354
  cmap = plt.cm.RdYlBu_r.copy()
355
  cmap.set_bad("#efefef")
356
- im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap, vmin=0.0, vmax=0.85)
357
 
358
  ax.set_xticks(range(len(DGP_SPECS)))
359
  ax.set_xticklabels([label for _, label in DGP_SPECS])
@@ -386,7 +386,8 @@ def fig1_disparity_heatmap(suite: dict[str, dict], output_dir: Path):
386
  highlight_best_cells(ax, matrix, methods, exclude={"oracle"})
387
 
388
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
389
- cbar.set_label("Max disparity")
 
390
 
391
  save_figure(fig, output_dir, "fig1_synthetic_disparity_heatmap.pdf")
392
  plt.close(fig)
@@ -554,7 +555,7 @@ def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
554
  fig, ax = plt.subplots(figsize=(7.0, 4.0), constrained_layout=True)
555
  cmap = plt.cm.RdYlBu_r.copy()
556
  cmap.set_bad("#efefef")
557
- im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap, vmin=0.0, vmax=0.9)
558
 
559
  ax.set_xticks(range(len(tasks)))
560
  ax.set_xticklabels([task_labels.get(task, task) for task in tasks])
@@ -587,7 +588,8 @@ def fig5_real_disparity_heatmap(real_suite: dict[str, dict], output_dir: Path):
587
  highlight_best_cells(ax, matrix, methods)
588
 
589
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
590
- cbar.set_label("Max disparity")
 
591
 
592
  save_figure(fig, output_dir, "fig5_real_disparity_heatmap.pdf")
593
  plt.close(fig)
 
353
  fig, ax = plt.subplots(figsize=(7.0, 3.6), constrained_layout=True)
354
  cmap = plt.cm.RdYlBu_r.copy()
355
  cmap.set_bad("#efefef")
356
+ im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap)
357
 
358
  ax.set_xticks(range(len(DGP_SPECS)))
359
  ax.set_xticklabels([label for _, label in DGP_SPECS])
 
386
  highlight_best_cells(ax, matrix, methods, exclude={"oracle"})
387
 
388
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
389
+ cbar.set_label("Max disparity (low → high)")
390
+ cbar.set_ticks([])
391
 
392
  save_figure(fig, output_dir, "fig1_synthetic_disparity_heatmap.pdf")
393
  plt.close(fig)
 
555
  fig, ax = plt.subplots(figsize=(7.0, 4.0), constrained_layout=True)
556
  cmap = plt.cm.RdYlBu_r.copy()
557
  cmap.set_bad("#efefef")
558
+ im = ax.imshow(np.ma.masked_invalid(matrix), aspect="auto", cmap=cmap)
559
 
560
  ax.set_xticks(range(len(tasks)))
561
  ax.set_xticklabels([task_labels.get(task, task) for task in tasks])
 
588
  highlight_best_cells(ax, matrix, methods)
589
 
590
  cbar = fig.colorbar(im, ax=ax, shrink=0.95, pad=0.02)
591
+ cbar.set_label("Max disparity (low → high)")
592
+ cbar.set_ticks([])
593
 
594
  save_figure(fig, output_dir, "fig5_real_disparity_heatmap.pdf")
595
  plt.close(fig)