zy7_oldserver
1
fd601de
import SimpleITK as sitk
import matplotlib.pyplot as plt
import os
## get images using SimpleITK and plot them
def get_image(data_path,image_idx=0,slice_idx=10, ifprint=True, ifvis=True):
#list all files in the folder
file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
# get the target file
patient_names = file_list[image_idx]
target_path=os.path.join(data_path,file_list[image_idx])
target_file = os.listdir(target_path)
# get the ct file
ct_file=[i for i in target_file if 'ct' in i]
ct_file_path=os.path.join(target_path,ct_file[0])
ct_image=sitk.ReadImage(ct_file_path)
space_ct = ct_image.GetSpacing()
ct_array=sitk.GetArrayFromImage(ct_image)
'''
# test resample
ct_image_resampled=ppt.resample(ct_file_path,'outputimage.nii.gz',space*2)
ct_array_resampled=sitk.GetArrayFromImage(ct_image_resampled)
print('resampled shape:',ct_array_resampled.shape)
'''
# get the mask file
mask_file=[i for i in target_file if 'mask' in i]
mask_file_path=os.path.join(target_path,mask_file[0])
mask_image=sitk.ReadImage(mask_file_path)
mask_array=sitk.GetArrayFromImage(mask_image)
# get the mr file
mr_file=[i for i in target_file if 'mr' in i]
mr_file_path=os.path.join(target_path,mr_file[0])
mr_image=sitk.ReadImage(mr_file_path)
mr_array=sitk.GetArrayFromImage(mr_image)
space_mr = mr_image.GetSpacing()
if ifprint:
print(file_list[image_idx]) # the first file, 1PA001
print('spacing of ct image:', space_ct)
print('spacing of mr image:', space_mr)
print('shape of ct image:', ct_array.shape)
print('shape of mask image:',mask_array.shape)
print('shape of mr image:',mr_array.shape)
# get the min and max value of ct and mr images
print('min of ct image:',ct_array.min())
print('max of ct image:',ct_array.max())
print('min of mr image:',mr_array.min())
print('max of mr image:',mr_array.max())
if ifvis:
# visualzie the images
plt.figure(figsize=(5,5))
plt.subplot(3,1,1)
plt.imshow(ct_array[slice_idx,:,:],cmap='gray')
plt.subplot(3,1,2)
plt.imshow(mask_array[slice_idx,:,:],cmap='gray')
plt.subplot(3,1,3)
plt.imshow(mr_array[slice_idx,:,:],cmap='gray')
return ct_array,mask_array,mr_array,space_ct,space_mr,patient_names
# test dataloader by first 50 images
# test dataloader
from monai.utils import first
def test_dataloader_first(train_loader):
batch_test = first(train_loader)
ct_test, mr_test = batch_test["image"], batch_test["label"]
print('ct_test shape:',ct_test.shape)
def test_dataloader_2d_enumerate(train_loader):
batch_num=len(train_loader)
for i, batch in enumerate(train_loader):
ct, mr = batch["image"], batch["label"]
for j in range(10): # 10 is batch size
image0=ct[j,0,:,:]
label0=mr[j,0,:,:]
if (j+1)%10==0: # show one image every 10 images
plt.figure(figsize=(5,5))
plt.subplot(1,2,1)
plt.imshow(image0,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(label0,cmap='gray')
if i==2: # show 10 batches
break
# test dataloader by enumerate()
def test_dataloader_3d_enumerate(train_loader):
print(len(train_loader))
batch_num=len(train_loader)
for i,batch in enumerate(train_loader):
ct, mr = batch["image"], batch["label"]
for j in range(2): # 2 is batch size
image0=ct[j,0,50,:,:]
label0=mr[j,0,50,:,:]
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(image0,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(label0,cmap='gray')
if i==2:
break
def test_dataloader_2d_iter(train_loader):
# test dataloader in batch
iter_num=2
iterator=iter(train_loader)
for i in range(iter_num):
batch=next(iterator)
# get image and label from batch
image, label = batch["image"], batch["label"]
print(image.shape)
print(label.shape)
for j in range(10): # 2 is batch size
if (j+1)%10==0:
image0=image[j,0,:,:]
label0=label[j,0,:,:]
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(image0,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(label0,cmap='gray')
# test dataloader by iter and next
def test_dataloader_3d_iter(train_loader):
# test dataloader in batch
iter_num=2
iterator=iter(train_loader)
for i in range(iter_num):
batch=next(iterator)
# get image and label from batch
image, label = batch["image"], batch["label"]
print(image.shape)
print(label.shape)
for j in range(2): # 2 is batch size
image0=image[j,0,50,:,:]
label0=label[j,0,50,:,:]
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(image0,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(label0,cmap='gray')
# if main
if __name__ == "__main__":
data_pelvis_path=r'D:\Projects\Task1\pelvis'
_,_,_,_,_,_=get_image(data_pelvis_path,0,50,True,True)