Spaces:
Sleeping
Sleeping
File size: 758 Bytes
2222d7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
def get_dataloaders(data_dir, batch_size=32, val_split=0.2):
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
dataset = datasets.ImageFolder(data_dir, transform=transform)
val_len = int(len(dataset) * val_split)
train_len = len(dataset) - val_len
train_set, val_set = random_split(dataset, [train_len, val_len])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size)
return train_loader, val_loader, dataset.classes
|