CarGANDemo / src /utils /visualization.py
Parsa2025AI's picture
Upload folder using huggingface_hub
fab18b7 verified
"""Visualization helpers – save sample grids and plot loss curves."""
from pathlib import Path
import torch
import torch.nn as nn
from torchvision.utils import save_image, make_grid
import numpy as np
def save_sample_grid(
generator: nn.Module,
noise: torch.Tensor,
path: str | Path,
nrow: int = 4,
) -> None:
"""Generate images from fixed noise and save as a grid PNG."""
generator.eval()
with torch.no_grad():
fake = generator(noise).cpu()
# De-normalise from [-1, 1] β†’ [0, 1]
fake = (fake + 1) / 2
save_image(fake, str(path), nrow=nrow, padding=2)
generator.train()
print(f"[Viz] Sample grid saved β†’ {path}")
def tensor_to_pil(tensor: torch.Tensor):
"""Convert a single image tensor [-1,1] to a PIL Image."""
from PIL import Image
img = (tensor.squeeze().permute(1, 2, 0).cpu().numpy() + 1) / 2
img = (img * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(img)
def plot_loss_curves(history: list[dict], save_path: str | Path | None = None):
"""Plot generator and discriminator loss curves using matplotlib."""
try:
import matplotlib.pyplot as plt
except ImportError:
print("[Viz] matplotlib not installed – skipping loss plot.")
return
epochs = [h["epoch"] for h in history]
g_loss = [h["g_loss"] for h in history]
d_loss = [h["d_loss"] for h in history]
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(epochs, g_loss, label="Generator", linewidth=1.5)
ax.plot(epochs, d_loss, label="Discriminator", linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("GAN Training Losses")
ax.legend()
ax.grid(alpha=0.3)
if save_path:
fig.savefig(str(save_path), dpi=120, bbox_inches="tight")
print(f"[Viz] Loss curve saved β†’ {save_path}")
else:
plt.show()
plt.close(fig)