Omini3D / Dataloader /dataloader_utils.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
import os
import torch
# from torch import nn, optim
# from torch.autograd.variable import Variable
# from torchvision import transforms, datasets
# from torchvision.utils import save_image
# import torch.nn.functional as F
# import scipy.ndimage as spimg
# import pyquaternion as quater
# import random
import numpy as np
from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, generate_binary_structure
import pydicom
from scipy.ndimage import zoom
from einops import rearrange, reduce, repeat
def get_sizeRange_dict(roi=''):
"""
Returns a dictionary with size ranges for different regions of interest (ROIs).
If a specific ROI is provided, returns the size range for that ROI.
If no ROI is provided, returns the entire dictionary.
Args:
roi (str): The region of interest for which to get the size range.
Returns:
dict or list: A dictionary with size ranges for all ROIs, or a list with the size range for the specified ROI.
"""
# Define the size ranges for different ROIs
# The values are in the format [min_size, max_size]
# The sizes are in mm for the minimum and maximum dimensions
sizeRange_dict = {
'whole-body': [420, 2048],
'neck-thorax-abdomen-pelvis-leg': [400, 2048],
'neck-thorax-abdomen-pelvis': [380, 2048],
'thorax-abdomen-pelvis-leg': [360, 2048],
'neck-thorax-abdomen': [320, 1024],
'head-neck-thorax-abdomen': [360, 2048],
'head-neck-thorax': [340, 1024],
'thorax-abdomen-pelvis': [340, 1024],
'abdomen-pelvis-leg': [320, 1024],
'neck-thorax': [220, 1024],
'thorax-abdomen': [260, 1024],
'abdomen-pelvis': [260, 1024],
'pelvis-leg': [240, 1024],
'head-neck': [240, 1024],
'head': [150, 1024],
'brain': [128, 1024],
'neck': [140, 1024],
'abdomen': [240, 1024],
'pelvis': [220, 1024],
'thorax': [220, 1024],
'arm': [100, 1024],
'hand': [100, 1024],
'leg': [100, 1024],
'skeleton': [130, 1024],
}
if roi in sizeRange_dict:
return sizeRange_dict[roi]
else:
return sizeRange_dict
def remove_background(img,replace_value=None,num_bin=256,dim_ch=0,sigma=None):
# common_value1,common_value2=[], []
# if replace_value is None:
if dim_ch is None:
dim_ch=0
img=np.expand_dims(img,axis=dim_ch)
ims = np.split(img,img.shape[dim_ch],axis=dim_ch)
# ims =[img]
ims = [np.squeeze(im,axis=dim_ch) for im in ims]
msk1 = np.ones_like(ims[0])
for im in ims:
if num_bin>0:
flatten_im=im.flatten()
hist, bins = np.histogram(flatten_im,bins=range(num_bin))
# common_value1.append(np.argmax(hist))
common_value1 = np.argmax(hist)
# hist[common_value1] = -10**5
msk1[im!=common_value1] = 0
# common_value2 = np.argmax(hist)
if sigma is not None and sigma > 0:
# struct=generate_binary_structure()
msk1 = binary_dilation(msk1,iterations=int(sigma*4)).astype(float)
msk0 = binary_erosion(1-msk1,iterations=int(sigma*4)).astype(float)
msk_blur = gaussian_filter(msk0, sigma=sigma*4,truncate=sigma//4, mode='nearest')
# msk_blur = msk0
for id, im in enumerate(ims):
if replace_value is None:
# a=im[np.logical_not(msk1)]
# replace_value[id] = np.min(im[np.logical_not(msk1)])
replace_v=np.min(im[np.logical_not(msk1)])
else:
replace_v=replace_value[id]
# im[msk1==1] = replace_v
if sigma is not None and sigma>0:
im_blur=im
im_blur[msk1==1]=replace_v
im_blur = gaussian_filter(im_blur, sigma=sigma*4,truncate=sigma//4, mode='nearest')
# im[msk1==1] = im_blur[msk1==1]
im=im*(msk_blur) + im_blur*(1-msk_blur)
else:
im[msk1 == 1] = replace_v
# print(im.shape)
ims[id]=im
return np.stack(ims,axis=dim_ch)
def thresh_img(img,thresh = None,EPS = 10**-7):
if isinstance(thresh,list):
threshold=np.random.uniform(thresh[0],thresh[1])
upbound=1-np.random.uniform(thresh[0],thresh[1])-threshold
else:
threshold=thresh
if threshold is not None:
# img=img-threshold
# img=np.where(img>=0,img,0)
# img = np.maximum(img-threshold,0)
# img = torch.maximum(img - threshold,torch.tensor(0.))
if isinstance(img,list):
device=img[0].device
for i in range(len(img)):
img[i] = torch.clamp(img[i]-threshold,min=torch.tensor(0.).to(device),max=torch.tensor(upbound).to(device))
else:
device=img.device
img = torch.clamp(img-threshold,min=torch.tensor(0.).to(device),max=torch.tensor(upbound).to(device))
# return (img - img.min()) / (img.max() - img.min() + EPS)
return img
def clamp_img_tensor(img,clamp = [None,None]):
device=img.device
if clamp[0] is not None and clamp[1] is not None:
img = torch.clamp(img, min=torch.tensor(clamp[0]).to(device),max=torch.tensor(clamp[1]).to(device))
else:
if clamp[0] is not None:
img = torch.clamp(img, min=torch.tensor(clamp[0]).to(device))
if clamp[1] is not None:
img = torch.clamp(img, max=torch.tensor(clamp[1]).to(device))
return img
def read_CT_volume(folder_path,target_res = 128):
# read CT into a (128x128x128) cube and pad the insufficient dimension
dicom_slices = []
# Iterate over each file in the folder
for filename in sorted(os.listdir(folder_path), reverse=True):
if filename.endswith(".dcm"): # Check if the file is a DICOM file
file_path = os.path.join(folder_path, filename)
# Read the DICOM file
dicom_data = pydicom.dcmread(file_path)
# Append DICOM pixel data to the list
dicom_slices.append(dicom_data.pixel_array)
# Convert the list of slices to a numpy array
dicom_slices = np.array(dicom_slices)
dicome_volume = rearrange(dicom_slices, 'z h w -> h w z')
# Get spatial information from the first DICOM file
first_dicom = pydicom.dcmread(os.path.join(folder_path, os.listdir(folder_path)[0]))
slice_thickness = first_dicom.SliceThickness
pixel_spacing = first_dicom.PixelSpacing
# Get the scaling ratio for each dim
h_axis_ratio = pixel_spacing[0]
w_axis_ratio = pixel_spacing[1]
z_axis_ratio = slice_thickness
# find the longest dim that need to rescale
longest_axis = max([h_axis_ratio*dicome_volume.shape[0], w_axis_ratio*dicome_volume.shape[1],z_axis_ratio*dicome_volume.shape[2]])
c_factor = longest_axis/target_res
# print((h_axis_ratio/c_factor, w_axis_ratio/c_factor ,z_axis_ratio/c_factor))
resized_volume = zoom(dicome_volume, (h_axis_ratio/c_factor, w_axis_ratio/c_factor ,z_axis_ratio/c_factor))
# print('resize', resized_volume.shape)
max_dim_size = max(resized_volume.shape)
# Calculate padding for each dimension
padding_h = max_dim_size - resized_volume.shape[0]
padding_w = max_dim_size - resized_volume.shape[1]
padding_z = max_dim_size - resized_volume.shape[2]
pad_depth = (padding_z // 2, padding_z - padding_z // 2)
pad_height = (padding_h // 2, padding_h - padding_h // 2)
pad_width = (padding_w // 2, padding_w - padding_w // 2)
# Pad the array symmetrically
padded_resized_volume = np.pad(resized_volume, (pad_height, pad_width, pad_depth), mode='constant')
return padded_resized_volume, slice_thickness, pixel_spacing