from pathlib import Path from typing import Tuple from torchvision import datasets, transforms from torch.utils.data import DataLoader IM_SIZE = 224 # ImageNet normalization works well for ViT-based MAE backbones NORM_MEAN = [0.485, 0.456, 0.406] NORM_STD = [0.229, 0.224, 0.225] def build_transforms(train: bool): if train: return transforms.Compose([ transforms.Resize((IM_SIZE, IM_SIZE)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=NORM_MEAN, std=NORM_STD), ]) else: return transforms.Compose([ transforms.Resize((IM_SIZE, IM_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=NORM_MEAN, std=NORM_STD), ]) def get_loaders(root: str | Path, batch_size: int = 16, num_workers: int = 2) -> Tuple[DataLoader, DataLoader, list[str]]: root = Path(root) train_dir = root / "train" val_dir = root / "val" train_ds = datasets.ImageFolder(train_dir, transform=build_transforms(True)) val_ds = datasets.ImageFolder(val_dir, transform=build_transforms(False)) classes = train_ds.classes # label names from folder names train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_dl, val_dl, classes