zy7_oldserver
1
fd601de
import monai
import os
import numpy as np
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
NormalizeIntensity,
ResizeWithPadOrCrop,
Rotate90,
DivisiblePad,
CenterSpatialCrop,
SqueezeDim,
DivisiblePadd,
)
from torch.utils.data import DataLoader
import torch
class condDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels):
self.image_files = image_files
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
return self.image_files[index], self.labels[index]
def convert_csv_to_cond(train_cond_src, manual_crop=[]):
train_crop=[]
for cond_file in train_cond_src:
with open(cond_file, 'r', newline='') as csvfile:
import csv
csv_reader = csv.DictReader(csvfile)
label_index=0
for row in csv_reader:
# Append the value to the list
if len(manual_crop)==2:
if int(row['slice'])>=manual_crop[0] and int(row['slice'])<=manual_crop[1]:
# train_crop.append(int(row['slice'])) # Assuming each row contains only one value
label = label_index
label_index += 1
train_crop.append(label)
else:
train_crop.append(int(row['slice'])) # Assuming each row contains only one value
return train_crop
def manual_arrange_data(train_data,manual_crop=[]):
# Load 2D slices for training
shape_list_train = []
train_ds_2d = []
all_slices_train=0
for sample in train_data:
train_ds_2d_image = LoadImage(image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
#train_ds_2d_image=DivisiblePadd(["image", "label"], (-1,batch_size), mode="minimum")(train_ds_2d_image)
name = os.path.basename(os.path.dirname(sample))
num_slices = train_ds_2d_image.shape[-1]
#print(train_ds_2d_image.shape)
#print(num_slices)
shape_list_train.append({'patient': name, 'shape': train_ds_2d_image.shape})
if len(manual_crop)==2:
for i in range(manual_crop[0],manual_crop[1]+1):
train_ds_2d.append(train_ds_2d_image[:,:,i])
all_slices_train += manual_crop[1]-manual_crop[0]+1
else:
for i in range(num_slices):
train_ds_2d.append(train_ds_2d_image[:,:,i])
all_slices_train += num_slices
print('length of set:', len(train_ds_2d))
print('a slice example:', train_ds_2d[0].shape)
return train_ds_2d, shape_list_train, all_slices_train
def get_dataset(data_pelvis_path,
train_number,
val_number,
normalize='zscore',
pad='minimum',
resized_size=(512,512),
div_size=(16,16),
center_crop=0,
manual_crop=[],
train_batch_size=8,
val_batch_size=1):
#list all files in the folder
file_list=[i for i in os.listdir(data_pelvis_path) if 'overview' not in i]
file_list_path=[os.path.join(data_pelvis_path,i) for i in file_list]
#list all ct and mr files in folder
ct_file_list=[os.path.join(j,'ct.nii.gz') for j in file_list_path]
cond_file_list=[os.path.join(j,'ct_slice_cond.csv') for j in file_list_path]
############# condition data preparation
train_cond_src = cond_file_list[0:train_number] #{'label': j, 'cond_paths': j}
val_cond_src = cond_file_list[-val_number:]
# convert csv files to condition list
train_cond = convert_csv_to_cond(train_cond_src, manual_crop)
val_cond = convert_csv_to_cond(val_cond_src, manual_crop)
print('total length: ', len(train_cond))
#print(len(val_cond))
if len(manual_crop)==2:
num_classes = manual_crop[1]-manual_crop[0]+1
else:
num_classes = max(max(train_cond), max(val_cond))+1
print('number of classes:', num_classes)
class_names = [i+1 for i in range(num_classes)]
############# image data preparation
train_data = ct_file_list[0:train_number]
val_data = ct_file_list[-val_number:]
print('all files in dataset:',len(file_list))
train_transforms = get_transforms(normalize,pad,resized_size,div_size,center_crop=center_crop)
train_ds_2d, shape_list_train, all_slices_train = manual_arrange_data(train_data, manual_crop)
val_ds_2d, shape_list_val, all_slices_val = manual_arrange_data(val_data, manual_crop)
# check length
assert len(train_ds_2d)==len(train_cond), 'length of training set and condition set are not equal!'
print('train data length check pass!')
print('length of training set:', len(train_ds_2d))
print('length of training cond:', len(train_cond))
print('length of validation set:', len(val_ds_2d))
print('length of validation cond:', len(val_cond))
train_dataset = monai.data.Dataset(data=train_ds_2d, transform=train_transforms)
val_dataset = monai.data.Dataset(data=val_ds_2d, transform=train_transforms)
############# combine image and condition data
train_set=condDataset(train_dataset, train_cond)
val_set=condDataset(val_dataset, val_cond)
train_loader = DataLoader(train_set, batch_size=train_batch_size, num_workers=4, pin_memory=torch.cuda.is_available()) #
val_loader = DataLoader(val_set,num_workers=4, batch_size=val_batch_size, pin_memory=torch.cuda.is_available()) #
#val_volume_ds,
return train_loader,val_loader
def get_transforms(normalize,pad,resized_size,div_size,center_crop=0):
transform_list=[]
#transform_list.append(LoadImage(image_only=True))
transform_list.append(EnsureChannelFirst())
transform_list.append(ResizeWithPadOrCrop(spatial_size=resized_size,mode=pad)) # "constant", "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap", "empty"
transform_list.append(Rotate90(k=3))
transform_list.append(DivisiblePad(k=div_size, mode=pad))
if normalize=='zscore':
transform_list.append(NormalizeIntensity(nonzero=False, channel_wise=True))
print('zscore normalization')
elif normalize=='minmax':
transform_list.append(ScaleIntensity(minv=0, maxv=1.0))
print('minmax normalization')
elif normalize=='none':
print('no normalization')
if center_crop>0:
transform_list.append(CenterSpatialCrop(roi_size=(-1,-1,center_crop)))
train_transforms = Compose(transform_list)
# volume-level transforms for both image and label
return train_transforms
def get_cond_transforms(num_class=150):
y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_class)])
return y_pred_trans, y_trans
def get_length(dataset, patch_batch_size):
loader=DataLoader(dataset, batch_size=1)
iterator = iter(loader)
sum_nslices=0
for idx in range(len(loader)):
check_data = next(iterator)
nslices=check_data['image'].shape[-1]
sum_nslices+=nslices
if sum_nslices%patch_batch_size==0:
return sum_nslices//patch_batch_size
else:
return sum_nslices//patch_batch_size+1
def check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds):
# use batch_size=1 to check the volumes because the input volumes have different shapes
train_loader = DataLoader(train_volume_ds, batch_size=1)
val_loader = DataLoader(val_volume_ds, batch_size=1)
train_iterator = iter(train_loader)
val_iterator = iter(val_loader)
print('check training data:')
idx=0
for idx in range(len(train_loader)):
try:
train_check_data = next(train_iterator)
ds_idx = idx * 1
current_item = train_ds[ds_idx]
current_name = os.path.basename(os.path.dirname(current_item['image']))
print(idx, current_name, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
except:
ds_idx = idx * 1
current_item = train_ds[ds_idx]
current_name = os.path.basename(os.path.dirname(current_item['image']))
print('check data error! Check the input data:',current_name)
print("checked all training data.")
print('check validation data:')
idx=0
for idx in range(len(val_loader)):
try:
val_check_data = next(val_iterator)
ds_idx = idx * 1
current_item = val_ds[ds_idx]
current_name = os.path.basename(os.path.dirname(current_item['image']))
print(idx, current_name, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
except:
ds_idx = idx * 1
current_item = val_ds[ds_idx]
current_name = os.path.basename(os.path.dirname(current_item['image']))
print('check data error! Check the input data:',current_name)
print("checked all validation data.")
def save_volumes(train_ds, val_ds, saved_name_train, saved_name_val):
shape_list_train=[]
shape_list_val=[]
# use the function of saving information before
for sample in train_ds:
name = os.path.basename(os.path.dirname(sample['image']))
shape_list_train.append({'patient': name})
for sample in val_ds:
name = os.path.basename(os.path.dirname(sample['image']))
shape_list_val.append({'patient': name})
np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
def check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size):
for idx, train_check_data in enumerate(train_loader):
ds_idx = idx * train_batch_size
current_item = train_patch_ds[ds_idx]
print('check train data:')
print(current_item, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
for idx, val_check_data in enumerate(val_loader):
ds_idx = idx * val_batch_size
current_item = val_volume_ds[ds_idx]
print('check val data:')
print(current_item, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
def main():
dataset_path_razer=r'C:\Users\56991\Projects\Datasets\Task1\pelvis'
dataset_path_server = r"F:\yang_Projects\Datasets\Task1\pelvis"
train_loader,val_loader = get_dataset(dataset_path_server, train_number=5, val_number=1, manual_crop=[41,50])
from tqdm import tqdm
parameter_file=r'.\test.txt'
for image, label in tqdm(train_loader):
with open(parameter_file, 'a') as f:
f.write('image batch:' + str(image.shape)+'\n')
f.write('label batch:' + str(label)+'\n')
f.write('\n')
if __name__ == '__main__':
main()