Commit
·
b10768a
1
Parent(s):
ed5ac4a
CPD scripts
Browse files- Centerline/__init__.py +0 -0
- Centerline/centerline.py +241 -0
- Centerline/cpd_utils.py +48 -0
- Centerline/evaluate_BayesianCPD_skeleton.py +160 -0
- Centerline/evaluate_CPD_dense.py +156 -0
- Centerline/evaluate_CPD_nodes.py +158 -0
- Centerline/evaluate_CPD_skeleton.py +164 -0
- Centerline/get_vessels.py +30 -0
- Centerline/graph_utils.py +85 -0
- Centerline/skeleton_to_graph.py +167 -0
- Centerline/skeletonization.py +817 -0
- Centerline/thinPlateSplines.py +82 -0
- Centerline/visualization_utils.py +161 -0
Centerline/__init__.py
ADDED
|
File without changes
|
Centerline/centerline.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 4 |
+
parentdir = os.path.dirname(currentdir)
|
| 5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 6 |
+
|
| 7 |
+
# import tensorflow as tf
|
| 8 |
+
# tf.enable_eager_execution()
|
| 9 |
+
# import neurite.py.utils as neurite_utils
|
| 10 |
+
|
| 11 |
+
from skimage.morphology import skeletonize_3d, ball
|
| 12 |
+
from skimage.morphology import binary_closing, binary_opening
|
| 13 |
+
from skimage.filters import median
|
| 14 |
+
from skimage.measure import regionprops, label
|
| 15 |
+
from skimage.transform import warp
|
| 16 |
+
|
| 17 |
+
from scipy.ndimage import zoom
|
| 18 |
+
from scipy.interpolate import LinearNDInterpolator, Rbf
|
| 19 |
+
|
| 20 |
+
import h5py
|
| 21 |
+
import numpy as np
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
import re
|
| 24 |
+
import nibabel as nib
|
| 25 |
+
from nilearn.image import resample_img
|
| 26 |
+
|
| 27 |
+
from Centerline.graph_utils import graph_to_ndarray, deform_graph, get_bifurcation_nodes, subsample_graph, \
|
| 28 |
+
apply_displacement
|
| 29 |
+
from Centerline.skeleton_to_graph import get_graph_from_skeleton
|
| 30 |
+
from Centerline.visualization_utils import plot_skeleton, compare_graphs
|
| 31 |
+
|
| 32 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 33 |
+
from DeepDeformationMapRegistration.utils import constants as C
|
| 34 |
+
|
| 35 |
+
import cupy
|
| 36 |
+
from cupyx.scipy.ndimage import zoom as zoom_gpu
|
| 37 |
+
from cupyx.scipy.ndimage import map_coordinates
|
| 38 |
+
|
| 39 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
| 40 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
| 41 |
+
DATASET_FILENAME = 'volume'
|
| 42 |
+
IMGS_FOLDER = '/home/jpdefrutos/workspace/DeepDeformationMapRegistration/Centerline/centerlines'
|
| 43 |
+
|
| 44 |
+
DATASTE_RAW_FILES = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/nifti3'
|
| 45 |
+
LITS_SEGMENTATION_FILE = 'segmentation'
|
| 46 |
+
LITS_CT_FILE = 'volume'
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def warp_volume(volume, disp_map, indexing='ij'):
|
| 50 |
+
assert indexing is 'ij' or 'xy', 'Invalid indexing option. Only "ij" or "xy"'
|
| 51 |
+
grid_i = np.linspace(0, disp_map.shape[0], disp_map.shape[0], endpoint=False)
|
| 52 |
+
grid_j = np.linspace(0, disp_map.shape[1], disp_map.shape[1], endpoint=False)
|
| 53 |
+
grid_k = np.linspace(0, disp_map.shape[2], disp_map.shape[2], endpoint=False)
|
| 54 |
+
grid_i, grid_j, grid_k = np.meshgrid(grid_i, grid_j, grid_k, indexing=indexing)
|
| 55 |
+
grid_i = (grid_i.flatten() + disp_map[..., 0].flatten())[..., np.newaxis]
|
| 56 |
+
grid_j = (grid_j.flatten() + disp_map[..., 1].flatten())[..., np.newaxis]
|
| 57 |
+
grid_k = (grid_k.flatten() + disp_map[..., 2].flatten())[..., np.newaxis]
|
| 58 |
+
coords = np.hstack([grid_i, grid_j, grid_k]).reshape([*disp_map.shape[:-1], -1])
|
| 59 |
+
coords = coords.transpose((-1, 0, 1, 2))
|
| 60 |
+
# The returned volume has indexing xy
|
| 61 |
+
return warp(volume, coords)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def keep_largest_segmentation(img):
|
| 65 |
+
label_img = label(img)
|
| 66 |
+
rp = regionprops(label_img) # Regions labeled with 0 (bg) are ignored
|
| 67 |
+
biggest_area = (0, 0)
|
| 68 |
+
for l in range(0, label_img.max()):
|
| 69 |
+
if rp[l].area > biggest_area[1]:
|
| 70 |
+
biggest_area = (l + 1, rp[l].area)
|
| 71 |
+
img[label_img != biggest_area[0]] = 0.
|
| 72 |
+
return img
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def preprocess_image(img, keep_largest=False):
|
| 76 |
+
ret = binary_closing(img, ball(1))
|
| 77 |
+
ret = binary_opening(ret, ball(1))
|
| 78 |
+
#ret = median(ret, ball(1), mode='constant')
|
| 79 |
+
if keep_largest:
|
| 80 |
+
ret = keep_largest_segmentation(ret)
|
| 81 |
+
return ret.astype(np.float)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_displacement_map_interpolator(disp_map, backwards=False, indexing='ij'):
|
| 85 |
+
grid_i = np.linspace(0, disp_map.shape[0], disp_map.shape[0], endpoint=False)
|
| 86 |
+
grid_j = np.linspace(0, disp_map.shape[1], disp_map.shape[1], endpoint=False)
|
| 87 |
+
grid_k = np.linspace(0, disp_map.shape[2], disp_map.shape[2], endpoint=False)
|
| 88 |
+
grid_i, grid_j, grid_k = np.meshgrid(grid_i, grid_j, grid_k, indexing=indexing)
|
| 89 |
+
grid_i = grid_i.flatten()
|
| 90 |
+
grid_j = grid_j.flatten()
|
| 91 |
+
grid_k = grid_k.flatten()
|
| 92 |
+
# To generate the moving image, we used backwards mapping were the input was the fix image
|
| 93 |
+
# Now we are doing direct mapping from the fix graph coordinates to the moving coordinates
|
| 94 |
+
# The application points of the displacement map are thus the transformed "moving image"-grid
|
| 95 |
+
# and the displacement vectors are reversed
|
| 96 |
+
if backwards:
|
| 97 |
+
coords = np.hstack([grid_i[..., np.newaxis], grid_j[..., np.newaxis], grid_k[..., np.newaxis]])
|
| 98 |
+
return LinearNDInterpolator(coords, np.reshape(disp_map, [-1, 3]))
|
| 99 |
+
else:
|
| 100 |
+
grid_i = (grid_i + disp_map[..., 0].flatten())
|
| 101 |
+
grid_j = (grid_j + disp_map[..., 1].flatten())
|
| 102 |
+
grid_k = (grid_k + disp_map[..., 2].flatten())
|
| 103 |
+
|
| 104 |
+
coords = np.hstack([grid_i[..., np.newaxis], grid_j[..., np.newaxis], grid_k[..., np.newaxis]])
|
| 105 |
+
return LinearNDInterpolator(coords, -np.reshape(disp_map, [-1, 3]))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def resample_segmentation(img, output_shape, preserve_range, threshold=None, gpu=True):
|
| 109 |
+
# Preserve range can be a bool (keep or not the original dyn. range) or a list with a new dyn. range
|
| 110 |
+
zoom_f = np.divide(np.asarray(output_shape), np.asarray(img.shape))
|
| 111 |
+
|
| 112 |
+
if gpu:
|
| 113 |
+
out_img = zoom_gpu(cupy.asarray(img), zoom_f, order=1) # order = 0 or 1
|
| 114 |
+
else:
|
| 115 |
+
out_img = zoom(img, zoom_f)
|
| 116 |
+
if isinstance(preserve_range, bool):
|
| 117 |
+
if preserve_range:
|
| 118 |
+
range_min, range_max = np.amin(img), np.amax(img)
|
| 119 |
+
out_img = min_max_norm(out_img)
|
| 120 |
+
out_img = out_img * (range_max - range_min) + range_min
|
| 121 |
+
elif isinstance(preserve_range, list):
|
| 122 |
+
range_min, range_max = preserve_range
|
| 123 |
+
out_img = min_max_norm(out_img)
|
| 124 |
+
out_img = out_img * (range_max - range_min) + range_min
|
| 125 |
+
|
| 126 |
+
if threshold is not None and out_img.min() < threshold < out_img.max():
|
| 127 |
+
range_min, range_max = np.amin(out_img), np.amax(out_img)
|
| 128 |
+
out_img[out_img > threshold] = range_max
|
| 129 |
+
out_img[out_img < range_max] = range_min
|
| 130 |
+
return cupy.asnumpy(out_img) if gpu else out_img
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == '__main__':
|
| 134 |
+
for dataset_name in DATASET_NAMES:
|
| 135 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
| 136 |
+
dataset_files = os.listdir(dataset_loc)
|
| 137 |
+
dataset_files.sort()
|
| 138 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
| 139 |
+
|
| 140 |
+
iterator = tqdm(dataset_files)
|
| 141 |
+
for file_path in iterator:
|
| 142 |
+
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
|
| 143 |
+
|
| 144 |
+
iterator.set_description('{} ({}): laoding data'.format(file_num, dataset_name))
|
| 145 |
+
vol_file = h5py.File(file_path, 'r')
|
| 146 |
+
# fix_vessels = vol_file[C.H5_FIX_VESSELS_MASK][..., 0]
|
| 147 |
+
disp_map = vol_file[C.H5_GT_DISP][:]
|
| 148 |
+
bbox = vol_file['parameters/bbox'][:]
|
| 149 |
+
bbox_min = bbox[:3]
|
| 150 |
+
bbox_max = bbox[3:] + bbox_min
|
| 151 |
+
|
| 152 |
+
# Load vessel segmentation mask and resize to 64^3
|
| 153 |
+
fix_labels = nib.load(os.path.join(DATASTE_RAW_FILES, 'segmentation-{:04d}.nii.gz'.format(file_num)))
|
| 154 |
+
fix_vessels = fix_labels.slicer[..., 1]
|
| 155 |
+
fix_vessels = resample_img(fix_vessels, np.eye(3))
|
| 156 |
+
fix_vessels = np.asarray(fix_vessels.dataobj)
|
| 157 |
+
fix_vessels = preprocess_image(fix_vessels)
|
| 158 |
+
fix_vessels = resample_segmentation(fix_vessels, vol_file['parameters/first_reshape'][:], [0, 1], 0.3,
|
| 159 |
+
gpu=True)
|
| 160 |
+
fix_vessels = fix_vessels[bbox_min[0]:bbox_max[0], bbox_min[1]:bbox_max[1], bbox_min[2]:bbox_max[2]]
|
| 161 |
+
fix_vessels = resample_segmentation(fix_vessels, [64] * 3, [0, 1], 0.3, gpu=True)
|
| 162 |
+
fix_vessels = preprocess_image(fix_vessels)
|
| 163 |
+
|
| 164 |
+
mov_vessels = preprocess_image(warp_volume(fix_vessels, disp_map))
|
| 165 |
+
mov_skel = skeletonize_3d(mov_vessels)
|
| 166 |
+
### Fix the incorrect scaling ###
|
| 167 |
+
disp_map *= 2
|
| 168 |
+
bbox_size = np.asarray(bbox[3:]) # Only load the bbox size
|
| 169 |
+
rescale_factors = 64 / bbox_size
|
| 170 |
+
|
| 171 |
+
disp_map[..., 0] = np.multiply(disp_map[..., 0], rescale_factors[0])
|
| 172 |
+
disp_map[..., 1] = np.multiply(disp_map[..., 1], rescale_factors[1])
|
| 173 |
+
disp_map[..., 2] = np.multiply(disp_map[..., 2], rescale_factors[2])
|
| 174 |
+
#################################
|
| 175 |
+
|
| 176 |
+
iterator.set_description('{} ({}): getting graphs'.format(file_num, dataset_name))
|
| 177 |
+
# Prepare displacement map
|
| 178 |
+
disp_map_interpolator = build_displacement_map_interpolator(disp_map, backwards=False)
|
| 179 |
+
|
| 180 |
+
# Get skeleton and graph
|
| 181 |
+
fix_skel = skeletonize_3d(fix_vessels)
|
| 182 |
+
fix_graph = get_graph_from_skeleton(fix_skel, subsample=True)
|
| 183 |
+
mov_graph = get_graph_from_skeleton(mov_skel, subsample=True) # deform_graph(fix_graph, disp_map_interpolator)
|
| 184 |
+
|
| 185 |
+
##### TODO: ERASE Check the mov graph ######
|
| 186 |
+
# check_mov_vessels = vol_file[C.H5_MOV_VESSELS_MASK][..., 0]
|
| 187 |
+
# check_mov_vessels = preprocess_image(check_mov_vessels)
|
| 188 |
+
# check_mov_skel = skeletonize_3d(check_mov_vessels)
|
| 189 |
+
# check_mov_graph = get_graph_from_skeleton(check_mov_skel, subsample=True)
|
| 190 |
+
###########
|
| 191 |
+
fix_pts, fix_nodes, fix_edges = graph_to_ndarray(fix_graph)
|
| 192 |
+
mov_pts, mov_nodes, mov_edges = graph_to_ndarray(mov_graph)
|
| 193 |
+
|
| 194 |
+
fix_bifur_loc, fix_bifur_id = get_bifurcation_nodes(fix_graph)
|
| 195 |
+
mov_bifur_loc, mov_bifur_id = get_bifurcation_nodes(mov_graph)
|
| 196 |
+
|
| 197 |
+
iterator.set_description('{} ({}): saving data'.format(file_num, dataset_name))
|
| 198 |
+
pts_file_path, pts_file_name = os.path.split(file_path)
|
| 199 |
+
pts_file_name = pts_file_name.replace(DATASET_FILENAME, 'points')
|
| 200 |
+
pts_file_path = os.path.join(pts_file_path, pts_file_name)
|
| 201 |
+
pts_file = h5py.File(pts_file_path, 'w')
|
| 202 |
+
|
| 203 |
+
pts_file.create_dataset('fix/points', data=fix_pts)
|
| 204 |
+
pts_file.create_dataset('fix/nodes', data=fix_nodes)
|
| 205 |
+
pts_file.create_dataset('fix/edges', data=fix_edges)
|
| 206 |
+
pts_file.create_dataset('fix/bifurcations', data=fix_bifur_loc)
|
| 207 |
+
pts_file.create_dataset('fix/graph', data=fix_graph)
|
| 208 |
+
pts_file.create_dataset('fix/img', data=fix_vessels)
|
| 209 |
+
pts_file.create_dataset('fix/skeleton', data=fix_skel)
|
| 210 |
+
pts_file.create_dataset('fix/centroid', data=vol_file[C.H5_FIX_CENTROID][:])
|
| 211 |
+
|
| 212 |
+
pts_file.create_dataset('mov/points', data=mov_pts)
|
| 213 |
+
pts_file.create_dataset('mov/nodes', data=mov_nodes)
|
| 214 |
+
pts_file.create_dataset('mov/edges', data=mov_edges)
|
| 215 |
+
pts_file.create_dataset('mov/bifurcations', data=mov_bifur_loc)
|
| 216 |
+
pts_file.create_dataset('mov/graph', data=mov_graph)
|
| 217 |
+
pts_file.create_dataset('mov/img', data=mov_vessels)
|
| 218 |
+
pts_file.create_dataset('mov/skeleton', data=mov_skel)
|
| 219 |
+
pts_file.create_dataset('mov/centroid', data=vol_file[C.H5_MOV_CENTROID][:])
|
| 220 |
+
|
| 221 |
+
pts_file.create_dataset('parameters/voxel_size', data=vol_file['parameters/voxel_size'][:])
|
| 222 |
+
pts_file.create_dataset('parameters/original_affine', data=vol_file['parameters/original_affine'][:])
|
| 223 |
+
pts_file.create_dataset('parameters/isotropic_affine', data=vol_file['parameters/isotropic_affine'][:])
|
| 224 |
+
pts_file.create_dataset('parameters/original_shape', data=vol_file['parameters/original_shape'][:])
|
| 225 |
+
pts_file.create_dataset('parameters/isotropic_shape', data=vol_file['parameters/isotropic_shape'][:])
|
| 226 |
+
pts_file.create_dataset('parameters/first_reshape', data=vol_file['parameters/first_reshape'][:])
|
| 227 |
+
pts_file.create_dataset('parameters/bbox', data=vol_file['parameters/bbox'][:])
|
| 228 |
+
pts_file.create_dataset('parameters/last_reshape', data=vol_file['parameters/last_reshape'][:])
|
| 229 |
+
|
| 230 |
+
pts_file.create_dataset('displacement_map', data=disp_map)
|
| 231 |
+
|
| 232 |
+
vol_file.close()
|
| 233 |
+
pts_file.close()
|
| 234 |
+
|
| 235 |
+
iterator.set_description('{} ({}): drawing plots'.format(file_num, dataset_name))
|
| 236 |
+
num = pts_file_name.split('-')[-1].split('.hd5')[0]
|
| 237 |
+
imgs_folder = os.path.join(IMGS_FOLDER, dataset_name, num)
|
| 238 |
+
os.makedirs(imgs_folder, exist_ok=True)
|
| 239 |
+
plot_skeleton(fix_vessels, fix_skel, fix_graph, os.path.join(imgs_folder, 'fix'), ['.pdf', '.png'])
|
| 240 |
+
plot_skeleton(mov_vessels, mov_skel, mov_graph, os.path.join(imgs_folder, 'mov'), ['.pdf', '.png'])
|
| 241 |
+
iterator.set_description('{} ({})'.format(file_num, dataset_name))
|
Centerline/cpd_utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pycpd import DeformableRegistration, RigidRegistration
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
from scipy.interpolate import Rbf
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
def cpd_non_rigid_transform_pt(pt, Y, G, W):
|
| 8 |
+
from scipy.interpolate import LinearNDInterpolator
|
| 9 |
+
interp = LinearNDInterpolator(points=Y, values=np.dot(G, W), fill_value=0.)
|
| 10 |
+
return interp(pt)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def radial_basis_function(pts, vals, function='thin-plate'):
|
| 14 |
+
# The Rbf function does not handle n-D hyper-surfaces, so we need an interpolator per displacements. Actually it does mode='N-D'
|
| 15 |
+
pts_unique, idxs = np.unique(pts, return_index=True, axis=0) # Prevent singular matrices
|
| 16 |
+
ill_conditioned = False
|
| 17 |
+
with warnings.catch_warnings(record=True) as caught_warns:
|
| 18 |
+
warnings.simplefilter('always')
|
| 19 |
+
dx = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 0], function=function)
|
| 20 |
+
dy = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 1], function=function)
|
| 21 |
+
dz = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 2], function=function)
|
| 22 |
+
for w in caught_warns:
|
| 23 |
+
print(w)
|
| 24 |
+
ill_conditioned = ill_conditioned or 'ill-conditioned matrix' in str(w).lower()
|
| 25 |
+
return lambda int_pt: np.asarray([dx(*int_pt), dy(*int_pt), dz(*int_pt)]), ill_conditioned
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def deform_registration(fix_pts, mov_pts, callback_fnc=None, time_it=False, max_iterations=100, tolerance=1e-8, alpha=None, beta=None):
|
| 29 |
+
deform_reg = DeformableRegistration(**{'Y': mov_pts, 'X': fix_pts},
|
| 30 |
+
alpha=alpha, beta=beta, tolerance=tolerance, max_iterations=max_iterations)
|
| 31 |
+
start_t = time.time()
|
| 32 |
+
trf_mov_pts, deform_p = deform_reg.register(callback_fnc)
|
| 33 |
+
end_t = time.time()
|
| 34 |
+
if time_it:
|
| 35 |
+
return end_t - start_t, deform_reg
|
| 36 |
+
else:
|
| 37 |
+
return trf_mov_pts, deform_p, deform_reg
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def rigid_registration(fix_pts, mov_pts, callback_fnc=None, time_it=False):
|
| 41 |
+
rigid_reg = RigidRegistration(**{'Y': mov_pts, 'X': fix_pts})
|
| 42 |
+
start_t = time.time()
|
| 43 |
+
trf_mov_pts, trf_p = rigid_reg.register(callback_fnc)
|
| 44 |
+
end_t = time.time()
|
| 45 |
+
if time_it:
|
| 46 |
+
return end_t - start_t, rigid_reg
|
| 47 |
+
else:
|
| 48 |
+
return trf_mov_pts, trf_p, rigid_reg
|
Centerline/evaluate_BayesianCPD_skeleton.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 4 |
+
parentdir = os.path.dirname(currentdir)
|
| 5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 6 |
+
|
| 7 |
+
import h5py
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from functools import partial
|
| 10 |
+
import numpy as np
|
| 11 |
+
from scipy.spatial.distance import euclidean
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
| 14 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
| 15 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
| 16 |
+
from scipy.spatial.distance import cdist
|
| 17 |
+
from skimage.morphology import skeletonize_3d
|
| 18 |
+
import re
|
| 19 |
+
from probreg import bcpd
|
| 20 |
+
|
| 21 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
| 22 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
| 23 |
+
DATASET_FILENAME = 'points'
|
| 24 |
+
|
| 25 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
|
| 26 |
+
|
| 27 |
+
SCALE = 1e-2 # mm to cm
|
| 28 |
+
# CPD PARAMS (deform)
|
| 29 |
+
MAX_ITER = 200
|
| 30 |
+
ALPHA = 0.1
|
| 31 |
+
BETA = 1.0 # None = Use default
|
| 32 |
+
TOLERANCE = 1e-8
|
| 33 |
+
|
| 34 |
+
if __name__ == '__main__':
|
| 35 |
+
for dataset_name in DATASET_NAMES:
|
| 36 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
| 37 |
+
dataset_files = os.listdir(dataset_loc)
|
| 38 |
+
dataset_files.sort()
|
| 39 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
| 40 |
+
|
| 41 |
+
iterator = tqdm(dataset_files)
|
| 42 |
+
df = pd.DataFrame(columns=['DATASET',
|
| 43 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
| 44 |
+
'TIME_DEF', 'TIME_R_DEF',
|
| 45 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
| 46 |
+
'TRE_DEF', 'TRE_R_DEF',
|
| 47 |
+
'DS_DISP',
|
| 48 |
+
'DATA_PATH',
|
| 49 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
| 50 |
+
for i, file_path in enumerate(iterator):
|
| 51 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
| 52 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
| 53 |
+
iterator.set_description('{}: start'.format(fn))
|
| 54 |
+
pts_file = h5py.File(file_path, 'r')
|
| 55 |
+
# fix_pts = pts_file['fix/points'][:]
|
| 56 |
+
# fix_nodes = pts_file['fix/nodes'][:]
|
| 57 |
+
fix_skel = pts_file['fix/skeleton'][:]
|
| 58 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
| 59 |
+
|
| 60 |
+
# mov_pts = pts_file['mov/points'][:]
|
| 61 |
+
# mov_nodes = pts_file['mov/nodes'][:]
|
| 62 |
+
mov_skel = pts_file['mov/skeleton'][:]
|
| 63 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
| 64 |
+
|
| 65 |
+
bbox = pts_file['parameters/bbox'][:]
|
| 66 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
| 67 |
+
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
|
| 68 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
| 69 |
+
# TODO: bring back to original shape!
|
| 70 |
+
# Reshape to original_shape
|
| 71 |
+
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
|
| 72 |
+
# fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
|
| 73 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 74 |
+
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
|
| 75 |
+
fix_skel = skeletonize_3d(fix_skel)
|
| 76 |
+
fix_skel_pts = np.argwhere(fix_skel)
|
| 77 |
+
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
|
| 78 |
+
# mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
|
| 79 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 80 |
+
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
|
| 81 |
+
mov_skel = skeletonize_3d(mov_skel)
|
| 82 |
+
mov_skel_pts = np.argwhere(mov_skel)
|
| 83 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
| 84 |
+
|
| 85 |
+
ill_cond_def = False
|
| 86 |
+
ill_cond_r_def = False
|
| 87 |
+
# Deformable only
|
| 88 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
| 89 |
+
|
| 90 |
+
tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE)
|
| 91 |
+
|
| 92 |
+
if np.isnan(deform_reg_def.diff):
|
| 93 |
+
tre_def = np.nan
|
| 94 |
+
pred_mov_centroid = mov_centroid
|
| 95 |
+
else:
|
| 96 |
+
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
|
| 97 |
+
displacement_mov_centroid = tps(mov_centroid)
|
| 98 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
| 99 |
+
|
| 100 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 101 |
+
|
| 102 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
| 103 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 104 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 105 |
+
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 106 |
+
|
| 107 |
+
# Rigid followed by deformable
|
| 108 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
| 109 |
+
|
| 110 |
+
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
|
| 111 |
+
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
|
| 112 |
+
|
| 113 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
|
| 114 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
|
| 115 |
+
|
| 116 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
|
| 117 |
+
rigid_yt = rigid_reg_r_def.TY
|
| 118 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
|
| 119 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 120 |
+
alpha=ALPHA, beta=BETA)
|
| 121 |
+
|
| 122 |
+
if np.isnan(deform_reg_r_def.diff):
|
| 123 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 124 |
+
else:
|
| 125 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 126 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
|
| 127 |
+
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
|
| 128 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
| 129 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
| 130 |
+
|
| 131 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 132 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
|
| 133 |
+
|
| 134 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
| 135 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 136 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 137 |
+
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
| 141 |
+
df = df.append({'DATASET': dataset_name,
|
| 142 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
| 143 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
| 144 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
| 145 |
+
'TIME_DEF': time_def,
|
| 146 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
| 147 |
+
'Q_DEF': deform_reg_def.diff,
|
| 148 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
| 149 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
| 150 |
+
'ILL_COND_DEF': ill_cond_def,
|
| 151 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
| 152 |
+
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
|
| 153 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
| 154 |
+
'DATA_PATH': file_path,
|
| 155 |
+
'DIST_CENTR': np.min(dist_centroid_to_pts),
|
| 156 |
+
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
|
| 157 |
+
'SAMPLE_NUM':fnum}, ignore_index=True)
|
| 158 |
+
pts_file.close()
|
| 159 |
+
|
| 160 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/evaluate_CPD_dense.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 3 |
+
parentdir = os.path.dirname(currentdir)
|
| 4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 5 |
+
|
| 6 |
+
import h5py
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from functools import partial
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.spatial.distance import euclidean
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
| 13 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
| 14 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
| 15 |
+
from scipy.spatial.distance import cdist
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
| 19 |
+
DATASET_NAMES = ['None', 'Affine', 'None', 'Translation']
|
| 20 |
+
DATASET_FILENAME = 'points'
|
| 21 |
+
|
| 22 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/dense_final'
|
| 23 |
+
|
| 24 |
+
SCALE = 1e-2 # mm to cm
|
| 25 |
+
|
| 26 |
+
# CPD PARAMS (deform)
|
| 27 |
+
MAX_ITER = 200
|
| 28 |
+
ALPHA = 2.
|
| 29 |
+
BETA = 2. # None = Use default
|
| 30 |
+
TOLERANCE = 1e-8
|
| 31 |
+
RBF_FUNCTION='thin-plate'
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
for dataset_name in DATASET_NAMES:
|
| 35 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
| 36 |
+
dataset_files = os.listdir(dataset_loc)
|
| 37 |
+
dataset_files.sort()
|
| 38 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
| 39 |
+
|
| 40 |
+
iterator = tqdm(dataset_files)
|
| 41 |
+
df = pd.DataFrame(columns=['DATASET',
|
| 42 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
| 43 |
+
'TIME_DEF', 'TIME_R_DEF',
|
| 44 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
| 45 |
+
'TRE_DEF', 'TRE_R_DEF',
|
| 46 |
+
'DS_DISP',
|
| 47 |
+
'DATA_PATH',
|
| 48 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
| 49 |
+
for i, file_path in enumerate(iterator):
|
| 50 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
| 51 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
| 52 |
+
iterator.set_description('{}: start'.format(fn))
|
| 53 |
+
pts_file = h5py.File(file_path, 'r')
|
| 54 |
+
fix_pts = pts_file['fix/points'][:]
|
| 55 |
+
# fix_nodes = pts_file['fix/nodes'][:]
|
| 56 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
| 57 |
+
|
| 58 |
+
mov_pts = pts_file['mov/points'][:]
|
| 59 |
+
# mov_nodes = pts_file['mov/nodes'][:]
|
| 60 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
| 61 |
+
|
| 62 |
+
bbox = pts_file['parameters/bbox'][:]
|
| 63 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
| 64 |
+
original_shape = pts_file['parameters/isotropic_shape'][:]
|
| 65 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
| 66 |
+
# TODO: bring back to original shape!
|
| 67 |
+
# Reshape to original_shape
|
| 68 |
+
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
|
| 69 |
+
fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
|
| 70 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, original_shape)
|
| 71 |
+
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
|
| 72 |
+
mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
|
| 73 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, original_shape)
|
| 74 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
| 75 |
+
|
| 76 |
+
ill_cond_def = False
|
| 77 |
+
ill_cond_r_def = False
|
| 78 |
+
# Deformable only
|
| 79 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
| 80 |
+
|
| 81 |
+
# deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
|
| 82 |
+
|
| 83 |
+
# _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
|
| 84 |
+
time_def, deform_reg_def = deform_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True,
|
| 85 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 86 |
+
alpha=ALPHA, beta=BETA)
|
| 87 |
+
if np.isnan(deform_reg_def.diff):
|
| 88 |
+
tre_def = np.nan
|
| 89 |
+
pred_mov_centroid = np.zeros((3,))
|
| 90 |
+
else:
|
| 91 |
+
tps, ill_cond_def = radial_basis_function(mov_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
| 92 |
+
displacement_mov_centroid = tps(mov_centroid)
|
| 93 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
| 94 |
+
|
| 95 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 96 |
+
|
| 97 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
| 98 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 99 |
+
plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 100 |
+
plot_cpd(fix_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 101 |
+
|
| 102 |
+
# Rigid followed by deformable
|
| 103 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
| 104 |
+
|
| 105 |
+
# rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER,
|
| 106 |
+
# '{}/{:04d}/RIGID_DEF/rigid'.format(
|
| 107 |
+
# dataset_name, fnum)))
|
| 108 |
+
# deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER,
|
| 109 |
+
# '{}/{:04d}/RIGID_DEF/deform'.format(
|
| 110 |
+
# dataset_name, fnum)))
|
| 111 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
|
| 112 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
|
| 113 |
+
|
| 114 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True)
|
| 115 |
+
rigid_yt = rigid_reg_r_def.TY
|
| 116 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_pts*SCALE, rigid_yt, time_it=True,
|
| 117 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 118 |
+
alpha=ALPHA, beta=BETA)
|
| 119 |
+
|
| 120 |
+
if np.isnan(deform_reg_r_def.diff):
|
| 121 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 122 |
+
else:
|
| 123 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 124 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
| 125 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
| 126 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
| 127 |
+
|
| 128 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 129 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_pts)
|
| 130 |
+
|
| 131 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
| 132 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 133 |
+
plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 134 |
+
plot_cpd(fix_pts, deform_reg_r_def.TY, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 135 |
+
|
| 136 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
| 137 |
+
df = df.append({'DATASET': dataset_name,
|
| 138 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
| 139 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
| 140 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
| 141 |
+
'TIME_DEF': time_def,
|
| 142 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
| 143 |
+
'Q_DEF': deform_reg_def.diff,
|
| 144 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
| 145 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
| 146 |
+
'ILL_COND_DEF': ill_cond_def,
|
| 147 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
| 148 |
+
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
|
| 149 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
| 150 |
+
'DATA_PATH': file_path,
|
| 151 |
+
'DIST_CENTR': np.min(dist_centroid_to_pts),
|
| 152 |
+
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
|
| 153 |
+
'SAMPLE_NUM': fnum}, ignore_index=True)
|
| 154 |
+
pts_file.close()
|
| 155 |
+
|
| 156 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/evaluate_CPD_nodes.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 3 |
+
parentdir = os.path.dirname(currentdir)
|
| 4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 5 |
+
|
| 6 |
+
import h5py
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from functools import partial
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.spatial.distance import euclidean
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
| 13 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
| 14 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
| 15 |
+
from scipy.spatial.distance import cdist
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
| 19 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
| 20 |
+
DATASET_FILENAME = 'points'
|
| 21 |
+
|
| 22 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/nodes_final'
|
| 23 |
+
|
| 24 |
+
SCALE = 1e-2 # mm to cm
|
| 25 |
+
|
| 26 |
+
# CPD PARAMS (deform)
|
| 27 |
+
MAX_ITER = 200
|
| 28 |
+
ALPHA = 2.
|
| 29 |
+
BETA = 2. # None = Use default
|
| 30 |
+
TOLERANCE = 1e-8
|
| 31 |
+
RBF_FUNCTION='thin-plate'
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
for dataset_name in DATASET_NAMES:
|
| 35 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
| 36 |
+
dataset_files = os.listdir(dataset_loc)
|
| 37 |
+
dataset_files.sort()
|
| 38 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
| 39 |
+
|
| 40 |
+
iterator = tqdm(dataset_files)
|
| 41 |
+
df = pd.DataFrame(columns=['DATASET',
|
| 42 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
| 43 |
+
'TIME_DEF', 'TIME_R_DEF',
|
| 44 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
| 45 |
+
'TRE_DEF', 'TRE_R_DEF',
|
| 46 |
+
'DS_DISP',
|
| 47 |
+
'DATA_PATH',
|
| 48 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
| 49 |
+
for i, file_path in enumerate(iterator):
|
| 50 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
| 51 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
| 52 |
+
iterator.set_description('{}: start'.format(fn))
|
| 53 |
+
pts_file = h5py.File(file_path, 'r')
|
| 54 |
+
fix_pts = pts_file['fix/points'][:]
|
| 55 |
+
fix_nodes = pts_file['fix/nodes'][:]
|
| 56 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
| 57 |
+
|
| 58 |
+
mov_pts = pts_file['mov/points'][:]
|
| 59 |
+
mov_nodes = pts_file['mov/nodes'][:]
|
| 60 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
| 61 |
+
|
| 62 |
+
bbox = pts_file['parameters/bbox'][:]
|
| 63 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
| 64 |
+
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
|
| 65 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
| 66 |
+
# TODO: bring back to original shape!
|
| 67 |
+
# Reshape to original_shape
|
| 68 |
+
fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 69 |
+
fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 70 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 71 |
+
mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 72 |
+
mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 73 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 74 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
| 75 |
+
|
| 76 |
+
if mov_nodes.shape[0] == 1:
|
| 77 |
+
# Otherwise we only have a point, and CPD can't handle that... absurd!
|
| 78 |
+
fix_nodes = fix_pts
|
| 79 |
+
mov_nodes = mov_pts
|
| 80 |
+
|
| 81 |
+
ill_cond_def = False
|
| 82 |
+
ill_cond_r_def = False
|
| 83 |
+
# Deformable only
|
| 84 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
| 85 |
+
|
| 86 |
+
# deform_cb = partial(plot_cpd_registration_step,
|
| 87 |
+
# out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
|
| 88 |
+
|
| 89 |
+
# _, _, deform_reg_def = deform_registration(fix_nodes, mov_nodes, deform_cb)
|
| 90 |
+
time_def, deform_reg_def = deform_registration(fix_nodes*SCALE, mov_nodes*SCALE, time_it=True,
|
| 91 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 92 |
+
alpha=ALPHA, beta=BETA)
|
| 93 |
+
if np.isnan(deform_reg_def.diff):
|
| 94 |
+
tre_def = np.nan
|
| 95 |
+
pred_mov_centroid = np.zeros((3,))
|
| 96 |
+
else:
|
| 97 |
+
tps, ill_cond_def = radial_basis_function(mov_nodes, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
| 98 |
+
displacement_mov_centroid = tps(mov_centroid)
|
| 99 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
| 100 |
+
|
| 101 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 102 |
+
|
| 103 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
| 104 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 105 |
+
plot_cpd(fix_nodes, mov_nodes, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 106 |
+
plot_cpd(fix_nodes, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 107 |
+
|
| 108 |
+
# Rigid followed by deformable
|
| 109 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
| 110 |
+
|
| 111 |
+
# rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
|
| 112 |
+
# deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
|
| 113 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_nodes, mov_nodes, rigid_cb)
|
| 114 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_nodes, rigid_yt, deform_cb)
|
| 115 |
+
|
| 116 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_nodes*SCALE, mov_nodes*SCALE, time_it=True)
|
| 117 |
+
rigid_yt = rigid_reg_r_def.TY
|
| 118 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_nodes*SCALE, rigid_yt, time_it=True,
|
| 119 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 120 |
+
alpha=ALPHA, beta=BETA)
|
| 121 |
+
|
| 122 |
+
if np.isnan(deform_reg_r_def.diff):
|
| 123 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 124 |
+
else:
|
| 125 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 126 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
| 127 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
| 128 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
| 129 |
+
|
| 130 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 131 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_nodes)
|
| 132 |
+
|
| 133 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
| 134 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 135 |
+
plot_cpd(fix_nodes, mov_nodes, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 136 |
+
plot_cpd(fix_nodes, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 137 |
+
|
| 138 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
| 139 |
+
df = df.append({'DATASET':dataset_name,
|
| 140 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
| 141 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
| 142 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
| 143 |
+
'TIME_DEF': time_def,
|
| 144 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
| 145 |
+
'Q_DEF': deform_reg_def.diff,
|
| 146 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
| 147 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
| 148 |
+
'ILL_COND_DEF': ill_cond_def,
|
| 149 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
| 150 |
+
'TRE_DEF':tre_def, 'TRE_R_DEF':tre_r_def,
|
| 151 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
| 152 |
+
'DATA_PATH':file_path,
|
| 153 |
+
'DIST_CENTR':np.min(dist_centroid_to_pts),
|
| 154 |
+
'DIST_CENTR_DEF_95':np.percentile(dist_centroid_to_pts, 95),
|
| 155 |
+
'SAMPLE_NUM':fnum}, ignore_index=True)
|
| 156 |
+
pts_file.close()
|
| 157 |
+
|
| 158 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/evaluate_CPD_skeleton.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 4 |
+
parentdir = os.path.dirname(currentdir)
|
| 5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 6 |
+
|
| 7 |
+
import h5py
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from functools import partial
|
| 10 |
+
import numpy as np
|
| 11 |
+
from scipy.spatial.distance import euclidean
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
| 14 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
| 15 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
| 16 |
+
from scipy.spatial.distance import cdist
|
| 17 |
+
from skimage.morphology import skeletonize_3d
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
| 21 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
| 22 |
+
DATASET_FILENAME = 'points'
|
| 23 |
+
|
| 24 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
|
| 25 |
+
|
| 26 |
+
SCALE = 1e-2 # mm to cm
|
| 27 |
+
# CPD PARAMS (deform)
|
| 28 |
+
MAX_ITER = 200
|
| 29 |
+
ALPHA = 0.1
|
| 30 |
+
BETA = 1.0 # None = Use default
|
| 31 |
+
TOLERANCE = 1e-8
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
for dataset_name in DATASET_NAMES:
|
| 35 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
| 36 |
+
dataset_files = os.listdir(dataset_loc)
|
| 37 |
+
dataset_files.sort()
|
| 38 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
| 39 |
+
|
| 40 |
+
iterator = tqdm(dataset_files)
|
| 41 |
+
df = pd.DataFrame(columns=['DATASET',
|
| 42 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
| 43 |
+
'TIME_DEF', 'TIME_R_DEF',
|
| 44 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
| 45 |
+
'TRE_DEF', 'TRE_R_DEF',
|
| 46 |
+
'DS_DISP',
|
| 47 |
+
'DATA_PATH',
|
| 48 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
| 49 |
+
for i, file_path in enumerate(iterator):
|
| 50 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
| 51 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
| 52 |
+
iterator.set_description('{}: start'.format(fn))
|
| 53 |
+
pts_file = h5py.File(file_path, 'r')
|
| 54 |
+
# fix_pts = pts_file['fix/points'][:]
|
| 55 |
+
# fix_nodes = pts_file['fix/nodes'][:]
|
| 56 |
+
fix_skel = pts_file['fix/skeleton'][:]
|
| 57 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
| 58 |
+
|
| 59 |
+
# mov_pts = pts_file['mov/points'][:]
|
| 60 |
+
# mov_nodes = pts_file['mov/nodes'][:]
|
| 61 |
+
mov_skel = pts_file['mov/skeleton'][:]
|
| 62 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
| 63 |
+
|
| 64 |
+
bbox = pts_file['parameters/bbox'][:]
|
| 65 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
| 66 |
+
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
|
| 67 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
| 68 |
+
# TODO: bring back to original shape!
|
| 69 |
+
# Reshape to original_shape
|
| 70 |
+
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
|
| 71 |
+
# fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
|
| 72 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 73 |
+
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
|
| 74 |
+
fix_skel = skeletonize_3d(fix_skel)
|
| 75 |
+
fix_skel_pts = np.argwhere(fix_skel)
|
| 76 |
+
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
|
| 77 |
+
# mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
|
| 78 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
| 79 |
+
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
|
| 80 |
+
mov_skel = skeletonize_3d(mov_skel)
|
| 81 |
+
mov_skel_pts = np.argwhere(mov_skel)
|
| 82 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
| 83 |
+
|
| 84 |
+
ill_cond_def = False
|
| 85 |
+
ill_cond_r_def = False
|
| 86 |
+
# Deformable only
|
| 87 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
| 88 |
+
|
| 89 |
+
deform_cb = partial(plot_cpd_registration_step,
|
| 90 |
+
out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
|
| 91 |
+
|
| 92 |
+
# _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
|
| 93 |
+
time_def, deform_reg_def = deform_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True,
|
| 94 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 95 |
+
alpha=ALPHA, beta=BETA)
|
| 96 |
+
if np.isnan(deform_reg_def.diff):
|
| 97 |
+
tre_def = np.nan
|
| 98 |
+
pred_mov_centroid = mov_centroid
|
| 99 |
+
else:
|
| 100 |
+
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
|
| 101 |
+
displacement_mov_centroid = tps(mov_centroid)
|
| 102 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
| 103 |
+
|
| 104 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 105 |
+
|
| 106 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
| 107 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 108 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 109 |
+
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 110 |
+
|
| 111 |
+
# Rigid followed by deformable
|
| 112 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
| 113 |
+
|
| 114 |
+
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
|
| 115 |
+
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
|
| 116 |
+
|
| 117 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
|
| 118 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
|
| 119 |
+
|
| 120 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
|
| 121 |
+
rigid_yt = rigid_reg_r_def.TY
|
| 122 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
|
| 123 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
| 124 |
+
alpha=ALPHA, beta=BETA)
|
| 125 |
+
|
| 126 |
+
if np.isnan(deform_reg_r_def.diff):
|
| 127 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 128 |
+
else:
|
| 129 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
| 130 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
|
| 131 |
+
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
|
| 132 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
| 133 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
| 134 |
+
|
| 135 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
| 136 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
|
| 137 |
+
|
| 138 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
| 139 |
+
os.makedirs(plot_file, exist_ok=True)
|
| 140 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
| 141 |
+
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
| 145 |
+
df = df.append({'DATASET': dataset_name,
|
| 146 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
| 147 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
| 148 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
| 149 |
+
'TIME_DEF': time_def,
|
| 150 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
| 151 |
+
'Q_DEF': deform_reg_def.diff,
|
| 152 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
| 153 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
| 154 |
+
'ILL_COND_DEF': ill_cond_def,
|
| 155 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
| 156 |
+
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
|
| 157 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
| 158 |
+
'DATA_PATH': file_path,
|
| 159 |
+
'DIST_CENTR': np.min(dist_centroid_to_pts),
|
| 160 |
+
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
|
| 161 |
+
'SAMPLE_NUM':fnum}, ignore_index=True)
|
| 162 |
+
pts_file.close()
|
| 163 |
+
|
| 164 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/get_vessels.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import os
|
| 4 |
+
import h5py
|
| 5 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
| 6 |
+
|
| 7 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
| 8 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
| 9 |
+
DATASET_FILENAME = 'volume'
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
for dataset_name in DATASET_NAMES:
|
| 14 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
| 15 |
+
dataset_files = os.listdir(dataset_loc)
|
| 16 |
+
dataset_files.sort()
|
| 17 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
| 18 |
+
|
| 19 |
+
iterator = tqdm(dataset_files)
|
| 20 |
+
for fn in iterator:
|
| 21 |
+
f = os.path.split(fn)[-1].split('.hd5')[0]
|
| 22 |
+
vol_file = h5py.File(fn, 'r')
|
| 23 |
+
fix_vessels = vol_file[C.H5_FIX_VESSELS_MASK][..., 0]
|
| 24 |
+
mov_vessels = vol_file[C.H5_MOV_VESSELS_MASK][..., 0]
|
| 25 |
+
|
| 26 |
+
dst_folder = os.path.join(os.getcwd(), 'VESSELS', dataset_name)
|
| 27 |
+
os.makedirs(dst_folder, exist_ok=True)
|
| 28 |
+
save_nifti(fix_vessels, os.path.join(dst_folder, f+'_fix.nii.gz'))
|
| 29 |
+
save_nifti(mov_vessels, os.path.join(dst_folder, f+'_mov.nii.gz'))
|
| 30 |
+
vol_file.close()
|
Centerline/graph_utils.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.interpolate import RegularGridInterpolator, LinearNDInterpolator
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def graph_to_ndarray(graph):
|
| 7 |
+
out_nodes = np.empty((1, 3))
|
| 8 |
+
out_edges = np.empty((1, 3))
|
| 9 |
+
visited_nodes = list()
|
| 10 |
+
visited_node_pairs = list()
|
| 11 |
+
for (start_node, end_node) in graph.edges():
|
| 12 |
+
if (not (start_node, end_node) in visited_node_pairs) and (not (end_node, start_node) in visited_node_pairs):
|
| 13 |
+
edge = graph[start_node][end_node]['pts']
|
| 14 |
+
out_edges = np.vstack([out_edges, edge])
|
| 15 |
+
|
| 16 |
+
# Avoid duplicates
|
| 17 |
+
if not (start_node in visited_nodes):
|
| 18 |
+
out_nodes = np.vstack([out_nodes, graph.nodes[start_node]['o']])
|
| 19 |
+
visited_nodes.append(start_node)
|
| 20 |
+
if not (end_node in visited_nodes):
|
| 21 |
+
out_nodes = np.vstack([out_nodes, graph.nodes[end_node]['o']])
|
| 22 |
+
visited_nodes.append(end_node)
|
| 23 |
+
|
| 24 |
+
visited_node_pairs.append((start_node, end_node))
|
| 25 |
+
|
| 26 |
+
return np.vstack([out_edges, out_nodes]), out_nodes, out_edges
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_bifurcation_nodes(graph: nx.Graph):
|
| 30 |
+
# Vertex degree relates to the number of branches connected to a given node
|
| 31 |
+
out_nodes = np.empty((1, 3))
|
| 32 |
+
bif_nodes_id = list()
|
| 33 |
+
for node_num, deg in graph.degree:
|
| 34 |
+
if deg > 1:
|
| 35 |
+
bif_nodes_id.append(node_num)
|
| 36 |
+
out_nodes = np.vstack([out_nodes, graph.nodes[node_num]['o']])
|
| 37 |
+
|
| 38 |
+
return out_nodes, bif_nodes_id
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def apply_displacement(pts_list: np.ndarray, interpolator: [RegularGridInterpolator, LinearNDInterpolator]):
|
| 42 |
+
pts_list = pts_list.astype(np.float)
|
| 43 |
+
ret_val = pts_list + interpolator(pts_list).squeeze()
|
| 44 |
+
return ret_val
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def deform_graph(graph, dm_interpolator: [RegularGridInterpolator, LinearNDInterpolator]):
|
| 48 |
+
def_graph = nx.Graph()
|
| 49 |
+
for (start_node, end_node) in graph.edges():
|
| 50 |
+
edge = graph[start_node][end_node]['pts']
|
| 51 |
+
def_edge = apply_displacement(edge, dm_interpolator)
|
| 52 |
+
|
| 53 |
+
def_start_node_pts = apply_displacement(graph.nodes[start_node]['pts'], dm_interpolator)
|
| 54 |
+
def_end_node_pts = apply_displacement(graph.nodes[end_node]['pts'], dm_interpolator)
|
| 55 |
+
|
| 56 |
+
def_start_node_o = apply_displacement(graph.nodes[start_node]['o'], dm_interpolator)
|
| 57 |
+
def_end_node_o = apply_displacement(graph.nodes[end_node]['o'], dm_interpolator)
|
| 58 |
+
|
| 59 |
+
def_graph.add_node(start_node, pts=def_start_node_pts, o=def_start_node_o)
|
| 60 |
+
def_graph.add_node(end_node, pts=def_end_node_pts, o=def_end_node_o)
|
| 61 |
+
def_graph.add_edge(start_node, end_node, pts=def_edge, weight=len(def_edge))
|
| 62 |
+
return def_graph
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def subsample_graph(graph: nx.Graph, num_samples=3):
|
| 66 |
+
sub_graph = nx.Graph()
|
| 67 |
+
for (start_node, end_node) in graph.edges():
|
| 68 |
+
edge = graph[start_node][end_node]['pts']
|
| 69 |
+
edge_len = edge.shape[0]
|
| 70 |
+
sub_edge_len = (edge_len - 2) // num_samples # Do not count the pts corresponding to the nodes (-2)
|
| 71 |
+
|
| 72 |
+
sub_edge = [edge[0]]
|
| 73 |
+
include_last = bool((edge_len - 2) % num_samples) # Skip the last point, as this is too close to the node
|
| 74 |
+
if sub_edge_len:
|
| 75 |
+
idxs = np.arange(0, edge_len, num_samples)[1:] if include_last else np.arange(0, edge_len, num_samples)[1:-1]
|
| 76 |
+
for i in idxs:
|
| 77 |
+
sub_edge.append(edge[i])
|
| 78 |
+
|
| 79 |
+
sub_edge.append(edge[-1])
|
| 80 |
+
sub_edge = np.asarray(sub_edge)
|
| 81 |
+
sub_graph.add_node(start_node, pts=graph.nodes[start_node]['pts'], o=graph.nodes[start_node]['o'])
|
| 82 |
+
sub_graph.add_node(end_node, pts=graph.nodes[end_node]['pts'], o=graph.nodes[end_node]['o'])
|
| 83 |
+
sub_graph.add_edge(start_node, end_node, pts=sub_edge, weight=len(sub_edge))
|
| 84 |
+
return sub_graph
|
| 85 |
+
|
Centerline/skeleton_to_graph.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SRC: https://github.com/Image-Py/sknw/blob/master/sknw/sknw.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import networkx as nx
|
| 4 |
+
from Centerline.graph_utils import subsample_graph
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def neighbors(shape):
|
| 8 |
+
dim = len(shape)
|
| 9 |
+
block = np.ones([3] * dim)
|
| 10 |
+
block[tuple([1] * dim)] = 0
|
| 11 |
+
idx = np.where(block > 0)
|
| 12 |
+
idx = np.array(idx, dtype=np.uint8).T
|
| 13 |
+
idx = np.array(idx - [1] * dim)
|
| 14 |
+
acc = np.cumprod((1,) + shape[::-1][:-1])
|
| 15 |
+
return np.dot(idx, acc[::-1])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# my mark
|
| 19 |
+
def mark(img, nbs): # mark the array use (0, 1, 2)
|
| 20 |
+
img = img.ravel()
|
| 21 |
+
for p in range(len(img)):
|
| 22 |
+
if img[p] == 0: continue
|
| 23 |
+
s = 0
|
| 24 |
+
for dp in nbs:
|
| 25 |
+
if img[p + dp] != 0: s += 1
|
| 26 |
+
if s == 2:
|
| 27 |
+
img[p] = 1
|
| 28 |
+
else:
|
| 29 |
+
img[p] = 2
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# trans index to r, c...
|
| 33 |
+
def idx2rc(idx, acc):
|
| 34 |
+
rst = np.zeros((len(idx), len(acc)), dtype=np.int16)
|
| 35 |
+
for i in range(len(idx)):
|
| 36 |
+
for j in range(len(acc)):
|
| 37 |
+
rst[i, j] = idx[i] // acc[j]
|
| 38 |
+
idx[i] -= rst[i, j] * acc[j]
|
| 39 |
+
rst -= 1
|
| 40 |
+
return rst
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# fill a node (may be two or more points)
|
| 44 |
+
def fill(img, p, num, nbs, acc, buf):
|
| 45 |
+
back = img[p]
|
| 46 |
+
img[p] = num
|
| 47 |
+
buf[0] = p
|
| 48 |
+
cur = 0;
|
| 49 |
+
s = 1;
|
| 50 |
+
|
| 51 |
+
while True:
|
| 52 |
+
p = buf[cur]
|
| 53 |
+
for dp in nbs:
|
| 54 |
+
cp = p + dp
|
| 55 |
+
if img[cp] == back:
|
| 56 |
+
img[cp] = num
|
| 57 |
+
buf[s] = cp
|
| 58 |
+
s += 1
|
| 59 |
+
cur += 1
|
| 60 |
+
if cur == s: break
|
| 61 |
+
return idx2rc(buf[:s], acc)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# trace the edge and use a buffer, then buf.copy, if use [] numba not works
|
| 65 |
+
def trace(img, p, nbs, acc, buf):
|
| 66 |
+
c1 = 0;
|
| 67 |
+
c2 = 0;
|
| 68 |
+
newp = 0
|
| 69 |
+
cur = 0
|
| 70 |
+
|
| 71 |
+
while True:
|
| 72 |
+
buf[cur] = p
|
| 73 |
+
img[p] = 0
|
| 74 |
+
cur += 1
|
| 75 |
+
for dp in nbs:
|
| 76 |
+
cp = p + dp
|
| 77 |
+
if img[cp] >= 10:
|
| 78 |
+
if c1 == 0:
|
| 79 |
+
c1 = img[cp]
|
| 80 |
+
else:
|
| 81 |
+
c2 = img[cp]
|
| 82 |
+
if img[cp] == 1:
|
| 83 |
+
newp = cp
|
| 84 |
+
p = newp
|
| 85 |
+
if c2 != 0: break
|
| 86 |
+
return (c1 - 10, c2 - 10, idx2rc(buf[:cur], acc))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# parse the image then get the nodes and edges
|
| 90 |
+
def parse_struc(img, pts, nbs, acc):
|
| 91 |
+
img = img.ravel()
|
| 92 |
+
buf = np.zeros(131072, dtype=np.int64)
|
| 93 |
+
num = 10
|
| 94 |
+
nodes = []
|
| 95 |
+
for p in pts:
|
| 96 |
+
if img[p] == 2:
|
| 97 |
+
nds = fill(img, p, num, nbs, acc, buf)
|
| 98 |
+
num += 1
|
| 99 |
+
nodes.append(nds)
|
| 100 |
+
|
| 101 |
+
edges = []
|
| 102 |
+
for p in pts:
|
| 103 |
+
for dp in nbs:
|
| 104 |
+
if img[p + dp] == 1:
|
| 105 |
+
edge = trace(img, p + dp, nbs, acc, buf)
|
| 106 |
+
edges.append(edge)
|
| 107 |
+
return nodes, edges
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# use nodes and edges build a networkx graph
|
| 111 |
+
def build_graph(nodes, edges, multi=False):
|
| 112 |
+
graph = nx.MultiGraph() if multi else nx.Graph()
|
| 113 |
+
for i in range(len(nodes)):
|
| 114 |
+
graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0))
|
| 115 |
+
for s, e, pts in edges:
|
| 116 |
+
l = np.linalg.norm(pts[1:] - pts[:-1], axis=1).sum()
|
| 117 |
+
graph.add_edge(s, e, pts=pts, weight=l)
|
| 118 |
+
return graph
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def buffer(ske):
|
| 122 |
+
buf = np.zeros(tuple(np.array(ske.shape) + 2), dtype=np.uint16)
|
| 123 |
+
buf[tuple([slice(1, -1)] * buf.ndim)] = ske
|
| 124 |
+
return buf
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_sknw(ske, multi=False):
|
| 128 |
+
buf = buffer(ske)
|
| 129 |
+
nbs = neighbors(buf.shape)
|
| 130 |
+
acc = np.cumprod((1,) + buf.shape[::-1][:-1])[::-1]
|
| 131 |
+
mark(buf, nbs)
|
| 132 |
+
pts = np.array(np.where(buf.ravel() == 2))[0]
|
| 133 |
+
nodes, edges = parse_struc(buf, pts, nbs, acc)
|
| 134 |
+
return build_graph(nodes, edges, multi)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# draw the graph
|
| 138 |
+
def draw_graph(img, graph, cn=255, ce=128):
|
| 139 |
+
acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1]
|
| 140 |
+
img = img.ravel()
|
| 141 |
+
for idx in graph.nodes():
|
| 142 |
+
pts = graph.nodes[idx]['pts']
|
| 143 |
+
img[np.dot(pts, acc)] = cn
|
| 144 |
+
for (s, e) in graph.edges():
|
| 145 |
+
eds = graph[s][e]
|
| 146 |
+
for i in eds:
|
| 147 |
+
pts = eds[i]['pts']
|
| 148 |
+
img[np.dot(pts, acc)] = ce
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_graph_from_skeleton(mask, subsample=False):
|
| 152 |
+
graph = build_sknw(mask, False)
|
| 153 |
+
if len(graph.nodes) > 1 and len(graph.edges) and subsample:
|
| 154 |
+
graph = subsample_graph(graph, 3)
|
| 155 |
+
return graph
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == '__main__':
|
| 159 |
+
g = nx.MultiGraph()
|
| 160 |
+
g.add_nodes_from([1, 2, 3, 4, 5])
|
| 161 |
+
g.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 5), (5, 4)])
|
| 162 |
+
print(g.nodes())
|
| 163 |
+
print(g.edges())
|
| 164 |
+
a = g.subgraph(1)
|
| 165 |
+
print('d')
|
| 166 |
+
print(a)
|
| 167 |
+
print('d')
|
Centerline/skeletonization.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import nibabel as nib
|
| 4 |
+
from scipy import ndimage as ndi
|
| 5 |
+
from scipy.signal import convolve
|
| 6 |
+
from numpy.linalg import norm
|
| 7 |
+
import networkx as nx
|
| 8 |
+
import logging
|
| 9 |
+
import traceback
|
| 10 |
+
import timeit
|
| 11 |
+
import time
|
| 12 |
+
import math
|
| 13 |
+
from ast import literal_eval as make_tuple
|
| 14 |
+
from skimage.measure import label
|
| 15 |
+
import subprocess
|
| 16 |
+
import platform
|
| 17 |
+
import glob
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def loadVolume(volumeFolderPath, volumeName):
|
| 21 |
+
"""
|
| 22 |
+
Load nifti files (*.nii or *.nii.gz).
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
volumeFolderPath : str
|
| 26 |
+
Folder of the volume file.
|
| 27 |
+
volumeName : str
|
| 28 |
+
Name of the volume file.
|
| 29 |
+
|
| 30 |
+
Returns
|
| 31 |
+
-------
|
| 32 |
+
volume : ndarray
|
| 33 |
+
Volume data in the form of numpy ndarray.
|
| 34 |
+
affine : ndarray
|
| 35 |
+
Associated affine transformation matrix in the form of numpy ndarray.
|
| 36 |
+
"""
|
| 37 |
+
volumeFilePath = os.path.join(volumeFolderPath, volumeName)
|
| 38 |
+
volumeImg = nib.load(volumeFilePath)
|
| 39 |
+
volume = volumeImg.get_data()
|
| 40 |
+
shape = volume.shape
|
| 41 |
+
affine = volumeImg.affine
|
| 42 |
+
print('Volume loaded from {} with shape = {}.'.format(volumeFilePath, shape))
|
| 43 |
+
|
| 44 |
+
return volume, affine
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def saveVolume(volume, affine, path, astype=None):
|
| 48 |
+
"""
|
| 49 |
+
Save the given volume to the specified location in specified data type.
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
volume : ndarray
|
| 53 |
+
Volume data to be saved.
|
| 54 |
+
affine : ndarray
|
| 55 |
+
The affine transformation matrix associated with the volume.
|
| 56 |
+
path : str
|
| 57 |
+
The absolute path where the volume is going to be saved.
|
| 58 |
+
astype : numpy dtype, optional
|
| 59 |
+
The desired data type of the volume data.
|
| 60 |
+
"""
|
| 61 |
+
if astype is None:
|
| 62 |
+
astype = np.uint8
|
| 63 |
+
|
| 64 |
+
nib.save(nib.Nifti1Image(volume.astype(astype), affine), path)
|
| 65 |
+
print('Volume saved to {} as type {}.'.format(path, astype))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def labelVolume(volume, minSize=1, maxHop=3):
|
| 69 |
+
"""
|
| 70 |
+
Partition the volume into several connected components and attach labels.
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
volume : ndarray
|
| 74 |
+
Volume to be partitioned.
|
| 75 |
+
minSize : int, optional
|
| 76 |
+
The connected component that is less than this size will be disgarded.
|
| 77 |
+
maxHop : int, optional
|
| 78 |
+
Controls how neighboring voxels are defined. See `label` doc for details.
|
| 79 |
+
|
| 80 |
+
Returns
|
| 81 |
+
-------
|
| 82 |
+
labeled : ndarray
|
| 83 |
+
The partitioned and labeled volume. Each connected component has a label (a positive integer) and the background
|
| 84 |
+
is labeled as 0.
|
| 85 |
+
labelResult : list
|
| 86 |
+
In the form of [[label1, size1], [label2, size2], ...]
|
| 87 |
+
"""
|
| 88 |
+
labeled, maxNum = label(volume, return_num=True, connectivity=maxHop)
|
| 89 |
+
counts = np.bincount(labeled.ravel())
|
| 90 |
+
countLoc = np.nonzero(counts)[0]
|
| 91 |
+
sizeList = counts[countLoc]
|
| 92 |
+
labelResult = list(zip(countLoc[sizeList >= minSize], sizeList[sizeList >= minSize]))
|
| 93 |
+
# print(labelResult)
|
| 94 |
+
# print('Total segments: {}'.format(np.count_nonzero(sizeList >= minSize)))
|
| 95 |
+
return labeled, labelResult
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def analyze(vesselVolumeMask, baseFolder):
|
| 99 |
+
"""
|
| 100 |
+
Main function to provoke the skeletonization process. Note that here I am using the docker version of the code. If
|
| 101 |
+
you have already downloaded the original C++ code and successfully compiled it, then you can run that compiled code
|
| 102 |
+
instead of this one.
|
| 103 |
+
"""
|
| 104 |
+
vesselVolumeMask = vesselVolumeMask.astype(np.uint8)
|
| 105 |
+
vesselVolumeMask[vesselVolumeMask != 0] = 1
|
| 106 |
+
vesselVolumeMask = np.swapaxes(vesselVolumeMask, 0, 2)
|
| 107 |
+
shape = vesselVolumeMask.shape
|
| 108 |
+
|
| 109 |
+
vesselVolumeMaskLabeled, vesselVolumeMaskLabelResult = labelVolume(vesselVolumeMask, minSize=1)
|
| 110 |
+
directory = os.path.join(baseFolder, 'skeletonizationResult')
|
| 111 |
+
if not os.path.exists(directory):
|
| 112 |
+
os.makedirs(directory)
|
| 113 |
+
print('Directory {} created.'.format(directory))
|
| 114 |
+
|
| 115 |
+
vesselVolumeMaskLabelInfoFilename = 'vesselVolumeMaskLabelInfo.npz'
|
| 116 |
+
vesselVolumeMaskLabelInfoFilePath = os.path.join(directory, vesselVolumeMaskLabelInfoFilename)
|
| 117 |
+
np.savez_compressed(vesselVolumeMaskLabelInfoFilePath, vesselVolumeMaskLabeled=vesselVolumeMaskLabeled,
|
| 118 |
+
vesselVolumeMaskLabelResult=vesselVolumeMaskLabelResult)
|
| 119 |
+
print('{} saved to {}.'.format(vesselVolumeMaskLabelInfoFilename, vesselVolumeMaskLabelInfoFilePath))
|
| 120 |
+
|
| 121 |
+
# directory2 = directory + 'labelNum=' + str(labelNum) + '/'
|
| 122 |
+
# if not os.path.exists(directory2):
|
| 123 |
+
# os.makedirs(directory2)
|
| 124 |
+
# with open(directory2 + 'BB.txt', 'w') as f1:
|
| 125 |
+
# f1.write('1\n')
|
| 126 |
+
# f1.write('{} {} {}\n'.format(0, 0, 0))
|
| 127 |
+
# f1.write('{} {} {}'.format(*shape))
|
| 128 |
+
# '''
|
| 129 |
+
BBFilePath = os.path.join(directory, 'BB.txt')
|
| 130 |
+
f1 = open(BBFilePath, 'w')
|
| 131 |
+
f1.write('1\n')
|
| 132 |
+
f1.write('{} {} {}\n'.format(0, 0, 0))
|
| 133 |
+
f1.write('{} {} {}'.format(*shape))
|
| 134 |
+
f1.close()
|
| 135 |
+
|
| 136 |
+
vesselCoords = np.array(np.where(vesselVolumeMask)).T
|
| 137 |
+
xyzFilePath = os.path.join(directory, 'xyz.txt')
|
| 138 |
+
np.savetxt(xyzFilePath, vesselCoords, fmt='%1u')
|
| 139 |
+
f2 = open(xyzFilePath, "r")
|
| 140 |
+
contents = f2.readlines()
|
| 141 |
+
f2.close()
|
| 142 |
+
|
| 143 |
+
contents.insert(0, '{}\n'.format(len(vesselCoords)))
|
| 144 |
+
|
| 145 |
+
f2 = open(xyzFilePath, "w")
|
| 146 |
+
contents = "".join(contents)
|
| 147 |
+
f2.write(contents)
|
| 148 |
+
f2.close()
|
| 149 |
+
# '''
|
| 150 |
+
|
| 151 |
+
# '''
|
| 152 |
+
currentPlatform = platform.system()
|
| 153 |
+
print('Current platform is {}.'.format(currentPlatform))
|
| 154 |
+
if currentPlatform == 'Windows':
|
| 155 |
+
cmd = '"C:/Program Files/Docker/Docker/Resources/bin/docker.exe" run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
|
| 156 |
+
elif currentPlatform == 'Darwin':
|
| 157 |
+
cmd = 'docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2'
|
| 158 |
+
elif currentPlatform == 'Linux':
|
| 159 |
+
cmd = '/usr/local/bin/docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
|
| 160 |
+
cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
|
| 161 |
+
cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2'
|
| 162 |
+
|
| 163 |
+
print('cmd={}'.format(cmd))
|
| 164 |
+
subprocess.call(cmd, shell=True)
|
| 165 |
+
# '''
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def combineSkeletonSegments(skeletonSegmentFolderPath):
|
| 169 |
+
"""
|
| 170 |
+
Collect and combine the results from the skeletonization.
|
| 171 |
+
Parameters
|
| 172 |
+
----------
|
| 173 |
+
skeletonSegmentFolderPath : str
|
| 174 |
+
The folder that contains the segments information (result_segments_xyz*.txt).
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
segmentList : list
|
| 179 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
| 180 |
+
represents a centerpoint coordinates.
|
| 181 |
+
"""
|
| 182 |
+
segmentList = []
|
| 183 |
+
files = glob.glob(os.path.join(skeletonSegmentFolderPath, 'result_segments_xyz*.txt'))
|
| 184 |
+
for segmentFile in files:
|
| 185 |
+
result = readSegmentFile(segmentFile)
|
| 186 |
+
segmentList += result
|
| 187 |
+
|
| 188 |
+
return segmentList
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def readSegmentFile(segmentFile):
|
| 192 |
+
"""
|
| 193 |
+
Parse the segment files (result_segments_xyz*.txt) and return segments information in a list.
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
segmentFile : str
|
| 197 |
+
Path to the segment file.
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
segmentList : list
|
| 202 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
| 203 |
+
represents a centerpoint coordinates.
|
| 204 |
+
"""
|
| 205 |
+
isFirstLine = True
|
| 206 |
+
isSegmentLength = True
|
| 207 |
+
segmentList = []
|
| 208 |
+
with open(segmentFile) as f:
|
| 209 |
+
for line in f:
|
| 210 |
+
if isFirstLine:
|
| 211 |
+
numOfSegments = int(line)
|
| 212 |
+
isFirstLine = False
|
| 213 |
+
else:
|
| 214 |
+
if isSegmentLength:
|
| 215 |
+
segmentLength = int(line)
|
| 216 |
+
isSegmentLength = False
|
| 217 |
+
segmentCounter = 1
|
| 218 |
+
segment = []
|
| 219 |
+
else:
|
| 220 |
+
if segmentCounter <= segmentLength:
|
| 221 |
+
voxel = tuple([int(x) for x in line.split(' ')])
|
| 222 |
+
segment.append(voxel[::-1])
|
| 223 |
+
segmentCounter += 1
|
| 224 |
+
else:
|
| 225 |
+
segmentCounter += 1
|
| 226 |
+
isSegmentLength = True
|
| 227 |
+
segmentList.append(segment)
|
| 228 |
+
assert (len(segment) == segmentLength)
|
| 229 |
+
|
| 230 |
+
return segmentList
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# def drawSegments(segmentList):
|
| 234 |
+
# pass
|
| 235 |
+
|
| 236 |
+
def processSegments(segmentList, shape):
|
| 237 |
+
"""
|
| 238 |
+
Re-partition the segments so that each segment is a simple branch, i.e., it does not contain bifurcation point
|
| 239 |
+
unless at the two ends.
|
| 240 |
+
Note that this function might be replaced by another more concise function `getSegmentList`.
|
| 241 |
+
Parameters
|
| 242 |
+
----------
|
| 243 |
+
segmentList : list
|
| 244 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
| 245 |
+
represents a centerpoint coordinates.
|
| 246 |
+
shape : tuple
|
| 247 |
+
Shape of the vessel volume (used for ploting).
|
| 248 |
+
|
| 249 |
+
Returns
|
| 250 |
+
-------
|
| 251 |
+
G : NetworkX graph
|
| 252 |
+
A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch.
|
| 253 |
+
segmentList : list
|
| 254 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
| 255 |
+
represents a centerpoint coordinates.
|
| 256 |
+
errorSegments : list
|
| 257 |
+
A list that contains segments that cannot be fixed.
|
| 258 |
+
"""
|
| 259 |
+
## Import pyqtgraph ##
|
| 260 |
+
from pyqtgraph.Qt import QtCore, QtGui
|
| 261 |
+
import pyqtgraph as pg
|
| 262 |
+
import pyqtgraph.opengl as gl
|
| 263 |
+
|
| 264 |
+
## Init ##
|
| 265 |
+
app = pg.QtGui.QApplication([])
|
| 266 |
+
w = gl.GLViewWidget()
|
| 267 |
+
w.opts['distance'] = 800
|
| 268 |
+
w.setGeometry(0, 110, 1600, 900)
|
| 269 |
+
offset = np.array(shape) / (-2.0)
|
| 270 |
+
|
| 271 |
+
G = nx.Graph()
|
| 272 |
+
colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')]
|
| 273 |
+
colorPointer = 0
|
| 274 |
+
skeleton = np.full(shape, 0)
|
| 275 |
+
for segment in segmentList:
|
| 276 |
+
# G.add_path(list(map(tuple, segment)))
|
| 277 |
+
G.add_path(segment)
|
| 278 |
+
segmentCoords = np.array(segment)
|
| 279 |
+
skeleton[tuple(segmentCoords.T)] = 1
|
| 280 |
+
# segmentCoordsView = segmentCoords + offset
|
| 281 |
+
# aa = gl.GLLinePlotItem(pos=segmentCoordsView, color=colorList[colorPointer], width=3)
|
| 282 |
+
# w.addItem(aa)
|
| 283 |
+
# colorPointer = colorPointer + 1 if colorPointer < len(colorList) - 1 else 0
|
| 284 |
+
|
| 285 |
+
# skeletonCoords = np.array(np.where(skeleton)).T
|
| 286 |
+
# skeletonCoordsView = (skeletonCoords + offset) * affineTransform
|
| 287 |
+
# aa = gl.GLScatterPlotItem(pos=skeletonCoordsView, size=5)
|
| 288 |
+
# w.addItem(aa)
|
| 289 |
+
|
| 290 |
+
# w.show()
|
| 291 |
+
|
| 292 |
+
voxelDegrees = np.array([v for _, v in G.degree(G.nodes())])
|
| 293 |
+
maxVoxelDegree = np.amax(voxelDegrees)
|
| 294 |
+
voxelDegreesZippedResult = list(zip(np.arange(maxVoxelDegree + 1), np.bincount(voxelDegrees)))
|
| 295 |
+
print('Voxel degree distribution is \n{}'.format(voxelDegreesZippedResult))
|
| 296 |
+
print('Number of cycles is {}'.format(len(nx.cycle_basis(G))))
|
| 297 |
+
|
| 298 |
+
# Remove duplicate segments
|
| 299 |
+
keepList = np.full((len(segmentList),), True)
|
| 300 |
+
duplicateCounter = 0
|
| 301 |
+
for idx, seg in enumerate(segmentList):
|
| 302 |
+
for idx2, seg2 in enumerate(segmentList[idx + 1:]):
|
| 303 |
+
if seg == seg2 or seg == seg2[::-1]:
|
| 304 |
+
keepList[idx + idx2] = False
|
| 305 |
+
duplicateCounter += 1
|
| 306 |
+
|
| 307 |
+
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
| 308 |
+
print('{} duplicate segments removed!'.format(duplicateCounter))
|
| 309 |
+
|
| 310 |
+
# Cut segments into sub-pieces if there are bifurcation points in the middle
|
| 311 |
+
extraSegments = []
|
| 312 |
+
keepList = np.full((len(segmentList),), True)
|
| 313 |
+
for idx, segment in enumerate(segmentList):
|
| 314 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
| 315 |
+
if len(voxelDegrees) >= 3:
|
| 316 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or (not np.all(voxelDegrees[1:-1] == 2)):
|
| 317 |
+
keepList[idx] = False
|
| 318 |
+
locs = np.nonzero(voxelDegrees != 2)[0]
|
| 319 |
+
if voxelDegrees[0] == 2:
|
| 320 |
+
locs = np.hstack((0, locs))
|
| 321 |
+
|
| 322 |
+
if voxelDegrees[-1] == 2:
|
| 323 |
+
locs = np.hstack((locs, len(voxelDegrees)))
|
| 324 |
+
|
| 325 |
+
newSegments = []
|
| 326 |
+
for ii in range(len(locs) - 1):
|
| 327 |
+
newSegments.append(segment[locs[ii]:(locs[ii + 1] + 1)])
|
| 328 |
+
|
| 329 |
+
extraSegments += newSegments
|
| 330 |
+
|
| 331 |
+
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
| 332 |
+
segmentList += extraSegments
|
| 333 |
+
|
| 334 |
+
# Remove duplicate segments again
|
| 335 |
+
keepList = np.full((len(segmentList),), True)
|
| 336 |
+
duplicateCounter = 0
|
| 337 |
+
for idx, seg in enumerate(segmentList):
|
| 338 |
+
for idx2, seg2 in enumerate(segmentList[idx + 1:]):
|
| 339 |
+
if seg == seg2 or seg == seg2[::-1]:
|
| 340 |
+
keepList[idx + idx2] = False
|
| 341 |
+
duplicateCounter += 1
|
| 342 |
+
|
| 343 |
+
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
| 344 |
+
print('{} duplicate segments removed in the second stage!'.format(duplicateCounter))
|
| 345 |
+
|
| 346 |
+
# Remove segment if it is completely contained in another segment
|
| 347 |
+
# keepList = np.full((len(segmentList),), True)
|
| 348 |
+
# sublistCounter = 0
|
| 349 |
+
# for idx, seg in enumerate(segmentList):
|
| 350 |
+
# for idx2, seg2 in enumerate(segmentList[idx + 1:]):
|
| 351 |
+
# if contains(seg, seg2):
|
| 352 |
+
# keepList[idx] = False
|
| 353 |
+
# sublistCounter += 1
|
| 354 |
+
# elif contains(seg2, seg):
|
| 355 |
+
# keepList[idx + idx2] = False
|
| 356 |
+
# sublistCounter += 1
|
| 357 |
+
|
| 358 |
+
# segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
| 359 |
+
# print('{} sublist segments removed!'.format(sublistCounter))
|
| 360 |
+
|
| 361 |
+
# Treat the segment if either end is not correct
|
| 362 |
+
hasInvalidSegments = False
|
| 363 |
+
for idx, segment in enumerate(segmentList):
|
| 364 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
| 365 |
+
if len(voxelDegrees) == 2:
|
| 366 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
|
| 367 |
+
# print('Degrees on either end is 2: {}'.format(voxelDegrees))
|
| 368 |
+
hasInvalidSegments = True
|
| 369 |
+
elif len(voxelDegrees) > 2:
|
| 370 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2):
|
| 371 |
+
# print('Degrees not correct: {}'.format(voxelDegrees))
|
| 372 |
+
hasInvalidSegments = True
|
| 373 |
+
|
| 374 |
+
if not hasInvalidSegments:
|
| 375 |
+
drawSegments(segmentList, shape)
|
| 376 |
+
print('No errors!')
|
| 377 |
+
errorSegments = []
|
| 378 |
+
return G, segmentList, errorSegments
|
| 379 |
+
|
| 380 |
+
iterCounter = 1
|
| 381 |
+
while hasInvalidSegments:
|
| 382 |
+
print('\n\nIter={}'.format(iterCounter))
|
| 383 |
+
keepList = np.full((len(segmentList),), True)
|
| 384 |
+
extraSegments = []
|
| 385 |
+
for idx, segment in enumerate(segmentList):
|
| 386 |
+
if keepList[idx]:
|
| 387 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
| 388 |
+
if voxelDegrees[0] == 2 and voxelDegrees[-1] == 2:
|
| 389 |
+
print('Both end have 2 neighbours')
|
| 390 |
+
elif voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
|
| 391 |
+
# print('Degrees on either end is 2: {}'.format(voxelDegrees))
|
| 392 |
+
# pass
|
| 393 |
+
# segmentCoords = np.array(segment)
|
| 394 |
+
if voxelDegrees[0] == 2:
|
| 395 |
+
otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
| 396 |
+
(seg[0] == segment[0] or seg[-1] == segment[0]) and keepList[
|
| 397 |
+
idx2] and idx != idx2]
|
| 398 |
+
if len(otherSegmentInfo) != 0:
|
| 399 |
+
if len(otherSegmentInfo) > 1:
|
| 400 |
+
# print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment))
|
| 401 |
+
otherSegmentInfoTemp = []
|
| 402 |
+
for idx2, seg in otherSegmentInfo:
|
| 403 |
+
if contains(segment, seg) or contains(segment[::-1], seg):
|
| 404 |
+
keepList[idx] = False
|
| 405 |
+
continue
|
| 406 |
+
elif contains(seg, segment) or contains(seg[::-1], segment):
|
| 407 |
+
keepList[idx2] = False
|
| 408 |
+
otherSegmentInfoTemp.append((idx2, seg))
|
| 409 |
+
|
| 410 |
+
otherSegmentInfo = otherSegmentInfoTemp
|
| 411 |
+
# otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))]
|
| 412 |
+
if len(otherSegmentInfo) > 1:
|
| 413 |
+
print('More than one other segments found!')
|
| 414 |
+
print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees))
|
| 415 |
+
for otherSegmentIdx, otherSegment in otherSegmentInfo:
|
| 416 |
+
otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)])
|
| 417 |
+
print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment,
|
| 418 |
+
otherSegmentVoxelDegrees))
|
| 419 |
+
elif len(otherSegmentInfo) == 1:
|
| 420 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
| 421 |
+
else:
|
| 422 |
+
print('No valid other segments found!')
|
| 423 |
+
continue
|
| 424 |
+
else:
|
| 425 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
| 426 |
+
if contains(segment, otherSegment) or contains(segment[::-1], otherSegment):
|
| 427 |
+
keepList[idx] = False
|
| 428 |
+
continue
|
| 429 |
+
elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment):
|
| 430 |
+
keepList[otherSegmentIdx] = False
|
| 431 |
+
continue
|
| 432 |
+
|
| 433 |
+
newSegment = otherSegment + segment[1:] if otherSegment[-1] == segment[0] else otherSegment[
|
| 434 |
+
::-1] + segment[
|
| 435 |
+
1:]
|
| 436 |
+
if not validateSegment(G, newSegment):
|
| 437 |
+
newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)])
|
| 438 |
+
print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees,
|
| 439 |
+
newSegmentVoxelDegrees))
|
| 440 |
+
else:
|
| 441 |
+
print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx))
|
| 442 |
+
|
| 443 |
+
extraSegments.append(newSegment)
|
| 444 |
+
keepList[idx] = False
|
| 445 |
+
keepList[otherSegmentIdx] = False
|
| 446 |
+
else:
|
| 447 |
+
print(
|
| 448 |
+
'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment,
|
| 449 |
+
voxelDegrees))
|
| 450 |
+
possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
| 451 |
+
(seg[0] == segment[0] or seg[-1] == segment[0]) and idx != idx2]
|
| 452 |
+
print('Possible segments: {}'.format(len(possibleSegmentsInfo)))
|
| 453 |
+
|
| 454 |
+
elif voxelDegrees[-1] == 2:
|
| 455 |
+
otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
| 456 |
+
(seg[0] == segment[-1] or seg[-1] == segment[-1]) and keepList[
|
| 457 |
+
idx2] and idx != idx2]
|
| 458 |
+
if len(otherSegmentInfo) != 0:
|
| 459 |
+
if len(otherSegmentInfo) > 1:
|
| 460 |
+
# print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment))
|
| 461 |
+
otherSegmentInfoTemp = []
|
| 462 |
+
for idx2, seg in otherSegmentInfo:
|
| 463 |
+
if contains(segment, seg) or contains(segment[::-1], seg):
|
| 464 |
+
keepList[idx] = False
|
| 465 |
+
continue
|
| 466 |
+
elif contains(seg, segment) or contains(seg[::-1], segment):
|
| 467 |
+
keepList[idx2] = False
|
| 468 |
+
otherSegmentInfoTemp.append((idx2, seg))
|
| 469 |
+
|
| 470 |
+
otherSegmentInfo = otherSegmentInfoTemp
|
| 471 |
+
# otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))]
|
| 472 |
+
if len(otherSegmentInfo) > 1:
|
| 473 |
+
print('More than one other segments found!')
|
| 474 |
+
print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees))
|
| 475 |
+
for otherSegmentIdx, otherSegment in otherSegmentInfo:
|
| 476 |
+
otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)])
|
| 477 |
+
print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment,
|
| 478 |
+
otherSegmentVoxelDegrees))
|
| 479 |
+
elif len(otherSegmentInfo) == 1:
|
| 480 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
| 481 |
+
else:
|
| 482 |
+
print('No valid other segments found!')
|
| 483 |
+
continue
|
| 484 |
+
else:
|
| 485 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
| 486 |
+
if contains(segment, otherSegment) or contains(segment[::-1], otherSegment):
|
| 487 |
+
keepList[idx] = False
|
| 488 |
+
continue
|
| 489 |
+
elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment):
|
| 490 |
+
keepList[otherSegmentIdx] = False
|
| 491 |
+
continue
|
| 492 |
+
|
| 493 |
+
newSegment = segment[:-1] + otherSegment if otherSegment[0] == segment[-1] else segment[
|
| 494 |
+
:-1] + otherSegment[
|
| 495 |
+
::-1]
|
| 496 |
+
if not validateSegment(G, newSegment):
|
| 497 |
+
newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)])
|
| 498 |
+
print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees,
|
| 499 |
+
newSegmentVoxelDegrees))
|
| 500 |
+
else:
|
| 501 |
+
print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx))
|
| 502 |
+
|
| 503 |
+
extraSegments.append(newSegment)
|
| 504 |
+
keepList[idx] = False
|
| 505 |
+
keepList[otherSegmentIdx] = False
|
| 506 |
+
else:
|
| 507 |
+
print(
|
| 508 |
+
'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment,
|
| 509 |
+
voxelDegrees))
|
| 510 |
+
possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
| 511 |
+
(seg[0] == segment[-1] or seg[-1] == segment[-1]) and idx != idx2]
|
| 512 |
+
print('Possible segments: {}'.format(len(possibleSegmentsInfo)))
|
| 513 |
+
|
| 514 |
+
segmentList = [segment for idx, segment in enumerate(segmentList) if keepList[idx]]
|
| 515 |
+
segmentList += extraSegments
|
| 516 |
+
hasInvalidSegments = False
|
| 517 |
+
errorSegments = []
|
| 518 |
+
for idx, segment in enumerate(segmentList):
|
| 519 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
| 520 |
+
if len(voxelDegrees) == 2:
|
| 521 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
|
| 522 |
+
print('Degrees on either end is 2: {}'.format(voxelDegrees))
|
| 523 |
+
hasInvalidSegments = True
|
| 524 |
+
errorSegments.append(segment)
|
| 525 |
+
elif len(voxelDegrees) > 2:
|
| 526 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2):
|
| 527 |
+
print('Degrees not correct: {}'.format(voxelDegrees))
|
| 528 |
+
hasInvalidSegments = True
|
| 529 |
+
errorSegments.append(segment)
|
| 530 |
+
|
| 531 |
+
print('hasInvalidSegments = {}'.format(hasInvalidSegments))
|
| 532 |
+
iterCounter += 1
|
| 533 |
+
if len(extraSegments) == 0:
|
| 534 |
+
hasInvalidSegments = False
|
| 535 |
+
print('While loop aborted because there is no change in segments!')
|
| 536 |
+
|
| 537 |
+
for errorSegment in errorSegments:
|
| 538 |
+
segmentList.remove(errorSegment)
|
| 539 |
+
|
| 540 |
+
# np.savez_compressed(directory + 'segmentList.npz', segmentList=segmentList)
|
| 541 |
+
# if partIdx != 10:
|
| 542 |
+
# nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNamePartial + str(partIdx) + '.nii.gz')
|
| 543 |
+
# else:
|
| 544 |
+
# nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNameTotal + '.nii.gz')
|
| 545 |
+
|
| 546 |
+
# nx.write_graphml(G, directory + 'graphRepresentation.graphml')
|
| 547 |
+
|
| 548 |
+
# drawAbstractGraph(offset, segmentList)
|
| 549 |
+
# drawAbstractGraph(offset, errorSegments)
|
| 550 |
+
|
| 551 |
+
print(errorSegments)
|
| 552 |
+
|
| 553 |
+
return G, segmentList, errorSegments
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def getSegmentList(G, nodeInfoDict):
|
| 557 |
+
"""
|
| 558 |
+
Generate segmentList from graph and nodeInfoDict.
|
| 559 |
+
Parameters
|
| 560 |
+
----------
|
| 561 |
+
G : NetworkX graph
|
| 562 |
+
The graph representation of the network.
|
| 563 |
+
nodeInfoDict : dict
|
| 564 |
+
A dictionary containing the information about nodes.
|
| 565 |
+
|
| 566 |
+
Returns
|
| 567 |
+
-------
|
| 568 |
+
segmentList : list
|
| 569 |
+
A list of segments in which each segment is a simple branch.
|
| 570 |
+
"""
|
| 571 |
+
startNodeIDList = [nodeID for nodeID in nodeInfoDict.keys() if nodeInfoDict[nodeID]['parentNodeID'] == -1]
|
| 572 |
+
print('startNodeIDList = {}'.format(startNodeIDList))
|
| 573 |
+
segmentList = []
|
| 574 |
+
for startNodeID in startNodeIDList:
|
| 575 |
+
segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID)
|
| 576 |
+
|
| 577 |
+
print('There are {} segments in segmentList'.format(len(segmentList)))
|
| 578 |
+
print(segmentList)
|
| 579 |
+
return segmentList
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID):
|
| 583 |
+
"""
|
| 584 |
+
Implementation of `getSegmentList`. Use DFS to traverse all the segments.
|
| 585 |
+
Parameters
|
| 586 |
+
----------
|
| 587 |
+
G : NetworkX graph
|
| 588 |
+
The graph representation of the network.
|
| 589 |
+
nodeInfoDict : dict
|
| 590 |
+
A dictionary containing the information about nodes.
|
| 591 |
+
segmentList : list
|
| 592 |
+
A list of segments in which each segment is a simple branch.
|
| 593 |
+
startNodeID : int
|
| 594 |
+
The index of the start point of a segment.
|
| 595 |
+
|
| 596 |
+
Returns
|
| 597 |
+
-------
|
| 598 |
+
segmentList : list
|
| 599 |
+
A list of segments in which each segment is a simple branch.
|
| 600 |
+
"""
|
| 601 |
+
neighborNodeIDList = [nodeID for nodeID in list(G[startNodeID].keys()) if
|
| 602 |
+
'visited' not in G[startNodeID][nodeID]] # use adjacency dict to find neighbors
|
| 603 |
+
newSegmentList = []
|
| 604 |
+
for neighborNodeID in neighborNodeIDList:
|
| 605 |
+
newSegment = [startNodeID, neighborNodeID]
|
| 606 |
+
G[startNodeID][neighborNodeID]['visited'] = True
|
| 607 |
+
currentNodeID = neighborNodeID
|
| 608 |
+
while G.degree(currentNodeID) == 2:
|
| 609 |
+
newNodeID = [nodeID for nodeID in G[currentNodeID].keys() if 'visited' not in G[currentNodeID][nodeID]][0]
|
| 610 |
+
G[currentNodeID][newNodeID]['visited'] = True
|
| 611 |
+
newSegment.append(newNodeID)
|
| 612 |
+
currentNodeID = newNodeID
|
| 613 |
+
|
| 614 |
+
newSegmentList.append(newSegment)
|
| 615 |
+
segmentList.append(newSegment)
|
| 616 |
+
segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, currentNodeID)
|
| 617 |
+
|
| 618 |
+
return segmentList
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def sublist(ls1, ls2):
|
| 622 |
+
'''
|
| 623 |
+
>>> sublist([], [1,2,3])
|
| 624 |
+
True
|
| 625 |
+
>>> sublist([1,2,3,4], [2,5,3])
|
| 626 |
+
True
|
| 627 |
+
>>> sublist([1,2,3,4], [0,3,2])
|
| 628 |
+
False
|
| 629 |
+
>>> sublist([1,2,3,4], [1,2,5,6,7,8,5,76,4,3])
|
| 630 |
+
False
|
| 631 |
+
'''
|
| 632 |
+
|
| 633 |
+
def get_all_in(one, another):
|
| 634 |
+
for element in one:
|
| 635 |
+
if element in another:
|
| 636 |
+
yield element
|
| 637 |
+
|
| 638 |
+
for x1, x2 in zip(get_all_in(ls1, ls2), get_all_in(ls2, ls1)):
|
| 639 |
+
if x1 != x2:
|
| 640 |
+
return False
|
| 641 |
+
|
| 642 |
+
return True
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def contains(lst1, lst2):
|
| 646 |
+
lst1, lst2 = (lst2, lst1) if len(lst1) > len(lst2) else (lst1, lst2)
|
| 647 |
+
if lst1[0] in lst2:
|
| 648 |
+
startLoc = lst2.index(lst1[0])
|
| 649 |
+
else:
|
| 650 |
+
return False
|
| 651 |
+
|
| 652 |
+
if lst1[-1] in lst2:
|
| 653 |
+
endLoc = lst2.index(lst1[-1])
|
| 654 |
+
else:
|
| 655 |
+
return False
|
| 656 |
+
|
| 657 |
+
if startLoc < endLoc:
|
| 658 |
+
if lst1 == lst2[startLoc:(endLoc + 1)]:
|
| 659 |
+
return True
|
| 660 |
+
else:
|
| 661 |
+
return False
|
| 662 |
+
else:
|
| 663 |
+
if lst1 == lst2[endLoc:(startLoc + 1)][::-1]:
|
| 664 |
+
return True
|
| 665 |
+
else:
|
| 666 |
+
return False
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def validateSegment(G, segment):
|
| 670 |
+
"""
|
| 671 |
+
Check whether a segment is a simple branch.
|
| 672 |
+
Parameters
|
| 673 |
+
----------
|
| 674 |
+
G : NetworkX graph
|
| 675 |
+
A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch.
|
| 676 |
+
segment : list
|
| 677 |
+
A list containing the coordinates of the centerpoints of a segment.
|
| 678 |
+
|
| 679 |
+
Returns
|
| 680 |
+
-------
|
| 681 |
+
result : bool
|
| 682 |
+
If True, the segment is a simple branch.
|
| 683 |
+
"""
|
| 684 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
| 685 |
+
if voxelDegrees[0] != 2 and voxelDegrees[-1] != 2:
|
| 686 |
+
if len(voxelDegrees) == 2:
|
| 687 |
+
result = True
|
| 688 |
+
elif len(voxelDegrees) > 2:
|
| 689 |
+
if np.all(voxelDegrees[1:-1] == 2):
|
| 690 |
+
result = True
|
| 691 |
+
else:
|
| 692 |
+
result = False
|
| 693 |
+
else:
|
| 694 |
+
print('Error! Segment with length 1 found!')
|
| 695 |
+
result = False
|
| 696 |
+
else:
|
| 697 |
+
result = False
|
| 698 |
+
|
| 699 |
+
return result
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def drawSegments(segmentList, shape):
|
| 703 |
+
"""
|
| 704 |
+
Plot all the segments in `segmentList`. Try to assign different colors to the segments connected to the same node.
|
| 705 |
+
Parameters
|
| 706 |
+
----------
|
| 707 |
+
segmentList : list
|
| 708 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
| 709 |
+
represents a centerpoint coordinates.
|
| 710 |
+
shape : tuple
|
| 711 |
+
Shape of the vessel volume (used for ploting).
|
| 712 |
+
"""
|
| 713 |
+
## Import pyqtgraph ##
|
| 714 |
+
from pyqtgraph.Qt import QtCore, QtGui
|
| 715 |
+
import pyqtgraph as pg
|
| 716 |
+
import pyqtgraph.opengl as gl
|
| 717 |
+
|
| 718 |
+
## Init ##
|
| 719 |
+
app = pg.QtGui.QApplication([])
|
| 720 |
+
w = gl.GLViewWidget()
|
| 721 |
+
w.opts['distance'] = 800
|
| 722 |
+
w.setGeometry(0, 110, 1600, 900)
|
| 723 |
+
offset = np.array(shape) / (-2.0)
|
| 724 |
+
|
| 725 |
+
colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')]
|
| 726 |
+
colorNames = ['Red', 'Green', 'Blue', 'Cyan', 'Magneta', 'Yellow']
|
| 727 |
+
numOfColors = len(colorList)
|
| 728 |
+
nodeColorDict = {}
|
| 729 |
+
for segment in segmentList:
|
| 730 |
+
startVoxel = segment[0]
|
| 731 |
+
endVoxel = segment[-1]
|
| 732 |
+
if startVoxel in nodeColorDict and endVoxel in nodeColorDict: # and endVoxel in [voxel for voxel, _ in nodeColorDict[startVoxel]]:
|
| 733 |
+
nodeColorDict[startVoxel].append([endVoxel, -1])
|
| 734 |
+
nodeColorDict[endVoxel].append([startVoxel, -1])
|
| 735 |
+
else:
|
| 736 |
+
if startVoxel not in nodeColorDict:
|
| 737 |
+
nodeColorDict[startVoxel] = [[endVoxel, -1]]
|
| 738 |
+
else:
|
| 739 |
+
nodeColorDict[startVoxel].append([endVoxel, -1])
|
| 740 |
+
|
| 741 |
+
if endVoxel not in nodeColorDict:
|
| 742 |
+
nodeColorDict[endVoxel] = [[startVoxel, -1]]
|
| 743 |
+
else:
|
| 744 |
+
nodeColorDict[endVoxel].append([startVoxel, -1])
|
| 745 |
+
|
| 746 |
+
existingColorsInStart = [colorCode for _, colorCode in nodeColorDict[startVoxel]]
|
| 747 |
+
existingColorsInEnd = [colorCode for _, colorCode in nodeColorDict[endVoxel]]
|
| 748 |
+
availableColors = [colorCode for colorCode in range(numOfColors) if
|
| 749 |
+
colorCode not in existingColorsInStart and colorCode not in existingColorsInEnd]
|
| 750 |
+
# print('color in start: {} and color in end: {}'.format(existingColorsInStart, existingColorsInEnd))
|
| 751 |
+
chosenColor = availableColors[0] if len(availableColors) != 0 else 0
|
| 752 |
+
nodeColorDict[startVoxel][-1][1] = chosenColor
|
| 753 |
+
nodeColorDict[endVoxel][-1][1] = chosenColor
|
| 754 |
+
|
| 755 |
+
segmentCoords = np.array(segment)
|
| 756 |
+
aa = gl.GLLinePlotItem(pos=segmentCoords, color=colorList[chosenColor], width=3)
|
| 757 |
+
aa.translate(*offset)
|
| 758 |
+
w.addItem(aa)
|
| 759 |
+
|
| 760 |
+
w.show()
|
| 761 |
+
pg.QtGui.QApplication.exec_()
|
| 762 |
+
# sys.exit(app.exec_())
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def main():
|
| 766 |
+
start_time = timeit.default_timer()
|
| 767 |
+
baseFolder = os.path.abspath(os.path.dirname(__file__))
|
| 768 |
+
|
| 769 |
+
## Load existing volume ##
|
| 770 |
+
vesselVolumeMaskFolderPath = baseFolder
|
| 771 |
+
vesselVolumeMaskFileName = 'vesselVolumeMask.nii.gz'
|
| 772 |
+
vesselVolumeMask, vesselVolumeMaskAffine = loadVolume(vesselVolumeMaskFolderPath, vesselVolumeMaskFileName)
|
| 773 |
+
|
| 774 |
+
## Skeletonization ##
|
| 775 |
+
# analyze(vesselVolumeMask, baseFolder)
|
| 776 |
+
|
| 777 |
+
skeletonSegmentFolderPath = os.path.join(baseFolder, 'skeletonizationResult/segments_by_cc')
|
| 778 |
+
segmentListRough = combineSkeletonSegments(skeletonSegmentFolderPath)
|
| 779 |
+
|
| 780 |
+
shape = vesselVolumeMask.shape
|
| 781 |
+
# drawSegments(segmentListRough, shape)
|
| 782 |
+
|
| 783 |
+
G, segmentList, errorSegments = processSegments(segmentListRough, shape=shape)
|
| 784 |
+
# drawSegments(segmentList, shape)
|
| 785 |
+
G = nx.Graph()
|
| 786 |
+
segmentIndex = 0
|
| 787 |
+
for segment in segmentList:
|
| 788 |
+
G.add_path(segment, segmentIndex=segmentIndex)
|
| 789 |
+
segmentIndex += 1
|
| 790 |
+
|
| 791 |
+
## Save graph representation ##
|
| 792 |
+
graphFileName = 'graphRepresentation.graphml'
|
| 793 |
+
graphFilePath = os.path.join(baseFolder, graphFileName)
|
| 794 |
+
nx.write_graphml(G, graphFilePath)
|
| 795 |
+
print('{} saved to {}.'.format(graphFileName, graphFilePath))
|
| 796 |
+
|
| 797 |
+
## Save segmentList ##
|
| 798 |
+
segmentListFileName = 'segmentList.npz'
|
| 799 |
+
segmentListFilePath = os.path.join(baseFolder, segmentListFileName)
|
| 800 |
+
np.savez_compressed(segmentListFilePath, segmentList=segmentList)
|
| 801 |
+
print('{} saved to {}.'.format(segmentListFileName, segmentListFilePath))
|
| 802 |
+
|
| 803 |
+
## Save skeleton.nii.gz ##
|
| 804 |
+
skeleton = np.zeros_like(vesselVolumeMask)
|
| 805 |
+
for segment in segmentList:
|
| 806 |
+
skeleton[tuple(np.array(segment).T)] = 1
|
| 807 |
+
|
| 808 |
+
skeletonFileName = 'skeleton.nii.gz'
|
| 809 |
+
skeletonFilePath = os.path.join(baseFolder, skeletonFileName)
|
| 810 |
+
saveVolume(skeleton, vesselVolumeMaskAffine, skeletonFilePath, astype=np.uint8)
|
| 811 |
+
|
| 812 |
+
elapsed = timeit.default_timer() - start_time
|
| 813 |
+
print('Elapsed: {} sec'.format(elapsed))
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
if __name__ == "__main__":
|
| 817 |
+
main()
|
Centerline/thinPlateSplines.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.spatial.distance import pdist, cdist, squareform
|
| 3 |
+
from sklearn.metrics import pairwise_distances
|
| 4 |
+
|
| 5 |
+
class ThinPlateSplines:
|
| 6 |
+
def __init__(self, ctrl_pts: np.ndarray, target_pts: np.ndarray, reg=0.0):
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
:param ctrl_pts: [N, d] tensor of control d-dimensional points
|
| 10 |
+
:param target_pts: [N, d] tensor of target d-dimensional points
|
| 11 |
+
:param reg: regularization coefficient
|
| 12 |
+
"""
|
| 13 |
+
self.__ctrl_pts = ctrl_pts
|
| 14 |
+
self.__target_pts = target_pts
|
| 15 |
+
self.__reg = reg
|
| 16 |
+
self.__num_ctrl_pts = ctrl_pts.shape[0]
|
| 17 |
+
self.__dim = ctrl_pts.shape[1]
|
| 18 |
+
|
| 19 |
+
self.__K = None
|
| 20 |
+
self.__compute_coeffs()
|
| 21 |
+
self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...] # Affine parameters of the TPS
|
| 22 |
+
self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] # Non-affine parameters of he TPS
|
| 23 |
+
|
| 24 |
+
def __compute_coeffs(self):
|
| 25 |
+
target_pts_aug = np.vstack([self.__target_pts,
|
| 26 |
+
np.zeros([self.__dim + 1, self.__dim])]).astype(self.__target_pts.dtype)
|
| 27 |
+
|
| 28 |
+
T_i = np.linalg.inv(self.__make_T()).astype(self.__target_pts.dtype)
|
| 29 |
+
self.__coeffs = np.matmul(T_i, target_pts_aug).astype(self.__target_pts.dtype)
|
| 30 |
+
|
| 31 |
+
def __make_T(self):
|
| 32 |
+
# cp: [K x 2] control points
|
| 33 |
+
# T: [(K+3) x (K+3)]
|
| 34 |
+
P = np.hstack([np.ones([self.__num_ctrl_pts, 1], dtype=np.float), self.__ctrl_pts])
|
| 35 |
+
zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
|
| 36 |
+
self.__K = self.__U_dist(self.__ctrl_pts)
|
| 37 |
+
alfa = np.mean(self.__K)
|
| 38 |
+
|
| 39 |
+
self.__K = self.__K + np.ones_like(self.__K) * np.power(alfa, 2) * self.__reg
|
| 40 |
+
|
| 41 |
+
top = np.hstack([P, self.__K])
|
| 42 |
+
bottom = np.hstack([P.transpose(), zeros])
|
| 43 |
+
|
| 44 |
+
return np.vstack([top, bottom])
|
| 45 |
+
|
| 46 |
+
def __U_dist(self, ctrl_pts, int_pts=None):
|
| 47 |
+
dist = pairwise_distances(ctrl_pts, int_pts)
|
| 48 |
+
|
| 49 |
+
if ctrl_pts.shape[-1] == 2:
|
| 50 |
+
u_dist = np.square(dist) * np.log(dist + 1e-6)
|
| 51 |
+
else:
|
| 52 |
+
u_dist = np.sqrt(dist)
|
| 53 |
+
|
| 54 |
+
return u_dist
|
| 55 |
+
|
| 56 |
+
def __lift_pts(self, int_pts: np.ndarray, num_pts):
|
| 57 |
+
# int_pts: [N x 2], input points
|
| 58 |
+
# cp: [K x 2], control points
|
| 59 |
+
# pLift: [N x (3+K)], lifted input points
|
| 60 |
+
|
| 61 |
+
int_pts_lift = np.hstack([self.__U_dist(self.__ctrl_pts, int_pts),
|
| 62 |
+
np.ones([num_pts, 1], dtype=np.float),
|
| 63 |
+
int_pts])
|
| 64 |
+
return int_pts_lift
|
| 65 |
+
|
| 66 |
+
def _get_coefficients(self):
|
| 67 |
+
return self.__coeffs
|
| 68 |
+
|
| 69 |
+
def interpolate(self, int_points):
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
:param int_points: [K, d] flattened d-points of a mesh
|
| 73 |
+
:return:
|
| 74 |
+
"""
|
| 75 |
+
num_pts = int_points.shape[0]
|
| 76 |
+
int_points_lift = self.__lift_pts(int_points, num_pts)
|
| 77 |
+
return np.dot(int_points_lift, self.__coeffs)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def bending_energy(self):
|
| 81 |
+
aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
|
| 82 |
+
return tf.matmul(aux, self.__non_aff_paramms)
|
Centerline/visualization_utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 3 |
+
from matplotlib.lines import Line2D
|
| 4 |
+
import numpy as np
|
| 5 |
+
from DeepDeformationMapRegistration.utils.visualization import add_axes_arrows_3d, remove_tick_labels, set_axes_size
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _plot_graph(graph, ax, nodes_colour='C3', edges_colour='C1', plot_nodes=True, plot_edges=True, add_axes=True):
|
| 10 |
+
if plot_edges:
|
| 11 |
+
for (start_node, end_node) in graph.edges():
|
| 12 |
+
edge_pts = graph[start_node][end_node]['pts']
|
| 13 |
+
edge_pts = np.vstack([graph.nodes[start_node]['o'], edge_pts])
|
| 14 |
+
edge_pts = np.vstack([edge_pts, graph.nodes[end_node]['o']])
|
| 15 |
+
ax.plot(edge_pts[:, 0], edge_pts[:, 1], edge_pts[:, 2], edges_colour)
|
| 16 |
+
|
| 17 |
+
if plot_nodes:
|
| 18 |
+
nodes = graph.nodes()
|
| 19 |
+
ps = np.array([nodes[i]['o'] for i in nodes])
|
| 20 |
+
if len(ps.shape) > 1:
|
| 21 |
+
ax.scatter(ps[:, 0], ps[:, 1], ps[:, 2], nodes_colour)
|
| 22 |
+
else:
|
| 23 |
+
ax.scatter(ps[0], ps[1], ps[2], nodes_colour)
|
| 24 |
+
ax.set_xlim(0, 63)
|
| 25 |
+
ax.set_ylim(0, 63)
|
| 26 |
+
ax.set_zlim(0, 63)
|
| 27 |
+
remove_tick_labels(ax, True)
|
| 28 |
+
if add_axes:
|
| 29 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b')
|
| 30 |
+
ax.view_init(None, 45)
|
| 31 |
+
|
| 32 |
+
return ax
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def plot_skeleton(img, skeleton, graph, filename='skeleton', extension=['.png']):
|
| 36 |
+
if not isinstance(extension, list):
|
| 37 |
+
extension = [extension]
|
| 38 |
+
# Skeleton
|
| 39 |
+
f = plt.figure(figsize=(5, 5))
|
| 40 |
+
ax = f.add_subplot(111, projection='3d')
|
| 41 |
+
|
| 42 |
+
coords = np.argwhere(skeleton)
|
| 43 |
+
i = coords[:, 0]
|
| 44 |
+
j = coords[:, 1]
|
| 45 |
+
k = coords[:, 2]
|
| 46 |
+
|
| 47 |
+
seg = ax.voxels(img, facecolors=(0., 0., 1., 0.3), label='image')
|
| 48 |
+
ske = ax.scatter(i, j, k, color='C1', label='skeleton', s=1)
|
| 49 |
+
ax.set_xlim(0, 63)
|
| 50 |
+
ax.set_ylim(0, 63)
|
| 51 |
+
ax.set_zlim(0, 63)
|
| 52 |
+
remove_tick_labels(ax, True)
|
| 53 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b')
|
| 54 |
+
ax.view_init(None, 45)
|
| 55 |
+
for ex in extension:
|
| 56 |
+
f.savefig(filename + '_segmentation_skeleton' + ex)
|
| 57 |
+
|
| 58 |
+
# Combined
|
| 59 |
+
ax = _plot_graph(graph, ax, 'r', 'r')
|
| 60 |
+
|
| 61 |
+
for ex in extension:
|
| 62 |
+
f.savefig(filename + '_combined' + ex)
|
| 63 |
+
plt.close()
|
| 64 |
+
|
| 65 |
+
# Graph
|
| 66 |
+
f = plt.figure(figsize=(5, 5))
|
| 67 |
+
ax = f.add_subplot(111, projection='3d')
|
| 68 |
+
|
| 69 |
+
ax = _plot_graph(graph, ax)
|
| 70 |
+
|
| 71 |
+
for ex in extension:
|
| 72 |
+
f.savefig(filename + '_graph' + ex)
|
| 73 |
+
plt.close()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def compare_graphs(graph_0, graph_1, graph_names=None, filename='compare_graphs'):
|
| 79 |
+
f = plt.figure(figsize=(12, 5))
|
| 80 |
+
if graph_names is None:
|
| 81 |
+
graph_names =['graph_0', 'graph_1']
|
| 82 |
+
else:
|
| 83 |
+
assert len(graph_names) == 2, 'A different name is expected for each graph'
|
| 84 |
+
ax = f.add_subplot(131, projection='3d')
|
| 85 |
+
ax = _plot_graph(graph_0, ax)
|
| 86 |
+
ax.set_title(graph_names[0], y=-0.01)
|
| 87 |
+
|
| 88 |
+
ax = f.add_subplot(132, projection='3d')
|
| 89 |
+
ax = _plot_graph(graph_1, ax)
|
| 90 |
+
ax.set_title(graph_names[1])
|
| 91 |
+
|
| 92 |
+
ax = f.add_subplot(133, projection='3d')
|
| 93 |
+
ax = _plot_graph(graph_0, ax, 'C2', 'C2', plot_nodes=False)
|
| 94 |
+
ax = _plot_graph(graph_1, ax, 'C4', 'C4', plot_nodes=False)
|
| 95 |
+
legend_elements = [Line2D([0], [0], color='C2', lw=2, label=graph_names[0]),
|
| 96 |
+
Line2D([0], [0], color='C4', lw=2, label=graph_names[1])]
|
| 97 |
+
ax.legend(handles=legend_elements)
|
| 98 |
+
|
| 99 |
+
f.savefig(filename + '_compare_graphs.png')
|
| 100 |
+
plt.close()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def plot_cpd_registration_step(iteration, error, X, Y, out_folder, add_axes=True, pdf=True):
|
| 104 |
+
fig = plt.figure(figsize=(8, 8))
|
| 105 |
+
ax = fig.add_axes([0, 0, .9, .9], projection='3d')
|
| 106 |
+
ax.scatter(X[:, 0], X[:, 1], X[:, 2], color='C1', label='Fixed')
|
| 107 |
+
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], color='C2', label='Moving')
|
| 108 |
+
|
| 109 |
+
ax.text2D(0.95, 0.98, 'Iteration: {:d}'.format(
|
| 110 |
+
iteration), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')
|
| 111 |
+
#ax.text2D(0.95, 0.90, 'Error: {:10.4f}'.format(
|
| 112 |
+
# error), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')
|
| 113 |
+
ax.legend(loc='upper left', fontsize='x-large')
|
| 114 |
+
|
| 115 |
+
if add_axes:
|
| 116 |
+
x_range = [np.min(np.hstack([X[:, 0], Y[:, 0]])), np.max(np.hstack([X[:, 0], Y[:, 0]]))]
|
| 117 |
+
y_range = [np.min(np.hstack([X[:, 1], Y[:, 1]])), np.max(np.hstack([X[:, 1], Y[:, 1]]))]
|
| 118 |
+
z_range = [np.min(np.hstack([X[:, 2], Y[:, 2]])), np.max(np.hstack([X[:, 2], Y[:, 2]]))]
|
| 119 |
+
ax.set_xlim(x_range[0], x_range[1])
|
| 120 |
+
ax.set_ylim(y_range[0], y_range[1])
|
| 121 |
+
ax.set_zlim(z_range[0], z_range[1])
|
| 122 |
+
|
| 123 |
+
remove_tick_labels(ax, True)
|
| 124 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3)
|
| 125 |
+
ax.view_init(None, 45)
|
| 126 |
+
|
| 127 |
+
os.makedirs(out_folder, exist_ok=True)
|
| 128 |
+
fig.savefig(os.path.join(out_folder, '{:04d}.png'.format(iteration)))
|
| 129 |
+
if pdf:
|
| 130 |
+
fig.savefig(os.path.join(out_folder, '{:04d}.pdf'.format(iteration)))
|
| 131 |
+
plt.close()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, file_name):
|
| 135 |
+
fig = plt.figure(figsize=(8, 8))
|
| 136 |
+
ax = fig.add_axes([0, 0, .9, .9], projection='3d')
|
| 137 |
+
ax.scatter(fix_pts[:, 0], fix_pts[:, 1], fix_pts[:, 2], color='C1', label='Fixed')
|
| 138 |
+
ax.scatter(mov_pts[:, 0], mov_pts[:, 1], mov_pts[:, 2], color='C2', label='Moving')
|
| 139 |
+
ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='none', s=100, edgecolor='b', label='Centroid')
|
| 140 |
+
ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='none', s=100, edgecolor='b')
|
| 141 |
+
ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='C1')
|
| 142 |
+
ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='C2')
|
| 143 |
+
|
| 144 |
+
x_range = [np.min(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]])),
|
| 145 |
+
np.max(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]]))]
|
| 146 |
+
y_range = [np.min(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]])),
|
| 147 |
+
np.max(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]]))]
|
| 148 |
+
z_range = [np.min(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]])),
|
| 149 |
+
np.max(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]]))]
|
| 150 |
+
ax.set_xlim(x_range[0], x_range[1])
|
| 151 |
+
ax.set_ylim(y_range[0], y_range[1])
|
| 152 |
+
ax.set_zlim(z_range[0], z_range[1])
|
| 153 |
+
|
| 154 |
+
remove_tick_labels(ax, True)
|
| 155 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3)
|
| 156 |
+
ax.view_init(None, 45)
|
| 157 |
+
ax.legend(fontsize='xx-large')
|
| 158 |
+
fig.savefig(file_name + '.png')
|
| 159 |
+
fig.savefig(file_name + '.pdf')
|
| 160 |
+
plt.close()
|
| 161 |
+
|