| | """ |
| | Visualization utilities for evaluation. |
| | |
| | Functions: |
| | - make_image_grid: Create and optionally save a grid of images |
| | - visualize_denoising: Visualize the denoising process |
| | - format_prompt_caption: Format prompts for display in image captions |
| | """ |
| |
|
| | import os |
| | from typing import List, Optional |
| |
|
| | import numpy as np |
| | import torch |
| | import torchvision |
| |
|
| |
|
| | def make_image_grid( |
| | images: torch.Tensor, |
| | rows: Optional[int] = None, |
| | cols: Optional[int] = None, |
| | save_path: Optional[str] = None, |
| | normalize: bool = True, |
| | value_range: Optional[tuple] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Create a grid of images and optionally save it. |
| | |
| | Args: |
| | images: Tensor of shape [B, C, H, W] |
| | rows: Number of rows (optional) |
| | cols: Number of columns (optional) |
| | save_path: Path to save the grid image |
| | normalize: Whether to normalize images to [0, 1] |
| | value_range: Range of values in input images (min, max) |
| | |
| | Returns: |
| | Grid tensor |
| | """ |
| | if rows is None and cols is None: |
| | nrow = int(np.ceil(np.sqrt(images.shape[0]))) |
| | elif cols is not None: |
| | nrow = cols |
| | else: |
| | nrow = int(np.ceil(images.shape[0] / rows)) |
| |
|
| | grid = torchvision.utils.make_grid( |
| | images, nrow=nrow, normalize=normalize, value_range=value_range, padding=2 |
| | ) |
| |
|
| | if save_path: |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| | torchvision.utils.save_image(grid, save_path) |
| |
|
| | return grid |
| |
|
| |
|
| | def visualize_denoising( |
| | intermediate_steps: List[torch.Tensor], save_path: str, num_steps_to_show: int = 10 |
| | ): |
| | """ |
| | Visualize the denoising process by selecting a subset of steps. |
| | |
| | Args: |
| | intermediate_steps: List of tensors [B, C, H, W] from the sampling process |
| | save_path: Path to save the visualization |
| | num_steps_to_show: Number of steps to display |
| | """ |
| | total_steps = len(intermediate_steps) |
| | if total_steps < num_steps_to_show: |
| | indices = list(range(total_steps)) |
| | else: |
| | indices = np.linspace(0, total_steps - 1, num_steps_to_show, dtype=int).tolist() |
| |
|
| | selected_steps = [intermediate_steps[i] for i in indices] |
| |
|
| | |
| | first_sample_steps = [step[0] for step in selected_steps] |
| |
|
| | |
| | stacked = torch.stack(first_sample_steps) |
| |
|
| | |
| | make_image_grid( |
| | stacked, |
| | rows=1, |
| | cols=len(selected_steps), |
| | save_path=save_path, |
| | normalize=True, |
| | value_range=(-1, 1), |
| | ) |
| |
|
| |
|
| | def format_prompt_caption(prompts: List[str], limit: int = 32) -> str: |
| | """ |
| | Format a list of prompts for display as an image caption. |
| | |
| | Args: |
| | prompts: List of prompt strings |
| | limit: Maximum number of prompts to include |
| | |
| | Returns: |
| | Formatted caption string |
| | """ |
| | if not prompts: |
| | return "" |
| | trimmed = [p.replace("\n", " ").strip() for p in prompts[:limit]] |
| | lines = [f"{idx + 1}. {text}" for idx, text in enumerate(trimmed)] |
| | remaining = len(prompts) - len(trimmed) |
| | if remaining > 0: |
| | lines.append(f"... (+{remaining} more)") |
| | return "\n\n".join(lines) |
| |
|
| |
|