Spaces:
Runtime error
Runtime error
| from torch.utils import data | |
| from PIL import Image | |
| import os | |
| class Dataset(data.Dataset): | |
| 'Characterizes a dataset for PyTorch' | |
| def __init__(self, path, transform=None): | |
| 'Initialization' | |
| self.file_names = self.get_filenames(path) | |
| self.transform = transform | |
| def __len__(self): | |
| 'Denotes the total number of samples' | |
| return len(self.file_names) | |
| def __getitem__(self, index): | |
| 'Generates one sample of data' | |
| img = Image.open(self.file_names[index]).convert('RGB') | |
| # Convert image and label to torch tensors | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| return img | |
| def get_filenames(self, data_path): | |
| images = [] | |
| for path, subdirs, files in os.walk(data_path): | |
| for name in files: | |
| if name.rfind('jpg') != -1 or name.rfind('png') != -1: | |
| filename = os.path.join(path, name) | |
| if os.path.isfile(filename): | |
| images.append(filename) | |
| return images | |