Jepa / dataset_loader.py
Ananthusajeev190's picture
Upload 5 files
046e256 verified
raw
history blame contribute delete
452 Bytes
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def get_dataloader(batch_size=16, img_size=128, data_path="./data"):
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor()
])
dataset = datasets.ImageFolder(root=data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader