zy7_oldserver
1
fd601de
import nibabel as nib
from monai.transforms import (
Compose,
EnsureChannelFirst,
Rotate90,
ResizeWithPadOrCrop,
)
from monai.transforms import SaveImage
import numpy as np
import os
import torch
# save validation images
'''nib.save(
nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine), os.path.join(output_directory, img_name)
)'''
## some functions for GAN training
# output_train_log: to save training loss log to a text file every epoch
# output_val_log: to save validation metrics to a text file every epoch
import monai
from torch import nn
from torchmetrics import MeanAbsoluteError
from torchmetrics.image import StructuralSimilarityIndexMeasure,PeakSignalNoiseRatio
import numpy as np
import os
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from monai.transforms.utils import allow_missing_keys_mode
# save output images
def group_labels(test_labels):
size_to_labels = {}
labels_group=[]
labels_groups=[]
group_num=0
size_of_labels = [test_labels[0]['target'].shape]
for label in test_labels:
size = label['target'].shape
if size == size_of_labels[group_num]:
labels_group.append(label)
else:
group_num+=1
size_of_labels.append(size)
labels_groups.append(labels_group)
labels_group=[]
labels_group.append(label)
#print(size)
#print(group_num)
labels_groups.append(labels_group)
return labels_groups,size_of_labels
# divide the different patients from val_outputs
def write_nifti(val_outputs, output_dir=f'.\logs', filename='val'):
labels_groups,size_of_labels=group_labels(val_outputs)
nun_val_patients=len(labels_groups)
for i in range(nun_val_patients):
val_output=labels_groups[i]
# unsqueeze means add a dimension at the position of 3, and then use cat to combine the slices at this position
concatenated_outputs = torch.cat([label['target'].unsqueeze(3) for label in val_output], dim=3)
print(concatenated_outputs.shape)
SaveImage(output_dir=output_dir, output_postfix=f'{filename}_{i}',resample=True)(concatenated_outputs.detach().cpu())#torch.tensor(concatenated_outputs)
def write_nifti_volume(val_outputs, output_dir=f'.\logs', filename='val'):
SaveImage(output_dir=output_dir, output_postfix=f'{filename}',resample=True)(val_outputs.detach().cpu())
def reverse_transforms(output_images, orig_images,transforms):
# reverse the transforms
output_images.applied_operations = orig_images.applied_operations
val_output_dict = {"target": output_images[0,:,:,:,:]} # always set val_batch_size=1
with allow_missing_keys_mode(transforms):
reversed_images_dict=transforms.inverse(val_output_dict)
reversed_images=reversed_images_dict["target"]
return reversed_images
def calculate_ssim(pred, target):
ssim = StructuralSimilarityIndexMeasure().to(pred.device)
return ssim(pred, target)
def calculate_mae(pred, target):
mae = MeanAbsoluteError().to(pred.device)
return mae(pred, target)
def calculate_psnr(pred, target):
psnr = PeakSignalNoiseRatio().to(pred.device)
return psnr(pred, target)
def val_log(epoch, step, gen_image, orig_image, saved_path):
val_ssim=calculate_ssim(gen_image,orig_image)
val_mae=calculate_mae(gen_image,orig_image)
val_psnr=calculate_psnr(gen_image,orig_image)
print(f"val_ssim: {val_ssim}, val_mae: {val_mae}, val_psnr: {val_psnr}.")
val_metrices = {'ssim': val_ssim, 'mae': val_mae, 'psnr':val_psnr}
infer_log_file=os.path.join(saved_path, "infer_log.txt")
output_val_log(epoch, step, infer_log_file, val_metrices)
return val_metrices, infer_log_file
def output_val_log(epoch, val_step,val_log_file=r'.\logs\val_log.txt',val_metrices={'ssim': 0, 'mae': 0, 'psnr':0}):
# Save validation log to a text file every epoch
ssim=val_metrices['ssim'] if 'ssim' in val_metrices else 0
mae=val_metrices['mae'] if 'mae' in val_metrices else 0
psnr=val_metrices['psnr'] if 'psnr' in val_metrices else 0
with open(val_log_file, 'a') as f: # append mode
f.write(f'epoch {epoch}, val set {val_step}, SSIM: {ssim}, MAE: {mae}, PSNR: {psnr}\n')
def calculate_reverse_info(untransformed_loader):
ct_data_list=[]
mri_data_list=[]
mean_list_ct=[]
std_list_ct=[]
mean_list_mri=[]
std_list_mri=[]
ct_shape_list=[]
mri_shape_list=[]
untransformed_CT_min_list=[]
untransformed_CT_max_list=[]
untransformed_MRI_min_list=[]
untransformed_MRI_max_list=[]
# calculate the mean and std of the original data
for idx, checkdata in enumerate(untransformed_loader):
untransformed_CT=checkdata['target']
untransformed_MRI=checkdata['source']
mean_ct=torch.mean(untransformed_CT.float())
std_ct=torch.std(untransformed_CT.float())
mean_list_ct.append(mean_ct)
std_list_ct.append(std_ct)
mean_mri=torch.mean(untransformed_MRI.float())
std_mri=torch.std(untransformed_MRI.float())
mean_list_mri.append(mean_mri)
std_list_mri.append(std_mri)
ct_shape_list.append(untransformed_CT.shape)
mri_shape_list.append(untransformed_MRI.shape)
untransformed_CT_min_list.append(torch.min(untransformed_CT))
untransformed_CT_max_list.append(torch.max(untransformed_CT))
untransformed_MRI_min_list.append(torch.min(untransformed_MRI))
untransformed_MRI_max_list.append(torch.max(untransformed_MRI))
ct_data_list.append(untransformed_CT)
mri_data_list.append(untransformed_MRI)
all_reverse_info={"CT_mean":mean_list_ct,
"CT_std":std_list_ct,
"MRI_mean":mean_list_mri,
"MRI_std":std_list_mri,
"CT_shape":ct_shape_list,
"MRI_shape":mri_shape_list,
"CT_min":untransformed_CT_min_list,
"CT_max":untransformed_CT_max_list,
"MRI_min":untransformed_MRI_min_list,
"MRI_max":untransformed_MRI_max_list,
"CT_data":ct_data_list,
"MRI_data":mri_data_list}
return all_reverse_info
# Define function to reverse normalization
def reverse_normalize_data(tensor,
mean=None,
std=None,
min_val=None,
max_val=None,
mode='zscore'):
if mode == 'zscore':
return tensor * std + mean if mean is not None and std is not None else tensor
elif mode == 'minmax':
return (tensor+1) /2 * (max_val - min_val) + min_val if min_val is not None and max_val is not None else tensor
elif mode == 'inputonlyminmax' or mode == 'none' or mode == 'inputonlyzscore':
tensor = tensor
return tensor
# Define function to normalize and reverse normalize
def normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'):
if mode == 'zscore':
return (tensor - mean) / std if mean is not None and std is not None else tensor
elif mode == 'minmax': # for minmax to -1 and 1
return (tensor - min_val) / (max_val - min_val) if min_val is not None and max_val is not None else tensor
return tensor
def save_val_images(val_outputs,val_slice_num,val_names,epoch,saved_img_folder):
# save validation images
if val_outputs.shape[0]==sum(val_slice_num):
# isolate different patients' data
# val_data_for_check=val_outputs.clone()
slice_number=val_slice_num # e.g. [200,200,150,230]
val_data_list=[]
check_step=0
for i in slice_number:
val_data0=val_outputs[:i,:,:,:]
val_data_list.append(val_data0)
# delete the first i rows of val_outputs
val_outputs = val_outputs.narrow(0,i,val_outputs.size(0)-i)
# check if the data is isolated correctly
# assert torch.all(val_data_for_check[0:i]==val_data_list[check_step])
check_step+=1
# save validation images
for i in range(len(val_data_list)):
#height=self.shape_list_val[i]["shape"][1] #338
#width=self.shape_list_val[i]["shape"][0] #565
#original_shape=(height,width)
file_name=f'pred_{val_names[i]}_epoch_{epoch+1}'
write_nifti(val_data_list[i],saved_img_folder,file_name)
else:
print(val_outputs.shape[0])
print(sum(val_slice_num))
print("something wrong with validation set, please check")
def compare_imgs(input_imgs, target_imgs, fake_imgs,
saved_name,
imgformat='jpg',
dpi = 500,
model_name='DDPM',):
from PIL import Image
input_imgs = input_imgs.squeeze().cpu().numpy()
input_imgs = (input_imgs * 255).astype(np.uint8)
input_imgs = Image.fromarray(input_imgs)
target_imgs = target_imgs.squeeze().cpu().numpy()
target_imgs = (target_imgs * 255).astype(np.uint8)
target_imgs = Image.fromarray(target_imgs)
fake_imgs = fake_imgs.squeeze().cpu().numpy()
fake_imgs = (fake_imgs * 255).astype(np.uint8)
fake_imgs = Image.fromarray(fake_imgs)
titles = ['MRI', 'CT', model_name]
fig, axs = plt.subplots(1, 3, figsize=(12, 5)) #
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0.1)
plt.margins(0,0)
# MRI image
axs[0].imshow(input_imgs, cmap='gray')
axs[0].set_title(titles[0])
axs[0].axis('off')
# CT image
axs[1].imshow(target_imgs, cmap='gray')
axs[1].set_title(titles[1])
axs[1].axis('off')
# fake image
axs[2].imshow(fake_imgs, cmap='gray')
axs[2].set_title(titles[2])
axs[2].axis('off')
fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig)
# save individual images
# save output image individually
title1 = 'MRI'
fig_mri = plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs, cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_mri.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_mri)
title2 = 'CT'
fig_ct = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(target_imgs, cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_ct.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_ct)
title3 = model_name
fig_fake = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(fake_imgs, cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_fake.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_fake)
# Define function to save images
def save_single_image(input_imgs,filename, imgformat, dpi=300):
plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs, cmap='gray')
plt.savefig(filename, format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close()
def arrange_images(input_imgs,
label_imgs,
fake_imgs,
model_name,
saved_name,
imgformat='jpg',
dpi = 500):
titles = ['MRI', 'CT', model_name]
fig, axs = plt.subplots(1, 3, figsize=(12, 5)) #
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0.1)
plt.margins(0,0)
cnt = 0
#print(gen_imgs[cnt].shape)
axs[0].imshow(input_imgs, cmap='gray') # 0,0,
axs[0].set_title(titles[0])
axs[0].axis('off')
axs[1].imshow(label_imgs, cmap='gray')
axs[1].set_title(titles[1])
axs[1].axis('off')
axs[2].imshow(fake_imgs, cmap='gray')
axs[2].set_title(titles[2])
axs[2].axis('off')
# save image as png
fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi)
#plt.show()
plt.close(fig)
# Define function to plot histograms
def plot_histogram(data, title, ax, color='blue', alpha=0.7,
x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000):
#x_lower_limit, x_upper_limit = -100, 300 #-1100, 3000
#y_lower_limit, y_upper_limit = 0, 15000
bins = 256
ax.hist(data.flatten(), bins=bins,range=(x_lower_limit, x_upper_limit), color=color, alpha=alpha)
ax.set_ylim([y_lower_limit, y_upper_limit])
ax.set_title(title)
ax.set_xlabel('Pixel intensity')
ax.set_ylabel('Frequency')
# Arrange three histograms
def arrange_histograms(original, reversed, saved_name, titles=['original','reversed'], dpi=300, mode='ct'):
# Plot histograms
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
plot_histogram(original, f'Histogram for {titles[0]}', axs[0],color='red')
plot_histogram(reversed, f'Histogram for {titles[1]}', axs[1],color='green')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# Arrange three histograms
def arrange_3_histograms(source,target, output, saved_name , dpi=300,
x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000):
# Plot histograms
fig, axs = plt.subplots(3, 1, figsize=(10, 8))
plot_histogram(source, f'Histogram for source', axs[0],color='red',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
plot_histogram(target, f'Histogram for target', axs[1],color='green',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
plot_histogram(output, f'Histogram for output', axs[2],color='blue',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
#plot_histogram(transformed, f'Histogram for transformed {mode}', axs[2],color='blue')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# boxplot
data = [source.flatten(), target.flatten(), output.flatten()]
plt.boxplot(data)
plt.xticks([1, 2, 3], ['Source', 'Target', 'Fake'])
plt.title('Pixel Value Distribution')
plt.xlabel('Image Type')
plt.ylabel('Pixel Values')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name.replace('histogram','boxplot'), dpi=dpi)
plt.close()
def arrange_4_histograms(real1,fake1, real2, fake2, saved_name , dpi=300):
# Plot histograms
fig, axs = plt.subplots(4, 1, figsize=(10, 8))
plot_histogram(real1, f'Histogram for real1', axs[0],color='red')
plot_histogram(fake1, f'Histogram for fake1', axs[1],color='red')
plot_histogram(real2, f'Histogram for real2', axs[2],color='green')
plot_histogram(fake2, f'Histogram for fake2', axs[3],color='green')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# save output images
def sample_images(model, input, label,slice_idx, epoch, batch_i, saved_folder, model_name='model'):
fake = model(input)
input_imgs=input.cpu().detach().numpy()
label_imgs=label.cpu().detach().numpy()
fake_imgs=fake.cpu().detach().numpy()
gen_imgs = np.concatenate(
[[input_imgs[slice_idx,0,:,:].squeeze()],
[label_imgs[slice_idx,0,:,:].squeeze()],
[fake_imgs[slice_idx,0,:,:].squeeze()]])
if not os.path.exists(saved_folder):
os.makedirs(saved_folder)
saved_name=os.path.join(saved_folder,f"{epoch}_{batch_i}.jpg")
titles = ['MRI', 'CT', 'Translated']
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
cnt = 0
for j in range(3):
#print(gen_imgs[cnt].shape)
axs[j].imshow(gen_imgs[cnt], cmap='gray')
axs[j].set_title(titles[j])
axs[j].axis('off')
cnt += 1
fig.savefig(saved_name)
#plt.show()
plt.close(fig)
# save individual images
# save output image individually
title1 = 'MRI'
fig_mri, axs_mri = plt.subplots(1, 1) #, figsize=(5, 4))
axs_mri.imshow(gen_imgs[0].squeeze(), cmap='gray')
axs_mri.set_title(title1)
axs_mri.axis('off')
fig_mri.savefig(saved_name.replace('.jpg','_mri.jpg'))
plt.close(fig_mri)
title2 = 'CT'
fig_ct, axs_ct = plt.subplots(1, 1)
axs_ct.imshow(gen_imgs[1].squeeze(), cmap='gray')
axs_ct.set_title(title2)
axs_ct.axis('off')
fig_ct.savefig(saved_name.replace('.jpg','_ct.jpg'))
plt.close(fig_ct)
title3 = model_name
fig_fake, axs_fake = plt.subplots(1, 1)
axs_fake.imshow(gen_imgs[2].squeeze(), cmap='gray')
axs_fake.set_title(title3)
axs_fake.axis('off')
fig_fake.savefig(saved_name.replace('.jpg','_fake.jpg'))
plt.close(fig_fake)
def save_images(input_imgs, label_imgs,fake_imgs,
slice_idx,
saved_name='./test.jpg',
imgformat='jpg',
dpi = 1000,
model_name='model'):
titles = ['MRI', 'CT', model_name]
fig, axs = plt.subplots(1, 3, figsize=(12, 5)) #
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0.1)
plt.margins(0,0)
cnt = 0
#print(gen_imgs[cnt].shape)
axs[0].imshow(input_imgs[:,:,slice_idx].squeeze(), cmap='gray') # 0,0,
axs[0].set_title(titles[0])
axs[0].axis('off')
axs[1].imshow(label_imgs[:,:,slice_idx], cmap='gray')
axs[1].set_title(titles[1])
axs[1].axis('off')
axs[2].imshow(fake_imgs[:,:,slice_idx].squeeze(), cmap='gray')
axs[2].set_title(titles[2])
axs[2].axis('off')
# save image as png
fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi)
#plt.show()
plt.close(fig)
# save individual images
# save output image individually
title1 = 'MRI'
fig_mri = plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs[:,:,slice_idx].squeeze(), cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_mri.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_mri)
title2 = 'CT'
fig_ct = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(label_imgs[:,:,slice_idx].squeeze(), cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_ct.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_ct)
title3 = model_name
fig_fake = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(fake_imgs[:,:,slice_idx].squeeze(), cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_fake.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_fake)
# save output images
def sample_images2(model, input, label,slice_idx, epoch, batch_i, saved_folder):
if not os.path.exists(saved_folder):
os.makedirs(saved_folder)
saved_name=f"{epoch}_{batch_i}.jpg"
fake = model(input)
input_imgs=input.cpu().detach().numpy()
target_imgs=label.cpu().detach().numpy()
fake_imags=fake.cpu().detach().numpy()
gen_imgs = np.concatenate(
[[input_imgs[slice_idx,0,:,:].squeeze()],
[target_imgs[slice_idx,0,:,:].squeeze()],
[fake_imags[slice_idx,0,:,:].squeeze()]])
titles = ['MRI', 'CT', 'Translated']
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
cnt = 0
for j in range(3):
#print(gen_imgs[cnt].shape)
axs[j].imshow(gen_imgs[cnt], cmap='gray')
axs[j].set_title(titles[j])
axs[j].axis('off')
cnt += 1
fig.savefig(os.path.join(saved_folder,saved_name))
#plt.show()
plt.close(fig)
def sample_images_3D(model, input, label, epoch, batch_i, saved_folder):
fake = model(input)
input_imgs=input.cpu().detach().numpy()
target_imgs=label.cpu().detach().numpy()
fake_imags=fake.cpu().detach().numpy()
try:
gen_imgs = np.concatenate(
[[input_imgs[0,0,:,:,50].squeeze()],
[target_imgs[0,0,:,:,50].squeeze()],
[fake_imags[0,0,:,:,50].squeeze()]])
except:
gen_imgs = np.concatenate(
[[input_imgs[0,0,:,:,10].squeeze()],
[target_imgs[0,0,:,:,10].squeeze()],
[fake_imags[0,0,:,:,10].squeeze()]])
titles = ['MRI', 'CT', 'Translated']
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
cnt = 0
for j in range(3):
#print(gen_imgs[cnt].shape)
axs[j].imshow(gen_imgs[cnt], cmap='gray')
axs[j].set_title(titles[j])
axs[j].axis('off')
cnt += 1
if not os.path.exists(saved_folder):
os.makedirs(saved_folder)
saved_name=f"{epoch}_{batch_i}.jpg"
fig.savefig(os.path.join(saved_folder,saved_name))
#plt.show()
plt.close(fig)