File size: 5,217 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import monai
from monai.transforms import (
    Compose,
    LoadImaged,
    Rotate90d,
    ScaleIntensityd,
    EnsureChannelFirstd,
    ResizeWithPadOrCropd,
    DivisiblePadd,
    ThresholdIntensityd,
    NormalizeIntensityd,
    SqueezeDimd,
    Identityd,
    CenterSpatialCropd,
)
from monai.data import Dataset
from torch.utils.data import DataLoader
import torch

from .basics import get_file_list, check_batch_data, get_transforms, load_volumes, crop_volumes
##### slices #####
def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=8,val_batch_size=1,window_width=1,ifcheck=True):
    patch_func = monai.data.PatchIterd(
        keys=["source", "target"],
        patch_size=(None, None, window_width),  # dynamic first two dimensions
        start_pos=(0, 0, 0)
    )
    if window_width==1:
        patch_transform = Compose(
            [
                SqueezeDimd(keys=["source", "target"], dim=-1),  # squeeze the last dim
            ]
        )
    else:
        patch_transform = None
    # for training
    train_patch_ds = monai.data.GridPatchDataset(
        data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
    train_loader = DataLoader(
        train_patch_ds,
        batch_size=train_batch_size,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    # for validation
    val_patch_ds = monai.data.GridPatchDataset(
        data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
    val_loader = DataLoader(
        val_patch_ds, #val_volume_ds, 
        num_workers=0, 
        batch_size=val_batch_size,
        pin_memory=torch.cuda.is_available())
    
    if ifcheck:
        check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
    return train_loader,val_loader
def myslicesloader(data_pelvis_path,
                   normalize='minmax',
                   pad='minimum',
                   train_number=1,
                   val_number=1,
                   train_batch_size=8,
                   val_batch_size=1,
                   saved_name_train='./train_ds_2d.csv',
                   saved_name_val='./val_ds_2d.csv',
                   resized_size=(512,512,None),
                   div_size=(16,16,None),
                   center_crop=20,
                   ifcheck_volume=True,
                   ifcheck_sclices=False,):
    
    # volume-level transforms for both image and label
    train_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='train',prob=0.8)
    val_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='val')
    train_ds, val_ds = get_file_list(data_pelvis_path, 
                                     train_number, 
                                     val_number)
    train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop)
    train_ds, val_ds = load_volumes(train_transforms, val_transforms,
                                                train_crop_ds, val_crop_ds, 
                                                train_ds, val_ds, 
                                                saved_name_train, saved_name_val,
                                                ifsave=True,
                                                ifcheck=ifcheck_volume)
    train_loader,val_loader = load_batch_slices(train_ds, 
                                                val_ds, 
                                                train_batch_size,
                                                val_batch_size=val_batch_size,
                                                window_width=1,
                                                ifcheck=ifcheck_sclices)
    return train_ds, val_ds, train_loader,val_loader,train_transforms,val_transforms

def len_patchloader(train_volume_ds,train_batch_size):
    slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds)))
    print('total slices in training set:',slice_number)

    import math
    batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
    print('total batches in training set:',batch_number)
    return slice_number,batch_number

if __name__ == '__main__':
    dataset_path=r"F:\yang_Projects\Datasets\Task1\pelvis"
    train_volume_ds,_,train_loader,_,_,_ = myslicesloader(dataset_path,
                    normalize='none',
                    train_number=2,
                    val_number=1,
                    train_batch_size=4,
                    val_batch_size=1,
                    saved_name_train='./train_ds_2d.csv',
                    saved_name_val='./val_ds_2d.csv',
                    resized_size=(512, 512, None),
                    div_size=(16,16,None),
                    ifcheck_volume=False,
                    ifcheck_sclices=False,)
    from tqdm import tqdm
    parameter_file=r'.\test.txt'
    for data in tqdm(train_loader):
         with open(parameter_file, 'a') as f:
            f.write('image batch:' + str(data["image"].shape)+'\n')
            f.write('label batch:' + str(data["label"].shape)+'\n')
            f.write('\n')