Retina_Training / dataset.py
Habeeb Okunade
Develop model training
39ec591
from pathlib import Path
from typing import Tuple
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
IM_SIZE = 224
# ImageNet normalization works well for ViT-based MAE backbones
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]
def build_transforms(train: bool):
if train:
return transforms.Compose([
transforms.Resize((IM_SIZE, IM_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])
else:
return transforms.Compose([
transforms.Resize((IM_SIZE, IM_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])
def get_loaders(root: str | Path, batch_size: int = 16, num_workers: int = 2) -> Tuple[DataLoader, DataLoader, list[str]]:
root = Path(root)
train_dir = root / "train"
val_dir = root / "val"
train_ds = datasets.ImageFolder(train_dir, transform=build_transforms(True))
val_ds = datasets.ImageFolder(val_dir, transform=build_transforms(False))
classes = train_ds.classes # label names from folder names
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_dl, val_dl, classes