File size: 5,496 Bytes
0ca4c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Visualization helpers: image grids, denoising-trajectory GIFs, and
latent-interpolation grids.

All functions accept tensors in the [-1, 1] range (model output convention)
unless otherwise stated, and write/return uint8 arrays in [0, 255].
"""
from __future__ import annotations

import math
import os
from typing import List, Optional, Sequence

import numpy as np
import torch
from PIL import Image


# ---------------------------------------------------------------------------
# Small primitives
# ---------------------------------------------------------------------------
def to_uint8(x: torch.Tensor) -> np.ndarray:
    """Tensor in [-1, 1] (B,3,H,W) or (3,H,W) -> uint8 numpy (H,W,3) or (B,H,W,3)."""
    x = x.detach().to(torch.float32).cpu()
    x = (x.clamp(-1.0, 1.0) + 1.0) * 127.5
    x = x.round().clamp(0, 255).to(torch.uint8)
    if x.ndim == 4:
        return x.permute(0, 2, 3, 1).numpy()                # (B,H,W,3)
    if x.ndim == 3:
        return x.permute(1, 2, 0).numpy()                   # (H,W,3)
    raise ValueError(f"unsupported shape {x.shape}")


def make_grid(images: torch.Tensor, nrow: Optional[int] = None, pad: int = 2,
              pad_value: float = 1.0) -> np.ndarray:
    """Lay a batch of images out as a grid. Inputs in [-1, 1].

    Returns uint8 (H, W, 3).
    """
    if images.ndim != 4:
        raise ValueError(f"expected (B,C,H,W), got {images.shape}")
    B, C, H, W = images.shape
    if nrow is None:
        nrow = int(math.ceil(math.sqrt(B)))
    ncol = int(math.ceil(B / nrow))

    grid_h = ncol * H + (ncol + 1) * pad
    grid_w = nrow * W + (nrow + 1) * pad

    grid = torch.full((C, grid_h, grid_w), pad_value, dtype=images.dtype)
    for i in range(B):
        r, c = divmod(i, nrow)
        y = pad + r * (H + pad)
        x = pad + c * (W + pad)
        grid[:, y:y + H, x:x + W] = images[i]

    return to_uint8(grid)


def save_image_grid(images: torch.Tensor, path: str, nrow: Optional[int] = None) -> str:
    arr = make_grid(images, nrow=nrow)
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    Image.fromarray(arr).save(path)
    return path


# ---------------------------------------------------------------------------
# Denoising trajectory GIF
# ---------------------------------------------------------------------------
def trajectory_to_gif(
    trajectory: Sequence[torch.Tensor],
    path: str,
    fps: int = 10,
    nrow: Optional[int] = None,
) -> str:
    """Save a list of tensors (each (B,C,H,W) in [-1,1]) as an animated GIF.

    Each frame is laid out as a grid of all batch items.
    """
    import imageio.v2 as imageio       # local import; heavy dep

    frames = []
    for x in trajectory:
        if x.ndim == 3:
            x = x.unsqueeze(0)
        frames.append(make_grid(x, nrow=nrow))

    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    duration = 1.0 / max(fps, 1)
    imageio.mimsave(path, frames, format="GIF", duration=duration, loop=0)
    return path


# ---------------------------------------------------------------------------
# Latent interpolation
# ---------------------------------------------------------------------------
def slerp(z1: torch.Tensor, z2: torch.Tensor, t: float) -> torch.Tensor:
    """Spherical linear interpolation between two same-shape latents.

    Falls back to lerp if vectors are nearly colinear (avoids div-by-zero).
    """
    flat1 = z1.flatten(start_dim=0)
    flat2 = z2.flatten(start_dim=0)
    dot = (flat1 * flat2).sum() / (flat1.norm() * flat2.norm() + 1e-12)
    dot = dot.clamp(-1.0, 1.0)
    omega = torch.acos(dot)
    sin_omega = torch.sin(omega)
    if sin_omega.abs() < 1e-6:
        return (1 - t) * z1 + t * z2
    a = torch.sin((1 - t) * omega) / sin_omega
    b = torch.sin(t * omega) / sin_omega
    return a * z1 + b * z2


def interpolate_latents(z1: torch.Tensor, z2: torch.Tensor, num_steps: int = 8,
                        method: str = "slerp") -> torch.Tensor:
    """Return a tensor of shape (num_steps, *z1.shape) of interpolated latents."""
    ts = torch.linspace(0.0, 1.0, num_steps)
    out = []
    for t in ts:
        if method == "slerp":
            out.append(slerp(z1, z2, t.item()))
        elif method == "lerp":
            out.append((1 - t) * z1 + t * z2)
        else:
            raise ValueError(method)
    return torch.stack(out, dim=0)


# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import tempfile

    torch.manual_seed(0)
    imgs = torch.randn(8, 3, 32, 32).clamp(-1, 1)

    grid = make_grid(imgs, nrow=4)
    assert grid.dtype == np.uint8 and grid.ndim == 3 and grid.shape[2] == 3

    with tempfile.TemporaryDirectory() as td:
        p1 = save_image_grid(imgs, os.path.join(td, "g.png"))
        assert os.path.exists(p1)

        traj = [torch.randn(4, 3, 16, 16).clamp(-1, 1) for _ in range(6)]
        p2 = trajectory_to_gif(traj, os.path.join(td, "t.gif"), fps=8, nrow=2)
        assert os.path.exists(p2) and os.path.getsize(p2) > 0

    z1 = torch.randn(1, 3, 16, 16)
    z2 = torch.randn(1, 3, 16, 16)
    interps = interpolate_latents(z1, z2, num_steps=5, method="slerp")
    assert interps.shape == (5, 1, 3, 16, 16)
    # endpoints recovered
    assert torch.allclose(interps[0], z1, atol=1e-5)
    assert torch.allclose(interps[-1], z2, atol=1e-5)

    print("visualize.py: all tests passed")