davidlsan's picture
Upload dataset.py
9363cd3 verified
raw
history blame contribute delete
392 Bytes
from torch.utils.data import Dataset
class EuroSATDataset(Dataset):
def __init__(self, split, transform):
self._data = split
self.transform = transform
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
row = self._data[idx]
image = row["image"]
label = row["label"]
return self.transform(image), label