File size: 5,389 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import SimpleITK as sitk
import matplotlib.pyplot as plt
import os
## get images using SimpleITK and plot them
def get_image(data_path,image_idx=0,slice_idx=10, ifprint=True, ifvis=True):
    #list all files in the folder
    file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
        
    # get the target file
    patient_names = file_list[image_idx]
    target_path=os.path.join(data_path,file_list[image_idx])
    target_file = os.listdir(target_path)

    # get the ct file
    ct_file=[i for i in target_file if 'ct' in i]
    ct_file_path=os.path.join(target_path,ct_file[0])
    ct_image=sitk.ReadImage(ct_file_path)
    space_ct = ct_image.GetSpacing()
    ct_array=sitk.GetArrayFromImage(ct_image)

    ''' 
    # test resample
    ct_image_resampled=ppt.resample(ct_file_path,'outputimage.nii.gz',space*2)
    ct_array_resampled=sitk.GetArrayFromImage(ct_image_resampled)
    print('resampled shape:',ct_array_resampled.shape)
    '''
    # get the mask file
    mask_file=[i for i in target_file if 'mask' in i]
    mask_file_path=os.path.join(target_path,mask_file[0])
    mask_image=sitk.ReadImage(mask_file_path)
    mask_array=sitk.GetArrayFromImage(mask_image)
    

    # get the mr file
    mr_file=[i for i in target_file if 'mr' in i]
    mr_file_path=os.path.join(target_path,mr_file[0])
    mr_image=sitk.ReadImage(mr_file_path)
    mr_array=sitk.GetArrayFromImage(mr_image)
    space_mr = mr_image.GetSpacing()
    if ifprint:
        print(file_list[image_idx]) # the first file, 1PA001
        print('spacing of ct image:', space_ct)
        print('spacing of mr image:', space_mr)
        print('shape of ct image:', ct_array.shape)
        print('shape of mask image:',mask_array.shape)
        print('shape of mr image:',mr_array.shape)
        # get the min and max value of ct and mr images
        print('min of ct image:',ct_array.min())
        print('max of ct image:',ct_array.max())
        print('min of mr image:',mr_array.min())
        print('max of mr image:',mr_array.max())
    if ifvis:
        # visualzie the images
        plt.figure(figsize=(5,5))
        plt.subplot(3,1,1)
        plt.imshow(ct_array[slice_idx,:,:],cmap='gray')
        plt.subplot(3,1,2)
        plt.imshow(mask_array[slice_idx,:,:],cmap='gray')
        plt.subplot(3,1,3)
        plt.imshow(mr_array[slice_idx,:,:],cmap='gray')
    return ct_array,mask_array,mr_array,space_ct,space_mr,patient_names

# test dataloader by first 50 images
# test dataloader
from monai.utils import first
def test_dataloader_first(train_loader):
    batch_test = first(train_loader)
    ct_test, mr_test = batch_test["image"], batch_test["label"]
    print('ct_test shape:',ct_test.shape)
    
def test_dataloader_2d_enumerate(train_loader):
    batch_num=len(train_loader)
    for i, batch in enumerate(train_loader):
        ct, mr = batch["image"], batch["label"]
        for j in range(10): # 10 is batch size
            image0=ct[j,0,:,:]
            label0=mr[j,0,:,:]
            if (j+1)%10==0: # show one image every 10 images
                plt.figure(figsize=(5,5))
                plt.subplot(1,2,1)
                plt.imshow(image0,cmap='gray')
                plt.subplot(1,2,2)
                plt.imshow(label0,cmap='gray')
        if i==2: # show 10 batches
            break

# test dataloader by enumerate()
def test_dataloader_3d_enumerate(train_loader):
    print(len(train_loader))
    batch_num=len(train_loader)
    for i,batch in enumerate(train_loader):
        ct, mr = batch["image"], batch["label"]
        for j in range(2): # 2 is batch size
            image0=ct[j,0,50,:,:]
            label0=mr[j,0,50,:,:]
            plt.figure(figsize=(10,10))
            plt.subplot(1,2,1)
            plt.imshow(image0,cmap='gray')
            plt.subplot(1,2,2)
            plt.imshow(label0,cmap='gray')
        if i==2:
            break
def test_dataloader_2d_iter(train_loader):
    # test dataloader in batch
    iter_num=2
    iterator=iter(train_loader)
    for i in range(iter_num):
        batch=next(iterator)
        # get image and label from batch
        image, label = batch["image"], batch["label"]
        print(image.shape)
        print(label.shape)
        for j in range(10): # 2 is batch size
            if (j+1)%10==0:
                image0=image[j,0,:,:]
                label0=label[j,0,:,:]
                plt.figure(figsize=(10,10))
                plt.subplot(1,2,1)
                plt.imshow(image0,cmap='gray')
                plt.subplot(1,2,2)
                plt.imshow(label0,cmap='gray')

# test dataloader by iter and next
def test_dataloader_3d_iter(train_loader):
    # test dataloader in batch
    iter_num=2
    iterator=iter(train_loader)
    for i in range(iter_num):
        batch=next(iterator)
        # get image and label from batch
        image, label = batch["image"], batch["label"]
        print(image.shape)
        print(label.shape)
        for j in range(2): # 2 is batch size
            image0=image[j,0,50,:,:]
            label0=label[j,0,50,:,:]
            plt.figure(figsize=(10,10))
            plt.subplot(1,2,1)
            plt.imshow(image0,cmap='gray')
            plt.subplot(1,2,2)
            plt.imshow(label0,cmap='gray')

# if main
if __name__ == "__main__":
    data_pelvis_path=r'D:\Projects\Task1\pelvis'
    _,_,_,_,_,_=get_image(data_pelvis_path,0,50,True,True)