File size: 5,719 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os, sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)
import json

# CORRECT_DATA_PATH = os.path.join(ROOT_DIR, '../..')
# CORRECT_DATA_PATH = os.path.join('/hy-tmp')
CORRECT_DATA_PATH = '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D'


def traverse_and_print(data, path=()):
    for key, value in data.items():
        current_path = path + (key,)

        if isinstance(key, str) and 'DATASETS' in key:
            print(f"KEY (str): {key}")

        if isinstance(value, str) and 'DATASETS' in value:
            print(f"  VALUE (str): {value}")
        elif isinstance(value, dict):
            traverse_and_print(value, current_path)

def traverse_and_check(data, path=()):
    failed_files = []
    for key, value in data.items():
        current_path = path + (key,)

        if isinstance(key, str) and 'DATASETS_processed' in key:
            if os.path.isfile(key):
                print(f'\rCheck pass: {key}', end='')
            else:
                print(f'\rCheck fail ! : {key}')
                failed_files.append(key)

        if isinstance(value, str) and 'DATASETS_processed' in value:
            if os.path.isfile(value):
                print(f'\rCheck pass: {value}', end='')
            else:
                print(f'\rCheck fail ! : {value}')
                failed_files.append(value)
        elif isinstance(value, dict):
            traverse_and_check(value, current_path)
    
    if failed_files != []:
        print(f'\nCheck finished. Failed files: {failed_files}')
        return False
    else:
        print('\nAll files check passed!')
        return True

def traverse_and_revise(data, path=()):
    what_need_change = [
        '/home/jachin/data/Github/data/data_gen_def',
        '/home/data/Github/data/data_gen_def',
    ]
    for key, value in list(data.items()):
        current_path = path + (key,)

        new_key = key
        if isinstance(key, str) and 'data_gen_def' in key:
            for wnc in what_need_change:
                if wnc in key:
                    new_key = key.replace(wnc, CORRECT_DATA_PATH)
            
            # change keys
            data[new_key] = data.pop(key)
            value = data[new_key]
            current_path = path + (new_key,)

        if isinstance(value, str) and 'data_gen_def' in value:
            for wnc in what_need_change:
                if wnc in value:
                    data[new_key] = value.replace(wnc, CORRECT_DATA_PATH)

        elif isinstance(value, dict):
            traverse_and_revise(value, current_path)
    
    return data

def traverse_and_rename_label(data, old_label, new_label, task_keys=("segmentation", "registration")):
    """Rename a label key inside Label_path -> segmentation/registration for every entry.

    Example: rename "brain" -> "brain_tumour" to fix the BraTS mislabel.
    """
    count = 0
    for key, value in data.items():
        if not isinstance(value, dict):
            continue
        label_path = value.get("Label_path")
        if isinstance(label_path, dict):
            for tk in task_keys:
                task_dict = label_path.get(tk)
                if isinstance(task_dict, dict) and old_label in task_dict:
                    task_dict[new_label] = task_dict.pop(old_label)
                    count += 1
        else:
            # recurse into nested dicts
            count += traverse_and_rename_label(value, old_label, new_label, task_keys)
    return count


mapping_files = {
    'MSD': 'nifty_mappings/MSD_mappings.json',
    'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
    'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json',
    'CancerImageArchive': 'nifty_mappings/CIA_mappings.json',
    'MnMs': 'nifty_mappings/MnMs_mappings.json',
    'Brats2019': 'nifty_mappings/Brats2019_mappings.json',
    'Brats2020': 'nifty_mappings/Brats2020_mappings.json',
    'Brats2021': 'nifty_mappings/Brats2021_mappings.json',
    'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json',
    'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json',
    'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
    'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json',
    'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json',
    'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json',
}
for k,v in mapping_files.items():
    mapping_files[k] = os.path.join(ROOT_DIR, v)


if __name__ == "__main__":
    # --- Fix BraTS / MSD mislabel: "brain" -> "brain_tumour" ---
    rename_datasets = ['Brats2019', 'Brats2020', 'Brats2021', 'MSD']
    for ds_name in rename_datasets:
        if ds_name not in mapping_files:
            continue
        v = mapping_files[ds_name]
        with open(v, 'r') as f:
            mappings_tmp = json.load(f)
        n = traverse_and_rename_label(mappings_tmp, 'brain', 'brain_tumour')
        if n > 0:
            with open(v, 'w') as f:
                json.dump(mappings_tmp, f, indent=4)
            print(f'[{ds_name}] Renamed "brain" -> "brain_tumour" in {n} entries, saved to {v}')
        else:
            print(f'[{ds_name}] No "brain" labels found (already renamed?)')

    # --- Path revision (uncomment to run) ---
    # for k,v in mapping_files.items():
    #     with open(v, 'r') as f:
    #         mappings_tmp = json.load(f)
    #         new_mappings_tmp = traverse_and_revise(mappings_tmp)
    #         # traverse_and_print(new_mappings_tmp)
    #         # all_good = traverse_and_check(new_mappings_tmp)
    #     # save in-place
    #     with open(v, 'w') as f:
    #         json.dump(new_mappings_tmp, f, indent=4)
    #     print(f'Saved revised mapping to {v}')