Spaces:
Sleeping
Sleeping
File size: 5,496 Bytes
0ca4c93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """Visualization helpers: image grids, denoising-trajectory GIFs, and
latent-interpolation grids.
All functions accept tensors in the [-1, 1] range (model output convention)
unless otherwise stated, and write/return uint8 arrays in [0, 255].
"""
from __future__ import annotations
import math
import os
from typing import List, Optional, Sequence
import numpy as np
import torch
from PIL import Image
# ---------------------------------------------------------------------------
# Small primitives
# ---------------------------------------------------------------------------
def to_uint8(x: torch.Tensor) -> np.ndarray:
"""Tensor in [-1, 1] (B,3,H,W) or (3,H,W) -> uint8 numpy (H,W,3) or (B,H,W,3)."""
x = x.detach().to(torch.float32).cpu()
x = (x.clamp(-1.0, 1.0) + 1.0) * 127.5
x = x.round().clamp(0, 255).to(torch.uint8)
if x.ndim == 4:
return x.permute(0, 2, 3, 1).numpy() # (B,H,W,3)
if x.ndim == 3:
return x.permute(1, 2, 0).numpy() # (H,W,3)
raise ValueError(f"unsupported shape {x.shape}")
def make_grid(images: torch.Tensor, nrow: Optional[int] = None, pad: int = 2,
pad_value: float = 1.0) -> np.ndarray:
"""Lay a batch of images out as a grid. Inputs in [-1, 1].
Returns uint8 (H, W, 3).
"""
if images.ndim != 4:
raise ValueError(f"expected (B,C,H,W), got {images.shape}")
B, C, H, W = images.shape
if nrow is None:
nrow = int(math.ceil(math.sqrt(B)))
ncol = int(math.ceil(B / nrow))
grid_h = ncol * H + (ncol + 1) * pad
grid_w = nrow * W + (nrow + 1) * pad
grid = torch.full((C, grid_h, grid_w), pad_value, dtype=images.dtype)
for i in range(B):
r, c = divmod(i, nrow)
y = pad + r * (H + pad)
x = pad + c * (W + pad)
grid[:, y:y + H, x:x + W] = images[i]
return to_uint8(grid)
def save_image_grid(images: torch.Tensor, path: str, nrow: Optional[int] = None) -> str:
arr = make_grid(images, nrow=nrow)
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
Image.fromarray(arr).save(path)
return path
# ---------------------------------------------------------------------------
# Denoising trajectory GIF
# ---------------------------------------------------------------------------
def trajectory_to_gif(
trajectory: Sequence[torch.Tensor],
path: str,
fps: int = 10,
nrow: Optional[int] = None,
) -> str:
"""Save a list of tensors (each (B,C,H,W) in [-1,1]) as an animated GIF.
Each frame is laid out as a grid of all batch items.
"""
import imageio.v2 as imageio # local import; heavy dep
frames = []
for x in trajectory:
if x.ndim == 3:
x = x.unsqueeze(0)
frames.append(make_grid(x, nrow=nrow))
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
duration = 1.0 / max(fps, 1)
imageio.mimsave(path, frames, format="GIF", duration=duration, loop=0)
return path
# ---------------------------------------------------------------------------
# Latent interpolation
# ---------------------------------------------------------------------------
def slerp(z1: torch.Tensor, z2: torch.Tensor, t: float) -> torch.Tensor:
"""Spherical linear interpolation between two same-shape latents.
Falls back to lerp if vectors are nearly colinear (avoids div-by-zero).
"""
flat1 = z1.flatten(start_dim=0)
flat2 = z2.flatten(start_dim=0)
dot = (flat1 * flat2).sum() / (flat1.norm() * flat2.norm() + 1e-12)
dot = dot.clamp(-1.0, 1.0)
omega = torch.acos(dot)
sin_omega = torch.sin(omega)
if sin_omega.abs() < 1e-6:
return (1 - t) * z1 + t * z2
a = torch.sin((1 - t) * omega) / sin_omega
b = torch.sin(t * omega) / sin_omega
return a * z1 + b * z2
def interpolate_latents(z1: torch.Tensor, z2: torch.Tensor, num_steps: int = 8,
method: str = "slerp") -> torch.Tensor:
"""Return a tensor of shape (num_steps, *z1.shape) of interpolated latents."""
ts = torch.linspace(0.0, 1.0, num_steps)
out = []
for t in ts:
if method == "slerp":
out.append(slerp(z1, z2, t.item()))
elif method == "lerp":
out.append((1 - t) * z1 + t * z2)
else:
raise ValueError(method)
return torch.stack(out, dim=0)
# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import tempfile
torch.manual_seed(0)
imgs = torch.randn(8, 3, 32, 32).clamp(-1, 1)
grid = make_grid(imgs, nrow=4)
assert grid.dtype == np.uint8 and grid.ndim == 3 and grid.shape[2] == 3
with tempfile.TemporaryDirectory() as td:
p1 = save_image_grid(imgs, os.path.join(td, "g.png"))
assert os.path.exists(p1)
traj = [torch.randn(4, 3, 16, 16).clamp(-1, 1) for _ in range(6)]
p2 = trajectory_to_gif(traj, os.path.join(td, "t.gif"), fps=8, nrow=2)
assert os.path.exists(p2) and os.path.getsize(p2) > 0
z1 = torch.randn(1, 3, 16, 16)
z2 = torch.randn(1, 3, 16, 16)
interps = interpolate_latents(z1, z2, num_steps=5, method="slerp")
assert interps.shape == (5, 1, 3, 16, 16)
# endpoints recovered
assert torch.allclose(interps[0], z1, atol=1e-5)
assert torch.allclose(interps[-1], z2, atol=1e-5)
print("visualize.py: all tests passed")
|