| | import json |
| | import cv2 |
| | import numpy as np |
| |
|
| | from torch.utils.data import Dataset |
| | import sys |
| | sys.path.insert(1, '/home/gholipos/physionet.org/files/mimic-cxr-jpg/2.0.0/') |
| |
|
| | class MyDataset(Dataset): |
| | def __init__(self): |
| | self.data = [] |
| | with open('./dataset/dataset_train.json', 'rt') as f: |
| | for line in f: |
| | self.data.append(json.loads(line)) |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| |
|
| | source_filename = item['source'] |
| | target_filename = item['target'] |
| | prompt = item['prompt'] |
| |
|
| | source = cv2.imread(source_filename) |
| | target = cv2.imread(target_filename) |
| |
|
| | source = cv2.resize(source, (512, 512)) |
| | target = cv2.resize(target, (512, 512)) |
| |
|
| | |
| | source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) |
| | target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) |
| |
|
| | |
| | source = source.astype(np.float32) / 255.0 |
| |
|
| | |
| | target = (target.astype(np.float32) / 127.5) - 1.0 |
| |
|
| | return dict(jpg=target, txt=prompt, hint=source) |
| |
|
| |
|