from torch.utils.data import DataLoader import numpy as np import os def check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds): # use batch_size=1 to check the volumes because the input volumes have different shapes train_loader = DataLoader(train_volume_ds, batch_size=1) val_loader = DataLoader(val_volume_ds, batch_size=1) train_iterator = iter(train_loader) val_iterator = iter(val_loader) print('check training data:') idx=0 for idx in range(len(train_loader)): try: train_check_data = next(train_iterator) ds_idx = idx * 1 current_item = train_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['image'])) print(idx, current_name, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape) except: ds_idx = idx * 1 current_item = train_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['image'])) print('check data error! Check the input data:',current_name) print("checked all training data.") print('check validation data:') idx=0 for idx in range(len(val_loader)): try: val_check_data = next(val_iterator) ds_idx = idx * 1 current_item = val_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['image'])) print(idx, current_name, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape) except: ds_idx = idx * 1 current_item = val_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['image'])) print('check data error! Check the input data:',current_name) print("checked all validation data.") def save_volumes(train_ds, val_ds, saved_name_train, saved_name_val): shape_list_train=[] shape_list_val=[] # use the function of saving information before for sample in train_ds: name = os.path.basename(os.path.dirname(sample['image'])) shape_list_train.append({'patient': name}) for sample in val_ds: name = os.path.basename(os.path.dirname(sample['image'])) shape_list_val.append({'patient': name}) np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string def check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size): for idx, train_check_data in enumerate(train_loader): ds_idx = idx * train_batch_size current_item = train_patch_ds[ds_idx] print('check train data:') print(current_item, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape) for idx, val_check_data in enumerate(val_loader): ds_idx = idx * val_batch_size current_item = val_volume_ds[ds_idx] print('check val data:') print(current_item, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)