File size: 3,239 Bytes
eb52c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Visualization utilities for evaluation.

Functions:
- make_image_grid: Create and optionally save a grid of images
- visualize_denoising: Visualize the denoising process
- format_prompt_caption: Format prompts for display in image captions
"""

import os
from typing import List, Optional

import numpy as np
import torch
import torchvision


def make_image_grid(
    images: torch.Tensor,
    rows: Optional[int] = None,
    cols: Optional[int] = None,
    save_path: Optional[str] = None,
    normalize: bool = True,
    value_range: Optional[tuple] = None,
) -> torch.Tensor:
    """
    Create a grid of images and optionally save it.

    Args:
        images: Tensor of shape [B, C, H, W]
        rows: Number of rows (optional)
        cols: Number of columns (optional)
        save_path: Path to save the grid image
        normalize: Whether to normalize images to [0, 1]
        value_range: Range of values in input images (min, max)

    Returns:
        Grid tensor
    """
    if rows is None and cols is None:
        nrow = int(np.ceil(np.sqrt(images.shape[0])))
    elif cols is not None:
        nrow = cols
    else:
        nrow = int(np.ceil(images.shape[0] / rows))

    grid = torchvision.utils.make_grid(
        images, nrow=nrow, normalize=normalize, value_range=value_range, padding=2
    )

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torchvision.utils.save_image(grid, save_path)

    return grid


def visualize_denoising(
    intermediate_steps: List[torch.Tensor], save_path: str, num_steps_to_show: int = 10
):
    """
    Visualize the denoising process by selecting a subset of steps.

    Args:
        intermediate_steps: List of tensors [B, C, H, W] from the sampling process
        save_path: Path to save the visualization
        num_steps_to_show: Number of steps to display
    """
    total_steps = len(intermediate_steps)
    if total_steps < num_steps_to_show:
        indices = list(range(total_steps))
    else:
        indices = np.linspace(0, total_steps - 1, num_steps_to_show, dtype=int).tolist()

    selected_steps = [intermediate_steps[i] for i in indices]

    # Take the first sample from the batch for visualization
    first_sample_steps = [step[0] for step in selected_steps]  # List of [C, H, W]

    # Stack them: [Num_steps, C, H, W]
    stacked = torch.stack(first_sample_steps)

    # Make grid: 1 row, Num_steps columns
    make_image_grid(
        stacked,
        rows=1,
        cols=len(selected_steps),
        save_path=save_path,
        normalize=True,
        value_range=(-1, 1),
    )


def format_prompt_caption(prompts: List[str], limit: int = 32) -> str:
    """
    Format a list of prompts for display as an image caption.

    Args:
        prompts: List of prompt strings
        limit: Maximum number of prompts to include

    Returns:
        Formatted caption string
    """
    if not prompts:
        return ""
    trimmed = [p.replace("\n", " ").strip() for p in prompts[:limit]]
    lines = [f"{idx + 1}. {text}" for idx, text in enumerate(trimmed)]
    remaining = len(prompts) - len(trimmed)
    if remaining > 0:
        lines.append(f"... (+{remaining} more)")
    return "\n\n".join(lines)