| import csv | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from abc import abstractmethod | |
| from itertools import islice | |
| from typing import List, Tuple, Dict, Any | |
| from torch.utils.data import DataLoader | |
| import PIL | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| import pandas as pd | |
| from torchvision import transforms | |
| from PIL import Image | |
| from dataset.randaugment import RandomAugment | |
| class Chestxray14_Dataset(Dataset): | |
| def __init__(self, csv_path, is_train=True): | |
| data_info = pd.read_csv(csv_path) | |
| self.img_path_list = np.asarray(data_info.iloc[:, 0]) | |
| self.class_list = np.asarray(data_info.iloc[:, 2:]) | |
| normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| if is_train: | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop( | |
| 224, scale=(0.2, 1.0), interpolation=Image.BICUBIC | |
| ), | |
| transforms.RandomHorizontalFlip(), | |
| RandomAugment( | |
| 2, | |
| 7, | |
| isPIL=True, | |
| augs=[ | |
| "Identity", | |
| "AutoContrast", | |
| "Equalize", | |
| "Brightness", | |
| "Sharpness", | |
| "ShearX", | |
| "ShearY", | |
| "TranslateX", | |
| "TranslateY", | |
| "Rotate", | |
| ], | |
| ), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| else: | |
| self.transform = transforms.Compose( | |
| [transforms.Resize([224, 224]), transforms.ToTensor(), normalize,] | |
| ) | |
| def __getitem__(self, index): | |
| img_path = self.img_path_list[index] | |
| class_label = self.class_list[index] | |
| img = PIL.Image.open(img_path).convert("RGB") | |
| image = self.transform(img) | |
| return {"image": image, "label": class_label} | |
| def __len__(self): | |
| return len(self.img_path_list) | |
| def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): | |
| loaders = [] | |
| for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( | |
| datasets, samplers, batch_size, num_workers, is_trains, collate_fns | |
| ): | |
| if is_train: | |
| shuffle = sampler is None | |
| drop_last = True | |
| else: | |
| shuffle = False | |
| drop_last = False | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=bs, | |
| num_workers=n_worker, | |
| pin_memory=True, | |
| sampler=sampler, | |
| shuffle=shuffle, | |
| collate_fn=collate_fn, | |
| drop_last=drop_last, | |
| ) | |
| loaders.append(loader) | |
| return loaders | |