File size: 2,055 Bytes
af59988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Utility functions for visualization and helpers.
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional

from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES


def denormalize(tensor: torch.Tensor) -> torch.Tensor:
    """Denormalize image tensor from ImageNet normalization."""
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    return tensor * std + mean


def show_batch(
    images: torch.Tensor,
    labels: torch.Tensor,
    predictions: Optional[torch.Tensor] = None,
    n_images: int = 8,
    save_path: Optional[str] = None
):
    """Display a batch of images with labels."""
    n_images = min(n_images, len(images))
    cols = 4
    rows = (n_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
    axes = axes.flatten() if rows > 1 else [axes] if cols == 1 else axes

    for idx in range(n_images):
        img = denormalize(images[idx]).permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)

        axes[idx].imshow(img)
        axes[idx].axis('off')

        label = CLASS_NAMES[labels[idx]]
        title = f"True: {label}"

        if predictions is not None:
            pred = CLASS_NAMES[predictions[idx]]
            color = 'green' if pred == label else 'red'
            title += f"\nPred: {pred}"
            axes[idx].set_title(title, color=color, fontsize=10)
        else:
            axes[idx].set_title(title, fontsize=10)

    # Hide empty subplots
    for idx in range(n_images, len(axes)):
        axes[idx].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    plt.show()


def set_seed(seed: int = 42):
    """Set random seed for reproducibility."""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)