| import os, sys |
|
|
| currentdir = os.path.dirname(os.path.realpath(__file__)) |
| parentdir = os.path.dirname(currentdir) |
| sys.path.append(parentdir) |
|
|
| import h5py |
| from tqdm import tqdm |
| from functools import partial |
| import numpy as np |
| from scipy.spatial.distance import euclidean |
| import pandas as pd |
| from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space |
| from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd |
| from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration |
| from scipy.spatial.distance import cdist |
| from skimage.morphology import skeletonize_3d |
| import re |
| from probreg import bcpd |
|
|
| DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL' |
| DATASET_NAMES = ['Affine', 'None', 'Translation'] |
| DATASET_FILENAME = 'points' |
|
|
| OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton' |
|
|
| SCALE = 1e-2 |
| |
| MAX_ITER = 200 |
| ALPHA = 0.1 |
| BETA = 1.0 |
| TOLERANCE = 1e-8 |
|
|
| if __name__ == '__main__': |
| for dataset_name in DATASET_NAMES: |
| dataset_loc = os.path.join(DATASET_LOCATION, dataset_name) |
| dataset_files = os.listdir(dataset_loc) |
| dataset_files.sort() |
| dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f] |
|
|
| iterator = tqdm(dataset_files) |
| df = pd.DataFrame(columns=['DATASET', |
| 'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF', |
| 'TIME_DEF', 'TIME_R_DEF', |
| 'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF', |
| 'TRE_DEF', 'TRE_R_DEF', |
| 'DS_DISP', |
| 'DATA_PATH', |
| 'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM']) |
| for i, file_path in enumerate(iterator): |
| fn = os.path.split(file_path)[-1].split('.hd5')[0] |
| fnum = int(re.findall('(\d+)', fn)[0]) |
| iterator.set_description('{}: start'.format(fn)) |
| pts_file = h5py.File(file_path, 'r') |
| |
| |
| fix_skel = pts_file['fix/skeleton'][:] |
| fix_centroid = pts_file['fix/centroid'][:] |
|
|
| |
| |
| mov_skel = pts_file['mov/skeleton'][:] |
| mov_centroid = pts_file['mov/centroid'][:] |
|
|
| bbox = pts_file['parameters/bbox'][:] |
| first_reshape = pts_file['parameters/first_reshape'][:] |
| isotropic_shape = pts_file['parameters/isotropic_shape'][:] |
| iterator.set_description('{}: Loaded data'.format(fn)) |
| |
| |
| |
| |
| fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape) |
| fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape) |
| fix_skel = skeletonize_3d(fix_skel) |
| fix_skel_pts = np.argwhere(fix_skel) |
| |
| |
| mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape) |
| mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape) |
| mov_skel = skeletonize_3d(mov_skel) |
| mov_skel_pts = np.argwhere(mov_skel) |
| iterator.set_description('{}: reshaped data'.format(fn)) |
|
|
| ill_cond_def = False |
| ill_cond_r_def = False |
| |
| iterator.set_description('{}: Computing only deformable reg.'.format(fn)) |
|
|
| tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE) |
|
|
| if np.isnan(deform_reg_def.diff): |
| tre_def = np.nan |
| pred_mov_centroid = mov_centroid |
| else: |
| tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE) |
| displacement_mov_centroid = tps(mov_centroid) |
| pred_mov_centroid = mov_centroid + displacement_mov_centroid |
|
|
| tre_def = euclidean(pred_mov_centroid, fix_centroid) |
|
|
| plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)) |
| os.makedirs(plot_file, exist_ok=True) |
| plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration') |
| plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration') |
|
|
| |
| iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn)) |
|
|
| rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum))) |
| deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum))) |
|
|
| |
| |
|
|
| time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True) |
| rigid_yt = rigid_reg_r_def.TY |
| time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True, |
| tolerance=TOLERANCE, max_iterations=MAX_ITER, |
| alpha=ALPHA, beta=BETA) |
|
|
| if np.isnan(deform_reg_r_def.diff): |
| pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE |
| else: |
| mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE |
| tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, |
| np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE) |
| displacement_mov_centroid_t = tps(mov_centroid_t) |
| pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t |
|
|
| tre_r_def = euclidean(pred_mov_centroid, fix_centroid) |
| dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts) |
|
|
| plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum)) |
| os.makedirs(plot_file, exist_ok=True) |
| plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration') |
| plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration') |
|
|
|
|
| iterator.set_description('{}: Saving data'.format(fn)) |
| df = df.append({'DATASET': dataset_name, |
| 'ITERATIONS_DEF': deform_reg_def.iteration, |
| 'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration, |
| 'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration, |
| 'TIME_DEF': time_def, |
| 'TIME_R_DEF': time_r_def__r + time_r_def__def, |
| 'Q_DEF': deform_reg_def.diff, |
| 'Q_R_DEF__R': rigid_reg_r_def.q, |
| 'Q_R_DEF__DEF': deform_reg_r_def.diff, |
| 'ILL_COND_DEF': ill_cond_def, |
| 'ILL_COND_R_DEF': ill_cond_r_def, |
| 'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def, |
| 'DS_DISP':euclidean(mov_centroid, fix_centroid), |
| 'DATA_PATH': file_path, |
| 'DIST_CENTR': np.min(dist_centroid_to_pts), |
| 'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95), |
| 'SAMPLE_NUM':fnum}, ignore_index=True) |
| pts_file.close() |
|
|
| df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name))) |
|
|