| from torch.utils import data |
| from PIL import Image |
| import os |
|
|
|
|
| class Dataset(data.Dataset): |
| 'Characterizes a dataset for PyTorch' |
|
|
| def __init__(self, path, transform=None): |
| 'Initialization' |
| self.file_names = self.get_filenames(path) |
| self.transform = transform |
|
|
| def __len__(self): |
| 'Denotes the total number of samples' |
| return len(self.file_names) |
|
|
| def __getitem__(self, index): |
| 'Generates one sample of data' |
| img = Image.open(self.file_names[index]).convert('RGB') |
| |
| if self.transform is not None: |
| img = self.transform(img) |
| return img |
|
|
| def get_filenames(self, data_path): |
| images = [] |
| for path, subdirs, files in os.walk(data_path): |
| for name in files: |
| if name.rfind('jpg') != -1 or name.rfind('png') != -1: |
| filename = os.path.join(path, name) |
| if os.path.isfile(filename): |
| images.append(filename) |
| return images |
|
|