Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Tuple | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| IM_SIZE = 224 | |
| # ImageNet normalization works well for ViT-based MAE backbones | |
| NORM_MEAN = [0.485, 0.456, 0.406] | |
| NORM_STD = [0.229, 0.224, 0.225] | |
| def build_transforms(train: bool): | |
| if train: | |
| return transforms.Compose([ | |
| transforms.Resize((IM_SIZE, IM_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=NORM_MEAN, std=NORM_STD), | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((IM_SIZE, IM_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=NORM_MEAN, std=NORM_STD), | |
| ]) | |
| def get_loaders(root: str | Path, batch_size: int = 16, num_workers: int = 2) -> Tuple[DataLoader, DataLoader, list[str]]: | |
| root = Path(root) | |
| train_dir = root / "train" | |
| val_dir = root / "val" | |
| train_ds = datasets.ImageFolder(train_dir, transform=build_transforms(True)) | |
| val_ds = datasets.ImageFolder(val_dir, transform=build_transforms(False)) | |
| classes = train_ds.classes # label names from folder names | |
| train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
| val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) | |
| return train_dl, val_dl, classes |