| from torch.utils.data import Dataset | |
| class EuroSATDataset(Dataset): | |
| def __init__(self, split, transform): | |
| self._data = split | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self._data) | |
| def __getitem__(self, idx): | |
| row = self._data[idx] | |
| image = row["image"] | |
| label = row["label"] | |
| return self.transform(image), label |