Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| from PIL import Image | |
| class AcneDataset(Dataset): | |
| def __init__(self, dataDir, limit=True, transform=None): | |
| self.dataDir = dataDir | |
| self.image_names = os.listdir(self.dataDir) | |
| self.image_names = [os.path.join(self.dataDir, x) for x in self.image_names] | |
| self.image_names = [x for x in self.image_names if x.endswith('.jpg')] | |
| self.image_names = sorted(self.image_names) | |
| self.transform = transform | |
| if limit: | |
| self.image_names = self.image_names[1000:1200] | |
| def __len__(self): | |
| return len(self.image_names) | |
| def __getitem__(self, idx): | |
| imgName = self.image_names[idx] | |
| label = imgName.split('/')[-1].split('.')[0].split('_')[0][-1] | |
| label = int(label) | |
| label = np.array(label).astype(np.float32) | |
| img = Image.open(imgName) | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, label |