Spaces:
Running
Running
| from torch.utils.data import Dataset | |
| import os | |
| from torchvision.datasets.folder import default_loader | |
| import torchvision.transforms as T | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| class CommonDataset(Dataset): | |
| def __init__(self, images_path, labels_path, x_transform, y_transform): | |
| self.imgs_path = images_path | |
| self.labels_path = labels_path | |
| # for p in os.listdir(os.path.join(image_dir)): | |
| # p = os.path.join(dataset_project_dir, 'images', p) | |
| # if not p.endswith('png'): | |
| # continue | |
| # self.imgs_path += [p] | |
| # self.labels_path += [p.replace('images', 'labels_gt')] | |
| # self.x_transform = T.Compose( | |
| # [ | |
| # T.Resize((224, 224)), | |
| # T.ToTensor(), | |
| # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| # ] | |
| # ) | |
| # self.y_transform = T.Compose( | |
| # [ | |
| # T.Resize((224, 224)), | |
| # T.Lambda(lambda x: torch.from_numpy(np.array(x)).long()) | |
| # ] | |
| # ) | |
| self.x_transform = x_transform | |
| self.y_transform = y_transform | |
| def __len__(self): | |
| return len(self.imgs_path) | |
| def __getitem__(self, idx): | |
| x_path = os.path.join(self.imgs_path[idx]) | |
| y_path = os.path.join(self.labels_path[idx]) | |
| x = default_loader(x_path) | |
| # y = default_loader(y_path) | |
| y = Image.open(y_path).convert('L') | |
| x = self.x_transform(x) | |
| y = self.y_transform(y) | |
| return x, y | |