File size: 2,739 Bytes
233caeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""

Data loading and preprocessing for CIFAR-10 dataset

"""
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import config


def get_transforms(train=True):
    """

    Get data transformations for training or testing

    

    Args:

        train (bool): If True, returns training transforms with augmentation

        

    Returns:

        torchvision.transforms.Compose: Composed transforms

    """
    if train and config.USE_AUGMENTATION:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=config.RANDOM_CROP_PADDING),
            transforms.RandomHorizontalFlip(p=config.RANDOM_HORIZONTAL_FLIP),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2470, 0.2435, 0.2616]
            )
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2470, 0.2435, 0.2616]
            )
        ])
    
    return transform


def get_data_loaders():
    """

    Create train and test data loaders for CIFAR-10

    

    Returns:

        tuple: (train_loader, test_loader)

    """
    # Get transforms
    train_transform = get_transforms(train=True)
    test_transform = get_transforms(train=False)
    
    # Load datasets
    train_dataset = datasets.CIFAR10(
        root=config.DATA_DIR,
        train=True,
        download=True,
        transform=train_transform
    )
    
    test_dataset = datasets.CIFAR10(
        root=config.DATA_DIR,
        train=False,
        download=True,
        transform=test_transform
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True if config.DEVICE.type == 'cuda' else False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=True if config.DEVICE.type == 'cuda' else False
    )
    
    return train_loader, test_loader


def denormalize(tensor):
    """

    Denormalize a tensor image for visualization

    

    Args:

        tensor: Normalized tensor image

        

    Returns:

        tensor: Denormalized tensor image

    """
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
    return tensor * std + mean