PneumoniaAPI / src /utils.py
GitHub Actions
Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6
af59988
raw
history blame
2.06 kB
"""
Utility functions for visualization and helpers.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional
from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
"""Denormalize image tensor from ImageNet normalization."""
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
return tensor * std + mean
def show_batch(
images: torch.Tensor,
labels: torch.Tensor,
predictions: Optional[torch.Tensor] = None,
n_images: int = 8,
save_path: Optional[str] = None
):
"""Display a batch of images with labels."""
n_images = min(n_images, len(images))
cols = 4
rows = (n_images + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
axes = axes.flatten() if rows > 1 else [axes] if cols == 1 else axes
for idx in range(n_images):
img = denormalize(images[idx]).permute(1, 2, 0).numpy()
img = np.clip(img, 0, 1)
axes[idx].imshow(img)
axes[idx].axis('off')
label = CLASS_NAMES[labels[idx]]
title = f"True: {label}"
if predictions is not None:
pred = CLASS_NAMES[predictions[idx]]
color = 'green' if pred == label else 'red'
title += f"\nPred: {pred}"
axes[idx].set_title(title, color=color, fontsize=10)
else:
axes[idx].set_title(title, fontsize=10)
# Hide empty subplots
for idx in range(n_images, len(axes)):
axes[idx].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
def set_seed(seed: int = 42):
"""Set random seed for reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)