Spaces:
Runtime error
Runtime error
| import SimpleITK as sitk | |
| import matplotlib.pyplot as plt | |
| import os | |
| ## get images using SimpleITK and plot them | |
| def get_image(data_path,image_idx=0,slice_idx=10, ifprint=True, ifvis=True): | |
| #list all files in the folder | |
| file_list=[i for i in os.listdir(data_path) if 'overview' not in i] | |
| # get the target file | |
| patient_names = file_list[image_idx] | |
| target_path=os.path.join(data_path,file_list[image_idx]) | |
| target_file = os.listdir(target_path) | |
| # get the ct file | |
| ct_file=[i for i in target_file if 'ct' in i] | |
| ct_file_path=os.path.join(target_path,ct_file[0]) | |
| ct_image=sitk.ReadImage(ct_file_path) | |
| space_ct = ct_image.GetSpacing() | |
| ct_array=sitk.GetArrayFromImage(ct_image) | |
| ''' | |
| # test resample | |
| ct_image_resampled=ppt.resample(ct_file_path,'outputimage.nii.gz',space*2) | |
| ct_array_resampled=sitk.GetArrayFromImage(ct_image_resampled) | |
| print('resampled shape:',ct_array_resampled.shape) | |
| ''' | |
| # get the mask file | |
| mask_file=[i for i in target_file if 'mask' in i] | |
| mask_file_path=os.path.join(target_path,mask_file[0]) | |
| mask_image=sitk.ReadImage(mask_file_path) | |
| mask_array=sitk.GetArrayFromImage(mask_image) | |
| # get the mr file | |
| mr_file=[i for i in target_file if 'mr' in i] | |
| mr_file_path=os.path.join(target_path,mr_file[0]) | |
| mr_image=sitk.ReadImage(mr_file_path) | |
| mr_array=sitk.GetArrayFromImage(mr_image) | |
| space_mr = mr_image.GetSpacing() | |
| if ifprint: | |
| print(file_list[image_idx]) # the first file, 1PA001 | |
| print('spacing of ct image:', space_ct) | |
| print('spacing of mr image:', space_mr) | |
| print('shape of ct image:', ct_array.shape) | |
| print('shape of mask image:',mask_array.shape) | |
| print('shape of mr image:',mr_array.shape) | |
| # get the min and max value of ct and mr images | |
| print('min of ct image:',ct_array.min()) | |
| print('max of ct image:',ct_array.max()) | |
| print('min of mr image:',mr_array.min()) | |
| print('max of mr image:',mr_array.max()) | |
| if ifvis: | |
| # visualzie the images | |
| plt.figure(figsize=(5,5)) | |
| plt.subplot(3,1,1) | |
| plt.imshow(ct_array[slice_idx,:,:],cmap='gray') | |
| plt.subplot(3,1,2) | |
| plt.imshow(mask_array[slice_idx,:,:],cmap='gray') | |
| plt.subplot(3,1,3) | |
| plt.imshow(mr_array[slice_idx,:,:],cmap='gray') | |
| return ct_array,mask_array,mr_array,space_ct,space_mr,patient_names | |
| # test dataloader by first 50 images | |
| # test dataloader | |
| from monai.utils import first | |
| def test_dataloader_first(train_loader): | |
| batch_test = first(train_loader) | |
| ct_test, mr_test = batch_test["image"], batch_test["label"] | |
| print('ct_test shape:',ct_test.shape) | |
| def test_dataloader_2d_enumerate(train_loader): | |
| batch_num=len(train_loader) | |
| for i, batch in enumerate(train_loader): | |
| ct, mr = batch["image"], batch["label"] | |
| for j in range(10): # 10 is batch size | |
| image0=ct[j,0,:,:] | |
| label0=mr[j,0,:,:] | |
| if (j+1)%10==0: # show one image every 10 images | |
| plt.figure(figsize=(5,5)) | |
| plt.subplot(1,2,1) | |
| plt.imshow(image0,cmap='gray') | |
| plt.subplot(1,2,2) | |
| plt.imshow(label0,cmap='gray') | |
| if i==2: # show 10 batches | |
| break | |
| # test dataloader by enumerate() | |
| def test_dataloader_3d_enumerate(train_loader): | |
| print(len(train_loader)) | |
| batch_num=len(train_loader) | |
| for i,batch in enumerate(train_loader): | |
| ct, mr = batch["image"], batch["label"] | |
| for j in range(2): # 2 is batch size | |
| image0=ct[j,0,50,:,:] | |
| label0=mr[j,0,50,:,:] | |
| plt.figure(figsize=(10,10)) | |
| plt.subplot(1,2,1) | |
| plt.imshow(image0,cmap='gray') | |
| plt.subplot(1,2,2) | |
| plt.imshow(label0,cmap='gray') | |
| if i==2: | |
| break | |
| def test_dataloader_2d_iter(train_loader): | |
| # test dataloader in batch | |
| iter_num=2 | |
| iterator=iter(train_loader) | |
| for i in range(iter_num): | |
| batch=next(iterator) | |
| # get image and label from batch | |
| image, label = batch["image"], batch["label"] | |
| print(image.shape) | |
| print(label.shape) | |
| for j in range(10): # 2 is batch size | |
| if (j+1)%10==0: | |
| image0=image[j,0,:,:] | |
| label0=label[j,0,:,:] | |
| plt.figure(figsize=(10,10)) | |
| plt.subplot(1,2,1) | |
| plt.imshow(image0,cmap='gray') | |
| plt.subplot(1,2,2) | |
| plt.imshow(label0,cmap='gray') | |
| # test dataloader by iter and next | |
| def test_dataloader_3d_iter(train_loader): | |
| # test dataloader in batch | |
| iter_num=2 | |
| iterator=iter(train_loader) | |
| for i in range(iter_num): | |
| batch=next(iterator) | |
| # get image and label from batch | |
| image, label = batch["image"], batch["label"] | |
| print(image.shape) | |
| print(label.shape) | |
| for j in range(2): # 2 is batch size | |
| image0=image[j,0,50,:,:] | |
| label0=label[j,0,50,:,:] | |
| plt.figure(figsize=(10,10)) | |
| plt.subplot(1,2,1) | |
| plt.imshow(image0,cmap='gray') | |
| plt.subplot(1,2,2) | |
| plt.imshow(label0,cmap='gray') | |
| # if main | |
| if __name__ == "__main__": | |
| data_pelvis_path=r'D:\Projects\Task1\pelvis' | |
| _,_,_,_,_,_=get_image(data_pelvis_path,0,50,True,True) |