| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets,transforms | |
| from src.path import DATA_DIR | |
| def get_dataloader(config): | |
| batch_size = config.get('batch_size',64) | |
| data_path = config.get('data_path',DATA_DIR) | |
| num_workers = config.get('num_workers',4) | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean,std), | |
| ]) | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean,std), | |
| ]) | |
| train_data = datasets.CIFAR10(root=data_path,train=True,download=True,transform=train_transform) | |
| test_data = datasets.CIFAR10(root=data_path,train=False,download=True,transform=val_transform) | |
| train_loader = DataLoader( | |
| train_data, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| shuffle=True, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| test_loader = DataLoader( | |
| test_data, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| shuffle=False, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| return train_loader,test_loader |