| from torch.utils.data import Dataset | |
| import pandas as pd | |
| import cv2 as cv | |
| import os | |
| class myDataset(Dataset): | |
| def __init__(self, filelist, cxr_dir, bs_dir, | |
| transform=None): | |
| self.cxr_dir = cxr_dir | |
| self.bs_dir = bs_dir | |
| self.transform = transform | |
| self.filelist = pd.read_csv(filelist, sep="\t", header=None) | |
| def __len__(self): | |
| return len(self.filelist) | |
| def __getitem__(self, idx): | |
| file = self.filelist.iloc[idx, 0] | |
| cxr = cv.imread(os.path.join(self.cxr_dir, file)) | |
| bs = cv.imread(os.path.join(self.bs_dir, file)) | |
| if self.transform: | |
| cxr = self.transform(cxr) | |
| bs = self.transform(bs) | |
| return cxr, bs, file | |
| class myDiTDataset(Dataset): | |
| def __init__(self, filelist, cxr_dir, bs_dir, | |
| transform=None): | |
| self.cxr_dir = cxr_dir | |
| self.bs_dir = bs_dir | |
| self.transform = transform | |
| self.filelist = pd.read_csv(filelist, sep="\t", header=None) | |
| def __len__(self): | |
| return len(self.filelist) | |
| def __getitem__(self, idx): | |
| file = self.filelist.iloc[idx, 0] | |
| cxr = cv.imread(os.path.join(self.cxr_dir, file)) | |
| bs = cv.imread(os.path.join(self.bs_dir, file)) | |
| if self.transform: | |
| cxr, bs = self.transform(cxr, bs) | |
| return cxr, bs, file | |
| class mySingleDataset(Dataset): | |
| def __init__(self, filelist, cxr_dir, | |
| transform=None): | |
| self.cxr_dir = cxr_dir | |
| self.transform = transform | |
| self.filelist = pd.read_csv(filelist, sep="\t", header=None) | |
| def __len__(self): | |
| return len(self.filelist) | |
| def __getitem__(self, idx): | |
| file = self.filelist.iloc[idx, 0] | |
| cxr = cv.imread(os.path.join(self.cxr_dir, file)) | |
| if self.transform: | |
| cxr = self.transform(cxr) | |
| return cxr, file | |