ObjectRelator-Original
/
psalm
/model
/mask_decoder
/Mask2Former_Simplify
/utils
/DatasetAnalyzer.py
| #!/usr/bin/env python | |
| # -*- encoding: utf-8 -*- | |
| ''' | |
| @File : DatasetAnalyzer.py | |
| @Time : 2022/04/08 10:10:12 | |
| @Author : zzubqh | |
| @Version : 1.0 | |
| @Contact : baiqh@microport.com | |
| @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA | |
| @Desc : None | |
| ''' | |
| # here put the import lib | |
| import numpy as np | |
| import os | |
| import SimpleITK as sitk | |
| from multiprocessing import Pool | |
| class DatasetAnalyzer(object): | |
| """ | |
| 接收一个类似train.md的文件 | |
| 格式:**/ct_file.nii.gz, */seg_file.nii.gz | |
| """ | |
| def __init__(self, annotation_file, num_processes=4): | |
| self.dataset = [] | |
| self.num_processes = num_processes | |
| with open(annotation_file, 'r', encoding='utf-8') as rf: | |
| for line_item in rf: | |
| items = line_item.strip().split(',') | |
| self.dataset.append({'ct': items[0], 'mask': items[1]}) | |
| print('total load {0} ct files'.format(len(self.dataset))) | |
| def _get_effective_data(self, dataset_item: dict): | |
| itk_img = sitk.ReadImage(dataset_item['ct']) | |
| itk_mask = sitk.ReadImage(dataset_item['mask']) | |
| img_np = sitk.GetArrayFromImage(itk_img) | |
| mask_np = sitk.GetArrayFromImage(itk_mask) | |
| mask_index = mask_np > 0 | |
| effective_data = img_np[mask_index][::10] | |
| return list(effective_data) | |
| def compute_stats(self): | |
| if len(self.dataset) == 0: | |
| return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan | |
| process_pool = Pool(self.num_processes) | |
| data_value = process_pool.map(self._get_effective_data, self.dataset) | |
| print('sub process end, get {0} case data'.format(len(data_value))) | |
| voxels = [] | |
| for value in data_value: | |
| voxels += value | |
| median = np.median(voxels) | |
| mean = np.mean(voxels) | |
| sd = np.std(voxels) | |
| mn = np.min(voxels) | |
| mx = np.max(voxels) | |
| percentile_99_5 = np.percentile(voxels, 99.5) | |
| percentile_00_5 = np.percentile(voxels, 00.5) | |
| process_pool.close() | |
| process_pool.join() | |
| return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 | |
| if __name__ == '__main__': | |
| import tqdm | |
| annotation = r'/home/code/Dental/Segmentation/dataset/tooth_label.md' | |
| analyzer = DatasetAnalyzer(annotation, num_processes=8) | |
| out_dir = r'/data/Dental/SegTrainingClipdata' | |
| # t = analyzer.compute_stats() | |
| # print(t) | |
| # new_annotation = r'/home/code/BoneSegLandmark/dataset/knee_clip_label_seg.md' | |
| # wf = open(new_annotation, 'w', encoding='utf-8') | |
| # with open(annotation, 'r', encoding='utf-8') as rf: | |
| # for str_line in rf: | |
| # items = str_line.strip().split(',') | |
| # ct_name = os.path.basename(items[0]) | |
| # new_ct_path = os.path.join(out_dir, ct_name) | |
| # label_file = items[1] | |
| # wf.write('{0},{1}\r'.format(new_ct_path, label_file)) | |
| # wf.close() | |
| # 根据CT值的范围重新生成新CT | |
| for item in tqdm.tqdm(analyzer.dataset): | |
| ct_file = item['ct'] | |
| out_name = os.path.basename(ct_file) | |
| out_path = os.path.join(out_dir, out_name) | |
| itk_img = sitk.ReadImage(item['ct']) | |
| img_np = sitk.GetArrayFromImage(itk_img) | |
| data = np.clip(img_np, 181.0, 7578.0) | |
| clip_img = sitk.GetImageFromArray(data) | |
| clip_img.CopyInformation(itk_img) | |
| sitk.WriteImage(clip_img, out_path) | |