File size: 6,860 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from dataprocesser.dataset_anish import list_img_ad_from_anish_csv
from dataprocesser.dataset_synthrad import list_img_pID_from_synthrad_folder  

from dataprocesser.dataset_anika import pair_list_from_anika_dataset
from dataprocesser.dataset_anika import all_list_from_anika_dataset
from dataprocesser.dataset_combined_csv import list_img_seg_ad_pIDs_from_new_simplified_csv
from dataprocesser.dataset_anika import all_list_from_anika_dataset_include_duplicate
from dataprocesser.dataset_dominik import all_list_from_dominik_dataset
from dataprocesser.dataset_xcat import list_img_pID_from_XCAT_folder
    
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

def run(target_file_list=None, task='total', dataset='synthrad', device="gpu"):
    if target_file_list is None:
        target_file_list=create_dataset_list(dataset)
    multi_label_image=False if dataset == 'anika_newsynthetic' else True
    multi_label_image=False if dataset == 'synthesized' else True
    create_segmentation(target_file_list, task, device, multi_label_image)
    

def create_segmentation(dataset_list, task='total', device="gpu", multi_label_image=True):
    # task:
    # total 
    # total_mr
    # tissue_types_mr
    
    import nibabel as nib
    import nrrd
    import numpy as np
    from totalsegmentator.python_api import totalsegmentator
    for sample in dataset_list:
        input_path=sample
        print(f'create segmentation mask for {input_path}')
        if input_path.endswith('.nii') or input_path.endswith('.nii.gz'):
            if task == 'tissue_types_mr' or task == 'tissue_types':
                output_path=input_path.replace('.nii','_seg_tissue.nii')
            else:
                output_path=input_path.replace('.nii','_seg.nii')
            input_img = nib.load(input_path)

        elif input_path.endswith('.nrrd'):
            if task == 'tissue_types_mr' or task == 'tissue_types':
                output_path=input_path.replace('.nrrd','_seg_tissue.nii.gz')
            else:
                output_path=input_path.replace('.nrrd','_seg.nii.gz')
            np_img, header = nrrd.read(input_path)
            # Extract metadata for affine transformation
            spacing = header.get('space directions', None)
            if spacing is None:
                spacing = np.eye(3)  # Default to identity matrix if not available
            else:
                spacing = np.array(spacing)
            origin = header.get('space origin', [0, 0, 0])
            origin = np.array(origin)
            affine = np.zeros([4,4])
            affine[:3, :3] = spacing  # Set voxel dimensions
            affine[:3, 3] = origin  # Set the origin
            print('space directions', spacing)
            print('space origin', origin)
            print('affine', affine)
            input_img = nib.Nifti1Image(np_img, affine)

        totalsegmentator(input=input_img, output=output_path, task=task, fast=False, ml=multi_label_image, device=device)
        print(f'segmentation mask is saved as {output_path}')
    '''try:  
        pass
    except:
        print("An exception occurred")'''

from dataprocesser.step1_init_data_list import appart_img_and_seg
from dataprocesser.dataset_anika import (
    all_list_from_anika_dataset, 
    extract_patientID_from_Anika_dataset, 
    all_list_from_anika_dataset_include_duplicate)
from dataprocesser.dataset_synthrad import list_img_pID_from_synthrad_folder
from dataprocesser.dataset_anish import list_img_seg_ad_pIDs_from_anish_csv
from dataprocesser.dataset_dominik import all_list_from_dominik_dataset
from dataprocesser.dataset_combined_csv import list_img_seg_ad_pIDs_from_new_simplified_csv
from dataprocesser.dataset_xcat import list_img_pID_from_XCAT_folder
def create_dataset_list(dataset='anika_all'):
    def get_synthrad_files(data_dir, modality, saved_name):
        return list_img_pID_from_synthrad_folder(data_dir, accepted_modalities=modality, saved_name=saved_name)[0]

    def get_anika_pairs(ct_dir, mri_dir, mri_mode='t1_vibe_in'):
        matched_pairs = all_list_from_anika_dataset(ct_dir, mri_dir, mri_mode)
        return [paths['CT'] for paths in matched_pairs.values()], [paths['MRI'] for paths in matched_pairs.values()]

    def load_synthetic_folder(synthetic_folder, extract_id_func):
        assert os.path.isdir(synthetic_folder), f'{synthetic_folder} is not a valid directory'
        images, patient_IDs = [], []
        for roots, _, files in sorted(os.walk(synthetic_folder)):
            for file in files:
                if "seg_volume" not in file:
                    path = os.path.join(roots, file)
                    patient_IDs.append(extract_id_func(path))
                    images.append(path)
        return images

    dataset_handlers = {
        'anish': lambda: list_img_seg_ad_pIDs_from_new_simplified_csv(
            r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\datacsv\healthy_dissec_newserver_new.csv'
        )[3],

        'synthrad_ct': lambda: get_synthrad_files(r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis', 'ct', 'target_filenames.txt'),

        'synthrad_mr': lambda: get_synthrad_files(r'E:\Projects\yang_proj\data\synthrad\Task1_val\pelvis', 'mr', 'target_filenames.txt'),

        'anika': lambda: get_anika_pairs(
            r'E:\Projects\yang_proj\data\anika\CT',
            r'E:\Projects\yang_proj\data\anika\MR_registrated'
        )[0],

        'anika_all_ct': lambda: appart_img_and_seg(
            all_list_from_anika_dataset_include_duplicate(
                r'E:\Projects\yang_proj\data\anika\CT',
                r'E:\Projects\yang_proj\data\anika\MR_registrated'
            )[0]
        )[0],

        'anika_all_mr': lambda: appart_img_and_seg(
            all_list_from_anika_dataset_include_duplicate(
                r'E:\Projects\yang_proj\data\anika\CT',
                r'E:\Projects\yang_proj\data\anika\MR_registrated'
            )[1]
        )[0],

        'dominik': lambda: all_list_from_dominik_dataset(r'E:\Projects\yang_proj\data\Dominik_MR_VIBE'),

        'xcat_ct': lambda: list_img_pID_from_XCAT_folder(
            r'E:\Projects\chen_proj\aorta_XCAT2CT\CT_XCAT_aorta', saved_name=None
        )[0],

        'anika_newsynthetic': lambda: appart_img_and_seg(load_synthetic_folder(
            r'E:\Projects\yang_proj\data\anika\new_synthetic',
            lambda path: os.path.basename(os.path.dirname(path))
        ))[0],

        'synthesized': lambda: load_synthetic_folder(
            r'E:\Projects\yang_proj\data\ddpm_anika_more_512_for_2_seg\volume_output',
            lambda path: '_'.join(os.path.basename(path).split('_')[:2])
        )
    }

    if dataset not in dataset_handlers:
        raise ValueError(f"Unsupported dataset '{dataset}', please choose an available one!")

    return dataset_handlers[dataset]()

if __name__=='__main__':
    run()