"""Visualization helpers: image grids, denoising-trajectory GIFs, and latent-interpolation grids. All functions accept tensors in the [-1, 1] range (model output convention) unless otherwise stated, and write/return uint8 arrays in [0, 255]. """ from __future__ import annotations import math import os from typing import List, Optional, Sequence import numpy as np import torch from PIL import Image # --------------------------------------------------------------------------- # Small primitives # --------------------------------------------------------------------------- def to_uint8(x: torch.Tensor) -> np.ndarray: """Tensor in [-1, 1] (B,3,H,W) or (3,H,W) -> uint8 numpy (H,W,3) or (B,H,W,3).""" x = x.detach().to(torch.float32).cpu() x = (x.clamp(-1.0, 1.0) + 1.0) * 127.5 x = x.round().clamp(0, 255).to(torch.uint8) if x.ndim == 4: return x.permute(0, 2, 3, 1).numpy() # (B,H,W,3) if x.ndim == 3: return x.permute(1, 2, 0).numpy() # (H,W,3) raise ValueError(f"unsupported shape {x.shape}") def make_grid(images: torch.Tensor, nrow: Optional[int] = None, pad: int = 2, pad_value: float = 1.0) -> np.ndarray: """Lay a batch of images out as a grid. Inputs in [-1, 1]. Returns uint8 (H, W, 3). """ if images.ndim != 4: raise ValueError(f"expected (B,C,H,W), got {images.shape}") B, C, H, W = images.shape if nrow is None: nrow = int(math.ceil(math.sqrt(B))) ncol = int(math.ceil(B / nrow)) grid_h = ncol * H + (ncol + 1) * pad grid_w = nrow * W + (nrow + 1) * pad grid = torch.full((C, grid_h, grid_w), pad_value, dtype=images.dtype) for i in range(B): r, c = divmod(i, nrow) y = pad + r * (H + pad) x = pad + c * (W + pad) grid[:, y:y + H, x:x + W] = images[i] return to_uint8(grid) def save_image_grid(images: torch.Tensor, path: str, nrow: Optional[int] = None) -> str: arr = make_grid(images, nrow=nrow) os.makedirs(os.path.dirname(path) or ".", exist_ok=True) Image.fromarray(arr).save(path) return path # --------------------------------------------------------------------------- # Denoising trajectory GIF # --------------------------------------------------------------------------- def trajectory_to_gif( trajectory: Sequence[torch.Tensor], path: str, fps: int = 10, nrow: Optional[int] = None, ) -> str: """Save a list of tensors (each (B,C,H,W) in [-1,1]) as an animated GIF. Each frame is laid out as a grid of all batch items. """ import imageio.v2 as imageio # local import; heavy dep frames = [] for x in trajectory: if x.ndim == 3: x = x.unsqueeze(0) frames.append(make_grid(x, nrow=nrow)) os.makedirs(os.path.dirname(path) or ".", exist_ok=True) duration = 1.0 / max(fps, 1) imageio.mimsave(path, frames, format="GIF", duration=duration, loop=0) return path # --------------------------------------------------------------------------- # Latent interpolation # --------------------------------------------------------------------------- def slerp(z1: torch.Tensor, z2: torch.Tensor, t: float) -> torch.Tensor: """Spherical linear interpolation between two same-shape latents. Falls back to lerp if vectors are nearly colinear (avoids div-by-zero). """ flat1 = z1.flatten(start_dim=0) flat2 = z2.flatten(start_dim=0) dot = (flat1 * flat2).sum() / (flat1.norm() * flat2.norm() + 1e-12) dot = dot.clamp(-1.0, 1.0) omega = torch.acos(dot) sin_omega = torch.sin(omega) if sin_omega.abs() < 1e-6: return (1 - t) * z1 + t * z2 a = torch.sin((1 - t) * omega) / sin_omega b = torch.sin(t * omega) / sin_omega return a * z1 + b * z2 def interpolate_latents(z1: torch.Tensor, z2: torch.Tensor, num_steps: int = 8, method: str = "slerp") -> torch.Tensor: """Return a tensor of shape (num_steps, *z1.shape) of interpolated latents.""" ts = torch.linspace(0.0, 1.0, num_steps) out = [] for t in ts: if method == "slerp": out.append(slerp(z1, z2, t.item())) elif method == "lerp": out.append((1 - t) * z1 + t * z2) else: raise ValueError(method) return torch.stack(out, dim=0) # --------------------------------------------------------------------------- # Self-test # --------------------------------------------------------------------------- if __name__ == "__main__": import tempfile torch.manual_seed(0) imgs = torch.randn(8, 3, 32, 32).clamp(-1, 1) grid = make_grid(imgs, nrow=4) assert grid.dtype == np.uint8 and grid.ndim == 3 and grid.shape[2] == 3 with tempfile.TemporaryDirectory() as td: p1 = save_image_grid(imgs, os.path.join(td, "g.png")) assert os.path.exists(p1) traj = [torch.randn(4, 3, 16, 16).clamp(-1, 1) for _ in range(6)] p2 = trajectory_to_gif(traj, os.path.join(td, "t.gif"), fps=8, nrow=2) assert os.path.exists(p2) and os.path.getsize(p2) > 0 z1 = torch.randn(1, 3, 16, 16) z2 = torch.randn(1, 3, 16, 16) interps = interpolate_latents(z1, z2, num_steps=5, method="slerp") assert interps.shape == (5, 1, 3, 16, 16) # endpoints recovered assert torch.allclose(interps[0], z1, atol=1e-5) assert torch.allclose(interps[-1], z2, atol=1e-5) print("visualize.py: all tests passed")