import nibabel as nib from monai.transforms import ( Compose, EnsureChannelFirst, Rotate90, ResizeWithPadOrCrop, ) from monai.transforms import SaveImage import numpy as np import os import torch # save validation images '''nib.save( nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine), os.path.join(output_directory, img_name) )''' ## some functions for GAN training # output_train_log: to save training loss log to a text file every epoch # output_val_log: to save validation metrics to a text file every epoch import monai from torch import nn from torchmetrics import MeanAbsoluteError from torchmetrics.image import StructuralSimilarityIndexMeasure,PeakSignalNoiseRatio import numpy as np import os import torch import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from monai.transforms.utils import allow_missing_keys_mode # save output images def group_labels(test_labels): size_to_labels = {} labels_group=[] labels_groups=[] group_num=0 size_of_labels = [test_labels[0]['target'].shape] for label in test_labels: size = label['target'].shape if size == size_of_labels[group_num]: labels_group.append(label) else: group_num+=1 size_of_labels.append(size) labels_groups.append(labels_group) labels_group=[] labels_group.append(label) #print(size) #print(group_num) labels_groups.append(labels_group) return labels_groups,size_of_labels # divide the different patients from val_outputs def write_nifti(val_outputs, output_dir=f'.\logs', filename='val'): labels_groups,size_of_labels=group_labels(val_outputs) nun_val_patients=len(labels_groups) for i in range(nun_val_patients): val_output=labels_groups[i] # unsqueeze means add a dimension at the position of 3, and then use cat to combine the slices at this position concatenated_outputs = torch.cat([label['target'].unsqueeze(3) for label in val_output], dim=3) print(concatenated_outputs.shape) SaveImage(output_dir=output_dir, output_postfix=f'{filename}_{i}',resample=True)(concatenated_outputs.detach().cpu())#torch.tensor(concatenated_outputs) def write_nifti_volume(val_outputs, output_dir=f'.\logs', filename='val'): SaveImage(output_dir=output_dir, output_postfix=f'{filename}',resample=True)(val_outputs.detach().cpu()) def reverse_transforms(output_images, orig_images,transforms): # reverse the transforms output_images.applied_operations = orig_images.applied_operations val_output_dict = {"target": output_images[0,:,:,:,:]} # always set val_batch_size=1 with allow_missing_keys_mode(transforms): reversed_images_dict=transforms.inverse(val_output_dict) reversed_images=reversed_images_dict["target"] return reversed_images def calculate_ssim(pred, target): ssim = StructuralSimilarityIndexMeasure().to(pred.device) return ssim(pred, target) def calculate_mae(pred, target): mae = MeanAbsoluteError().to(pred.device) return mae(pred, target) def calculate_psnr(pred, target): psnr = PeakSignalNoiseRatio().to(pred.device) return psnr(pred, target) def val_log(epoch, step, gen_image, orig_image, saved_path): val_ssim=calculate_ssim(gen_image,orig_image) val_mae=calculate_mae(gen_image,orig_image) val_psnr=calculate_psnr(gen_image,orig_image) print(f"val_ssim: {val_ssim}, val_mae: {val_mae}, val_psnr: {val_psnr}.") val_metrices = {'ssim': val_ssim, 'mae': val_mae, 'psnr':val_psnr} infer_log_file=os.path.join(saved_path, "infer_log.txt") output_val_log(epoch, step, infer_log_file, val_metrices) return val_metrices, infer_log_file def output_val_log(epoch, val_step,val_log_file=r'.\logs\val_log.txt',val_metrices={'ssim': 0, 'mae': 0, 'psnr':0}): # Save validation log to a text file every epoch ssim=val_metrices['ssim'] if 'ssim' in val_metrices else 0 mae=val_metrices['mae'] if 'mae' in val_metrices else 0 psnr=val_metrices['psnr'] if 'psnr' in val_metrices else 0 with open(val_log_file, 'a') as f: # append mode f.write(f'epoch {epoch}, val set {val_step}, SSIM: {ssim}, MAE: {mae}, PSNR: {psnr}\n') def calculate_reverse_info(untransformed_loader): ct_data_list=[] mri_data_list=[] mean_list_ct=[] std_list_ct=[] mean_list_mri=[] std_list_mri=[] ct_shape_list=[] mri_shape_list=[] untransformed_CT_min_list=[] untransformed_CT_max_list=[] untransformed_MRI_min_list=[] untransformed_MRI_max_list=[] # calculate the mean and std of the original data for idx, checkdata in enumerate(untransformed_loader): untransformed_CT=checkdata['target'] untransformed_MRI=checkdata['source'] mean_ct=torch.mean(untransformed_CT.float()) std_ct=torch.std(untransformed_CT.float()) mean_list_ct.append(mean_ct) std_list_ct.append(std_ct) mean_mri=torch.mean(untransformed_MRI.float()) std_mri=torch.std(untransformed_MRI.float()) mean_list_mri.append(mean_mri) std_list_mri.append(std_mri) ct_shape_list.append(untransformed_CT.shape) mri_shape_list.append(untransformed_MRI.shape) untransformed_CT_min_list.append(torch.min(untransformed_CT)) untransformed_CT_max_list.append(torch.max(untransformed_CT)) untransformed_MRI_min_list.append(torch.min(untransformed_MRI)) untransformed_MRI_max_list.append(torch.max(untransformed_MRI)) ct_data_list.append(untransformed_CT) mri_data_list.append(untransformed_MRI) all_reverse_info={"CT_mean":mean_list_ct, "CT_std":std_list_ct, "MRI_mean":mean_list_mri, "MRI_std":std_list_mri, "CT_shape":ct_shape_list, "MRI_shape":mri_shape_list, "CT_min":untransformed_CT_min_list, "CT_max":untransformed_CT_max_list, "MRI_min":untransformed_MRI_min_list, "MRI_max":untransformed_MRI_max_list, "CT_data":ct_data_list, "MRI_data":mri_data_list} return all_reverse_info # Define function to reverse normalization def reverse_normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'): if mode == 'zscore': return tensor * std + mean if mean is not None and std is not None else tensor elif mode == 'minmax': return (tensor+1) /2 * (max_val - min_val) + min_val if min_val is not None and max_val is not None else tensor elif mode == 'inputonlyminmax' or mode == 'none' or mode == 'inputonlyzscore': tensor = tensor return tensor # Define function to normalize and reverse normalize def normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'): if mode == 'zscore': return (tensor - mean) / std if mean is not None and std is not None else tensor elif mode == 'minmax': # for minmax to -1 and 1 return (tensor - min_val) / (max_val - min_val) if min_val is not None and max_val is not None else tensor return tensor def save_val_images(val_outputs,val_slice_num,val_names,epoch,saved_img_folder): # save validation images if val_outputs.shape[0]==sum(val_slice_num): # isolate different patients' data # val_data_for_check=val_outputs.clone() slice_number=val_slice_num # e.g. [200,200,150,230] val_data_list=[] check_step=0 for i in slice_number: val_data0=val_outputs[:i,:,:,:] val_data_list.append(val_data0) # delete the first i rows of val_outputs val_outputs = val_outputs.narrow(0,i,val_outputs.size(0)-i) # check if the data is isolated correctly # assert torch.all(val_data_for_check[0:i]==val_data_list[check_step]) check_step+=1 # save validation images for i in range(len(val_data_list)): #height=self.shape_list_val[i]["shape"][1] #338 #width=self.shape_list_val[i]["shape"][0] #565 #original_shape=(height,width) file_name=f'pred_{val_names[i]}_epoch_{epoch+1}' write_nifti(val_data_list[i],saved_img_folder,file_name) else: print(val_outputs.shape[0]) print(sum(val_slice_num)) print("something wrong with validation set, please check") def compare_imgs(input_imgs, target_imgs, fake_imgs, saved_name, imgformat='jpg', dpi = 500, model_name='DDPM',): from PIL import Image input_imgs = input_imgs.squeeze().cpu().numpy() input_imgs = (input_imgs * 255).astype(np.uint8) input_imgs = Image.fromarray(input_imgs) target_imgs = target_imgs.squeeze().cpu().numpy() target_imgs = (target_imgs * 255).astype(np.uint8) target_imgs = Image.fromarray(target_imgs) fake_imgs = fake_imgs.squeeze().cpu().numpy() fake_imgs = (fake_imgs * 255).astype(np.uint8) fake_imgs = Image.fromarray(fake_imgs) titles = ['MRI', 'CT', model_name] fig, axs = plt.subplots(1, 3, figsize=(12, 5)) # plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0.1) plt.margins(0,0) # MRI image axs[0].imshow(input_imgs, cmap='gray') axs[0].set_title(titles[0]) axs[0].axis('off') # CT image axs[1].imshow(target_imgs, cmap='gray') axs[1].set_title(titles[1]) axs[1].axis('off') # fake image axs[2].imshow(fake_imgs, cmap='gray') axs[2].set_title(titles[2]) axs[2].axis('off') fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig) # save individual images # save output image individually title1 = 'MRI' fig_mri = plt.figure() #, figsize=(5, 4)) plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(input_imgs, cmap='gray') plt.savefig(saved_name.replace(f'.{imgformat}',f'_mri.{imgformat}'), format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig_mri) title2 = 'CT' fig_ct = plt.figure() plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(target_imgs, cmap='gray') plt.savefig(saved_name.replace(f'.{imgformat}',f'_ct.{imgformat}'), format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig_ct) title3 = model_name fig_fake = plt.figure() plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(fake_imgs, cmap='gray') plt.savefig(saved_name.replace(f'.{imgformat}',f'_fake.{imgformat}'), format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig_fake) # Define function to save images def save_single_image(input_imgs,filename, imgformat, dpi=300): plt.figure() #, figsize=(5, 4)) plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(input_imgs, cmap='gray') plt.savefig(filename, format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close() def arrange_images(input_imgs, label_imgs, fake_imgs, model_name, saved_name, imgformat='jpg', dpi = 500): titles = ['MRI', 'CT', model_name] fig, axs = plt.subplots(1, 3, figsize=(12, 5)) # plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0.1) plt.margins(0,0) cnt = 0 #print(gen_imgs[cnt].shape) axs[0].imshow(input_imgs, cmap='gray') # 0,0, axs[0].set_title(titles[0]) axs[0].axis('off') axs[1].imshow(label_imgs, cmap='gray') axs[1].set_title(titles[1]) axs[1].axis('off') axs[2].imshow(fake_imgs, cmap='gray') axs[2].set_title(titles[2]) axs[2].axis('off') # save image as png fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi) #plt.show() plt.close(fig) # Define function to plot histograms def plot_histogram(data, title, ax, color='blue', alpha=0.7, x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000): #x_lower_limit, x_upper_limit = -100, 300 #-1100, 3000 #y_lower_limit, y_upper_limit = 0, 15000 bins = 256 ax.hist(data.flatten(), bins=bins,range=(x_lower_limit, x_upper_limit), color=color, alpha=alpha) ax.set_ylim([y_lower_limit, y_upper_limit]) ax.set_title(title) ax.set_xlabel('Pixel intensity') ax.set_ylabel('Frequency') # Arrange three histograms def arrange_histograms(original, reversed, saved_name, titles=['original','reversed'], dpi=300, mode='ct'): # Plot histograms fig, axs = plt.subplots(2, 1, figsize=(10, 8)) plot_histogram(original, f'Histogram for {titles[0]}', axs[0],color='red') plot_histogram(reversed, f'Histogram for {titles[1]}', axs[1],color='green') # Show and save the histogram figure plt.tight_layout() plt.savefig(saved_name, dpi=dpi) plt.close(fig) # Arrange three histograms def arrange_3_histograms(source,target, output, saved_name , dpi=300, x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000): # Plot histograms fig, axs = plt.subplots(3, 1, figsize=(10, 8)) plot_histogram(source, f'Histogram for source', axs[0],color='red', x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit, y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit) plot_histogram(target, f'Histogram for target', axs[1],color='green', x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit, y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit) plot_histogram(output, f'Histogram for output', axs[2],color='blue', x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit, y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit) #plot_histogram(transformed, f'Histogram for transformed {mode}', axs[2],color='blue') # Show and save the histogram figure plt.tight_layout() plt.savefig(saved_name, dpi=dpi) plt.close(fig) # boxplot data = [source.flatten(), target.flatten(), output.flatten()] plt.boxplot(data) plt.xticks([1, 2, 3], ['Source', 'Target', 'Fake']) plt.title('Pixel Value Distribution') plt.xlabel('Image Type') plt.ylabel('Pixel Values') # Show and save the histogram figure plt.tight_layout() plt.savefig(saved_name.replace('histogram','boxplot'), dpi=dpi) plt.close() def arrange_4_histograms(real1,fake1, real2, fake2, saved_name , dpi=300): # Plot histograms fig, axs = plt.subplots(4, 1, figsize=(10, 8)) plot_histogram(real1, f'Histogram for real1', axs[0],color='red') plot_histogram(fake1, f'Histogram for fake1', axs[1],color='red') plot_histogram(real2, f'Histogram for real2', axs[2],color='green') plot_histogram(fake2, f'Histogram for fake2', axs[3],color='green') # Show and save the histogram figure plt.tight_layout() plt.savefig(saved_name, dpi=dpi) plt.close(fig) # save output images def sample_images(model, input, label,slice_idx, epoch, batch_i, saved_folder, model_name='model'): fake = model(input) input_imgs=input.cpu().detach().numpy() label_imgs=label.cpu().detach().numpy() fake_imgs=fake.cpu().detach().numpy() gen_imgs = np.concatenate( [[input_imgs[slice_idx,0,:,:].squeeze()], [label_imgs[slice_idx,0,:,:].squeeze()], [fake_imgs[slice_idx,0,:,:].squeeze()]]) if not os.path.exists(saved_folder): os.makedirs(saved_folder) saved_name=os.path.join(saved_folder,f"{epoch}_{batch_i}.jpg") titles = ['MRI', 'CT', 'Translated'] fig, axs = plt.subplots(1, 3, figsize=(20, 4)) cnt = 0 for j in range(3): #print(gen_imgs[cnt].shape) axs[j].imshow(gen_imgs[cnt], cmap='gray') axs[j].set_title(titles[j]) axs[j].axis('off') cnt += 1 fig.savefig(saved_name) #plt.show() plt.close(fig) # save individual images # save output image individually title1 = 'MRI' fig_mri, axs_mri = plt.subplots(1, 1) #, figsize=(5, 4)) axs_mri.imshow(gen_imgs[0].squeeze(), cmap='gray') axs_mri.set_title(title1) axs_mri.axis('off') fig_mri.savefig(saved_name.replace('.jpg','_mri.jpg')) plt.close(fig_mri) title2 = 'CT' fig_ct, axs_ct = plt.subplots(1, 1) axs_ct.imshow(gen_imgs[1].squeeze(), cmap='gray') axs_ct.set_title(title2) axs_ct.axis('off') fig_ct.savefig(saved_name.replace('.jpg','_ct.jpg')) plt.close(fig_ct) title3 = model_name fig_fake, axs_fake = plt.subplots(1, 1) axs_fake.imshow(gen_imgs[2].squeeze(), cmap='gray') axs_fake.set_title(title3) axs_fake.axis('off') fig_fake.savefig(saved_name.replace('.jpg','_fake.jpg')) plt.close(fig_fake) def save_images(input_imgs, label_imgs,fake_imgs, slice_idx, saved_name='./test.jpg', imgformat='jpg', dpi = 1000, model_name='model'): titles = ['MRI', 'CT', model_name] fig, axs = plt.subplots(1, 3, figsize=(12, 5)) # plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0.1) plt.margins(0,0) cnt = 0 #print(gen_imgs[cnt].shape) axs[0].imshow(input_imgs[:,:,slice_idx].squeeze(), cmap='gray') # 0,0, axs[0].set_title(titles[0]) axs[0].axis('off') axs[1].imshow(label_imgs[:,:,slice_idx], cmap='gray') axs[1].set_title(titles[1]) axs[1].axis('off') axs[2].imshow(fake_imgs[:,:,slice_idx].squeeze(), cmap='gray') axs[2].set_title(titles[2]) axs[2].axis('off') # save image as png fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi) #plt.show() plt.close(fig) # save individual images # save output image individually title1 = 'MRI' fig_mri = plt.figure() #, figsize=(5, 4)) plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(input_imgs[:,:,slice_idx].squeeze(), cmap='gray') plt.savefig(saved_name.replace(f'.{imgformat}',f'_mri.{imgformat}'), format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig_mri) title2 = 'CT' fig_ct = plt.figure() plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(label_imgs[:,:,slice_idx].squeeze(), cmap='gray') plt.savefig(saved_name.replace(f'.{imgformat}',f'_ct.{imgformat}'), format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig_ct) title3 = model_name fig_fake = plt.figure() plt.gca().set_axis_off() plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.imshow(fake_imgs[:,:,slice_idx].squeeze(), cmap='gray') plt.savefig(saved_name.replace(f'.{imgformat}',f'_fake.{imgformat}'), format=f'{imgformat}' , bbox_inches='tight', pad_inches=0, dpi=dpi) plt.close(fig_fake) # save output images def sample_images2(model, input, label,slice_idx, epoch, batch_i, saved_folder): if not os.path.exists(saved_folder): os.makedirs(saved_folder) saved_name=f"{epoch}_{batch_i}.jpg" fake = model(input) input_imgs=input.cpu().detach().numpy() target_imgs=label.cpu().detach().numpy() fake_imags=fake.cpu().detach().numpy() gen_imgs = np.concatenate( [[input_imgs[slice_idx,0,:,:].squeeze()], [target_imgs[slice_idx,0,:,:].squeeze()], [fake_imags[slice_idx,0,:,:].squeeze()]]) titles = ['MRI', 'CT', 'Translated'] fig, axs = plt.subplots(1, 3, figsize=(20, 4)) cnt = 0 for j in range(3): #print(gen_imgs[cnt].shape) axs[j].imshow(gen_imgs[cnt], cmap='gray') axs[j].set_title(titles[j]) axs[j].axis('off') cnt += 1 fig.savefig(os.path.join(saved_folder,saved_name)) #plt.show() plt.close(fig) def sample_images_3D(model, input, label, epoch, batch_i, saved_folder): fake = model(input) input_imgs=input.cpu().detach().numpy() target_imgs=label.cpu().detach().numpy() fake_imags=fake.cpu().detach().numpy() try: gen_imgs = np.concatenate( [[input_imgs[0,0,:,:,50].squeeze()], [target_imgs[0,0,:,:,50].squeeze()], [fake_imags[0,0,:,:,50].squeeze()]]) except: gen_imgs = np.concatenate( [[input_imgs[0,0,:,:,10].squeeze()], [target_imgs[0,0,:,:,10].squeeze()], [fake_imags[0,0,:,:,10].squeeze()]]) titles = ['MRI', 'CT', 'Translated'] fig, axs = plt.subplots(1, 3, figsize=(20, 4)) cnt = 0 for j in range(3): #print(gen_imgs[cnt].shape) axs[j].imshow(gen_imgs[cnt], cmap='gray') axs[j].set_title(titles[j]) axs[j].axis('off') cnt += 1 if not os.path.exists(saved_folder): os.makedirs(saved_folder) saved_name=f"{epoch}_{batch_i}.jpg" fig.savefig(os.path.join(saved_folder,saved_name)) #plt.show() plt.close(fig)