File size: 1,453 Bytes
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from pathlib import Path
from typing import Tuple
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

IM_SIZE = 224

# ImageNet normalization works well for ViT-based MAE backbones
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD  = [0.229, 0.224, 0.225]


def build_transforms(train: bool):
    if train:
        return transforms.Compose([
            transforms.Resize((IM_SIZE, IM_SIZE)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((IM_SIZE, IM_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ])


def get_loaders(root: str | Path, batch_size: int = 16, num_workers: int = 2) -> Tuple[DataLoader, DataLoader, list[str]]:
    root = Path(root)
    train_dir = root / "train"
    val_dir = root / "val"

    train_ds = datasets.ImageFolder(train_dir, transform=build_transforms(True))
    val_ds   = datasets.ImageFolder(val_dir,   transform=build_transforms(False))

    classes = train_ds.classes  # label names from folder names

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_dl, val_dl, classes