vicca / CXRGen /tutorial_dataset.py
sayehghp's picture
Add application file
0f8411f
import json
import cv2
import numpy as np
from torch.utils.data import Dataset
import sys
sys.path.insert(1, '/home/gholipos/physionet.org/files/mimic-cxr-jpg/2.0.0/')
class MyDataset(Dataset):
def __init__(self):
self.data = []
with open('./dataset/dataset_train.json', 'rt') as f:
for line in f:
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
source_filename = item['source']
target_filename = item['target']
prompt = item['prompt']
source = cv2.imread(source_filename)
target = cv2.imread(target_filename)
source = cv2.resize(source, (512, 512))
target = cv2.resize(target, (512, 512))
# Do not forget that OpenCV read images in BGR order.
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0
return dict(jpg=target, txt=prompt, hint=source)