zy7_oldserver
1
fd601de
from mydataloader.basics import get_transforms, get_file_list, load_volumes, crop_volumes
from torch.utils.data import DataLoader
import torch
import os
from monai.transforms.utils import allow_missing_keys_mode
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
# Define function to save images
def save_image(image, filename, idx, title, dpi=300):
plt.figure()
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.title(title)
plt.savefig(f"{filename}_{idx}.png", format='png', bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close()
# Define function to plot histograms
def plot_histogram(data, title, ax, color='blue', alpha=0.7):
bins = 256
ax.hist(data.flatten(), bins=bins, color=color, alpha=alpha)
ax.set_title(title)
ax.set_xlabel('Pixel intensity')
ax.set_ylabel('Frequency')
# Arrange three histograms
def arrange_histograms(original, transformed, reversed, mode='ct'):
# Plot histograms
fig, axs = plt.subplots(3, 1, figsize=(10, 8))
plot_histogram(original, f'Histogram for original {mode}', axs[0],color='red')
plot_histogram(transformed, f'Histogram for transformed {mode}', axs[1],color='green')
plot_histogram(reversed, f'Histogram for reversed {mode}', axs[2],color='blue')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(os.path.join(save_folder, f"{idx}_histograms_{mode}.png"), dpi=300)
plt.close(fig)
# Define function to normalize and reverse normalize
def normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'):
if mode == 'zscore':
return (tensor - mean) / std if mean is not None and std is not None else tensor
elif mode == 'minmax': # for minmax to -1 and 1
return (tensor - min_val) / (max_val - min_val) if min_val is not None and max_val is not None else tensor
elif mode == 'none':
return tensor
return tensor
# Define function to reverse normalization
def reverse_normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'):
if mode == 'zscore':
return tensor * std + mean if mean is not None and std is not None else tensor
elif mode == 'minmax':
return (tensor+1) /2 * (max_val - min_val) + min_val if min_val is not None and max_val is not None else tensor
elif mode == 'none':
return tensor
return tensor
# Normalization settings
normalization_methods = {
'zscore': {'apply': normalize_data, 'reverse': reverse_normalize_data},
'minmax': {'apply': normalize_data, 'reverse': reverse_normalize_data},
'inputonly': {'apply': lambda x: x, 'reverse': reverse_normalize_data},
'none': {'apply': lambda x: x, 'reverse': reverse_normalize_data}
}
# Other settings
dataset_path = r'D:\Projects\data\Task1\pelvis'
normalize = 'zscore'
pad = 'minimum'
train_number = 1
val_number = 1
train_batch_size = 8
val_batch_size = 1
saved_name_train = './train_ds_2d.csv'
saved_name_val = './val_ds_2d.csv'
resized_size = (512, 512, None)
div_size = (16, 16, None)
center_crop = 0
ifcheck_volume = False
ifcheck_sclices = False
save_folder = f'./logs/test_{normalize}'
os.makedirs(save_folder, exist_ok=True)
# Define your transforms, file lists, dataloaders, etc...
# volume-level transforms for both image and label
train_transforms = get_transforms(normalize,pad,resized_size,div_size, mode='train')
train_ds, val_ds = get_file_list(dataset_path,
train_number,
val_number,
source='mr',
target='ct',)
train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop)
untransformed_loader=DataLoader(train_crop_ds, batch_size=1)
train_ds, val_ds = load_volumes(train_transforms,
train_crop_ds, val_crop_ds,
train_ds,
val_ds,
saved_name_train,
saved_name_val,
ifsave=False,
ifcheck=ifcheck_volume)
transformed_loader = DataLoader(train_ds, batch_size=1)
ct_data_list=[]
mri_data_list=[]
mean_list_ct=[]
std_list_ct=[]
mean_list_mri=[]
std_list_mri=[]
ct_shape_list=[]
mri_shape_list=[]
untransformed_CT_min_list=[]
untransformed_CT_max_list=[]
untransformed_MRI_min_list=[]
untransformed_MRI_max_list=[]
# calculate the mean and std of the original data
for idx, checkdata in enumerate(untransformed_loader):
untransformed_CT=checkdata['target']
untransformed_MRI=checkdata['source']
mean_ct=torch.mean(untransformed_CT.float())
std_ct=torch.std(untransformed_CT.float())
mean_list_ct.append(mean_ct)
std_list_ct.append(std_ct)
mean_mri=torch.mean(untransformed_MRI.float())
std_mri=torch.std(untransformed_MRI.float())
mean_list_mri.append(mean_mri)
std_list_mri.append(std_mri)
ct_shape_list.append(untransformed_CT.shape)
mri_shape_list.append(untransformed_MRI.shape)
untransformed_CT_min_list.append(torch.min(untransformed_CT))
untransformed_CT_max_list.append(torch.max(untransformed_CT))
untransformed_MRI_min_list.append(torch.min(untransformed_MRI))
untransformed_MRI_max_list.append(torch.max(untransformed_MRI))
ct_data_list.append(untransformed_CT)
mri_data_list.append(untransformed_MRI)
# Process datasets
for idx, checkdata in enumerate(transformed_loader):
# Get your data (untransformed and transformed)
transformed_CT=checkdata['target']
transformed_MRI=checkdata['source']
dict = {'target': transformed_CT[0,:,:,:,:], "source": transformed_MRI[0,:,:,:,:]}
with allow_missing_keys_mode(train_transforms):
reversed_dict=train_transforms.inverse(dict)
reversed_ct=reversed_dict['target']
reversed_mri=reversed_dict["source"]
print(f"{idx} original CT data shape:",ct_shape_list[idx])
print(f"{idx} transformed CT data shape:", transformed_CT.shape)
print (f"{idx} reversed CT shape:",reversed_ct.shape)
print(f"{idx} original MRI data shape:",mri_shape_list[idx])
print(f"{idx} transformed MRI data shape:", transformed_MRI.shape)
print(f"{idx} transformed MRI data shape:", transformed_MRI.shape)
reversed_ct = reversed_ct.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104]
transformed_CT = transformed_CT.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104]
reversed_mri = reversed_mri.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104]
transformed_MRI = transformed_MRI.squeeze().permute(1,0,2) #[452, 315, 104] -> [315, 452, 104]
# Normalize and reverse normalization
norm_method = normalization_methods[normalize]
reversed_ct = norm_method['reverse'](transformed_CT, mean=mean_list_ct[idx], std=std_list_ct[idx],
min_val=untransformed_CT_min_list[idx], max_val=untransformed_CT_max_list[idx], mode=normalize)
reversed_mri = norm_method['reverse'](transformed_MRI, mean=mean_list_mri[idx], std=std_list_mri[idx],
min_val=untransformed_MRI_min_list[idx], max_val=untransformed_MRI_max_list[idx], mode=normalize)
## compare the min and max value of the original and reversed data
print(f"{idx} untransformed ct data min and max: {untransformed_CT_min_list[idx]}, {untransformed_CT_max_list[idx]}")
print(f"{idx} transformed ct data min and max: {torch.min(transformed_CT)}, {torch.max(transformed_CT)}")
print(f"{idx} reversed ct min and max: {torch.min(reversed_ct)}, {torch.max(reversed_ct)}")
print(f"{idx} untransformed mri data min and max: {untransformed_MRI_min_list[idx]}, {untransformed_MRI_max_list[idx]}")
print(f"{idx} transformed mri data min and max: {torch.min(transformed_MRI)}, {torch.max(transformed_MRI)}")
print(f"{idx} reversed mri min and max: {torch.min(reversed_mri)}, {torch.max(reversed_mri)}")
# Save images
for i in range(reversed_ct.shape[-1]):
if 46 <= i <= 50:
save_image(ct_data_list[idx][0, 0, :, :, i], os.path.join(save_folder, f"original_ct_{i}"), idx, "Original CT")
save_image(transformed_CT[:, :, i], os.path.join(save_folder, f"transformed_ct_{i}"), idx, "Transformed CT")
save_image(reversed_ct[:, :, i], os.path.join(save_folder, f"reversed_ct_{i}"), idx, "Reversed CT")
# ... Repeat for MRI images
save_image(mri_data_list[idx][0, 0, :, :, i], os.path.join(save_folder, f"original_mri_{i}"), idx, "Original MRI")
save_image(transformed_MRI[:, :, i], os.path.join(save_folder, f"transformed_mri_{i}"), idx, "Transformed MRI")
save_image(reversed_mri[:, :, i], os.path.join(save_folder, f"reversed_mri_{i}"), idx, "Reversed MRI")
arrange_histograms(ct_data_list[idx].numpy(), transformed_CT.numpy(), reversed_ct.numpy(), mode='ct')
arrange_histograms(mri_data_list[idx].numpy(), transformed_MRI.numpy(), reversed_mri.numpy(), mode='mri')