| import os |
| import time |
|
|
| import torch |
| from torch.cuda.amp import autocast as autocast |
| from tqdm import tqdm |
| from einops import rearrange, repeat, reduce |
| import numpy as np |
| import pandas as pd |
| from pathlib import Path |
| import nibabel as nib |
| import shutil |
| import pickle |
| from scipy.ndimage import gaussian_filter |
| import torch.distributed as dist |
|
|
| from evaluate.metric import calculate_metric_percase |
| from evaluate.merge_after_evaluate import merge |
| from train.dist import is_master |
|
|
| def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16): |
| tmp = np.zeros(tile_size) |
| center_coords = [i // 2 for i in tile_size] |
| sigmas = [i * sigma_scale for i in tile_size] |
| tmp[tuple(center_coords)] = 1 |
| gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) |
|
|
| |
|
|
| gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor |
| gaussian_importance_map = gaussian_importance_map.astype(dtype) |
|
|
| |
| gaussian_importance_map[gaussian_importance_map == 0] = np.min( |
| gaussian_importance_map[gaussian_importance_map != 0]) |
|
|
| return gaussian_importance_map |
|
|
| def evaluate(model, |
| text_encoder, |
| device, |
| testset, |
| testloader, |
| dice_score, |
| nsd_score, |
| csv_path, |
| resume, |
| save_interval, |
| visualization): |
| |
| |
| if visualization: |
| nib_dir = csv_path.replace('.csv', '') |
| |
| |
| if is_master(): |
| |
| datasets_labels_metrics = {} |
| |
| |
| samples_labels_metrics = {} |
| |
| |
| datasets_labels_sets = {} |
| |
| |
| results_of_samples = [] |
| |
| |
| if resume and is_master(): |
| root_dir = os.path.dirname(csv_path) |
| prefix = os.path.basename(csv_path).replace('.csv', '_tmp_rank') |
| pkl_to_del = [] |
| for f in os.listdir(root_dir): |
| if prefix in f: |
| |
| pkl_path = f'{root_dir}/{f}' |
| with open(pkl_path, 'rb') as f: |
| results_of_samples += pickle.load(f) |
| print(f'Load results from {pkl_path}') |
| pkl_to_del.append(pkl_path) |
| |
| |
| |
| for pkl_path in pkl_to_del: |
| os.remove(pkl_path) |
| print(f'Del {pkl_path}') |
| merge_pkl = csv_path.replace('.csv', f'_tmp_rank0.pkl') |
| with open(merge_pkl, 'wb') as f: |
| pickle.dump(results_of_samples, f) |
| print(f'Load results of {len(results_of_samples)} samples, Merge into {merge_pkl}') |
| |
| model.eval() |
| text_encoder.eval() |
| |
| with torch.no_grad(): |
| |
| data_time = 0 |
| pred_time = 0 |
| metric_time = 0 |
| |
| avg_patch_batch_num = 0 |
| avg_query_batch_num = 0 |
| |
| |
| if is_master(): |
| testloader = tqdm(testloader, disable=False) |
| else: |
| testloader = tqdm(testloader, disable=True) |
| |
| |
| gaussian = torch.tensor(compute_gaussian((288, 288, 96))).to(device) |
|
|
| end_time = time.time() |
| for sample in testloader: |
| |
| dataset_name = sample['dataset_name'] |
| sample_id = sample['sample_id'] |
| batched_patches = sample['batched_patches'] |
| batched_y1y2_x1x2_z1z2 = sample['batched_y1y2_x1x2_z1z2'] |
| labels = sample['labels'] |
| gt_segmentation = sample['gt_segmentation'].numpy() |
| modality = sample['modality'] |
| image_path = sample['image_path'] |
|
|
| n,h,w,d = gt_segmentation.shape |
| prediction = torch.zeros((n, h, w, d)) |
| accumulation = torch.zeros((n, h, w, d)) |
| |
| data_time += (time.time()-end_time) |
| end_time = time.time() |
| |
| with autocast(): |
| |
| queries = text_encoder(labels, modality) |
| |
| |
| for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): |
| patches = patches.to(device=device) |
| prediction_patch = model(queries=queries, image_input=patches, train_mode=False) |
| prediction_patch = torch.sigmoid(prediction_patch) |
| prediction_patch = prediction_patch.detach() |
| |
| |
| for b in range(len(y1y2_x1x2_z1z2_ls)): |
| y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls[b] |
|
|
| |
| tmp = prediction_patch[b, :, :y2-y1, :x2-x1, :z2-z1] * gaussian[:y2-y1, :x2-x1, :z2-z1] |
| prediction[:, y1:y2, x1:x2, z1:z2] += tmp.cpu() |
| accumulation[:, y1:y2, x1:x2, z1:z2] += gaussian[:y2-y1, :x2-x1, :z2-z1].cpu() |
| |
| pred_time += (time.time()-end_time) |
| end_time = time.time() |
| |
| |
| prediction = prediction / accumulation |
| prediction = torch.where(prediction>0.5, 1.0, 0.0) |
| prediction = prediction.numpy() |
| |
| |
| scores = [] |
| for j in range(len(labels)): |
| scores.append(calculate_metric_percase(prediction[j, :, :, :], gt_segmentation[j, :, :, :], dice_score, nsd_score)) |
| |
| |
| if visualization: |
| Path(f'{nib_dir}/{dataset_name}').mkdir(exist_ok=True, parents=True) |
| |
| results = np.zeros((h, w, d)) |
| for j, label in enumerate(labels): |
| results += prediction[j, :, :, :] * (j+1) |
| Path(f'{nib_dir}/{dataset_name}/seg_{sample_id}').mkdir(exist_ok=True, parents=True) |
| |
| segobj = nib.nifti2.Nifti1Image(prediction[j, :, :, :], np.eye(4)) |
| nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}/{label}.nii.gz') |
| segobj = nib.nifti2.Nifti1Image(results, np.eye(4)) |
| nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}.nii.gz') |
| |
| image = testset.load_image(image_path) |
| image = np.squeeze(image) |
| imgobj = nib.nifti2.Nifti1Image(image, np.eye(4)) |
| nib.save(imgobj, f'{nib_dir}/{dataset_name}/img_{sample_id}.nii.gz') |
| |
| gt = np.zeros((h, w, d)) |
| for j, label in enumerate(labels): |
| gt += gt_segmentation[j, :, :, :] * (j+1) |
| Path(f'{nib_dir}/{dataset_name}/gt_{sample_id}').mkdir(exist_ok=True, parents=True) |
| |
| segobj = nib.nifti2.Nifti1Image(gt_segmentation[j, :, :, :], np.eye(4)) |
| nib.save(segobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}/{label}.nii.gz') |
| gtobj = nib.nifti2.Nifti1Image(gt, np.eye(4)) |
| nib.save(gtobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}.nii.gz') |
| |
| metric_time += (time.time()-end_time) |
| end_time = time.time() |
| |
| |
| results_of_samples.append([dataset_name, modality, sample_id, scores, labels]) |
| |
| |
| if len(results_of_samples) % save_interval == 0: |
| with open(csv_path.replace('.csv', f'_tmp_rank{dist.get_rank()}.pkl'), 'wb') as f: |
| pickle.dump(results_of_samples, f) |
| |
| """ |
| # gather results from all device to rank-0 (solution 1) |
| gather_results = [None for i in range(dist.get_world_size())] |
| dist.gather_object( |
| results_of_samples, |
| gather_results if dist.get_rank() == 0 else None, |
| dst = 0 |
| ) |
| |
| if int(dist.get_rank()) == 0: |
| results_of_samples = [tmp for ls in results_of_samples for tmp in ls] |
| """ |
| |
| avg_patch_batch_num /= len(testloader) |
| avg_query_batch_num /= len(testloader) |
| data_time /= len(testloader) |
| pred_time /= len(testloader) |
| metric_time /= len(testloader) |
| print(f'On Rank {dist.get_rank()}, each sample has {avg_patch_batch_num} batch of patches and {avg_query_batch_num} batch of queries, Data Time: {data_time}, Pred Time: {pred_time}, Dice Time: {metric_time}') |
| |
| torch.cuda.empty_cache() |
| |
| |
| with open(csv_path.replace('.csv', f'_fnl_rank{dist.get_rank()}.pkl'), 'wb') as f: |
| pickle.dump(results_of_samples, f) |
| |
| |
| if is_master(): |
| |
| |
| while True: |
| all_process_finished = True |
| for rank_id in range(torch.distributed.get_world_size()): |
| if not os.path.exists(csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl')): |
| all_process_finished = False |
| break |
| if all_process_finished: |
| break |
| else: |
| time.sleep(10) |
| |
| |
| results_of_samples = [] |
| for rank_id in range(torch.distributed.get_world_size()): |
| fnl_results_file = csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl') |
| tmp_results_file = csv_path.replace('.csv', f'_tmp_rank{rank_id}.pkl') |
| with open(fnl_results_file, 'rb') as f: |
| results_of_samples += pickle.load(f) |
| os.remove(fnl_results_file) |
| if os.path.exists(tmp_results_file): |
| os.remove(tmp_results_file) |
| |
| |
| unique_set = set() |
| deduplicated_results_of_samples = [] |
| for dataset_name, modality, sample_id, scores, labels in results_of_samples: |
| if f'{dataset_name}/{sample_id}' not in unique_set: |
| unique_set.add(f'{dataset_name}/{sample_id}') |
| deduplicated_results_of_samples.append([dataset_name, modality, sample_id, scores, labels]) |
| results_of_samples = deduplicated_results_of_samples |
| |
| |
| with open(csv_path.replace('.csv', '.pkl'), 'wb') as f: |
| pickle.dump(results_of_samples, f) |
|
|
| |
| for dataset_name, modality, sample_id, scores, labels in results_of_samples: |
| dataset_name = f'{dataset_name}({modality})' |
| |
| if dataset_name not in datasets_labels_metrics: |
| datasets_labels_metrics[dataset_name] = {} |
| if dataset_name not in datasets_labels_sets: |
| datasets_labels_sets[dataset_name] = set() |
| if dataset_name not in samples_labels_metrics: |
| samples_labels_metrics[dataset_name] = {} |
| samples_labels_metrics[dataset_name][sample_id] = {} |
| |
| for metric_dict, label in zip(scores, labels): |
| |
| |
| if label not in datasets_labels_metrics[dataset_name]: |
| datasets_labels_metrics[dataset_name][label] = {k:[v] for k,v in metric_dict.items()} |
| else: |
| for k,v in metric_dict.items(): |
| datasets_labels_metrics[dataset_name][label][k].append(v) |
| |
| |
| |
| if label not in datasets_labels_sets[dataset_name]: |
| datasets_labels_sets[dataset_name].add(label) |
| |
| |
| |
| samples_labels_metrics[dataset_name][sample_id][label] = {k:v for k,v in metric_dict.items()} |
| |
| |
| |
| |
| |
| |
| |
| |
| info = 'Metrics of Each Dataset:\n' |
| avg_df = {} |
| for dataset in datasets_labels_metrics.keys(): |
| avg_df[dataset] = {k:[] for k in metric_dict.keys()} |
| for label in datasets_labels_metrics[dataset].keys(): |
| avg_df[f'{dataset}, {label}'] = [] |
| for metric in datasets_labels_metrics[dataset][label].keys(): |
| label_metric = np.average(datasets_labels_metrics[dataset][label][metric]) |
| avg_df[f'{dataset}, {label}'].append(label_metric) |
| avg_df[dataset][metric].append(label_metric) |
| avg_df[dataset] = {k:np.average(v) for k,v in avg_df[dataset].items()} |
| info += f'{dataset} | ' |
| for k ,v in avg_df[dataset].items(): |
| info += f'{v}({k}) | ' |
| info += '\n' |
| avg_df[dataset] = list(avg_df[dataset].values()) |
| avg_df = pd.DataFrame(avg_df).T |
| avg_df.columns = list(metric_dict.keys()) |
| avg_df.to_csv(csv_path) |
| print(info) |
| |
| |
| |
| df_list = [['summary', avg_df]] |
| for dataset, label_set in datasets_labels_sets.items(): |
| metric_df ={} |
| if dice_score: |
| metric_df['dice'] = {} |
| if nsd_score: |
| metric_df['nsd'] = {} |
|
|
| |
| |
| |
| |
| |
| |
| |
| for image_id, label_dict in samples_labels_metrics[dataset].items(): |
| for metric in metric_df: |
| tmp = [] |
| for label in label_set: |
| score = label_dict[label][metric] if label in label_dict else -1 |
| tmp.append(score) |
| metric_df[metric][image_id] = tmp |
| |
| for metric, metric_df in metric_df.items(): |
| metric_df = pd.DataFrame(metric_df).T |
| metric_df.columns = list(label_set) |
| df_list.append([dataset+f'({metric})', metric_df]) |
| |
| xlsx_path = csv_path.replace('.csv', '.xlsx') |
| with pd.ExcelWriter(xlsx_path) as writer: |
| for name, df in df_list: |
| |
| if len(name) > 31: |
| name = name[len(name)-31:] |
| df.to_excel(writer, sheet_name=name, index=True) |
| |
| |
| |
| os.remove(csv_path.replace('.csv', '.pkl')) |
| |
| else: |
| |
| pass |
| |
| |
| |
| return |
| |
| |