zy7_oldserver
1
fd601de
import monai
import os
import numpy as np
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
SqueezeDimd,
CenterSpatialCropd,
)
from monai.data import Dataset
from torch.utils.data import DataLoader
import torch
from .checkdata import check_volumes, save_volumes, check_batch_data
from .basics import get_file_list,crop_volumes, load_volumes, get_transforms
def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,window_width=1,ifcheck=True):
patch_func = monai.data.PatchIterd(
keys=["source", "target"],
patch_size=(None, None, window_width), # dynamic first two dimensions
start_pos=(0, 0, 0)
)
if window_width==1:
patch_transform = Compose(
[
SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
]
)
else:
patch_transform = None
# for training
train_patch_ds = monai.data.GridPatchDataset(
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
train_loader = DataLoader(
train_patch_ds,
batch_size=train_batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
# for validation
val_loader = DataLoader(
val_volume_ds,
num_workers=1,
batch_size=val_batch_size,
pin_memory=torch.cuda.is_available())
if ifcheck:
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
return train_loader,val_loader
def load_batch_slices3D(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,ifcheck=True):
patch_func = monai.data.PatchIterd(
keys=["source", "target"],
patch_size=(None, None,32), # dynamic first two dimensions
start_pos=(0, 0, 0)
)
# for training
train_patch_ds = monai.data.GridPatchDataset(
data=train_volume_ds, patch_iter=patch_func, with_coordinates=False)
train_loader = DataLoader(
train_patch_ds,
batch_size=train_batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
# for validation
val_loader = DataLoader(
val_volume_ds,
num_workers=1,
batch_size=val_batch_size,
pin_memory=torch.cuda.is_available())
if ifcheck:
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
return train_loader,val_loader
def myslicesloader(configs,paths):
data_path=configs.dataset.data_dir
train_number=configs.dataset.train_number
val_number=configs.dataset.val_number
train_batch_size=configs.dataset.batch_size
val_batch_size=configs.dataset.val_batch_size
saved_name_train=paths["saved_name_train"]
saved_name_val=paths["saved_name_val"]
center_crop=configs.dataset.center_crop
source=configs.dataset.source
target=configs.dataset.target
ifcheck_volume=False
ifcheck_sclices=False
# volume-level transforms for both image and label
train_transforms = get_transforms(configs,mode='train')
val_transforms = get_transforms(configs,mode='val')
#list all files in the folder
file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
file_list_path=[os.path.join(data_path,i) for i in file_list]
#list all ct and mr files in folder
mask='mask'
source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path]
mask_file_list=[os.path.join(j,f'{mask}.nii.gz') for j in file_list_path]
train_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
val_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
print('all files in dataset:',len(file_list))
# load volumes and center crop
if center_crop>0:
crop=Compose([LoadImaged(keys=["source", "target", "mask"]),
EnsureChannelFirstd(keys=["source", "target", "mask"]),
CenterSpatialCropd(keys=["source", "target", "mask"], roi_size=(-1,-1,center_crop)),
])
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
print('center crop:',center_crop)
else:
crop=Compose([LoadImaged(keys=["source", "target", "mask"]),
EnsureChannelFirstd(keys=["source", "target", "mask"]),
])
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
# load volumes
train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
ifsave,ifcheck=False,False
if ifsave:
save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
if ifcheck:
check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
# batch-level slicer for both image and label
window_width=1
patch_func = monai.data.PatchIterd(
keys=["source", "target", "mask"],
patch_size=(None, None, window_width), # dynamic first two dimensions
start_pos=(0, 0, 0)
)
if window_width==1:
patch_transform = Compose(
[
SqueezeDimd(keys=["source", "target", "mask"], dim=-1), # squeeze the last dim
]
)
else:
patch_transform = None
# for training
train_patch_ds = monai.data.GridPatchDataset(
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
train_loader = DataLoader(
train_patch_ds,
batch_size=train_batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
# for validation
val_loader = DataLoader(
val_volume_ds,
num_workers=1,
batch_size=val_batch_size,
pin_memory=torch.cuda.is_available())
if ifcheck:
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
return train_crop_ds,val_crop_ds,train_loader,val_loader,train_transforms,val_transforms
def mydataloader_3d(data_pelvis_path,
train_number,
val_number,
train_batch_size,
val_batch_size,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
resized_size=(600,400,150),
div_size=(16,16,16),
ifcheck_volume=True,):
# volume-level transforms for both image and segmentation
normalize='zscore'
train_transforms = get_transforms(normalize,resized_size,div_size)
train_ds, val_ds = get_file_list(data_pelvis_path,
train_number,
val_number)
#train_volume_ds, val_volume_ds
train_volume_ds,val_volume_ds = load_volumes(train_transforms=train_transforms,
train_ds=train_ds,
val_ds=val_ds,
saved_name_train=saved_name_train,
saved_name_val=saved_name_train,
ifsave=True,
ifcheck=ifcheck_volume)
'''
train_loader = DataLoader(train_volume_ds, batch_size=train_batch_size)
val_loader = DataLoader(val_volume_ds, batch_size=val_batch_size)
'''
ifcheck_sclices=False
train_loader,val_loader = load_batch_slices3D(train_volume_ds,
val_volume_ds,
train_batch_size,
val_batch_size=val_batch_size,
ifcheck=ifcheck_sclices)
return train_loader,val_loader,train_transforms
from torchvision.utils import save_image
def save_dataset_as_png(train_ds, train_volume_ds,saved_img_folder,saved_label_folder):
train_loader = DataLoader(train_volume_ds, batch_size=1)
for idx, train_check_data in enumerate(train_loader):
image_volume = train_check_data['image']
label_volume = train_check_data['label']
current_item = train_ds[idx]
file_name_prex = os.path.basename(os.path.dirname(current_item['image']))
slices_num=image_volume.shape[-1]
for i in range(slices_num):
image_i=image_volume[0,0,:,:,i]
label_i=label_volume[0,0,:,:,i]
#print(label_volume.shape)
#SaveImage(output_dir=saved_img_folder, output_postfix=f'{file_name_prex}_image', output_ext='.png', resample=True)(image_volume[0,:,:,:,0])
save_image(image_i, f'{saved_img_folder}\{file_name_prex}_image_{i}.png')
save_image(label_i, f'{saved_label_folder}\{file_name_prex}_label_{i}.png')
def pre_dataset_for_stylegan(data_pelvis_path,
normalize,
train_number,
val_number,
saved_img_folder,
saved_label_folder,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
resized_size=(600,400,None),
div_size=(16,16,None),):
train_transforms = get_transforms(normalize,resized_size,div_size)
train_ds, val_ds = get_file_list(data_pelvis_path,
train_number,
val_number)
train_volume_ds, _ = load_volumes(train_transforms,
train_ds,
val_ds,
saved_name_train,
saved_name_val,
ifsave=False,
ifcheck=False)
save_dataset_as_png(train_ds, train_volume_ds,saved_img_folder,saved_label_folder)
return train_ds,train_volume_ds
def sum_slices(data_pelvis_path, num=180):
train_ds, val_ds=get_file_list(data_pelvis_path, 0, num)
train_ds_2d, val_ds_2d,\
all_slices_train,all_slices_val,\
shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
ifsave=False)
print(all_slices_val)
return all_slices_val
def transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val,ifsave=True):
# Load 2D slices of CT images
train_ds_2d = []
val_ds_2d = []
shape_list_train = []
shape_list_val = []
all_slices_train=0
all_slices_val=0
# Load 2D slices for training
for sample in train_ds:
train_ds_2d_image = LoadImaged(keys=["source","target"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
name = os.path.basename(os.path.dirname(sample['image']))
num_slices = train_ds_2d_image["source"].shape[-1]
shape_list_train.append({'patient': name, 'shape': train_ds_2d_image["image"].shape})
for i in range(num_slices):
train_ds_2d.append({'image': train_ds_2d_image['image'][:,:,i], 'label': train_ds_2d_image['label'][:,:,i]})
all_slices_train += num_slices
# Load 2D slices for validation
for sample in val_ds:
val_ds_2d_image = LoadImaged(keys=["source","target"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
name = os.path.basename(os.path.dirname(sample['image']))
shape_list_val.append({'patient': name, 'shape': val_ds_2d_image["image"].shape})
num_slices = val_ds_2d_image["image"].shape[-1]
for i in range(num_slices):
val_ds_2d.append({'image': val_ds_2d_image['image'][:,:,i], 'label': val_ds_2d_image['label'][:,:,i]})
all_slices_val += num_slices
# Save shape list to csv
if ifsave:
np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
return train_ds_2d, val_ds_2d, all_slices_train, all_slices_val, shape_list_train, shape_list_val
def get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size, val_batch_size,normalize, resized_size=(600,400), div_size=(16,16,None),):
# Define transforms
train_transforms = get_transforms(normalize,resized_size,div_size)
train_transforms_list=train_transforms.__dict__['transforms']
batch_size = batch_size
# Create training dataset and data loader
train_dataset = Dataset(data=train_ds_2d, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
val_batch_size = val_batch_size
# Create validation dataset and data loader
val_dataset = Dataset(data=val_ds_2d, transform=train_transforms)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=1, pin_memory=True)
return train_loader, val_loader, train_transforms_list,train_transforms
def mydataloader(data_pelvis_path,
train_number,
val_number,
batch_size,
val_batch_size,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
resized_size=(600,400)):
train_ds, val_ds = get_file_list(data_pelvis_path,
train_number,
val_number)
train_ds_2d, val_ds_2d,\
all_slices_train,all_slices_val,\
shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds,
saved_name_train,
saved_name_val,ifsave=True)
train_loader, val_loader, \
train_transforms_list,train_transforms = get_train_val_loaders(train_ds_2d,
val_ds_2d,
batch_size=batch_size,
val_batch_size=val_batch_size,
normalize='zscore',
resized_size=resized_size,
div_size=(16,16,None),)
return train_loader,val_loader,\
train_transforms_list,train_transforms,\
all_slices_train,all_slices_val,\
shape_list_train,shape_list_val