ArtFlow / artflow /evaluation /visualize.py
kaupane's picture
Upload folder using huggingface_hub
eb52c18 verified
"""
Visualization utilities for evaluation.
Functions:
- make_image_grid: Create and optionally save a grid of images
- visualize_denoising: Visualize the denoising process
- format_prompt_caption: Format prompts for display in image captions
"""
import os
from typing import List, Optional
import numpy as np
import torch
import torchvision
def make_image_grid(
images: torch.Tensor,
rows: Optional[int] = None,
cols: Optional[int] = None,
save_path: Optional[str] = None,
normalize: bool = True,
value_range: Optional[tuple] = None,
) -> torch.Tensor:
"""
Create a grid of images and optionally save it.
Args:
images: Tensor of shape [B, C, H, W]
rows: Number of rows (optional)
cols: Number of columns (optional)
save_path: Path to save the grid image
normalize: Whether to normalize images to [0, 1]
value_range: Range of values in input images (min, max)
Returns:
Grid tensor
"""
if rows is None and cols is None:
nrow = int(np.ceil(np.sqrt(images.shape[0])))
elif cols is not None:
nrow = cols
else:
nrow = int(np.ceil(images.shape[0] / rows))
grid = torchvision.utils.make_grid(
images, nrow=nrow, normalize=normalize, value_range=value_range, padding=2
)
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torchvision.utils.save_image(grid, save_path)
return grid
def visualize_denoising(
intermediate_steps: List[torch.Tensor], save_path: str, num_steps_to_show: int = 10
):
"""
Visualize the denoising process by selecting a subset of steps.
Args:
intermediate_steps: List of tensors [B, C, H, W] from the sampling process
save_path: Path to save the visualization
num_steps_to_show: Number of steps to display
"""
total_steps = len(intermediate_steps)
if total_steps < num_steps_to_show:
indices = list(range(total_steps))
else:
indices = np.linspace(0, total_steps - 1, num_steps_to_show, dtype=int).tolist()
selected_steps = [intermediate_steps[i] for i in indices]
# Take the first sample from the batch for visualization
first_sample_steps = [step[0] for step in selected_steps] # List of [C, H, W]
# Stack them: [Num_steps, C, H, W]
stacked = torch.stack(first_sample_steps)
# Make grid: 1 row, Num_steps columns
make_image_grid(
stacked,
rows=1,
cols=len(selected_steps),
save_path=save_path,
normalize=True,
value_range=(-1, 1),
)
def format_prompt_caption(prompts: List[str], limit: int = 32) -> str:
"""
Format a list of prompts for display as an image caption.
Args:
prompts: List of prompt strings
limit: Maximum number of prompts to include
Returns:
Formatted caption string
"""
if not prompts:
return ""
trimmed = [p.replace("\n", " ").strip() for p in prompts[:limit]]
lines = [f"{idx + 1}. {text}" for idx, text in enumerate(trimmed)]
remaining = len(prompts) - len(trimmed)
if remaining > 0:
lines.append(f"... (+{remaining} more)")
return "\n\n".join(lines)