Spaces:
Build error
Build error
File size: 1,645 Bytes
84d0c9e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def get_transforms(image_size=224):
train_transforms = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
val_test_transforms = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
return train_transforms, val_test_transforms
def get_dataloaders(data_dir, batch_size=32, image_size=224, num_workers=2):
train_transforms, val_test_transforms = get_transforms(image_size)
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')
test_dir = os.path.join(data_dir, 'test')
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)
test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
class_names = train_dataset.classes
return train_loader, val_loader, test_loader, class_names
|