import os from torchvision import datasets, transforms from torch.utils.data import DataLoader def get_transforms(image_size=224): train_transforms = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) val_test_transforms = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) return train_transforms, val_test_transforms def get_dataloaders(data_dir, batch_size=32, image_size=224, num_workers=2): train_transforms, val_test_transforms = get_transforms(image_size) train_dir = os.path.join(data_dir, 'train') val_dir = os.path.join(data_dir, 'val') test_dir = os.path.join(data_dir, 'test') train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms) val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms) test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) class_names = train_dataset.classes return train_loader, val_loader, test_loader, class_names