from mydataloader.basics import get_transforms, get_file_list, load_volumes, crop_volumes from torch.utils.data import DataLoader import torch import os from monai.transforms.utils import allow_missing_keys_mode import matplotlib.pyplot as plt import numpy as np import nibabel as nib # Define function to save images def save_image(image, filename, idx, title, dpi=300): plt.figure() plt.imshow(image, cmap='gray') plt.axis('off') plt.title(title) plt.savefig(f"{filename}_{idx}.png", format='png', bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close() # Define function to plot histograms def plot_histogram(data, title, ax, color='blue', alpha=0.7): bins = 256 ax.hist(data.flatten(), bins=bins, color=color, alpha=alpha) ax.set_title(title) ax.set_xlabel('Pixel intensity') ax.set_ylabel('Frequency') # Arrange three histograms def arrange_histograms(original, transformed, reversed, mode='ct'): # Plot histograms fig, axs = plt.subplots(3, 1, figsize=(10, 8)) plot_histogram(original, f'Histogram for original {mode}', axs[0],color='red') plot_histogram(transformed, f'Histogram for transformed {mode}', axs[1],color='green') plot_histogram(reversed, f'Histogram for reversed {mode}', axs[2],color='blue') # Show and save the histogram figure plt.tight_layout() plt.savefig(os.path.join(save_folder, f"{idx}_histograms_{mode}.png"), dpi=300) plt.close(fig) # Define function to normalize and reverse normalize def normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'): if mode == 'zscore': return (tensor - mean) / std if mean is not None and std is not None else tensor elif mode == 'minmax': # for minmax to -1 and 1 return (tensor - min_val) / (max_val - min_val) if min_val is not None and max_val is not None else tensor elif mode == 'none': return tensor return tensor # Define function to reverse normalization def reverse_normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'): if mode == 'zscore': return tensor * std + mean if mean is not None and std is not None else tensor elif mode == 'minmax': return (tensor+1) /2 * (max_val - min_val) + min_val if min_val is not None and max_val is not None else tensor elif mode == 'none': return tensor return tensor # Normalization settings normalization_methods = { 'zscore': {'apply': normalize_data, 'reverse': reverse_normalize_data}, 'minmax': {'apply': normalize_data, 'reverse': reverse_normalize_data}, 'inputonly': {'apply': lambda x: x, 'reverse': reverse_normalize_data}, 'none': {'apply': lambda x: x, 'reverse': reverse_normalize_data} } # Other settings dataset_path = r'D:\Projects\data\Task1\pelvis' normalize = 'zscore' pad = 'minimum' train_number = 1 val_number = 1 train_batch_size = 8 val_batch_size = 1 saved_name_train = './train_ds_2d.csv' saved_name_val = './val_ds_2d.csv' resized_size = (512, 512, None) div_size = (16, 16, None) center_crop = 0 ifcheck_volume = False ifcheck_sclices = False save_folder = f'./logs/test_{normalize}' os.makedirs(save_folder, exist_ok=True) # Define your transforms, file lists, dataloaders, etc... # volume-level transforms for both image and label train_transforms = get_transforms(normalize,pad,resized_size,div_size, mode='train') train_ds, val_ds = get_file_list(dataset_path, train_number, val_number, source='mr', target='ct',) train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop) untransformed_loader=DataLoader(train_crop_ds, batch_size=1) train_ds, val_ds = load_volumes(train_transforms, train_crop_ds, val_crop_ds, train_ds, val_ds, saved_name_train, saved_name_val, ifsave=False, ifcheck=ifcheck_volume) transformed_loader = DataLoader(train_ds, batch_size=1) ct_data_list=[] mri_data_list=[] mean_list_ct=[] std_list_ct=[] mean_list_mri=[] std_list_mri=[] ct_shape_list=[] mri_shape_list=[] untransformed_CT_min_list=[] untransformed_CT_max_list=[] untransformed_MRI_min_list=[] untransformed_MRI_max_list=[] # calculate the mean and std of the original data for idx, checkdata in enumerate(untransformed_loader): untransformed_CT=checkdata['target'] untransformed_MRI=checkdata['source'] mean_ct=torch.mean(untransformed_CT.float()) std_ct=torch.std(untransformed_CT.float()) mean_list_ct.append(mean_ct) std_list_ct.append(std_ct) mean_mri=torch.mean(untransformed_MRI.float()) std_mri=torch.std(untransformed_MRI.float()) mean_list_mri.append(mean_mri) std_list_mri.append(std_mri) ct_shape_list.append(untransformed_CT.shape) mri_shape_list.append(untransformed_MRI.shape) untransformed_CT_min_list.append(torch.min(untransformed_CT)) untransformed_CT_max_list.append(torch.max(untransformed_CT)) untransformed_MRI_min_list.append(torch.min(untransformed_MRI)) untransformed_MRI_max_list.append(torch.max(untransformed_MRI)) ct_data_list.append(untransformed_CT) mri_data_list.append(untransformed_MRI) # Process datasets for idx, checkdata in enumerate(transformed_loader): # Get your data (untransformed and transformed) transformed_CT=checkdata['target'] transformed_MRI=checkdata['source'] dict = {'target': transformed_CT[0,:,:,:,:], "source": transformed_MRI[0,:,:,:,:]} with allow_missing_keys_mode(train_transforms): reversed_dict=train_transforms.inverse(dict) reversed_ct=reversed_dict['target'] reversed_mri=reversed_dict["source"] print(f"{idx} original CT data shape:",ct_shape_list[idx]) print(f"{idx} transformed CT data shape:", transformed_CT.shape) print (f"{idx} reversed CT shape:",reversed_ct.shape) print(f"{idx} original MRI data shape:",mri_shape_list[idx]) print(f"{idx} transformed MRI data shape:", transformed_MRI.shape) print(f"{idx} transformed MRI data shape:", transformed_MRI.shape) reversed_ct = reversed_ct.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104] transformed_CT = transformed_CT.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104] reversed_mri = reversed_mri.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104] transformed_MRI = transformed_MRI.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104] # Normalize and reverse normalization norm_method = normalization_methods[normalize] reversed_ct = norm_method['reverse'](transformed_CT, mean=mean_list_ct[idx], std=std_list_ct[idx], min_val=untransformed_CT_min_list[idx], max_val=untransformed_CT_max_list[idx], mode=normalize) reversed_mri = norm_method['reverse'](transformed_MRI, mean=mean_list_mri[idx], std=std_list_mri[idx], min_val=untransformed_MRI_min_list[idx], max_val=untransformed_MRI_max_list[idx], mode=normalize) ## compare the min and max value of the original and reversed data print(f"{idx} untransformed ct data min and max: {untransformed_CT_min_list[idx]}, {untransformed_CT_max_list[idx]}") print(f"{idx} transformed ct data min and max: {torch.min(transformed_CT)}, {torch.max(transformed_CT)}") print(f"{idx} reversed ct min and max: {torch.min(reversed_ct)}, {torch.max(reversed_ct)}") print(f"{idx} untransformed mri data min and max: {untransformed_MRI_min_list[idx]}, {untransformed_MRI_max_list[idx]}") print(f"{idx} transformed mri data min and max: {torch.min(transformed_MRI)}, {torch.max(transformed_MRI)}") print(f"{idx} reversed mri min and max: {torch.min(reversed_mri)}, {torch.max(reversed_mri)}") # Save images for i in range(reversed_ct.shape[-1]): if 46 <= i <= 50: save_image(ct_data_list[idx][0, 0, :, :, i], os.path.join(save_folder, f"original_ct_{i}"), idx, "Original CT") save_image(transformed_CT[:, :, i], os.path.join(save_folder, f"transformed_ct_{i}"), idx, "Transformed CT") save_image(reversed_ct[:, :, i], os.path.join(save_folder, f"reversed_ct_{i}"), idx, "Reversed CT") # ... Repeat for MRI images save_image(mri_data_list[idx][0, 0, :, :, i], os.path.join(save_folder, f"original_mri_{i}"), idx, "Original MRI") save_image(transformed_MRI[:, :, i], os.path.join(save_folder, f"transformed_mri_{i}"), idx, "Transformed MRI") save_image(reversed_mri[:, :, i], os.path.join(save_folder, f"reversed_mri_{i}"), idx, "Reversed MRI") arrange_histograms(ct_data_list[idx].numpy(), transformed_CT.numpy(), reversed_ct.numpy(), mode='ct') arrange_histograms(mri_data_list[idx].numpy(), transformed_MRI.numpy(), reversed_mri.numpy(), mode='mri')