import torch from torch.utils.data import DataLoader from torchvision import datasets,transforms from src.path import DATA_DIR def get_dataloader(config): batch_size = config.get('batch_size',128) data_path = config.get('data_path',DATA_DIR) num_workers = config.get('num_workers',8) train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) val_transform = transforms.Compose([ transforms.ToTensor(), ]) print("đŸ“‚Loading CIFAR100 dataset...") train_data = datasets.CIFAR100(root=data_path,train=True,download=True,transform=train_transform) test_data = datasets.CIFAR100(root=data_path,train=False,download=True,transform=val_transform) train_loader = DataLoader( train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=True, persistent_workers=True, ) test_loader = DataLoader( test_data, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True, ) return train_loader,test_loader