import os from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split def get_dataloaders(data_dir, batch_size=32, val_split=0.2): transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) dataset = datasets.ImageFolder(data_dir, transform=transform) val_len = int(len(dataset) * val_split) train_len = len(dataset) - val_len train_set, val_set = random_split(dataset, [train_len, val_len]) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_set, batch_size=batch_size) return train_loader, val_loader, dataset.classes