DDIM_Image_Generation / utils /visualize.py
NoobNovel's picture
DDIM face generation — full project
0ca4c93
"""Visualization helpers: image grids, denoising-trajectory GIFs, and
latent-interpolation grids.
All functions accept tensors in the [-1, 1] range (model output convention)
unless otherwise stated, and write/return uint8 arrays in [0, 255].
"""
from __future__ import annotations
import math
import os
from typing import List, Optional, Sequence
import numpy as np
import torch
from PIL import Image
# ---------------------------------------------------------------------------
# Small primitives
# ---------------------------------------------------------------------------
def to_uint8(x: torch.Tensor) -> np.ndarray:
"""Tensor in [-1, 1] (B,3,H,W) or (3,H,W) -> uint8 numpy (H,W,3) or (B,H,W,3)."""
x = x.detach().to(torch.float32).cpu()
x = (x.clamp(-1.0, 1.0) + 1.0) * 127.5
x = x.round().clamp(0, 255).to(torch.uint8)
if x.ndim == 4:
return x.permute(0, 2, 3, 1).numpy() # (B,H,W,3)
if x.ndim == 3:
return x.permute(1, 2, 0).numpy() # (H,W,3)
raise ValueError(f"unsupported shape {x.shape}")
def make_grid(images: torch.Tensor, nrow: Optional[int] = None, pad: int = 2,
pad_value: float = 1.0) -> np.ndarray:
"""Lay a batch of images out as a grid. Inputs in [-1, 1].
Returns uint8 (H, W, 3).
"""
if images.ndim != 4:
raise ValueError(f"expected (B,C,H,W), got {images.shape}")
B, C, H, W = images.shape
if nrow is None:
nrow = int(math.ceil(math.sqrt(B)))
ncol = int(math.ceil(B / nrow))
grid_h = ncol * H + (ncol + 1) * pad
grid_w = nrow * W + (nrow + 1) * pad
grid = torch.full((C, grid_h, grid_w), pad_value, dtype=images.dtype)
for i in range(B):
r, c = divmod(i, nrow)
y = pad + r * (H + pad)
x = pad + c * (W + pad)
grid[:, y:y + H, x:x + W] = images[i]
return to_uint8(grid)
def save_image_grid(images: torch.Tensor, path: str, nrow: Optional[int] = None) -> str:
arr = make_grid(images, nrow=nrow)
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
Image.fromarray(arr).save(path)
return path
# ---------------------------------------------------------------------------
# Denoising trajectory GIF
# ---------------------------------------------------------------------------
def trajectory_to_gif(
trajectory: Sequence[torch.Tensor],
path: str,
fps: int = 10,
nrow: Optional[int] = None,
) -> str:
"""Save a list of tensors (each (B,C,H,W) in [-1,1]) as an animated GIF.
Each frame is laid out as a grid of all batch items.
"""
import imageio.v2 as imageio # local import; heavy dep
frames = []
for x in trajectory:
if x.ndim == 3:
x = x.unsqueeze(0)
frames.append(make_grid(x, nrow=nrow))
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
duration = 1.0 / max(fps, 1)
imageio.mimsave(path, frames, format="GIF", duration=duration, loop=0)
return path
# ---------------------------------------------------------------------------
# Latent interpolation
# ---------------------------------------------------------------------------
def slerp(z1: torch.Tensor, z2: torch.Tensor, t: float) -> torch.Tensor:
"""Spherical linear interpolation between two same-shape latents.
Falls back to lerp if vectors are nearly colinear (avoids div-by-zero).
"""
flat1 = z1.flatten(start_dim=0)
flat2 = z2.flatten(start_dim=0)
dot = (flat1 * flat2).sum() / (flat1.norm() * flat2.norm() + 1e-12)
dot = dot.clamp(-1.0, 1.0)
omega = torch.acos(dot)
sin_omega = torch.sin(omega)
if sin_omega.abs() < 1e-6:
return (1 - t) * z1 + t * z2
a = torch.sin((1 - t) * omega) / sin_omega
b = torch.sin(t * omega) / sin_omega
return a * z1 + b * z2
def interpolate_latents(z1: torch.Tensor, z2: torch.Tensor, num_steps: int = 8,
method: str = "slerp") -> torch.Tensor:
"""Return a tensor of shape (num_steps, *z1.shape) of interpolated latents."""
ts = torch.linspace(0.0, 1.0, num_steps)
out = []
for t in ts:
if method == "slerp":
out.append(slerp(z1, z2, t.item()))
elif method == "lerp":
out.append((1 - t) * z1 + t * z2)
else:
raise ValueError(method)
return torch.stack(out, dim=0)
# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import tempfile
torch.manual_seed(0)
imgs = torch.randn(8, 3, 32, 32).clamp(-1, 1)
grid = make_grid(imgs, nrow=4)
assert grid.dtype == np.uint8 and grid.ndim == 3 and grid.shape[2] == 3
with tempfile.TemporaryDirectory() as td:
p1 = save_image_grid(imgs, os.path.join(td, "g.png"))
assert os.path.exists(p1)
traj = [torch.randn(4, 3, 16, 16).clamp(-1, 1) for _ in range(6)]
p2 = trajectory_to_gif(traj, os.path.join(td, "t.gif"), fps=8, nrow=2)
assert os.path.exists(p2) and os.path.getsize(p2) > 0
z1 = torch.randn(1, 3, 16, 16)
z2 = torch.randn(1, 3, 16, 16)
interps = interpolate_latents(z1, z2, num_steps=5, method="slerp")
assert interps.shape == (5, 1, 3, 16, 16)
# endpoints recovered
assert torch.allclose(interps[0], z1, atol=1e-5)
assert torch.allclose(interps[-1], z2, atol=1e-5)
print("visualize.py: all tests passed")