Spaces:
Sleeping
Sleeping
| from albumentations.pytorch import ToTensorV2 | |
| from torch.utils.data import DataLoader, Dataset | |
| class CustomDataset(Dataset): | |
| def __init__(self, data_dir=str, transforms=None): | |
| self.data_dir = data_dir | |
| self.transforms = transforms | |
| self.data = [] | |
| self.class_map = {} | |
| self.extensions = ("jpeg", "jpg", "png") | |
| file_list = sorted(glob.glob(self.data_dir + "/*")) | |
| for class_path in file_list: | |
| class_name = class_path.split("/")[-1] | |
| for img_path in glob.glob(class_path + "/*"): | |
| ext = img_path.split("/")[-1].split(".")[-1] | |
| if ext in self.extensions: | |
| self.data.append([img_path, class_name]) | |
| for idx, class_path in enumerate(file_list): | |
| class_name = class_path.split("/")[-1] | |
| self.class_map[class_name] = idx | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| img_path, class_name = self.data[idx] | |
| img = cv2.imread(img_path) | |
| # Applying transforms on image | |
| if self.transforms: | |
| img = self.transforms(image=img)["image"] | |
| label = self.class_map[class_name] | |
| return img, label | |
| def get_transform(): | |
| resize = A.Resize(224, 224) | |
| normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| to_tensor = ToTensorV2() | |
| return A.Compose([resize, normalize, to_tensor]) | |
| if __name__ == "__main__": | |
| data_dir = os.path.join(os.path.abspath(__file__ + "/../../"), "data/train/") | |
| print(data_dir) | |
| dataset = CustomDataset(data_dir=data_dir, transforms=get_transform()) | |
| # print(len(dataset)) | |
| # print(dataset[0][0].shape) | |
| data_loader = DataLoader(dataset, batch_size=16, shuffle=True) | |
| total_imgs = 0 | |
| for imgs, labels in data_loader: | |
| total_imgs += int(imgs.shape[0]) | |
| print(imgs.shape) | |
| break | |
| print(total_imgs) |