| import logging |
| import shutil |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
|
|
|
|
| def download_and_extract_benchmark(name: str, url: Path, output: Path) -> None: |
| benchmark_dir = output / name |
| if not output.exists(): |
| output.mkdir(parents=True) |
|
|
| if benchmark_dir.exists(): |
| logger.info(f"Benchmark {name} already exists at {benchmark_dir}, skipping download.") |
| return |
|
|
| if name == "stanford2d3d": |
| |
| txt = "\n" + "#" * 108 + "\n\n" |
| txt += "To download the Stanford2D3D dataset, you must agree to the terms of use:\n\n" |
| txt += ( |
| "https://docs.google.com/forms/d/e/" |
| + "1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1\n\n" |
| ) |
| txt += "#" * 108 + "\n\n" |
| txt += "Did you fill out the data sharing and usage terms? [y/n] " |
| choice = input(txt) |
| if choice.lower() != "y": |
| raise ValueError( |
| "You must agree to the terms of use to download the Stanford2D3D dataset." |
| ) |
|
|
| zip_file = output / f"{name}.zip" |
|
|
| if not zip_file.exists(): |
| logger.info(f"Downloading benchmark {name} to {zip_file} from {url}.") |
| torch.hub.download_url_to_file(url, zip_file) |
|
|
| logger.info(f"Extracting benchmark {name} in {output}.") |
| shutil.unpack_archive(zip_file, output, format="zip") |
| zip_file.unlink() |
|
|
|
|
| def check_keys_recursive(d, pattern): |
| if isinstance(pattern, dict): |
| {check_keys_recursive(d[k], v) for k, v in pattern.items()} |
| else: |
| for k in pattern: |
| assert k in d.keys() |
|
|
|
|
| def plot_scatter_grid( |
| results, x_keys, y_keys, name=None, diag=False, ax=None, line_idx=0, show_means=True |
| ): |
| if ax is None: |
| N, M = len(y_keys), len(x_keys) |
| fig, ax = plt.subplots(N, M, figsize=(M * 6, N * 5)) |
|
|
| if N == 1: |
| ax = np.array(ax) |
| ax = ax.reshape(1, -1) |
|
|
| if M == 1: |
| ax = np.array(ax) |
| ax = ax.reshape(-1, 1) |
| else: |
| fig = None |
|
|
| for j, kx in enumerate(x_keys): |
| for i, ky in enumerate(y_keys): |
| ax[i, j].scatter( |
| results[kx], |
| results[ky], |
| s=1, |
| alpha=0.5, |
| label=name or None, |
| ) |
|
|
| ax[i, j].set_xlabel(f"{' '.join(kx.split('_')).title()}") |
| ax[i, j].set_ylabel(f"{' '.join(ky.split('_')).title()}") |
|
|
| low = min(ax[i, j].get_xlim()[0], ax[i, j].get_ylim()[0]) |
| high = max(ax[i, j].get_xlim()[1], ax[i, j].get_ylim()[1]) |
| if diag == "all" or (i == j and diag): |
| ax[i, j].plot([low, high], [low, high], ls="--", c="red", label="y=x") |
|
|
| if name or diag == "all" or (i == j and diag): |
| ax[i, j].legend() |
|
|
| if not show_means: |
| return fig, ax |
|
|
| means = {"y": {}, "x": {}} |
| for kx in x_keys: |
| for ky in y_keys: |
| means["x"][kx] = np.mean(results[kx]) |
| means["y"][ky] = np.mean(results[ky]) |
|
|
| for j, kx in enumerate(x_keys): |
| for i, ky in enumerate(y_keys): |
| xlim = np.min(results[kx]), np.max(results[kx]) |
| ylim = np.min(results[ky]), np.max(results[ky]) |
| means_x = [means["x"][kx]] |
| means_y = [means["y"][ky]] |
| color = plt.cm.tab10(line_idx) |
| ax[i, j].vlines(means_x, *ylim, colors=[color]) |
| ax[i, j].hlines(means_y, *xlim, colors=[color]) |
|
|
| return fig, ax |
|
|