File size: 1,902 Bytes
fab18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Visualization helpers – save sample grids and plot loss curves."""

from pathlib import Path
import torch
import torch.nn as nn
from torchvision.utils import save_image, make_grid
import numpy as np


def save_sample_grid(
    generator: nn.Module,
    noise: torch.Tensor,
    path: str | Path,
    nrow: int = 4,
) -> None:
    """Generate images from fixed noise and save as a grid PNG."""
    generator.eval()
    with torch.no_grad():
        fake = generator(noise).cpu()
    # De-normalise from [-1, 1] → [0, 1]
    fake = (fake + 1) / 2
    save_image(fake, str(path), nrow=nrow, padding=2)
    generator.train()
    print(f"[Viz] Sample grid saved → {path}")


def tensor_to_pil(tensor: torch.Tensor):
    """Convert a single image tensor [-1,1] to a PIL Image."""
    from PIL import Image
    img = (tensor.squeeze().permute(1, 2, 0).cpu().numpy() + 1) / 2
    img = (img * 255).clip(0, 255).astype(np.uint8)
    return Image.fromarray(img)


def plot_loss_curves(history: list[dict], save_path: str | Path | None = None):
    """Plot generator and discriminator loss curves using matplotlib."""
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("[Viz] matplotlib not installed – skipping loss plot.")
        return

    epochs = [h["epoch"] for h in history]
    g_loss = [h["g_loss"] for h in history]
    d_loss = [h["d_loss"] for h in history]

    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(epochs, g_loss, label="Generator", linewidth=1.5)
    ax.plot(epochs, d_loss, label="Discriminator", linewidth=1.5)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title("GAN Training Losses")
    ax.legend()
    ax.grid(alpha=0.3)

    if save_path:
        fig.savefig(str(save_path), dpi=120, bbox_inches="tight")
        print(f"[Viz] Loss curve saved → {save_path}")
    else:
        plt.show()
    plt.close(fig)