File size: 5,755 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import monai
import os
import numpy as np
from monai.transforms import (
    Compose,
    LoadImaged,
    Rotate90d,
    ScaleIntensityd,
    EnsureChannelFirstd,
    ResizeWithPadOrCropd,
    DivisiblePadd,
    ThresholdIntensityd,
    NormalizeIntensityd,
    SqueezeDimd,
    Identityd,
    CenterSpatialCropd,
)

from monai.data import Dataset
from torch.utils.data import DataLoader
import torch
from .basics import get_file_list, get_transforms

def transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val, batch_size=8,ifsave=True):
    # Load 2D slices of CT images
    train_ds_2d = []
    val_ds_2d = []
    shape_list_train = []
    shape_list_val = []
    all_slices_train=0
    all_slices_val=0

    # Load 2D slices for training
    for sample in train_ds:
        train_ds_2d_image = LoadImaged(keys=["image", "label"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
        train_ds_2d_image=DivisiblePadd(["image", "label"], (-1,batch_size), mode="minimum")(train_ds_2d_image)
        name = os.path.basename(os.path.dirname(sample['image']))
        num_slices = train_ds_2d_image["image"].shape[-1]
        #print(train_ds_2d_image["image"].shape)
        #print(num_slices)
        shape_list_train.append({'patient': name, 'shape': train_ds_2d_image["image"].shape})
        for i in range(num_slices):
            train_ds_2d.append({'image': train_ds_2d_image['image'][:,:,i], 'label': train_ds_2d_image['label'][:,:,i]})
        all_slices_train += num_slices

    # Load 2D slices for validation
    for sample in val_ds:
        val_ds_2d_image = LoadImaged(keys=["image", "label"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
        val_ds_2d_image=DivisiblePadd(["image", "label"], (-1, batch_size), mode="minimum")(val_ds_2d_image)
        name = os.path.basename(os.path.dirname(sample['image']))
        shape_list_val.append({'patient': name, 'shape': val_ds_2d_image["image"].shape})
        num_slices = val_ds_2d_image["image"].shape[-1]
        for i in range(num_slices):
            val_ds_2d.append({'image': val_ds_2d_image['image'][:,:,i], 'label': val_ds_2d_image['label'][:,:,i]})
        all_slices_val += num_slices
    # Save shape list to csv
    if ifsave:
        np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
        np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
    return train_ds_2d, val_ds_2d, all_slices_train, all_slices_val, shape_list_train, shape_list_val
def get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size, val_batch_size,resized_size=(256,256)):
    # Define transforms
    '''
    normalize='zscore'
    div_size=(16,16,None)
    train_transforms = get_transforms(normalize,resized_size,div_size)
    '''

    train_transforms = Compose(
        [
            EnsureChannelFirstd(keys=["image", "label"]),
            NormalizeIntensityd(keys=["image", "label"], nonzero=False, channel_wise=True), # z-score normalization
            ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=resized_size,mode="minimum"),
            Rotate90d(keys=["image", "label"], k=3),
            DivisiblePadd(["image", "label"], 16, mode="minimum"),
        ]
    )
    train_transforms_list=train_transforms.__dict__['transforms']
    # Create training dataset and data loader
    train_dataset = Dataset(data=train_ds_2d, transform=train_transforms)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    val_batch_size = val_batch_size
    # Create validation dataset and data loader
    val_dataset = Dataset(data=val_ds_2d, transform=train_transforms)
    val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=0, pin_memory=True)
    return train_loader, val_loader, train_transforms_list,train_transforms
def mydataloader(data_pelvis_path,
                normalize='zscore',
                pad='minimum',
                train_number=1,
                val_number=1,
                batch_size=8,
                val_batch_size=1,
                saved_name_train='./train_ds_2d.csv',saved_name_val='./val_ds_2d.csv',
                resized_size=(512,512),
                div_size=(16,16,None),
                center_crop=20,): 
    #train_transforms = get_transforms(normalize,pad,resized_size,div_size)
    train_ds, val_ds = get_file_list(data_pelvis_path, 
                                     train_number, 
                                     val_number)
    train_ds_2d, val_ds_2d,\
    all_slices_train,all_slices_val,\
    shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds, 
                                                            saved_name_train, 
                                                            saved_name_val,
                                                            batch_size=batch_size,
                                                            ifsave=False)
    train_loader, val_loader, \
    train_transforms_list,train_transforms = get_train_val_loaders(train_ds_2d, 
                                                                val_ds_2d, 
                                                                batch_size=batch_size, 
                                                                val_batch_size=val_batch_size,
                                                                resized_size=resized_size)
    return train_loader,val_loader,\
            train_transforms_list,train_transforms,\
            all_slices_train,all_slices_val,\
            shape_list_train,shape_list_val