| """Sampling utilities — generate and save image grids during training.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| from torchvision.utils import make_grid, save_image |
|
|
| from diffusion import GaussianDiffusion |
|
|
|
|
| @torch.no_grad() |
| def sample_and_save( |
| model: nn.Module, |
| diffusion: GaussianDiffusion, |
| step: int, |
| save_dir: str, |
| device: torch.device, |
| fixed_noise: torch.Tensor, |
| ema_model: nn.Module | None = None, |
| ) -> None: |
| """Generate a grid of samples from fixed noise and save as PNG. |
| |
| Uses the same fixed noise at every call so you can watch the same |
| "seeds" evolve as training progresses. |
| """ |
| |
| m = ema_model if ema_model is not None else model |
| m.eval() |
|
|
| samples = diffusion.p_sample_loop( |
| m, |
| shape=fixed_noise.shape, |
| device=device, |
| noise=fixed_noise, |
| ) |
|
|
| m.train() |
|
|
| |
| samples = (samples + 1.0) * 0.5 |
| samples = samples.clamp(0.0, 1.0) |
|
|
| grid = make_grid(samples, nrow=int(fixed_noise.shape[0] ** 0.5)) |
|
|
| os.makedirs(save_dir, exist_ok=True) |
| save_path = os.path.join(save_dir, f"sample_{step:07d}.png") |
| save_image(grid, save_path) |
|
|