Spaces:
Paused
Paused
| """Visualization helpers β save sample grids and plot loss curves.""" | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.utils import save_image, make_grid | |
| import numpy as np | |
| def save_sample_grid( | |
| generator: nn.Module, | |
| noise: torch.Tensor, | |
| path: str | Path, | |
| nrow: int = 4, | |
| ) -> None: | |
| """Generate images from fixed noise and save as a grid PNG.""" | |
| generator.eval() | |
| with torch.no_grad(): | |
| fake = generator(noise).cpu() | |
| # De-normalise from [-1, 1] β [0, 1] | |
| fake = (fake + 1) / 2 | |
| save_image(fake, str(path), nrow=nrow, padding=2) | |
| generator.train() | |
| print(f"[Viz] Sample grid saved β {path}") | |
| def tensor_to_pil(tensor: torch.Tensor): | |
| """Convert a single image tensor [-1,1] to a PIL Image.""" | |
| from PIL import Image | |
| img = (tensor.squeeze().permute(1, 2, 0).cpu().numpy() + 1) / 2 | |
| img = (img * 255).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(img) | |
| def plot_loss_curves(history: list[dict], save_path: str | Path | None = None): | |
| """Plot generator and discriminator loss curves using matplotlib.""" | |
| try: | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("[Viz] matplotlib not installed β skipping loss plot.") | |
| return | |
| epochs = [h["epoch"] for h in history] | |
| g_loss = [h["g_loss"] for h in history] | |
| d_loss = [h["d_loss"] for h in history] | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| ax.plot(epochs, g_loss, label="Generator", linewidth=1.5) | |
| ax.plot(epochs, d_loss, label="Discriminator", linewidth=1.5) | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel("Loss") | |
| ax.set_title("GAN Training Losses") | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| if save_path: | |
| fig.savefig(str(save_path), dpi=120, bbox_inches="tight") | |
| print(f"[Viz] Loss curve saved β {save_path}") | |
| else: | |
| plt.show() | |
| plt.close(fig) | |