PneumoniaAPI / src /utils.py
GitHub Actions
Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6
af59988
"""
Utility functions for visualization and helpers.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional
from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
"""Denormalize image tensor from ImageNet normalization."""
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
return tensor * std + mean
def show_batch(
images: torch.Tensor,
labels: torch.Tensor,
predictions: Optional[torch.Tensor] = None,
n_images: int = 8,
save_path: Optional[str] = None
):
"""Display a batch of images with labels."""
n_images = min(n_images, len(images))
cols = 4
rows = (n_images + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
axes = axes.flatten() if rows > 1 else [axes] if cols == 1 else axes
for idx in range(n_images):
img = denormalize(images[idx]).permute(1, 2, 0).numpy()
img = np.clip(img, 0, 1)
axes[idx].imshow(img)
axes[idx].axis('off')
label = CLASS_NAMES[labels[idx]]
title = f"True: {label}"
if predictions is not None:
pred = CLASS_NAMES[predictions[idx]]
color = 'green' if pred == label else 'red'
title += f"\nPred: {pred}"
axes[idx].set_title(title, color=color, fontsize=10)
else:
axes[idx].set_title(title, fontsize=10)
# Hide empty subplots
for idx in range(n_images, len(axes)):
axes[idx].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
def set_seed(seed: int = 42):
"""Set random seed for reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)