import json import os import os.path as osp import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np def plot_image_grid(image_dirs, datasets, folder): num_datasets = len(datasets) if num_datasets != len(image_dirs): raise ValueError("Number of datasets must match number of image directories") total_images = 20 base_images_per_dataset = total_images // num_datasets remainder = total_images % num_datasets # Distribute remainder images images_per_dataset = [base_images_per_dataset + (1 if i < remainder else 0) for i in range(num_datasets)] # Create a 5x4 grid of subplots fig, axs = plt.subplots(5, 4, figsize=(15, 12)) axs = axs.ravel() # Flatten the 2D array of axes current_idx = 0 for dataset_idx, (image_dir, dataset, num_images) in enumerate(zip(image_dirs, datasets, images_per_dataset)): nums = np.linspace(0, 199, num=num_images, dtype=int) img_names = [str(num).zfill(3) + ".png" for num in nums] # Plot images for this dataset for i, img_name in enumerate(img_names): img_path = osp.join(image_dir, img_name) if osp.exists(img_path): img = plt.imread(img_path) axs[current_idx].imshow(img) axs[current_idx].axis("off") current_idx += 1 # Turn off any remaining empty subplots for i in range(current_idx, total_images): axs[i].set_visible(False) plt.tight_layout() plt.savefig(f"images_all_{folder}.png") plt.close() datasets = ["chair", "drums"] folders = os.listdir("./") final_results = {} results_info = {} metrics = ["train/loss", "train/PSNR", "train/mse", "test/PSNR", "iters", "train/reg", "train/reg_l1", "train/reg_tv_density", "train/reg_tv_app"] # Load results and compute metrics for folder in folders: if folder.startswith("run") and osp.isdir(folder): results_dict = np.load(osp.join(folder, "all_results.npy"), allow_pickle=True).item() run_info = {} image_dirs = [results_dict[dataset][0]["imgs"] for dataset in datasets] plot_image_grid(image_dirs, datasets, folder) for dataset in datasets: dset_curr = results_dict[dataset] iters = dset_curr[0]["iters"] run_info[dataset] = {} for metric in metrics: # check if metric is empty list losses = [dset_curr[int(i)][metric] for i in dset_curr.keys()] if len(losses[0]) == 0: losses = [0] * len(iters) run_info[dataset][metric] = { "iters": iters, "mean": [0] * len(iters), "stderr": [0] * len(iters) } continue losses = np.array(losses) mean_losses = np.mean(losses, axis=0) if len(losses) > 0: sterr_losses = np.std(losses, axis=0) / np.sqrt(len(losses)) else: sterr_losses = np.zeros_like(mean_losses) if metric.startswith("test"): iters_test = [i for i in range(0, len(losses[0]))] run_info[dataset][metric] = { "iters": iters_test, "mean": mean_losses, "stderr": sterr_losses } else: run_info[dataset][metric] = { "iters": iters, "mean": mean_losses, "stderr": sterr_losses } results_info[folder] = run_info def generate_color_palette(n): cmap = plt.get_cmap('tab20') return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)] labels = { "run_0": "Baseline", } runs = list(labels.keys()) colors = generate_color_palette(len(runs)) # Function to plot metrics def plot_metric(metric_name, datasets, results_info, runs, colors, labels): for dataset in datasets: plt.figure(figsize=(10, 6)) for i, run in enumerate(runs): metric_info = results_info[run][dataset].get(metric_name, {}) iters = metric_info.get("iters", []) mean = metric_info.get("mean", []) stderr = metric_info.get("stderr", []) if iters: plt.plot(iters, mean, label=f"{labels[run]} ({metric_name})", color=colors[i]) plt.fill_between(iters, np.array(mean) - np.array(stderr), np.array(mean) + np.array(stderr), color=colors[i], alpha=0.2) plt.title(f"{metric_name.capitalize()} Across Runs for {dataset} Dataset") plt.xlabel("Iteration") plt.ylabel(metric_name.capitalize()) plt.legend() plt.grid(True, which="both", ls="-", alpha=0.2) plt.tight_layout() plt.savefig(f"{metric_name.replace('/', '_')}_{dataset}.png") plt.close() # # Plotting metrics for all datasets metrics = ["train/PSNR", "train/mse", "test/PSNR", "train/reg", "train/reg_l1", "train/reg_tv_density", "train/reg_tv_app"] for metric in metrics: plot_metric(metric, datasets, results_info, runs, colors, labels)