| import json |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| import keras |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import tyro |
| from keras import ops |
| from matplotlib.patches import PathPatch |
| from matplotlib.path import Path as pltPath |
| from PIL import Image |
| from skimage import measure |
| from zea import log |
| from zea.utils import save_to_gif |
| from zea.visualize import plot_image_grid |
|
|
| from utils import postprocess |
|
|
|
|
| def add_shape_from_mask(ax, mask, **kwargs): |
| """add a shape to axis from mask array. |
| |
| Args: |
| ax (plt.ax): matplotlib axis |
| mask (ndarray): numpy array with non-zero |
| shape defining the region of interest. |
| Kwargs: |
| edgecolor (str): color of the shape's edge |
| facecolor (str): color of the shape's face |
| linewidth (int): width of the shape's edge |
| |
| Returns: |
| plt.ax: matplotlib axis with shape added |
| """ |
| |
| padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0) |
| contours = measure.find_contours(padded_mask, 0.5) |
| patches = [] |
| for contour in contours: |
| |
| contour -= 1 |
| path = pltPath(contour[:, ::-1]) |
| patch = PathPatch(path, **kwargs) |
| patches.append(ax.add_patch(patch)) |
| return patches |
|
|
|
|
| def matplotlib_figure_to_numpy(fig): |
| """Convert matplotlib figure to numpy array. |
| |
| Args: |
| fig (matplotlib.figure.Figure): figure to convert. |
| |
| Returns: |
| np.ndarray: numpy array of figure. |
| |
| """ |
| buf = BytesIO() |
| fig.savefig(buf, format="png", bbox_inches="tight") |
| buf.seek(0) |
| image = Image.open(buf).convert("RGB") |
| image = np.array(image)[..., :3] |
| buf.close() |
| return image |
|
|
|
|
| def plot_batch_with_named_masks( |
| images, masks_dict, mask_colors=None, titles=None, **kwargs |
| ): |
| """ |
| Plot batch of images in rows, each column overlays a different mask from the dict. |
| Mask labels are shown as column titles. If mask name is 'per_pixel_omega', show it |
| directly with inferno colormap (no overlay). |
| |
| Args: |
| images: np.ndarray, shape (batch, height, width, channels) |
| masks_dict: dict of {name: mask}, each mask shape (batch, height, width, channels) |
| mask_colors: dict of {name: color} or None (default colors used) |
| """ |
| mask_names = list(masks_dict.keys()) |
| batch_size = images.shape[0] |
| default_colors = ["red", "green", "#33aaff", "yellow", "magenta", "cyan"] |
| mask_colors = mask_colors or { |
| name: default_colors[i % len(default_colors)] |
| for i, name in enumerate(mask_names) |
| } |
|
|
| |
| columns = [] |
| cmaps = [] |
| for name in mask_names: |
| if name == "per_pixel_omega": |
| mask_np = np.array(masks_dict[name]) |
| columns.append(np.squeeze(mask_np)) |
| cmaps.append(["inferno"] * batch_size) |
| else: |
| columns.append(np.squeeze(images)) |
| cmaps.append(["gray"] * batch_size) |
|
|
| |
| all_images = np.stack(columns, axis=0) |
| |
| all_images = ( |
| np.transpose(all_images, (1, 0, 2, 3, 4)) |
| if all_images.ndim == 5 |
| else np.transpose(all_images, (1, 0, 2, 3)) |
| ) |
| |
| all_images = all_images.reshape(batch_size * len(mask_names), *images.shape[1:]) |
|
|
| |
| flat_cmaps = [] |
| for row in range(batch_size): |
| for col in range(len(mask_names)): |
| flat_cmaps.append(cmaps[col][row]) |
|
|
| fig, _ = plot_image_grid( |
| all_images, |
| ncols=len(mask_names), |
| remove_axis=False, |
| cmap=flat_cmaps, |
| figsize=(8, 3.3), |
| **kwargs, |
| ) |
|
|
| |
| for col_idx, name in enumerate(mask_names): |
| if name == "per_pixel_omega": |
| continue |
| mask_np = np.array(masks_dict[name]) |
| axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)] |
| for ax, mask_img in zip(axes, mask_np): |
| add_shape_from_mask( |
| ax, mask_img.squeeze(), color=mask_colors[name], alpha=0.3 |
| ) |
|
|
| |
| row_idx = 0 |
| if titles is None: |
| titles = mask_names |
| for col_idx, name in enumerate(titles): |
| ax_idx = row_idx * len(mask_names) + col_idx |
| fig.axes[ax_idx].set_title(name, fontsize=9, color="white") |
| fig.axes[ax_idx].set_facecolor("black") |
|
|
| |
| if "per_pixel_omega" in mask_names: |
| col_idx = mask_names.index("per_pixel_omega") |
| axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)] |
|
|
| |
| top_ax = axes[0] |
| bottom_ax = axes[-1] |
| top_pos = top_ax.get_position() |
| bottom_pos = bottom_ax.get_position() |
|
|
| full_y0 = bottom_pos.y0 |
| full_y1 = top_pos.y1 |
| full_height = full_y1 - full_y0 |
|
|
| |
| scale = 0.8 |
| height = full_height * scale |
| y0 = full_y0 + (full_height - height) / 2 |
|
|
| x0 = top_pos.x1 + 0.015 |
| width = 0.015 |
|
|
| |
| cax = fig.add_axes([x0, y0, width, height]) |
|
|
| im = axes[0].get_images()[0] if axes[0].get_images() else None |
| cbar = fig.colorbar(im, cax=cax) |
| cbar.set_label(r"Guidance weighting $\mathbf{p}$") |
| cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6)) |
| cbar.ax.yaxis.set_tick_params(labelsize=7) |
| cbar.ax.yaxis.label.set_size(8) |
|
|
| return fig |
|
|
|
|
| def plot_dehazed_results( |
| hazy_images, |
| pred_tissue_images, |
| pred_haze_images, |
| diffusion_model, |
| titles=("Hazy", "Dehazed", "Haze"), |
| ): |
| """Create and save visualization with optional mask overlays.""" |
|
|
| |
| input_shape = diffusion_model.input_shape |
| stack_images = ops.stack( |
| [ |
| hazy_images, |
| pred_tissue_images, |
| pred_haze_images, |
| ] |
| ) |
| stack_images = ops.reshape(stack_images, (-1, input_shape[0], input_shape[1])) |
|
|
| |
| fig, _ = plot_image_grid( |
| stack_images, |
| ncols=len(hazy_images), |
| remove_axis=False, |
| vmin=0, |
| vmax=255, |
| ) |
| |
| for i, ax in enumerate(fig.axes): |
| if i % len(hazy_images) == 0: |
| label = titles[(i // len(hazy_images)) % len(titles)] |
| ax.set_ylabel(label, fontsize=12) |
|
|
| return fig |
|
|
|
|
| def plot_metrics(metrics, limits, out_path): |
| plt.style.use("seaborn-v0_8-darkgrid") |
| fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=200) |
| colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"] |
| metric_labels = { |
| "CNR": r"CNR $\uparrow$", |
| "gCNR": r"gCNR $\uparrow$", |
| "KS_A": r"KS$_{septum}$ $\downarrow$", |
| "KS_B": r"KS$_{ventricle}$ $\uparrow$", |
| } |
| |
| legend_handles = [] |
| import matplotlib.lines as mlines |
|
|
| min_style = { |
| "color": "crimson", |
| "linestyle": "--", |
| "lw": 1.2, |
| "marker": "o", |
| "markersize": 5, |
| } |
| max_style = { |
| "color": "crimson", |
| "linestyle": ":", |
| "lw": 1.2, |
| "marker": "s", |
| "markersize": 5, |
| } |
| for idx, (ax, (name, values)) in enumerate(zip(axes, metrics.items())): |
| ax.hist( |
| values, |
| bins=30, |
| color=colors[idx % len(colors)], |
| alpha=0.85, |
| edgecolor="black", |
| linewidth=0.7, |
| ) |
| ax.set_xlabel(metric_labels.get(name, name), fontsize=11) |
| if idx == 0: |
| ax.set_ylabel("Count", fontsize=10) |
| |
| if name in limits: |
| lims = limits[name] |
| if len(legend_handles) == 0: |
| |
| min_handle = mlines.Line2D([], [], **min_style, label="min score") |
| max_handle = mlines.Line2D([], [], **max_style, label="max score") |
| legend_handles.extend([min_handle, max_handle]) |
| if len(lims) > 0: |
| ax.axvline(lims[0], **min_style) |
| if len(lims) > 1: |
| ax.axvline(lims[1], **max_style) |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
| ax.tick_params(axis="both", which="major", labelsize=9) |
| |
| fig.legend( |
| handles=legend_handles, |
| loc="upper center", |
| ncol=2, |
| fontsize=10, |
| frameon=False, |
| bbox_to_anchor=(0.5, 1.02), |
| ) |
| return fig |
|
|
|
|
| def plot_optimization_history_from_json( |
| trials_data: List[Dict[str, Any]], output_path: Path, method: str |
| ): |
| """Plot optimization history directly from JSON data.""" |
|
|
| |
| completed_trials = [ |
| t for t in trials_data if t["state"] == "COMPLETE" and t["value"] is not None |
| ] |
|
|
| if not completed_trials: |
| print("No completed trials found!") |
| return |
|
|
| |
| completed_trials.sort(key=lambda x: x["number"]) |
|
|
| trial_numbers = [t["number"] for t in completed_trials] |
| values = [t["value"] for t in completed_trials] |
|
|
| |
| plt.style.use("seaborn-v0_8-darkgrid") |
|
|
| |
| fig, ax = plt.subplots(figsize=(5, 3), dpi=600) |
|
|
| |
| ax.scatter( |
| trial_numbers, |
| values, |
| c="#0057b7", |
| alpha=0.6, |
| s=30, |
| edgecolor="black", |
| linewidth=0.5, |
| ) |
|
|
| |
| best_values = [] |
| current_best = values[0] |
| for val in values: |
| if val > current_best: |
| current_best = val |
| best_values.append(current_best) |
|
|
| ax.plot( |
| trial_numbers, |
| best_values, |
| color="#d62d20", |
| linewidth=2.5, |
| label="Best Value", |
| marker="o", |
| markersize=4, |
| markevery=max(1, len(trial_numbers) // 20), |
| ) |
|
|
| ax.set_xlabel("Trial", fontsize=11) |
| ax.set_ylabel("Objective Value", fontsize=11) |
| |
| ax.legend(fontsize=10, frameon=False) |
|
|
| |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
| ax.tick_params(axis="both", which="major", labelsize=9) |
|
|
| |
| fig.savefig( |
| output_path / f"optimization_history_{method}.png", dpi=600, bbox_inches="tight" |
| ) |
| fig.savefig( |
| output_path / f"optimization_history_{method}.pdf", dpi=600, bbox_inches="tight" |
| ) |
| plt.close(fig) |
|
|
|
|
| def create_animation_frame(hazy_images, tissue_frame, haze_frame): |
| """Create a single animation frame from the tracked progress.""" |
| batch, height, width = ops.shape(hazy_images) |
| frame_stack = ops.stack( |
| [ |
| hazy_images, |
| tissue_frame, |
| haze_frame, |
| ] |
| ) |
| frame_stack = ops.reshape(frame_stack, (-1, height, width)) |
| fig_frame, _ = plot_image_grid( |
| frame_stack, |
| ncols=len(hazy_images), |
| remove_axis=False, |
| vmin=0, |
| vmax=255, |
| ) |
| labels = ["Hazy", "Haze", "Tissue"] |
| for i, ax in enumerate(fig_frame.axes): |
| label = labels[i % len(labels)] |
| ax.set_ylabel(label, fontsize=12) |
| frame_array = matplotlib_figure_to_numpy(fig_frame) |
| plt.close(fig_frame) |
| return frame_array |
|
|
|
|
| def create_animation(hazy_images, diffusion_model, output_path, fps): |
| """Create animation from tracked progress frames.""" |
| if not (len(diffusion_model.track_progress) > 1): |
| log.warning( |
| "Animation requested but no intermediate frames were tracked. " |
| "Try reducing diffusion_steps or ensure progress tracking is enabled." |
| ) |
| return |
|
|
| log.info(f"Creating animation with {len(diffusion_model.track_progress)} frames...") |
|
|
| animation_frames = [] |
| progbar = keras.utils.Progbar( |
| len(diffusion_model.track_progress), unit_name="frame" |
| ) |
| for tissue_frame in diffusion_model.track_progress: |
| haze_frame = hazy_images - tissue_frame - 1 |
| tissue_frame = postprocess(tissue_frame, diffusion_model.input_range) |
| haze_frame = postprocess(haze_frame, diffusion_model.input_range) |
| _hazy_images = postprocess(hazy_images, diffusion_model.input_range) |
| frame_array = create_animation_frame(_hazy_images, tissue_frame, haze_frame) |
| animation_frames.append(frame_array) |
| progbar.add(1) |
|
|
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| animation_path = Path(output_path).with_suffix(".gif") |
| save_to_gif(animation_frames, animation_path, fps=fps) |
|
|
|
|
| def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"): |
| json_path = Path(json_file) |
| if not json_path.exists(): |
| raise FileNotFoundError(f"JSON file not found: {json_file}") |
|
|
| |
| with open(json_path, "r") as f: |
| trials_data = json.load(f) |
|
|
| print(f"Loaded {len(trials_data)} trials from {json_file}") |
|
|
| |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| print("Generating optimization history plot...") |
| plot_optimization_history_from_json(trials_data, output_path, method) |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|